File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change 1515PyTorch utilities: Utilities related to PyTorch 
1616""" 
1717
18+ import  re 
1819from  typing  import  List , Optional , Tuple , Union 
1920
2021from  . import  logging 
@@ -195,3 +196,17 @@ def device_synchronize(device_type: Optional[str] = None):
195196        device_type  =  get_device ()
196197    device_mod  =  getattr (torch , device_type , torch .cuda )
197198    device_mod .synchronize ()
199+ 
200+ 
201+ def  _find_modules_by_class_name (module : "torch.nn.Module" , class_name : str ) ->  List [Tuple [str , "torch.nn.Module" ]]:
202+     """ 
203+     Recursively find all modules in a PyTorch module that match the specified class name. The class 
204+     name could be partial/full name or a regex pattern. 
205+     """ 
206+     pattern  =  re .compile (class_name )
207+     matching_name_module_pairs  =  []
208+     for  name , submodule  in  module .named_modules ():
209+         submodule_cls  =  unwrap_module (submodule ).__class__ 
210+         if  pattern .search (submodule_cls .__name__ ):
211+             matching_name_module_pairs .append ((name , submodule ))
212+     return  matching_name_module_pairs 
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments