Skip to content

Commit dd88017

Browse files
committed
Llama torchTRT lib and env initialization reorg
1 parent f00d349 commit dd88017

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

examples/distributed_inference/tensor_parallel_llama3.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,30 @@
55
import time
66

77
import torch
8-
import torch_tensorrt
8+
import torch.distributed as dist
99
from llama3_model import ModelArgs, ParallelTransformer
10+
from tensor_parallel_initialize_dist import (
11+
cleanup_distributed_env,
12+
initialize_distributed_env,
13+
)
1014
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
1115
from torch.distributed._composable.fsdp.fully_shard import fully_shard
1216
from torch.distributed._tensor import Replicate, Shard
1317
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1418
checkpoint_wrapper,
1519
)
20+
21+
if not dist.is_initialized():
22+
initialize_distributed_env()
23+
24+
import torch_tensorrt
1625
from torch_tensorrt.dynamo.distributed.utils import (
17-
cleanup_distributed_env,
1826
get_tensor_parallel_device_mesh,
19-
initialize_distributed_env,
2027
initialize_logger,
2128
)
2229

23-
if not dist.is_initialized():
24-
initialize_distributed_env()
25-
2630
device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
27-
logger = initialize_logger(_rank, "tensor_parallel_simple_example")
31+
logger = initialize_logger(_rank, "tensor_parallel_llama3")
2832

2933
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
3034
assert (

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,8 @@ def setup_input_tensors(
359359
need_cudagraphs_record: bool,
360360
) -> None:
361361
for i, input_name in enumerate(self.input_names):
362+
contiguous_inputs[i] = complex_to_ri_stacked_tensor(contiguous_inputs[i])
362363
if not contiguous_inputs[i].is_cuda:
363-
contiguous_inputs[i] = complex_to_ri_stacked_tensor(
364-
contiguous_inputs[i]
365-
)
366364
logger.warning(
367365
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
368366
"This tensor is being moved by the runtime but for performance considerations, "

0 commit comments

Comments
 (0)