Skip to content

Commit f997b1d

Browse files
authored
Update utils.py to remove dupe replace_linear
1 parent 18e827d commit f997b1d

File tree

1 file changed

+0
-39
lines changed

1 file changed

+0
-39
lines changed

bitsandbytes/utils.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -99,45 +99,6 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
9999

100100
return idx
101101

102-
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None):
103-
"""
104-
Replace linear modules with a new Linear module.
105-
106-
Parameters:
107-
model (`torch.nn.Module`):
108-
Input model or `torch.nn.Module` as the function is run recursively.
109-
linear_replacement (`torch.nn.Module`):
110-
The linear module that replaces the old one. Only expects standard arguments.
111-
If other arguments need to be passed, use a lambda.
112-
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
113-
List of modules names not to convert. Defaults to `lm_head`.
114-
copy_weights (`bool`):
115-
Copy the weights from the old linear module to the new one
116-
post_processing_fun_name (`str`):
117-
A function name of the replacement linear class that is called
118-
after processing.
119-
"""
120-
for name, module in model.named_children():
121-
if len(list(module.children())) > 0:
122-
replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function)
123-
124-
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
125-
old_module = model._modules[name]
126-
model._modules[name] = linear_replacement(
127-
module.in_features,
128-
module.out_features,
129-
module.bias is not None,
130-
)
131-
if copy_weights:
132-
model._modules[name].weight = old_module.weight
133-
model._modules[name].bias = old_module.bias
134-
135-
if post_processing_function is not None:
136-
func = getattr(module, post_processing_function, None)
137-
if func is not None: func(module)
138-
return model
139-
140-
141102

142103
def execute_and_return(command_string: str) -> Tuple[str, str]:
143104
def _decode(subprocess_err_out_tuple):

0 commit comments

Comments
 (0)