Skip to content

Commit b610fe3

Browse files
committed
[fix]fix fp8 bug when load moe model
1 parent 32f7db7 commit b610fe3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lightllm/common/quantization/w8a8_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ def quantize_moe(self, weight):
9393
qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_)
9494
for i in range(num_experts):
9595
qweight, weight_scale = scaled_fp8_quant(
96-
weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False
96+
weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True
9797
)
9898
qweights[i] = qweight
9999
weight_scales.append(weight_scale)
100-
weight_scale = torch.cat(weight_scales, dim=0).reshape(-1)
100+
weight_scale = torch.stack(weight_scales, dim=0).contiguous()
101101
return qweights, weight_scale
102102

103103
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):

0 commit comments

Comments
 (0)