Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,20 @@ def forward(
)
elif self.fp8_quantization_type == Fp8QuantizationType.MxFp8:
current_hidden_states_scale = extra_inputs[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
current_hidden_states_scale = extra_inputs[0]
current_hidden_states_scale = hidden_states_scale


# During autotuner profiling, DynamicTensorSpec
# creates an undersized 1D scale tensor (it
# replaces dim 0 with the bucket value, but the
# MXFP8 scale is 1D with size num_tokens *
# hidden_size // 32). Recreate with correct size
# to avoid OOB reads in the MoE GEMM kernel.
padded_k = (current_hidden_size + 31) // 32 * 32
sf_size = current_num_tokens * padded_k // 32
if current_hidden_states_scale.numel() < sf_size:
current_hidden_states_scale = torch.ones(
(sf_size,),
dtype=torch.uint8,
device=hidden_states.device,
)
Comment on lines +1168 to +1173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a good catch to fix the illegal memory access error during autotuning. The logic to check and resize the current_hidden_states_scale tensor is correct.

However, as you noted, this fix may cause a performance regression. This is likely because filling the tensor with torch.ones does not provide realistic scale values for the autotuner. The autotuner might be selecting a suboptimal kernel based on this non-representative data.

To potentially resolve the performance regression, I suggest initializing the tensor with random data, which better simulates real-world scale values. This should help the autotuner find a more performant kernel.

Suggested change
if current_hidden_states_scale.numel() < sf_size:
current_hidden_states_scale = torch.ones(
(sf_size,),
dtype=torch.uint8,
device=hidden_states.device,
)
if current_hidden_states_scale.numel() < sf_size:
current_hidden_states_scale = torch.randint(
0, 256, (sf_size,),
dtype=torch.uint8,
device=hidden_states.device,
)

Comment on lines +1169 to +1173
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering the e8m0 data, maybe using 126-127 (0.5-2).

Suggested change
current_hidden_states_scale = torch.ones(
(sf_size,),
dtype=torch.uint8,
device=hidden_states.device,
)
if current_hidden_states_scale.numel() < sf_size:
current_hidden_states_scale = torch.randint(
126, 128, (sf_size,),
dtype=torch.uint8,
device=hidden_states.device,
)

else:
raise ValueError(
f"Unsupported FP8 quantization type: {self.fp8_quantization_type}"
Expand Down Expand Up @@ -1734,7 +1747,7 @@ def trtllm_fp8_block_scale_moe_op(
_, tactic = tuner.choose_one(
"flashinfer::trtllm_fp8_block_scale_moe",
[moe_runner],
MoERunner.tuning_config_with_hidden_states_scales, # FP8 block-scale uses hidden_states_scale
MoERunner.tuning_config_with_hidden_states_scales,
inputs,
routing_bias=routing_bias,
gemm1_weights=gemm1_weights,
Expand Down
Loading