Skip to content

Commit 1dbe602

Browse files
committed
Fix Params4bit tensor subclass handling
1 parent e54dc12 commit 1dbe602

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

bitsandbytes/nn/modules.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,46 @@ def to(self, *args, **kwargs):
356356

357357
return new_param
358358

359+
@classmethod
360+
def __torch_function__(cls, func, types, args=(), kwargs=None):
361+
if kwargs is None:
362+
kwargs = {}
363+
364+
if func in [torch.chunk, torch.split]:
365+
tensor = args[0]
366+
367+
result = super().__torch_function__(func, types, args, kwargs)
368+
369+
if isinstance(result, tuple):
370+
return tuple(
371+
cls(
372+
data=chunk,
373+
requires_grad=tensor.requires_grad,
374+
quant_state=tensor.quant_state,
375+
blocksize=tensor.blocksize,
376+
compress_statistics=tensor.compress_statistics,
377+
quant_type=tensor.quant_type,
378+
quant_storage=tensor.quant_storage,
379+
module=tensor.module,
380+
bnb_quantized=tensor.bnb_quantized,
381+
)
382+
for chunk in result
383+
)
384+
else:
385+
return cls(
386+
data=result,
387+
requires_grad=tensor.requires_grad,
388+
quant_state=tensor.quant_state,
389+
blocksize=tensor.blocksize,
390+
compress_statistics=tensor.compress_statistics,
391+
quant_type=tensor.quant_type,
392+
quant_storage=tensor.quant_storage,
393+
module=tensor.module,
394+
bnb_quantized=tensor.bnb_quantized,
395+
)
396+
397+
return super().__torch_function__(func, types, args, kwargs)
398+
359399

360400
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
361401
if getattr(module.weight, "quant_state", None) is not None:

0 commit comments

Comments
 (0)