File tree Expand file tree Collapse file tree 1 file changed +2
-8
lines changed
tensorrt_llm/_torch/models Expand file tree Collapse file tree 1 file changed +2
-8
lines changed Original file line number Diff line number Diff line change 1111from transformers .models .llama4 .modeling_llama4 import Llama4MultiModalProjector
1212
1313from tensorrt_llm ._torch .distributed import (AllReduce , AllReduceFusionOp ,
14- AllReduceParams , AllReduceStrategy ,
15- MoEAllReduce )
14+ AllReduceParams , MoEAllReduce )
1615from tensorrt_llm ._torch .models .checkpoints .base_weight_mapper import \
1716 BaseWeightMapper
1817from tensorrt_llm ._utils import get_sm_version
@@ -647,12 +646,7 @@ def __init__(
647646 eps = config .rms_norm_eps ,
648647 dtype = config .torch_dtype )
649648
650- # TODO: This is a temporary fix to disable oneshot kernel for pre-Blackwell arch to avoid perf regressions
651- self .all_reduce = AllReduce (
652- strategy = model_config .allreduce_strategy
653- if get_sm_version () >= 100 else AllReduceStrategy .NCCL ,
654- mapping = model_config .mapping ,
655- )
649+ self .all_reduce = AllReduce (mapping = model_config .mapping )
656650
657651 self .next_layer_layernorm : RMSNorm = None
658652 self .next_attn : LlamaAttention = None
You can’t perform that action at this time.
0 commit comments