Skip to content

Commit 5d27a64

Browse files
committed
rename mxfp scale format transformation function
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 41484f9 commit 5d27a64

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

vllm_ascend/device/device_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def npu_moe_init_routing(
5959
)
6060

6161
@staticmethod
62-
def normalize_mxfp8_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None:
62+
def maybe_normalize_mxfp_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None:
6363
return scale
6464

6565
@staticmethod
@@ -233,7 +233,7 @@ def npu_moe_init_routing(
233233
)
234234

235235
@staticmethod
236-
def normalize_mxfp8_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None:
236+
def maybe_normalize_mxfp_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None:
237237
if scale is None or scale.ndim != 2:
238238
return scale
239239
if scale.shape[-1] % 2 != 0:
@@ -291,7 +291,7 @@ def npu_dynamic_quant(
291291
if dynamic_scale is None:
292292
hidden_states, dynamic_scale = torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=act_quant_type)
293293

294-
return hidden_states, A5DeviceAdaptor.normalize_mxfp8_scale_layout(dynamic_scale)
294+
return hidden_states, A5DeviceAdaptor.maybe_normalize_mxfp_scale_layout(dynamic_scale)
295295

296296
@staticmethod
297297
def npu_grouped_matmul_swiglu_quant(
@@ -328,7 +328,7 @@ def npu_grouped_matmul_swiglu_quant(
328328
weight_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
329329
x_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
330330
)
331-
return out, A5DeviceAdaptor.normalize_mxfp8_scale_layout(out_scale), None
331+
return out, A5DeviceAdaptor.maybe_normalize_mxfp_scale_layout(out_scale), None
332332

333333
@staticmethod
334334
def get_quant_gmm2_kwargs(

vllm_ascend/ops/fused_moe/moe_mlp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ def quant_apply_mlp(
128128
quantized_hidden_states = None
129129
else:
130130
unquantized_hidden_states = None
131-
pertoken_scale = DeviceOperator.normalize_mxfp8_scale_layout(dynamic_scale) if use_mxfp_quant else dynamic_scale
131+
pertoken_scale = (
132+
DeviceOperator.maybe_normalize_mxfp_scale_layout(dynamic_scale) if use_mxfp_quant else dynamic_scale
133+
)
132134
quantized_hidden_states = hidden_states
133135

134136
bias1, bias2 = None, None

0 commit comments

Comments
 (0)