@@ -293,7 +293,7 @@ class WalkRegionState:
293293
294294 supported_srcs : set = _supported_layers
295295 supported_sinks : set = _supported_layers
296- scale_invariant_function : set = _scale_invariant_op
296+ scale_invariant_functions : set = _scale_invariant_op
297297 scale_invariant_layers : set = _scale_invariant_layers
298298 residual_fns : set = _residual_fns
299299 residual_methods : set = _residual_methods
@@ -1027,7 +1027,7 @@ def find_srcs_channel_dim(state, model, inp_node):
10271027 return total_channels
10281028 elif _is_scale_invariant_module (model , inp_node ,
10291029 state .scale_invariant_layers ) or _is_scale_invariant_function (
1030- inp_node , state .scale_invariant_function ):
1030+ inp_node , state .scale_invariant_functions ):
10311031 return find_srcs_channel_dim (state , model , inp_node .all_input_nodes [0 ])
10321032 else :
10331033 return _UNSUPPORTED_OP
@@ -1078,7 +1078,7 @@ def find_srcs(graph_model: GraphModule, starting_node: Node,
10781078 0 ]
10791079 elif _is_scale_invariant_module (
10801080 graph_model , node , state .scale_invariant_layers ) or _is_scale_invariant_function (
1081- node , state .scale_invariant_function ):
1081+ node , state .scale_invariant_functions ):
10821082 find_sinks (graph_model , node , state )
10831083 find_srcs (graph_model , node , state )
10841084 elif _is_add (node , state .residual_fns , state .residual_methods ):
@@ -1126,7 +1126,7 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
11261126
11271127 elif _is_scale_invariant_module (
11281128 graph_model , node , state .scale_invariant_layers ) or _is_scale_invariant_function (
1129- node , state .scale_invariant_function ):
1129+ node , state .scale_invariant_functions ):
11301130 find_sinks (graph_model , node , state )
11311131 elif _is_add (node , state .residual_fns , state .residual_methods ):
11321132 state .update_offset = False
@@ -1785,10 +1785,43 @@ def _merge_ln(layer_norm, next_module, scale_bias_by_weight):
17851785 _replace_bias (next_module , new_bias )
17861786
17871787
1788+ class RegionWalkMixin :
1789+
1790+ def __init__ (
1791+ self ,
1792+ supported_srcs : Tuple [Type [nn .Module ]] = _supported_layers ,
1793+ supported_sinks : Tuple [Type [nn .Module ]] = _supported_layers ,
1794+ scale_invariant_layers : Tuple [Type [nn .Module ]] = _scale_invariant_layers ,
1795+ scale_invariant_functions : Tuple [Callable ] = _scale_invariant_op ,
1796+ residual_fns : Tuple [Callable ] = _residual_fns ,
1797+ residual_methods : Tuple [str ] = _residual_methods ,
1798+ extra_state_kwargs : Optional [Dict [str , Tuple [Type [nn .Module ]]]] = None ):
1799+ self .supported_srcs = supported_srcs
1800+ self .supported_sinks = supported_sinks
1801+ self .scale_invariant_layers = scale_invariant_layers
1802+ self .scale_invariant_functions = scale_invariant_functions
1803+ self .residual_fns = residual_fns
1804+ self .residual_methods = residual_methods
1805+
1806+ if extra_state_kwargs is not None :
1807+ for attr_name , value in extra_state_kwargs .items ():
1808+ combined_value = value + getattr (self , attr_name )
1809+ setattr (self , attr_name , combined_value )
1810+
1811+ @property
1812+ def full_state_kwargs (self ) -> Dict [str , Tuple [Type [nn .Module ]]]:
1813+ return {
1814+ 'supported_srcs' : self .supported_srcs ,
1815+ 'supported_sinks' : self .supported_sinks ,
1816+ 'scale_invariant_layers' : self .scale_invariant_layers ,
1817+ 'scale_invariant_functions' : self .scale_invariant_functions ,
1818+ 'residual_fns' : self .residual_fns ,
1819+ 'residual_methods' : self .residual_methods }
1820+
1821+
17881822class RotationEqualization (GraphTransform ):
17891823
17901824 def __init__ (self , blacklist_layers , layers_to_expand ) -> None :
1791- super (RotationEqualization , self ).__init__ ()
17921825 if blacklist_layers is not None :
17931826 self .blacklist_layers = blacklist_layers
17941827 else :
@@ -1797,19 +1830,19 @@ def __init__(self, blacklist_layers, layers_to_expand) -> None:
17971830 self .layers_to_expand = layers_to_expand
17981831 else :
17991832 self .layers_to_expand = []
1800- self .supported_sinks = ()
18011833
18021834 def find_module (
18031835 self ,
18041836 model : nn .Module ,
18051837 regions : List [Region ],
1838+ supported_sinks : Tuple [nn .Module ],
18061839 prefix : str = '' ,
18071840 blacklist_layers : Optional [List [str ]] = None ):
18081841 """
18091842 Iterate through the model looking at immediate children of every module to look for supported modules.
18101843 This allows us to stop the search when we meet a top-level module that is supported.
18111844 """
1812- if isinstance (model , self . supported_sinks ):
1845+ if isinstance (model , supported_sinks ):
18131846 if prefix in blacklist_layers :
18141847 return
18151848 weight = get_weight_sink (model )
@@ -1820,7 +1853,7 @@ def find_module(
18201853 else :
18211854 for name , module in model .named_children ():
18221855 full_name = prefix + '.' + name if prefix != '' else name
1823- self .find_module (module , regions , full_name , blacklist_layers )
1856+ self .find_module (module , regions , supported_sinks , full_name , blacklist_layers )
18241857
18251858 def find_module_by_name (self , model : nn .Module , regions : List [Region ], prefix : str = '' ):
18261859 """
@@ -1852,7 +1885,7 @@ def transform_model(
18521885 return apply_rewriters (model , rewriters )
18531886
18541887
1855- class GraphRotationEqualization (RotationEqualization ):
1888+ class GraphRotationEqualization (RotationEqualization , RegionWalkMixin ):
18561889
18571890 def __init__ (
18581891 self ,
@@ -1866,16 +1899,20 @@ def __init__(
18661899 layers_to_expand : Optional [List [str ]] = None ,
18671900 expansion_step : int = None ,
18681901 delay_rewriters : bool = False ,
1869- return_rewriters : bool = False ) -> None :
1870- super (GraphRotationEqualization , self ).__init__ (blacklist_layers , layers_to_expand )
1902+ return_rewriters : bool = False ,
1903+ extra_state_kwargs : Optional [Dict [str , Tuple ]] = None ) -> None :
1904+ RotationEqualization .__init__ (self , blacklist_layers , layers_to_expand )
18711905
1872- self .supported_srcs = (nn .Linear , nn .Embedding )
1873- self .supported_sinks = (nn .Linear )
18741906 common_scale_invariant = list (_scale_invariant_layers )
18751907 common_scale_invariant .remove (torch .nn .ReLU )
18761908 common_scale_invariant .remove (torch .nn .LeakyReLU )
1877- self .scale_invariant_layers = tuple (common_scale_invariant ) + (RMSNorm ,)
1878- self .scale_invariant_function = ()
1909+ base_state_kwargs = {
1910+ 'supported_srcs' : (nn .Linear , nn .Embedding ),
1911+ 'supported_sinks' : (nn .Linear ,),
1912+ 'scale_invariant_layers' : tuple (common_scale_invariant ) + (RMSNorm ,),
1913+ 'scale_invariant_functions' : ()}
1914+ RegionWalkMixin .__init__ (self , ** base_state_kwargs , extra_state_kwargs = extra_state_kwargs )
1915+
18791916 self .orphan_sink = orphan_sink
18801917 self .rotate_matmul = rotate_matmul
18811918 self .full_rotation_method = full_rotation_method
@@ -1992,13 +2029,7 @@ def find_sink(node):
19922029 def apply (self ,
19932030 graph_model : GraphModule ) -> Union [Tuple [GraphModule , List [Transform ]], GraphModule ]:
19942031 rewriters = []
1995- regions = _extract_regions (
1996- graph_model ,
1997- state_impl_kwargs = {
1998- 'supported_srcs' : self .supported_srcs ,
1999- 'supported_sinks' : self .supported_sinks ,
2000- 'scale_invariant_layers' : self .scale_invariant_layers ,
2001- 'scale_invariant_function' : self .scale_invariant_function })
2032+ regions = _extract_regions (graph_model , state_impl_kwargs = self .full_state_kwargs )
20022033
20032034 expanded_regions = []
20042035 self .find_module_by_name (graph_model , expanded_regions )
@@ -2007,7 +2038,11 @@ def apply(self,
20072038
20082039 if self .orphan_sink :
20092040 blacklist_orphan_layers = self .blacklist_layers + self .layers_to_expand
2010- self .find_module (graph_model , orphan_regions , blacklist_layers = blacklist_orphan_layers )
2041+ self .find_module (
2042+ graph_model ,
2043+ orphan_regions ,
2044+ self .full_state_kwargs ['supported_sinks' ],
2045+ blacklist_layers = blacklist_orphan_layers )
20112046
20122047 if len (expanded_regions ) > 0 :
20132048 parameter_number_pre = 0
@@ -2095,20 +2130,23 @@ def apply_rewriters(
20952130 return model
20962131
20972132
2098- class LayerNormToRMS (GraphTransform ):
2133+ class LayerNormToRMS (GraphTransform , RegionWalkMixin ):
2134+
2135+ def __init__ (
2136+ self ,
2137+ return_rewriters : bool = False ,
2138+ extra_state_kwargs : Optional [Dict [str , Tuple ]] = None ) -> None :
2139+ GraphTransform .__init__ (self )
2140+
2141+ base_state_kwargs = {
2142+ 'supported_srcs' : (nn .Linear , nn .Embedding ), 'supported_sinks' : (nn .LayerNorm ,)}
2143+ RegionWalkMixin .__init__ (self , ** base_state_kwargs , extra_state_kwargs = extra_state_kwargs )
20992144
2100- def __init__ (self , return_rewriters = False ) -> None :
2101- super (LayerNormToRMS , self ).__init__ ()
2102- self .supported_srcs = (nn .Linear , nn .Embedding )
2103- self .supported_sinks = (nn .LayerNorm )
21042145 self .return_rewriters = return_rewriters
21052146 assert RMSNorm is not object , 'Update your Pytorch version to 2.4+'
21062147
21072148 def apply (self , graph_model : GraphModule ) -> GraphModule :
2108- regions = _extract_regions (
2109- graph_model ,
2110- state_impl_kwargs = {
2111- 'supported_srcs' : self .supported_srcs , 'supported_sinks' : self .supported_sinks })
2149+ regions = _extract_regions (graph_model , state_impl_kwargs = self .full_state_kwargs )
21122150
21132151 rewriters = []
21142152 if len (regions ) > 0 :
@@ -2141,18 +2179,17 @@ def apply(self, graph_model: GraphModule) -> GraphModule:
21412179 return graph_model
21422180
21432181
2144- class MergeLnAffine (GraphTransform ):
2182+ class MergeLnAffine (GraphTransform , RegionWalkMixin ):
21452183
2146- def __init__ (self ) -> None :
2147- super ( MergeLnAffine , self ) .__init__ ()
2184+ def __init__ (self , extra_state_kwargs : Optional [ Dict [ str , Tuple ]] = None ) -> None :
2185+ GraphTransform .__init__ (self )
21482186 self .supported_srcs = (RMSNorm , nn .LayerNorm )
2149- self .supported_sinks = (nn .Linear )
2187+ base_state_kwargs = {
2188+ 'supported_srcs' : (RMSNorm , nn .LayerNorm ), 'supported_sinks' : (nn .Linear ,)}
2189+ RegionWalkMixin .__init__ (self , ** base_state_kwargs , extra_state_kwargs = extra_state_kwargs )
21502190
21512191 def apply (self , graph_model : GraphModule ) -> GraphModule :
2152- regions = _extract_regions (
2153- graph_model ,
2154- state_impl_kwargs = {
2155- 'supported_srcs' : self .supported_srcs , 'supported_sinks' : self .supported_sinks })
2192+ regions = _extract_regions (graph_model , state_impl_kwargs = self .full_state_kwargs )
21562193
21572194 if len (regions ) > 0 :
21582195 scaled_biases = set ()
@@ -2180,18 +2217,21 @@ def __init__(
21802217 blacklist_layer : Optional [List ] = None ,
21812218 layers_to_expand : Optional [List ] = None ,
21822219 expansion_step : int = 0 ,
2183- block_rotation_dim : Optional [int ] = None ):
2184- super ().__init__ (blacklist_layer , layers_to_expand )
2220+ block_rotation_dim : Optional [int ] = None ,
2221+ extra_state_kwargs : Optional [Dict [str , Tuple ]] = None ):
2222+
2223+ RotationEqualization .__init__ (self , blacklist_layer , layers_to_expand )
21852224 self .expansion_step = expansion_step
2186- self .supported_sinks = (nn .Linear )
21872225 self .block_rotation_dim = block_rotation_dim
2226+ self .supported_sinks = (nn .Linear ,)
21882227
21892228 def apply (self , model : nn .Module ) -> nn .Module :
21902229 regions : List [Region ] = []
21912230 rewriters : List [Transform ] = []
21922231
21932232 blacklist_orphan_layers = self .blacklist_layers + self .layers_to_expand
2194- self .find_module (model , regions , blacklist_layers = blacklist_orphan_layers )
2233+ self .find_module (
2234+ model , regions , self .supported_sinks , blacklist_layers = blacklist_orphan_layers )
21952235 expanded_regions = []
21962236 self .find_module_by_name (model , expanded_regions )
21972237
0 commit comments