Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions invokeai/backend/quantization/gguf/ggml_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,32 @@ def dequantize_and_run(func, args, kwargs):
Also casts other floating point tensors to match the compute_dtype of GGMLTensors
to avoid dtype mismatches in matrix operations.
"""
# Find the compute_dtype from any GGMLTensor in the args
# Find the compute_dtype and target_device from any GGMLTensor in the args
compute_dtype = None
target_device = None
for a in args:
if hasattr(a, "compute_dtype"):
compute_dtype = a.compute_dtype
if isinstance(a, torch.Tensor) and target_device is None:
target_device = a.device
if compute_dtype is not None and target_device is not None:
break
if compute_dtype is None:
if compute_dtype is None or target_device is None:
for v in kwargs.values():
if hasattr(v, "compute_dtype"):
if hasattr(v, "compute_dtype") and compute_dtype is None:
compute_dtype = v.compute_dtype
if isinstance(v, torch.Tensor) and target_device is None:
target_device = v.device
if compute_dtype is not None and target_device is not None:
break

def process_tensor(t):
if hasattr(t, "get_dequantized_tensor"):
return t.get_dequantized_tensor()
result = t.get_dequantized_tensor()
# Ensure the dequantized tensor is on the target device
if target_device is not None and result.device != target_device:
result = result.to(target_device)
return result
elif isinstance(t, torch.Tensor) and compute_dtype is not None and t.is_floating_point():
# Cast other floating point tensors to match the GGUF compute_dtype
return t.to(compute_dtype)
Expand Down