Skip to content

Commit 9845569

Browse files
hiworldwzjwangzaijun
andauthored
fix fp8 weight quant need contiguous tensor (#632)
Co-authored-by: wangzaijun <[email protected]>
1 parent bd0712e commit 9845569

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

lightllm/common/quantization/vllm_quant.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def __init__(self):
6262
def quantize(self, weight: torch.Tensor):
6363
if self.is_moe:
6464
return self.quantize_moe(weight)
65-
qweight, weight_scale = ops.scaled_fp8_quant(weight.cuda(), scale=None, use_per_token_if_dynamic=True)
65+
qweight, weight_scale = ops.scaled_fp8_quant(
66+
weight.contiguous().cuda(), scale=None, use_per_token_if_dynamic=True
67+
)
6668
return qweight.transpose(0, 1), weight_scale
6769

6870
def quantize_moe(self, weight):
@@ -71,7 +73,9 @@ def quantize_moe(self, weight):
7173
weight_scales = []
7274
qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda()
7375
for i in range(num_experts):
74-
qweight, weight_scale = ops.scaled_fp8_quant(weight[i].cuda(), scale=None, use_per_token_if_dynamic=False)
76+
qweight, weight_scale = ops.scaled_fp8_quant(
77+
weight[i].contiguous().cuda(), scale=None, use_per_token_if_dynamic=False
78+
)
7579
qweights[i] = qweight
7680
weight_scales.append(weight_scale)
7781
weight_scale = torch.cat(weight_scales, dim=0).reshape(-1)

0 commit comments

Comments
 (0)