@@ -408,29 +408,32 @@ def __new__(cls, data, requires_grad=False, quant_type=None):
408408 def as_tensor (self ):
409409 return torch .Tensor ._make_subclass (torch .Tensor , self , self .requires_grad )
410410
411+ @staticmethod
412+ def _extract_quant_type (args ):
413+ # When converting from original format checkpoints we often use splits, cats etc on tensors
414+ # this method ensures that the returned tensor type from those operations remains GGUFParameter
415+ # so that we preserve quant_type information
416+ for arg in args :
417+ if isinstance (arg , list ) and isinstance (arg [0 ], GGUFParameter ):
418+ return arg [0 ].quant_type
419+ if isinstance (arg , GGUFParameter ):
420+ return arg .quant_type
421+ return None
422+
411423 @classmethod
412424 def __torch_function__ (cls , func , types , args = (), kwargs = None ):
413425 if kwargs is None :
414426 kwargs = {}
415427
416428 result = super ().__torch_function__ (func , types , args , kwargs )
417429
418- # When converting from original format checkpoints we often use splits, cats etc on tensors
419- # this method ensures that the returned tensor type from those operations remains GGUFParameter
420- # so that we preserve quant_type information
421- quant_type = None
422- for arg in args :
423- if isinstance (arg , list ) and isinstance (arg [0 ], GGUFParameter ):
424- quant_type = arg [0 ].quant_type
425- break
426- if isinstance (arg , GGUFParameter ):
427- quant_type = arg .quant_type
428- break
429430 if isinstance (result , torch .Tensor ):
431+ quant_type = cls ._extract_quant_type (args )
430432 return cls (result , quant_type = quant_type )
431433 # Handle tuples and lists
432- elif isinstance (result , ( tuple , list ) ):
434+ elif type (result ) in ( list , tuple ):
433435 # Preserve the original type (tuple or list)
436+ quant_type = cls ._extract_quant_type (args )
434437 wrapped = [cls (x , quant_type = quant_type ) if isinstance (x , torch .Tensor ) else x for x in result ]
435438 return type (result )(wrapped )
436439 else :
0 commit comments