@@ -1774,6 +1774,17 @@ def find_module_by_name(self, model: nn.Module, regions: List[Region], prefix: s
17741774 full_name = prefix + '.' + name if prefix != '' else name
17751775 self .find_module_by_name (module , regions , full_name )
17761776
1777+ def transform_model (
1778+ self , model : nn .Module , rewriters : List [Transform ], delay_rewriters : bool ) -> nn .Module :
1779+ # In some circumstances, it might be useful to apply model transformations at a later moment
1780+ # The user should not be resposible for this in any case
1781+ if delay_rewriters :
1782+ return model
1783+ if is_model_offloaded_accelerate (model ):
1784+ return apply_rewriters_accelerate (model , rewriters )
1785+ else :
1786+ return apply_rewriters (model , rewriters )
1787+
17771788
17781789class GraphRotationEqualization (RotationEqualization ):
17791790
@@ -2005,16 +2016,6 @@ def apply(self,
20052016 else :
20062017 return graph_model
20072018
2008- def transform_model (self , model , rewriters , delay_rewriters ):
2009- # In some circumstances, it might be useful to apply model transformations at a later moment
2010- # The user should not be resposible for this in any case
2011- if delay_rewriters :
2012- return model
2013- if is_model_offloaded_accelerate (model ):
2014- return apply_rewriters_accelerate (model , rewriters )
2015- else :
2016- return apply_rewriters (model , rewriters )
2017-
20182019
20192020@torch .no_grad ()
20202021def apply_rewriters (
@@ -2114,15 +2115,17 @@ def __init__(
21142115 self .supported_sinks = (nn .Linear )
21152116
21162117 def apply (self , model : nn .Module ) -> nn .Module :
2118+ regions : List [Region ] = []
2119+ rewriters : List [Transform ] = []
21172120
21182121 blacklist_orphan_layers = self .blacklist_layers + self .layers_to_expand
2119- regions : List [Region ] = []
21202122 self .find_module (model , regions , blacklist_layers = blacklist_orphan_layers )
21212123 expanded_regions = []
21222124 self .find_module_by_name (model , expanded_regions )
21232125
21242126 if len (expanded_regions ) > 0 :
21252127 regions .extend (expanded_regions )
21262128 if len (regions ) > 0 :
2127- _compute_rotations (model , regions , expansion_step = self .expansion_step )
2129+ rewriters .extend (_compute_rotations (model , regions , expansion_step = self .expansion_step ))
2130+ model = self .transform_model (model , rewriters , delay_rewriters = False )
21282131 return model
0 commit comments