diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index de82dcab079e..531fd612739c 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -408,6 +408,18 @@ def __new__(cls, data, requires_grad=False, quant_type=None): def as_tensor(self): return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad) + @staticmethod + def _extract_quant_type(args): + # When converting from original format checkpoints we often use splits, cats etc on tensors + # this method ensures that the returned tensor type from those operations remains GGUFParameter + # so that we preserve quant_type information + for arg in args: + if isinstance(arg, list) and isinstance(arg[0], GGUFParameter): + return arg[0].quant_type + if isinstance(arg, GGUFParameter): + return arg.quant_type + return None + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: @@ -415,22 +427,13 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): result = super().__torch_function__(func, types, args, kwargs) - # When converting from original format checkpoints we often use splits, cats etc on tensors - # this method ensures that the returned tensor type from those operations remains GGUFParameter - # so that we preserve quant_type information - quant_type = None - for arg in args: - if isinstance(arg, list) and isinstance(arg[0], GGUFParameter): - quant_type = arg[0].quant_type - break - if isinstance(arg, GGUFParameter): - quant_type = arg.quant_type - break if isinstance(result, torch.Tensor): + quant_type = cls._extract_quant_type(args) return cls(result, quant_type=quant_type) # Handle tuples and lists - elif isinstance(result, (tuple, list)): + elif type(result) in (list, tuple): # Preserve the original type (tuple or list) + quant_type = cls._extract_quant_type(args) wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result] return type(result)(wrapped) else: