@@ -99,45 +99,6 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
9999
100100 return idx
101101
102- def replace_linear (model , linear_replacement , skip_modules = ["lm_head" ], copy_weights = False , post_processing_function = None ):
103- """
104- Replace linear modules with a new Linear module.
105-
106- Parameters:
107- model (`torch.nn.Module`):
108- Input model or `torch.nn.Module` as the function is run recursively.
109- linear_replacement (`torch.nn.Module`):
110- The linear module that replaces the old one. Only expects standard arguments.
111- If other arguments need to be passed, use a lambda.
112- skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
113- List of modules names not to convert. Defaults to `lm_head`.
114- copy_weights (`bool`):
115- Copy the weights from the old linear module to the new one
116- post_processing_fun_name (`str`):
117- A function name of the replacement linear class that is called
118- after processing.
119- """
120- for name , module in model .named_children ():
121- if len (list (module .children ())) > 0 :
122- replace_linear (module , linear_replacement , skip_modules , copy_weights , post_processing_function )
123-
124- if isinstance (module , torch .nn .Linear ) and name not in skip_modules :
125- old_module = model ._modules [name ]
126- model ._modules [name ] = linear_replacement (
127- module .in_features ,
128- module .out_features ,
129- module .bias is not None ,
130- )
131- if copy_weights :
132- model ._modules [name ].weight = old_module .weight
133- model ._modules [name ].bias = old_module .bias
134-
135- if post_processing_function is not None :
136- func = getattr (module , post_processing_function , None )
137- if func is not None : func (module )
138- return model
139-
140-
141102
142103def execute_and_return (command_string : str ) -> Tuple [str , str ]:
143104 def _decode (subprocess_err_out_tuple ):
0 commit comments