@@ -356,6 +356,46 @@ def to(self, *args, **kwargs):
356356
357357 return new_param
358358
359+ @classmethod
360+ def __torch_function__ (cls , func , types , args = (), kwargs = None ):
361+ if kwargs is None :
362+ kwargs = {}
363+
364+ if func in [torch .chunk , torch .split ]:
365+ tensor = args [0 ]
366+
367+ result = super ().__torch_function__ (func , types , args , kwargs )
368+
369+ if isinstance (result , tuple ):
370+ return tuple (
371+ cls (
372+ data = chunk ,
373+ requires_grad = tensor .requires_grad ,
374+ quant_state = tensor .quant_state ,
375+ blocksize = tensor .blocksize ,
376+ compress_statistics = tensor .compress_statistics ,
377+ quant_type = tensor .quant_type ,
378+ quant_storage = tensor .quant_storage ,
379+ module = tensor .module ,
380+ bnb_quantized = tensor .bnb_quantized ,
381+ )
382+ for chunk in result
383+ )
384+ else :
385+ return cls (
386+ data = result ,
387+ requires_grad = tensor .requires_grad ,
388+ quant_state = tensor .quant_state ,
389+ blocksize = tensor .blocksize ,
390+ compress_statistics = tensor .compress_statistics ,
391+ quant_type = tensor .quant_type ,
392+ quant_storage = tensor .quant_storage ,
393+ module = tensor .module ,
394+ bnb_quantized = tensor .bnb_quantized ,
395+ )
396+
397+ return super ().__torch_function__ (func , types , args , kwargs )
398+
359399
360400def fix_4bit_weight_quant_state_from_module (module : Union ["Embedding4bit" , "Linear4bit" ]):
361401 if getattr (module .weight , "quant_state" , None ) is not None :
0 commit comments