Skip to content

Commit 0754bb5

Browse files
xinli-gitnpanpaliya
authored andcommitted
Fix Flashinfer Allreduce+Norm enable disable calculation based on fi_allreduce_fusion_max_token_num (vllm-project#21325)
Signed-off-by: XIn Li <[email protected]>
1 parent c58d816 commit 0754bb5

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def __call__(self, graph: fx.Graph):
159159
6: MiB // 2, # 512KB
160160
8: MiB // 2, # 512KB
161161
}
162+
# opt for a more conservative default value
163+
# when world size is not in _FI_MAX_SIZES
164+
_DEFAULT_FI_MAX_SIZE = MiB // 2
162165

163166
def call_trtllm_fused_allreduce_norm(
164167
allreduce_in: torch.Tensor,
@@ -173,12 +176,16 @@ def call_trtllm_fused_allreduce_norm(
173176
max_token_num: int,
174177
norm_out: Optional[torch.Tensor] = None,
175178
) -> None:
176-
use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[
177-
1] * allreduce_in.element_size() <= min(
178-
_FI_MAX_SIZES[world_size],
179-
max_token_num * allreduce_in.shape[0] *
180-
allreduce_in.element_size(),
181-
)
179+
180+
num_tokens, hidden_size = allreduce_in.shape
181+
element_size = allreduce_in.element_size()
182+
current_tensor_size = num_tokens * hidden_size * element_size
183+
max_fusion_size = max_token_num * hidden_size * element_size
184+
use_flashinfer = current_tensor_size <= min(
185+
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
186+
max_fusion_size,
187+
)
188+
182189
if use_flashinfer:
183190
assert (_FI_WORKSPACE_TENSOR is not None
184191
), "Flashinfer must be enabled when using flashinfer"

0 commit comments

Comments
 (0)