Skip to content

Commit 015e149

Browse files
authored
[https://nvbugs/1234567][fix] Revert https://github.com/NVIDIA/TensorRT-LLM/pull/7768/files (#7813)
Signed-off-by: Tao Li
1 parent 22c120e commit 015e149

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
1212

1313
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
14-
AllReduceParams, AllReduceStrategy,
15-
MoEAllReduce)
14+
AllReduceParams, MoEAllReduce)
1615
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
1716
BaseWeightMapper
1817
from tensorrt_llm._utils import get_sm_version
@@ -647,12 +646,7 @@ def __init__(
647646
eps=config.rms_norm_eps,
648647
dtype=config.torch_dtype)
649648

650-
# TODO: This is a temporary fix to disable oneshot kernel for pre-Blackwell arch to avoid perf regressions
651-
self.all_reduce = AllReduce(
652-
strategy=model_config.allreduce_strategy
653-
if get_sm_version() >= 100 else AllReduceStrategy.NCCL,
654-
mapping=model_config.mapping,
655-
)
649+
self.all_reduce = AllReduce(mapping=model_config.mapping)
656650

657651
self.next_layer_layernorm: RMSNorm = None
658652
self.next_attn: LlamaAttention = None

0 commit comments

Comments
 (0)