@@ -159,6 +159,9 @@ def __call__(self, graph: fx.Graph):
159
159
6 : MiB // 2 , # 512KB
160
160
8 : MiB // 2 , # 512KB
161
161
}
162
+ # opt for a more conservative default value
163
+ # when world size is not in _FI_MAX_SIZES
164
+ _DEFAULT_FI_MAX_SIZE = MiB // 2
162
165
163
166
def call_trtllm_fused_allreduce_norm (
164
167
allreduce_in : torch .Tensor ,
@@ -173,12 +176,16 @@ def call_trtllm_fused_allreduce_norm(
173
176
max_token_num : int ,
174
177
norm_out : Optional [torch .Tensor ] = None ,
175
178
) -> None :
176
- use_flashinfer = allreduce_in .shape [0 ] * allreduce_in .shape [
177
- 1 ] * allreduce_in .element_size () <= min (
178
- _FI_MAX_SIZES [world_size ],
179
- max_token_num * allreduce_in .shape [0 ] *
180
- allreduce_in .element_size (),
181
- )
179
+
180
+ num_tokens , hidden_size = allreduce_in .shape
181
+ element_size = allreduce_in .element_size ()
182
+ current_tensor_size = num_tokens * hidden_size * element_size
183
+ max_fusion_size = max_token_num * hidden_size * element_size
184
+ use_flashinfer = current_tensor_size <= min (
185
+ _FI_MAX_SIZES .get (world_size , _DEFAULT_FI_MAX_SIZE ),
186
+ max_fusion_size ,
187
+ )
188
+
182
189
if use_flashinfer :
183
190
assert (_FI_WORKSPACE_TENSOR is not None
184
191
), "Flashinfer must be enabled when using flashinfer"
0 commit comments