Skip to content

Commit 8acaf5f

Browse files
committed
Changes to TRT-LLM download tool for multigpu distributed case
1 parent 86d39ad commit 8acaf5f

File tree

7 files changed

+343
-363
lines changed

7 files changed

+343
-363
lines changed

examples/distributed_inference/tensor_parallel_initialize_dist.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

examples/distributed_inference/tensor_parallel_rotary_embedding.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616
import torch
1717
import torch_tensorrt
1818
from rotary_embedding import RotaryAttention, parallel_rotary_block
19-
from tensor_parallel_initialize_dist import (
19+
from torch.distributed import dist
20+
from torch_tensorrt.dynamo.distributed.utils import (
2021
cleanup_distributed_env,
22+
get_tensor_parallel_device_mesh,
2123
initialize_distributed_env,
24+
initialize_logger,
2225
)
2326

24-
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
25-
"./tensor_parallel_rotary_embedding"
26-
)
27+
if not dist.is_initialized():
28+
initialize_distributed_env()
2729

30+
device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
31+
logger = initialize_logger(_rank, "tensor_parallel_simple_example")
2832

2933
"""
3034
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning

examples/distributed_inference/tensor_parallel_simple_example.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,20 @@
3636
RowwiseParallel,
3737
parallelize_module,
3838
)
39-
40-
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
41-
"./tensor_parallel_simple_example"
39+
from torch_tensorrt.dynamo.distributed.utils import (
40+
cleanup_distributed_env,
41+
get_tensor_parallel_device_mesh,
42+
initialize_distributed_env,
43+
initialize_logger,
4244
)
4345

46+
if not dist.is_initialized():
47+
initialize_distributed_env()
48+
49+
device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
50+
logger = initialize_logger(_rank, "tensor_parallel_simple_example")
51+
52+
4453
"""
4554
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
4655
"""

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
1212
dynamo_tensorrt_converter,
1313
)
14+
from torch_tensorrt.dynamo.distributed.utils import load_tensorrt_llm_for_nccl
1415
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
1516
tensorrt_fused_nccl_all_gather_op,
1617
tensorrt_fused_nccl_reduce_scatter_op,
1718
)
18-
from torch_tensorrt.dynamo.utils import load_tensorrt_llm_for_nccl
1919

2020
_LOGGER: logging.Logger = logging.getLogger(__name__)
2121

0 commit comments

Comments
 (0)