diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index f9cf67c0ee9..af895fb3eee 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -17,21 +17,32 @@ def dequantize_and_run(func, args, kwargs): Also casts other floating point tensors to match the compute_dtype of GGMLTensors to avoid dtype mismatches in matrix operations. """ - # Find the compute_dtype from any GGMLTensor in the args + # Find the compute_dtype and target_device from any GGMLTensor in the args compute_dtype = None + target_device = None for a in args: if hasattr(a, "compute_dtype"): compute_dtype = a.compute_dtype + if isinstance(a, torch.Tensor) and target_device is None: + target_device = a.device + if compute_dtype is not None and target_device is not None: break - if compute_dtype is None: + if compute_dtype is None or target_device is None: for v in kwargs.values(): - if hasattr(v, "compute_dtype"): + if hasattr(v, "compute_dtype") and compute_dtype is None: compute_dtype = v.compute_dtype + if isinstance(v, torch.Tensor) and target_device is None: + target_device = v.device + if compute_dtype is not None and target_device is not None: break def process_tensor(t): if hasattr(t, "get_dequantized_tensor"): - return t.get_dequantized_tensor() + result = t.get_dequantized_tensor() + # Ensure the dequantized tensor is on the target device + if target_device is not None and result.device != target_device: + result = result.to(target_device) + return result elif isinstance(t, torch.Tensor) and compute_dtype is not None and t.is_floating_point(): # Cast other floating point tensors to match the GGUF compute_dtype return t.to(compute_dtype)