Skip to content

Commit 2d5687c

Browse files
committed
Distributed llama3 example
1 parent 8acaf5f commit 2d5687c

13 files changed

+603
-234
lines changed

examples/distributed_inference/llama3_model.py

Lines changed: 496 additions & 0 deletions
Large diffs are not rendered by default.

examples/distributed_inference/rotary_embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,20 @@ def parallel_rotary_block(rotary_block, tp_mesh):
8484
"wk": ColwiseParallel(),
8585
"wo": RowwiseParallel(output_layouts=Shard(0)),
8686
}
87-
rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode
87+
rotary_block.n_parallel = tp_mesh.size()
8888

8989
parallelize_module(rotary_block, tp_mesh, plan)
9090

9191

9292
class RotaryAttention(nn.Module):
93-
def __init__(self, dim: int, seq_len: int):
93+
def __init__(self, dim: int, seq_len: int, n_parallel: int = 1):
9494
super().__init__()
9595
self.dim = dim
9696
self.wq = nn.Linear(dim, dim)
9797
self.wk = nn.Linear(dim, dim)
9898
self.wo = nn.Linear(dim, dim)
9999
self.seq_len = seq_len
100-
self.n_parallel = 1
100+
self.n_parallel = n_parallel
101101
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
102102
self.init_weights()
103103

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Taken and modified pytorch lightening
2+
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
3+
import logging
4+
import os
5+
import time
6+
7+
import torch
8+
import torch_tensorrt
9+
from llama3_model import ModelArgs, ParallelTransformer
10+
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
11+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
12+
from torch.distributed._tensor import Replicate, Shard
13+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
14+
checkpoint_wrapper,
15+
)
16+
from torch_tensorrt.dynamo.distributed.utils import (
17+
cleanup_distributed_env,
18+
get_tensor_parallel_device_mesh,
19+
initialize_distributed_env,
20+
initialize_logger,
21+
)
22+
23+
if not dist.is_initialized():
24+
initialize_distributed_env()
25+
26+
device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
27+
logger = initialize_logger(_rank, "tensor_parallel_simple_example")
28+
29+
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
30+
assert (
31+
_world_size % 2 == 0
32+
), f"TP examples require even number of GPUs, but got {_world_size} gpus"
33+
34+
model_args = ModelArgs(
35+
vocab_size=32000,
36+
dim=1024,
37+
n_layers=4,
38+
n_heads=8,
39+
rope_theta=500000.0,
40+
n_kv_heads=8,
41+
device="cuda",
42+
)
43+
44+
with torch.no_grad():
45+
model = ParallelTransformer(model_args, device_mesh)
46+
torch.manual_seed(0)
47+
inp = torch.randint(32000, (8, 256), device="cuda")
48+
python_result = model(inp)
49+
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
50+
model = torch.compile(
51+
model,
52+
fullgraph=True,
53+
backend="torch_tensorrt",
54+
options={
55+
"use_python_runtime": True,
56+
"use_distributed_mode_trace": True,
57+
"debug": True,
58+
},
59+
dynamic=False,
60+
)
61+
62+
start = time.time()
63+
output = model(inp)
64+
end = time.time()
65+
logger.info(f"Compilation time is {end-start}")
66+
assert (python_result - output).std() < 0.01, "Compilation result is not correct."
67+
68+
cleanup_distributed_env()

examples/distributed_inference/tensor_parallel_rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
DIM = 128
4242

4343
with torch.no_grad():
44-
model = RotaryAttention(DIM, SEQ_LEN)
44+
model = RotaryAttention(DIM, SEQ_LEN, device_mesh.size())
4545
parallel_rotary_block(model, device_mesh)
4646
device = torch.device("cuda", device_mesh.get_rank())
4747
model.to(device)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from torch_tensorrt._Device import Device
1414
from torch_tensorrt._enums import EngineCapability, dtype
1515
from torch_tensorrt._features import needs_cross_compile
16-
from torch_tensorrt._Input import Input
1716
from torch_tensorrt.dynamo import _defaults, partitioning
1817
from torch_tensorrt.dynamo._DryRunTracker import (
1918
DryRunTracker,
@@ -296,7 +295,6 @@ def cross_compile_for_windows(
296295
arg_inputs = [arg_inputs] # type: ignore
297296

298297
# Prepare torch_trt inputs
299-
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
300298
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
301299
device = to_torch_tensorrt_device(device)
302300
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
@@ -386,7 +384,6 @@ def cross_compile_for_windows(
386384
)
387385
trt_gm = compile_module(
388386
gm,
389-
trt_arg_inputs,
390387
trt_kwarg_inputs,
391388
settings,
392389
)
@@ -632,7 +629,6 @@ def compile(
632629
arg_inputs = [arg_inputs] # type: ignore
633630

634631
# Prepare torch_trt inputs
635-
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
636632
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
637633
device = to_torch_tensorrt_device(device)
638634
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
@@ -723,16 +719,13 @@ def compile(
723719
logger.warning(
724720
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
725721
)
726-
trt_gm = compile_module(
727-
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
728-
)
722+
trt_gm = compile_module(gm, trt_kwarg_inputs, settings, engine_cache)
729723
return trt_gm
730724

731725

732726
@fn_supports_debugger
733727
def compile_module(
734728
gm: torch.fx.GraphModule,
735-
sample_arg_inputs: Sequence[Input],
736729
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
737730
settings: CompilationSettings = CompilationSettings(),
738731
engine_cache: Optional[BaseEngineCache] = None,

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from torch_tensorrt.dynamo.utils import (
2323
is_tegra_platform,
2424
parse_dynamo_kwargs,
25-
prepare_inputs,
2625
set_log_level,
2726
)
2827

@@ -150,9 +149,6 @@ def _pretraced_backend(
150149

151150
logger.debug("Lowered Input graph:\n " + str(gm.graph))
152151

153-
torchtrt_inputs = prepare_inputs(
154-
torch_inputs, disable_memory_format_check=True
155-
)
156152
if settings.require_full_compilation:
157153
logger.warning(
158154
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
@@ -163,7 +159,6 @@ def _pretraced_backend(
163159
)
164160
trt_compiled = compile_module(
165161
gm,
166-
torchtrt_inputs,
167162
settings=settings,
168163
engine_cache=engine_cache,
169164
)

py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py

Lines changed: 0 additions & 105 deletions
This file was deleted.

py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py

Lines changed: 0 additions & 112 deletions
This file was deleted.

0 commit comments

Comments
 (0)