Skip to content

Commit 070f45d

Browse files
committed
cleanup commented out deletions
1 parent 781fcd5 commit 070f45d

File tree

1 file changed

+5
-39
lines changed

1 file changed

+5
-39
lines changed

bitsandbytes/nn/modules.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -163,30 +163,6 @@ def from_prequantized(cls, data, quantized_stats, requires_grad=False, device='c
163163
self.compress_statistics = self.quant_state.nested
164164
self.quant_type = self.quant_state.quant_type
165165
return self
166-
167-
# @classmethod
168-
# def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
169-
# data = state_dict.pop(prefix.rstrip('.'))
170-
171-
# # extracting components for QuantState from state_dict
172-
# qs_dict = {}
173-
# for k, v in state_dict.items():
174-
# if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
175-
# qs_dict[k] = v
176-
# state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
177-
# qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
178-
179-
# if data.device.type != "cuda":
180-
# raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
181-
182-
# cls.requires_grad = requires_grad
183-
# cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
184-
# cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
185-
# cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
186-
# cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
187-
188-
# self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
189-
# return self, state_dict
190166

191167
def cuda(self, device):
192168
w = self.data.contiguous().half().cuda(device)
@@ -227,7 +203,7 @@ def to(self, *args, **kwargs):
227203

228204
class Linear4bit(nn.Linear):
229205

230-
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
206+
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None):
231207
super().__init__(input_features, output_features, bias, device)
232208
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
233209
# self.persistent_buffers = [] # TODO consider as way to save quant state
@@ -261,18 +237,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
261237
for k, v in self.weight.quant_state.as_dict(packed=True).items():
262238
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
263239

264-
# def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
265-
# missing_keys, unexpected_keys, error_msgs):
266-
# # Note: super()._load_from_state_dict() is not called here intentionally.
267-
# if self.bias is not None:
268-
# bias_data = state_dict.pop(prefix + "bias", None)
269-
# self.bias.data = bias_data.to(self.bias.data.device)
270-
271-
# self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
272-
# state_dict, prefix=prefix + "weight" + ".", requires_grad=False
273-
# )
274-
# unexpected_keys.extend(state_dict.keys())
275-
276240
def forward(self, x: torch.Tensor):
277241
# weights are cast automatically as Int8Params, but the bias has to be cast manually
278242
if self.bias is not None and self.bias.dtype != x.dtype:
@@ -295,10 +259,12 @@ def forward(self, x: torch.Tensor):
295259

296260
return out
297261

262+
298263
class LinearFP4(Linear4bit):
299-
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
264+
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
300265
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
301266

267+
302268
class LinearNF4(Linear4bit):
303269
''' Implements the NF4 data type.
304270
@@ -310,7 +276,7 @@ class LinearNF4(Linear4bit):
310276
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
311277
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
312278
'''
313-
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
279+
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
314280
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
315281

316282

0 commit comments

Comments
 (0)