|
5 | 5 | import time
|
6 | 6 |
|
7 | 7 | import torch
|
8 |
| -import torch_tensorrt |
| 8 | +import torch.distributed as dist |
9 | 9 | from llama3_model import ModelArgs, ParallelTransformer
|
| 10 | +from tensor_parallel_initialize_dist import ( |
| 11 | + cleanup_distributed_env, |
| 12 | + initialize_distributed_env, |
| 13 | +) |
10 | 14 | from torch.distributed._composable.fsdp import MixedPrecisionPolicy
|
11 | 15 | from torch.distributed._composable.fsdp.fully_shard import fully_shard
|
12 | 16 | from torch.distributed._tensor import Replicate, Shard
|
13 | 17 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
14 | 18 | checkpoint_wrapper,
|
15 | 19 | )
|
| 20 | + |
| 21 | +if not dist.is_initialized(): |
| 22 | + initialize_distributed_env() |
| 23 | + |
| 24 | +import torch_tensorrt |
16 | 25 | from torch_tensorrt.dynamo.distributed.utils import (
|
17 |
| - cleanup_distributed_env, |
18 | 26 | get_tensor_parallel_device_mesh,
|
19 |
| - initialize_distributed_env, |
20 | 27 | initialize_logger,
|
21 | 28 | )
|
22 | 29 |
|
23 |
| -if not dist.is_initialized(): |
24 |
| - initialize_distributed_env() |
25 |
| - |
26 | 30 | device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
|
27 |
| -logger = initialize_logger(_rank, "tensor_parallel_simple_example") |
| 31 | +logger = initialize_logger(_rank, "tensor_parallel_llama3") |
28 | 32 |
|
29 | 33 | logger.info(f"Starting PyTorch TP example on rank {_rank}.")
|
30 | 34 | assert (
|
|
0 commit comments