Skip to content

Commit 489a7b8

Browse files
committed
fix device error
1 parent ed7b7c7 commit 489a7b8

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

convert_hf_to_gguf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,6 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
335335

336336
# ref: https://github.com/vllm-project/compressed-tensors/blob/52792be02ec09e59f3517104e755a02d0e003fbb/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py
337337
def dequant_compressed_tensor(weight: Tensor, scale: Tensor) -> Tensor:
338-
scale = scale.float()
339338
weights_config = quant_config["config_groups"]["group_0"]["weights"]
340339
group_size = weights_config["group_size"]
341340
num_bits = weights_config["num_bits"]
@@ -349,15 +348,17 @@ def dequant_compressed_tensor(weight: Tensor, scale: Tensor) -> Tensor:
349348
mask = (1 << num_bits) - 1
350349
unpacked = torch.zeros(
351350
(weight.shape[0], weight.shape[1] * pack_factor),
352-
device=weight.device,
353351
dtype=torch.int32,
354352
)
355353
if self.lazy:
356354
unpacked = LazyTorchTensor.from_eager(unpacked)
355+
else:
356+
unpacked = unpacked.to(weight.device) # is this needed?
357357
for i in range(pack_factor):
358358
unpacked[:, i::pack_factor] = (weight >> (num_bits * i)) & mask
359359
# TODO: may need to unpad
360360
unpacked = unpacked - (mask + 1) // 2 # convert uint4 to int4 (shift scale)
361+
scale = scale.to(torch.float32)
361362
scale = scale.unsqueeze(2)
362363
unpacked = unpacked.to(torch.float32)
363364
unpacked = unpacked.reshape(-1, unpacked.shape[1] // group_size, group_size)

0 commit comments

Comments
 (0)