-
Notifications
You must be signed in to change notification settings - Fork 533
feat: Add FP4 TRTLLM-Gen throughput MOE batched gemms #1882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -894,7 +894,9 @@ def __init__( | |
self.gated_act_type = gated_act_type | ||
self.tile_tokens_dim = tile_tokens_dim | ||
|
||
def get_tile_tokens_dim(self, num_tokens: int, top_k: int): | ||
def get_tile_tokens_dim( | ||
self, num_tokens: int, top_k: int, max_tile_tokens_dim: int = 128 | ||
): | ||
# Factor to account for the imbalance of the experts. | ||
# factor equals to the | ||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert | ||
|
@@ -910,10 +912,10 @@ def get_tile_tokens_dim(self, num_tokens: int, top_k: int): | |
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) | ||
# And pad the number to the next power of 2. | ||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) | ||
# Cap to 8-64 tokens per CTA tile | ||
# as it's the range supported by the kernel. | ||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) | ||
|
||
if num_tokens_per_expert > 128 and num_tokens_per_expert < 256: | ||
tile_tokens_dim = 192 | ||
# Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel. | ||
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim) | ||
return tile_tokens_dim | ||
Comment on lines
914
to
919
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for calculating For example: # in flashinfer/utils.py
def _calculate_tile_dim_from_tokens_per_expert(num_tokens_per_expert: int, max_tile_tokens_dim: int = 128) -> int:
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
if 128 < num_tokens_per_expert < 256:
tile_tokens_dim = 192
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
return tile_tokens_dim Then both |
||
|
||
def get_valid_tactics( | ||
|
@@ -931,7 +933,7 @@ def get_valid_tactics( | |
) = inputs | ||
num_tokens = routing_logits.shape[0] | ||
tile_tokens_dim = ( | ||
self.get_tile_tokens_dim(num_tokens, self.top_k) | ||
self.get_tile_tokens_dim(num_tokens, self.top_k, 128) | ||
if self.tile_tokens_dim is None | ||
else self.tile_tokens_dim | ||
) | ||
|
@@ -975,7 +977,7 @@ def forward( | |
) = inputs | ||
num_tokens = routing_logits.shape[0] | ||
tile_tokens_dim = ( | ||
self.get_tile_tokens_dim(num_tokens, self.top_k) | ||
self.get_tile_tokens_dim(num_tokens, self.top_k, 128) | ||
if self.tile_tokens_dim is None | ||
else self.tile_tokens_dim | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block for calculating
max_num_padded_tokens_gemm1
andmax_num_padded_tokens_gemm2
is duplicated intrtllm_fp4_block_scale_moe_launcher
(lines 774-779). In fact, the entire functiontrtllm_fp8_block_scale_moe_launcher
is very similar totrtllm_fp4_block_scale_moe_launcher
. To improve maintainability and reduce redundancy, consider refactoring the common logic into a templated helper function.