Skip to content

Commit e7cf5b2

Browse files
committed
formatting - integration test passes
1 parent b2534e5 commit e7cf5b2

File tree

6 files changed

+10
-11
lines changed

6 files changed

+10
-11
lines changed

cpp/tensorrt_llm/kernels/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
file(GLOB_RECURSE SRC_CPP *.cpp)
1919
file(GLOB_RECURSE SRC_CU *.cu)
2020

21-
# Explicitly add newly added kernel files to ensure they're included
22-
# (GLOB only runs at configure time, not build time)
21+
# Explicitly add newly added kernel files to ensure they're included (GLOB only
22+
# runs at configure time, not build time)
2323
list(APPEND SRC_CU ${CMAKE_CURRENT_SOURCE_DIR}/helixAllToAll.cu)
2424
list(REMOVE_DUPLICATES SRC_CU)
2525

cpp/tensorrt_llm/kernels/helixAllToAll.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ __device__ __forceinline__ uint64_t* getFifoBasePtr(HelixAllToAllParams const& p
203203
return mappedMemory + fifoOffset;
204204
}
205205

206-
__device__ __forceinline__ HelixFifoInfo* getSenderHelixFifoInfo(HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
206+
__device__ __forceinline__ HelixFifoInfo* getSenderHelixFifoInfo(
207+
HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
207208
{
208209
// SenderSideHelixFifoInfo is physically located at sender rank
209210
int mappedMemoryRank = pairInfo.senderRank;

tensorrt_llm/_mnnvl_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,7 @@ def open_mnnvl_memory(cls, mapping: Mapping, size: int):
251251

252252
for i, remote_handle_data in enumerate(all_handles_data):
253253
rank_ptr = (
254-
cls.current_start_address
255-
+ cls.current_rank_stride * i
256-
+ cls.current_mem_offset
254+
cls.current_start_address + cls.current_rank_stride * i + cls.current_mem_offset
257255
)
258256
if i == comm_rank:
259257
# Local memory mapping

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from torch import nn
99

10+
from tensorrt_llm._mnnvl_utils import HelixCpMnnvlMemory, MnnvlMemory
1011
from tensorrt_llm._torch.distributed.symm_mem_allreduce import \
1112
SymmetricMemoryAllReduce
1213
from tensorrt_llm._utils import mpi_comm, mpi_disabled
@@ -15,7 +16,6 @@
1516
AllReduceStrategy, MoEAllReduceParams)
1617
from tensorrt_llm.logger import logger
1718
from tensorrt_llm.mapping import Mapping
18-
from tensorrt_llm._mnnvl_utils import HelixCpMnnvlMemory, MnnvlMemory
1919
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
2020

2121
_thread_local = threading.local()
@@ -442,7 +442,6 @@ def alltoall_native(self, field0: torch.Tensor, field1: torch.Tensor):
442442
return field0_out, field1_out
443443

444444

445-
446445
def reducescatter(
447446
input: Union[torch.Tensor, List[torch.Tensor]],
448447
mapping: Mapping,

tensorrt_llm/_torch/modules/attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import math
22
import os
33
import weakref
4-
from typing import Dict, Optional, Union, cast
4+
from typing import Optional, Union, cast
55

66
import torch
77
from torch import nn
88

99
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
10-
from tensorrt_llm._mnnvl_utils import HelixCpMnnvlMemory, MnnvlMemory
1110
from tensorrt_llm._utils import (get_sm_version, is_sm_100f, nvtx_range,
1211
nvtx_range_debug)
1312
from tensorrt_llm.logger import logger

tests/unittest/_torch/modules/test_mla_helix.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,9 @@ def test_mla_helix_distributed(
878878
for use_nccl in [False, True]:
879879
nccl_mode = "NCCL" if use_nccl else "FIFO"
880880
print(f"\n{'=' * 60}")
881-
print(f"Testing with TRTLLM_USE_NCCL_FOR_HELIX={'1' if use_nccl else '0'} ({nccl_mode} mode)")
881+
print(
882+
f"Testing with TRTLLM_USE_NCCL_FOR_HELIX={'1' if use_nccl else '0'} ({nccl_mode} mode)"
883+
)
882884
print(f"{'=' * 60}\n")
883885
for scenario in all_scenarios[:11]:
884886
timing_steps = 256

0 commit comments

Comments
 (0)