Skip to content

Commit 3b2d2ef

Browse files
fix(gguf): ensure dequantized tensors are on correct device for MPS (#8713)
When using GGUF-quantized models on MPS (Apple Silicon), the dequantized tensors could end up on a different device than the other operands in math operations, causing "Expected all tensors to be on the same device" errors. This fix ensures that after dequantization, tensors are moved to the same device as the other tensors in the operation. Co-authored-by: Lincoln Stein <[email protected]>
1 parent 6697484 commit 3b2d2ef

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

invokeai/backend/quantization/gguf/ggml_tensor.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,32 @@ def dequantize_and_run(func, args, kwargs):
1717
Also casts other floating point tensors to match the compute_dtype of GGMLTensors
1818
to avoid dtype mismatches in matrix operations.
1919
"""
20-
# Find the compute_dtype from any GGMLTensor in the args
20+
# Find the compute_dtype and target_device from any GGMLTensor in the args
2121
compute_dtype = None
22+
target_device = None
2223
for a in args:
2324
if hasattr(a, "compute_dtype"):
2425
compute_dtype = a.compute_dtype
26+
if isinstance(a, torch.Tensor) and target_device is None:
27+
target_device = a.device
28+
if compute_dtype is not None and target_device is not None:
2529
break
26-
if compute_dtype is None:
30+
if compute_dtype is None or target_device is None:
2731
for v in kwargs.values():
28-
if hasattr(v, "compute_dtype"):
32+
if hasattr(v, "compute_dtype") and compute_dtype is None:
2933
compute_dtype = v.compute_dtype
34+
if isinstance(v, torch.Tensor) and target_device is None:
35+
target_device = v.device
36+
if compute_dtype is not None and target_device is not None:
3037
break
3138

3239
def process_tensor(t):
3340
if hasattr(t, "get_dequantized_tensor"):
34-
return t.get_dequantized_tensor()
41+
result = t.get_dequantized_tensor()
42+
# Ensure the dequantized tensor is on the target device
43+
if target_device is not None and result.device != target_device:
44+
result = result.to(target_device)
45+
return result
3546
elif isinstance(t, torch.Tensor) and compute_dtype is not None and t.is_floating_point():
3647
# Cast other floating point tensors to match the GGUF compute_dtype
3748
return t.to(compute_dtype)

0 commit comments

Comments
 (0)