@@ -59,7 +59,7 @@ def npu_moe_init_routing(
5959 )
6060
6161 @staticmethod
62- def normalize_mxfp8_scale_layout (scale : torch .Tensor | None ) -> torch .Tensor | None :
62+ def maybe_normalize_mxfp_scale_layout (scale : torch .Tensor | None ) -> torch .Tensor | None :
6363 return scale
6464
6565 @staticmethod
@@ -233,7 +233,7 @@ def npu_moe_init_routing(
233233 )
234234
235235 @staticmethod
236- def normalize_mxfp8_scale_layout (scale : torch .Tensor | None ) -> torch .Tensor | None :
236+ def maybe_normalize_mxfp_scale_layout (scale : torch .Tensor | None ) -> torch .Tensor | None :
237237 if scale is None or scale .ndim != 2 :
238238 return scale
239239 if scale .shape [- 1 ] % 2 != 0 :
@@ -291,7 +291,7 @@ def npu_dynamic_quant(
291291 if dynamic_scale is None :
292292 hidden_states , dynamic_scale = torch_npu .npu_dynamic_mx_quant (hidden_states , dst_type = act_quant_type )
293293
294- return hidden_states , A5DeviceAdaptor .normalize_mxfp8_scale_layout (dynamic_scale )
294+ return hidden_states , A5DeviceAdaptor .maybe_normalize_mxfp_scale_layout (dynamic_scale )
295295
296296 @staticmethod
297297 def npu_grouped_matmul_swiglu_quant (
@@ -328,7 +328,7 @@ def npu_grouped_matmul_swiglu_quant(
328328 weight_scale_dtype = FLOAT8_E8M0FNU_DTYPE ,
329329 x_scale_dtype = FLOAT8_E8M0FNU_DTYPE ,
330330 )
331- return out , A5DeviceAdaptor .normalize_mxfp8_scale_layout (out_scale ), None
331+ return out , A5DeviceAdaptor .maybe_normalize_mxfp_scale_layout (out_scale ), None
332332
333333 @staticmethod
334334 def get_quant_gmm2_kwargs (
0 commit comments