Skip to content

Commit da0c004

Browse files
author
Lance Liao
committed
Merge branch 'pr-12315'
2 parents 2c8b44c + ec5bdf0 commit da0c004

File tree

16 files changed

+938
-37
lines changed

16 files changed

+938
-37
lines changed

.github/workflows/blossom-ci.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ jobs:
191191
"litaotju",
192192
"liyuhannnnn",
193193
"lkomali",
194+
"longcheng-nv",
194195
"longlee0622",
195196
"lowsfer",
196197
"lucaslie",
@@ -293,6 +294,7 @@ jobs:
293294
"tcherckez-nvidia",
294295
"thorjohnsen",
295296
"tianyuxbear",
297+
"tianyuz-nv",
296298
"tiffany940107",
297299
"tijyojwad",
298300
"timlee0212",
@@ -332,11 +334,13 @@ jobs:
332334
"xueweilnvidia",
333335
"xupinjie",
334336
"xuwchen",
337+
"xwang233",
335338
"xxi-nv",
336339
"yali-arch",
337340
"yechank-nvidia",
338341
"yibinl-nvidia",
339342
"yifeizhang-c",
343+
"YihuiLu512",
340344
"yihwang-nv",
341345
"yijingl-nvidia",
342346
"yilin-void",

cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ std::optional<tensorrt_llm::runtime::ITensor::UniquePtr> from_torch(std::optiona
6868
class PyKvCacheManager : public tbk::BaseKVCacheManager
6969
{
7070
public:
71-
NB_TRAMPOLINE(tbk::BaseKVCacheManager, 30);
71+
NB_TRAMPOLINE(tbk::BaseKVCacheManager, 36);
7272

7373
// using BaseKVCacheManager::BaseKVCacheManager; // Inherit constructors
7474
void allocatePools(bool useUvm = false) override
@@ -255,6 +255,12 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
255255
{
256256
NB_OVERRIDE_PURE(flushIterationEvents);
257257
}
258+
259+
SizeType32 countReusableBlocks(VecUniqueTokens const& uniqueTokens, tb::LlmRequest const& llmRequest,
260+
bool onlyAllocated = false) const override
261+
{
262+
NB_OVERRIDE_PURE(countReusableBlocks, uniqueTokens, llmRequest, onlyAllocated);
263+
}
258264
};
259265

260266
// TODO: Deduplicate executor bindings KvCacheStats

tensorrt_llm/_torch/models/modeling_nemotron_h.py

Lines changed: 117 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from tensorrt_llm.logger import logger
3232

3333
from ..attention_backend import AttentionMetadata
34-
from ..distributed import AllReduce
34+
from ..distributed import AllReduce, AllReduceFusionOp, AllReduceParams
3535
from ..model_config import ModelConfig
3636
from ..modules.attention import Attention
3737
from ..modules.decoder_layer import DecoderLayer
@@ -59,6 +59,7 @@ def __init__(
5959
self,
6060
model_config: ModelConfig[NemotronHConfig],
6161
layer_idx: int,
62+
reduce_output: bool = True,
6263
):
6364
config = model_config.pretrained_config
6465
if isinstance(config.intermediate_size, list):
@@ -76,6 +77,7 @@ def __init__(
7677
activation=relu2,
7778
dtype=config.torch_dtype,
7879
config=model_config,
80+
reduce_output=reduce_output,
7981
)
8082
self.layer_idx = layer_idx
8183

@@ -119,7 +121,8 @@ def forward(
119121
) -> torch.Tensor:
120122
return super().forward(position_ids=None,
121123
hidden_states=hidden_states,
122-
attn_metadata=attn_metadata)
124+
attn_metadata=attn_metadata,
125+
**kwargs)
123126

124127

125128
# Ref code: https://huggingface.co/nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024/blob/main/modeling_nemotron_h.py#L818
@@ -130,6 +133,7 @@ def __init__(
130133
model_config: ModelConfig[PretrainedConfig],
131134
layer_idx: int,
132135
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream],
136+
reduce_output: bool = False,
133137
):
134138
super().__init__()
135139

@@ -226,8 +230,7 @@ def __init__(
226230
activation_type=self.activation_type,
227231
)
228232

229-
if not model_config.mapping.enable_attention_dp:
230-
# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
233+
if reduce_output:
231234
self.allreduce = AllReduce(
232235
mapping=model_config.mapping,
233236
strategy=model_config.allreduce_strategy,
@@ -324,8 +327,10 @@ def _compute_routed_output():
324327
final_hidden_states = shared_output + routed_output
325328

326329
# Perform all-reduce after combining outputs for multi-GPU support.
327-
if not self.enable_attention_dp and self.mapping.tp_size > 1:
328-
final_hidden_states = self.allreduce(final_hidden_states)
330+
if self.allreduce is not None:
331+
final_hidden_states = self.allreduce(
332+
final_hidden_states,
333+
all_reduce_params=kwargs.get('all_reduce_params'))
329334

330335
return final_hidden_states.view(orig_shape)
331336

@@ -341,6 +346,7 @@ def __init__(
341346
# * -> TransformerLayer
342347
layer_type: str,
343348
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream],
349+
fuse_allreduce_norm: bool = False,
344350
):
345351
super().__init__()
346352

@@ -373,6 +379,13 @@ def __init__(
373379
)
374380
self.is_nvfp4 = False
375381

382+
# fuse_allreduce_norm is the model-level flag. When enabled, ALL
383+
# layers defer mixer AllReduce to the next layer's pre_allreduce (or
384+
# the model's final_allreduce). Only layers 1+ create a pre_allreduce
385+
# module; layer 0's input is already reduced from the embedding.
386+
self.fuse_allreduce_norm = fuse_allreduce_norm
387+
self.is_moe_layer = (layer_type == "E")
388+
376389
self.norm = RMSNorm(
377390
hidden_size=config.hidden_size,
378391
eps=config.rms_norm_eps,
@@ -382,9 +395,22 @@ def __init__(
382395
quantize_type="nvfp4" if self.is_nvfp4 else None,
383396
# Enable high precision output for MoE layer (only with NVFP4).
384397
# It might be overridden in `_try_attach_nvfp4_scale` function.
385-
return_hp_output=layer_type == "E" and self.is_nvfp4,
398+
return_hp_output=self.is_moe_layer and self.is_nvfp4,
386399
)
387400

401+
if fuse_allreduce_norm and layer_idx > 0:
402+
self.pre_allreduce = AllReduce(
403+
mapping=model_config.mapping,
404+
strategy=model_config.allreduce_strategy,
405+
)
406+
407+
# Mixer creation. The fuse_allreduce_norm optimization is orthogonal
408+
# to AllReduce topology: Transformer/MoE gate it at forward time via
409+
# AllReduceParams; MLP/Mamba gate it at init time via reduce_output
410+
# (their base classes don't thread all_reduce_params through forward).
411+
has_tp_allreduce = (not model_config.mapping.enable_attention_dp
412+
and model_config.mapping.tp_size > 1)
413+
388414
if layer_type == "M":
389415
self.mixer = Mamba2Mixer(
390416
d_model=config.hidden_size,
@@ -399,19 +425,27 @@ def __init__(
399425
dtype=config.torch_dtype,
400426
config=model_config,
401427
)
428+
if fuse_allreduce_norm:
429+
self.mixer.out_proj.reduce_output = False
402430
elif layer_type == "-":
403-
self.mixer = MLPLayer(model_config, layer_idx)
431+
self.mixer = MLPLayer(
432+
model_config,
433+
layer_idx,
434+
reduce_output=not fuse_allreduce_norm,
435+
)
404436
elif layer_type == "*":
405437
self.mixer = TransformerLayer(
406438
model_config,
407439
layer_idx,
408-
reduce_output=not model_config.mapping.enable_attention_dp
409-
and model_config.mapping.tp_size > 1,
440+
reduce_output=has_tp_allreduce,
410441
)
411442
elif layer_type == "E":
412-
self.mixer = NemotronHMOE(model_config,
413-
layer_idx=layer_idx,
414-
aux_stream_dict=aux_stream_dict)
443+
self.mixer = NemotronHMOE(
444+
model_config,
445+
layer_idx=layer_idx,
446+
aux_stream_dict=aux_stream_dict,
447+
reduce_output=has_tp_allreduce,
448+
)
415449
else:
416450
raise ValueError(f"{layer_type} is not supported")
417451

@@ -436,7 +470,7 @@ def _try_attach_nvfp4_scale(self):
436470

437471
# Special handling for MoE layer: fetch shared_expert.up_proj.input_scale
438472
# as representation of the input scale.
439-
if self.layer_type == "E":
473+
if self.is_moe_layer:
440474
if (hasattr(self.mixer, "shared_experts")
441475
and self.mixer.shared_experts is not None
442476
and hasattr(self.mixer.shared_experts, "up_proj")
@@ -463,16 +497,50 @@ def forward(
463497
if residual is None:
464498
residual = torch.zeros_like(hidden_states)
465499

466-
if self.norm.return_hp_output:
500+
if hasattr(self, 'pre_allreduce'):
501+
norm = self.norm
502+
has_nvfp4_scale = hasattr(norm, 'nvfp4_scale')
503+
if norm.is_nvfp4 and has_nvfp4_scale and norm.return_hp_output:
504+
fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4
505+
elif norm.is_nvfp4 and has_nvfp4_scale:
506+
fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
507+
else:
508+
fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
509+
all_reduce_params = AllReduceParams(
510+
fusion_op=fusion_op,
511+
residual=residual,
512+
norm_weight=norm.weight,
513+
eps=norm.variance_epsilon,
514+
trigger_completion_at_end=False,
515+
**(dict(scale=norm.nvfp4_scale)
516+
if has_nvfp4_scale and norm.is_nvfp4 else {}),
517+
)
518+
result = self.pre_allreduce(hidden_states,
519+
all_reduce_params=all_reduce_params)
520+
if fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4:
521+
norm_out, act_fp4, act_sf, residual = result
522+
hidden_states = (Fp4QuantizedTensor(act_fp4, act_sf), norm_out)
523+
elif fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4:
524+
act_fp4, act_sf, residual = result
525+
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
526+
else:
527+
hidden_states, residual = result
528+
elif self.norm.return_hp_output:
467529
hidden_states, residual, high_precision_normed_output = self.norm(
468530
hidden_states, residual)
469531
hidden_states = (hidden_states, high_precision_normed_output)
470532
else:
471533
hidden_states, residual = self.norm(hidden_states, residual)
472-
hidden_states = self.mixer(hidden_states,
473-
attn_metadata,
474-
spec_metadata=spec_metadata,
475-
**kwargs)
534+
535+
# When fuse_allreduce_norm is active, tell Transformer/MoE mixers to
536+
# skip their own AllReduce (it is handled by pre_allreduce /
537+
# final_allreduce instead). MLP/Mamba ignore this kwarg; their
538+
# reduce_output was set at init time.
539+
mixer_kwargs = dict(spec_metadata=spec_metadata, **kwargs)
540+
if self.fuse_allreduce_norm:
541+
mixer_kwargs['all_reduce_params'] = AllReduceParams(
542+
enable_allreduce=False)
543+
hidden_states = self.mixer(hidden_states, attn_metadata, **mixer_kwargs)
476544

477545
if spec_metadata is not None and spec_metadata.is_layer_capture(
478546
self.layer_idx):
@@ -519,14 +587,20 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
519587
gather_output=True,
520588
)
521589

590+
self.fuse_allreduce_norm = (not model_config.mapping.enable_attention_dp
591+
and model_config.mapping.tp_size > 1)
592+
522593
# create layers
523594
layers = []
524595
for layer_idx, layer_type in enumerate(config.hybrid_override_pattern):
525596
layers.append(
526-
NemotronHLayer(model_config,
527-
layer_idx,
528-
layer_type,
529-
aux_stream_dict=self.aux_stream_dict))
597+
NemotronHLayer(
598+
model_config,
599+
layer_idx,
600+
layer_type,
601+
aux_stream_dict=self.aux_stream_dict,
602+
fuse_allreduce_norm=self.fuse_allreduce_norm,
603+
))
530604
self.layers = nn.ModuleList(layers)
531605
self.num_hidden_layers = config.num_hidden_layers
532606

@@ -537,6 +611,13 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
537611
dtype=config.torch_dtype,
538612
)
539613

614+
# AllReduce for fusing with final norm (after last layer's mixer)
615+
if self.fuse_allreduce_norm:
616+
self.final_allreduce = AllReduce(
617+
mapping=model_config.mapping,
618+
strategy=model_config.allreduce_strategy,
619+
)
620+
540621
def forward(
541622
self,
542623
attn_metadata: AttentionMetadata,
@@ -567,7 +648,19 @@ def forward(
567648
spec_metadata=spec_metadata,
568649
mamba_metadata=mamba_metadata,
569650
)
570-
hidden_states, _ = self.norm_f(hidden_states, residual)
651+
652+
if self.fuse_allreduce_norm:
653+
hidden_states, _ = self.final_allreduce(
654+
hidden_states,
655+
all_reduce_params=AllReduceParams(
656+
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
657+
residual=residual,
658+
norm_weight=self.norm_f.weight,
659+
eps=self.norm_f.variance_epsilon,
660+
trigger_completion_at_end=False,
661+
))
662+
else:
663+
hidden_states, _ = self.norm_f(hidden_states, residual)
571664
return hidden_states
572665

573666

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,13 @@ def __init__(
139139

140140
# Choose between flashinfer and native implementation. (default to flashinfer)
141141
self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype
142-
supported_head_dim_in_flashinfer = [64, 128]
143-
self._use_flashinfer = head_dim in supported_head_dim_in_flashinfer
142+
# TODO: Update head_dims and head_group_ratios once flashinfer is updated.
143+
supported_head_dims = [64, 128]
144+
supported_head_group_ratios = [1, 8, 16]
145+
head_group_ratio = (self.tp_nheads //
146+
self.tp_ngroups if self.tp_ngroups > 0 else 0)
147+
self._use_flashinfer = (head_dim in supported_head_dims and
148+
head_group_ratio in supported_head_group_ratios)
144149
# Stochastic rounding requires FlashInfer and fp16 cache
145150
self._use_stochastic_rounding = (
146151
config.quant_config.mamba_ssm_stochastic_rounding

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,7 @@ def create_py_executor_instance(
13351335
waiting_queue_policy = (scheduler_config.waiting_queue_policy
13361336
if scheduler_config is not None else
13371337
WaitingQueuePolicy.FCFS)
1338+
13381339
return PyExecutor(
13391340
resource_manager,
13401341
scheduler,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from .scheduler import (RequestScheduler, ScheduledRequests,
6868
SerializableSchedulerOutput, WaitingQueue,
6969
create_waiting_queue)
70-
from .scheduler.adp_router import ADPRouter, DefaultADPRouter
70+
from .scheduler.adp_router import ADPRouter
7171

7272
# Environment variable to specify iteration ranges for profiling start/stop.
7373
# Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..."
@@ -285,8 +285,7 @@ def __init__(
285285
virtual_memory_pools: Optional[dict] = None,
286286
hang_detection_timeout: Optional[int] = None,
287287
execution_stream: Optional[torch.cuda.Stream] = None,
288-
waiting_queue_policy: WaitingQueuePolicy = WaitingQueuePolicy.FCFS,
289-
adp_router: Optional[ADPRouter] = None):
288+
waiting_queue_policy: WaitingQueuePolicy = WaitingQueuePolicy.FCFS):
290289
super(PyExecutor, self).__init__()
291290
self.device_id = torch.cuda.current_device()
292291
self.global_rank = dist.rank
@@ -313,7 +312,6 @@ def __init__(
313312
self.model_engine = model_engine
314313
self.enable_attention_dp = model_engine.enable_attention_dp
315314
self.dist = dist
316-
self.adp_router: ADPRouter = (adp_router or DefaultADPRouter(dist=dist))
317315
self.sampler = sampler
318316
self.drafter = drafter
319317
self.draft_model_engine = getattr(self.drafter, "draft_model_engine",
@@ -387,6 +385,12 @@ def __init__(
387385
self.enable_kv_cache_reuse
388386
and self.kv_cache_manager.enable_partial_reuse)
389387

388+
self.adp_router: ADPRouter = ADPRouter.create(
389+
dist=self.dist,
390+
kv_cache_manager=self.kv_cache_manager,
391+
attention_dp_config=self.llm_args.attention_dp_config,
392+
)
393+
390394
self.max_input_len = max_input_len
391395
# _executor_loop private data
392396
self.max_num_active_requests = model_engine.get_max_num_sequences()
@@ -2573,6 +2577,9 @@ def _fetch_new_requests(
25732577

25742578
# 6. Schedule requests across ranks (DP only)
25752579
if self.enable_attention_dp:
2580+
if self.adp_router.needs_prefix_matches:
2581+
self.adp_router.gather_prefix_matches(new_requests)
2582+
25762583
all_ranks_new_requests, self.expected_num_active_requests = \
25772584
self.adp_router.route_requests(
25782585
all_rank_states, new_requests,

0 commit comments

Comments
 (0)