@@ -408,29 +408,32 @@ def __new__(cls, data, requires_grad=False, quant_type=None):
408408    def  as_tensor (self ):
409409        return  torch .Tensor ._make_subclass (torch .Tensor , self , self .requires_grad )
410410
411+     @staticmethod  
412+     def  _extract_quant_type (args ):
413+         # When converting from original format checkpoints we often use splits, cats etc on tensors 
414+         # this method ensures that the returned tensor type from those operations remains GGUFParameter 
415+         # so that we preserve quant_type information 
416+         for  arg  in  args :
417+             if  isinstance (arg , list ) and  isinstance (arg [0 ], GGUFParameter ):
418+                 return  arg [0 ].quant_type 
419+             if  isinstance (arg , GGUFParameter ):
420+                 return  arg .quant_type 
421+         return  None 
422+ 
411423    @classmethod  
412424    def  __torch_function__ (cls , func , types , args = (), kwargs = None ):
413425        if  kwargs  is  None :
414426            kwargs  =  {}
415427
416428        result  =  super ().__torch_function__ (func , types , args , kwargs )
417429
418-         # When converting from original format checkpoints we often use splits, cats etc on tensors 
419-         # this method ensures that the returned tensor type from those operations remains GGUFParameter 
420-         # so that we preserve quant_type information 
421-         quant_type  =  None 
422-         for  arg  in  args :
423-             if  isinstance (arg , list ) and  isinstance (arg [0 ], GGUFParameter ):
424-                 quant_type  =  arg [0 ].quant_type 
425-                 break 
426-             if  isinstance (arg , GGUFParameter ):
427-                 quant_type  =  arg .quant_type 
428-                 break 
429430        if  isinstance (result , torch .Tensor ):
431+             quant_type  =  cls ._extract_quant_type (args )
430432            return  cls (result , quant_type = quant_type )
431433        # Handle tuples and lists 
432-         elif  isinstance (result , ( tuple ,  list ) ):
434+         elif  type (result )  in  ( list ,  tuple ):
433435            # Preserve the original type (tuple or list) 
436+             quant_type  =  cls ._extract_quant_type (args )
434437            wrapped  =  [cls (x , quant_type = quant_type ) if  isinstance (x , torch .Tensor ) else  x  for  x  in  result ]
435438            return  type (result )(wrapped )
436439        else :
0 commit comments