Skip to content

Commit fa812a9

Browse files
authored
chore: Fixes required for LLM models (#3002)
1 parent 015f13b commit fa812a9

31 files changed

+1029
-353
lines changed

examples/dynamo/torch_export_gpt2.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
.. _torch_export_gpt2:
3+
4+
Compiling GPT2 using the Torch-TensorRT with dynamo backend
5+
==========================================================
6+
7+
This interactive script is intended as a sample of the Torch-TensorRT workflow with dynamo backend on a GPT2 model."""
8+
9+
# %%
10+
# Imports and Model Definition
11+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
12+
import torch
13+
import torch_tensorrt
14+
from transformers import AutoModelForCausalLM, AutoTokenizer
15+
from utils import export_llm, generate
16+
17+
# %%
18+
19+
# Define the parameters and initialize the model
20+
MAX_TOKENS = 32
21+
DEVICE = torch.device("cuda:0")
22+
23+
# Define the GPT2 model from hugging face
24+
# kv_cache is not supported in Torch-TRT currently.
25+
# CPU is used here so that GPU memory is reserved for TRT compilation.
26+
with torch.no_grad():
27+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
28+
model = AutoModelForCausalLM.from_pretrained(
29+
"gpt2",
30+
pad_token_id=tokenizer.eos_token_id,
31+
use_cache=False,
32+
attn_implementation="eager",
33+
).eval()
34+
35+
# %%
36+
# Tokenize a sample input prompt and get pytorch model outputs
37+
prompt = "I enjoy walking with my cute dog"
38+
model_inputs = tokenizer(prompt, return_tensors="pt")
39+
input_ids = model_inputs["input_ids"]
40+
41+
# Auto-regressive generation loop for greedy decoding using PyTorch model
42+
# We use a custom generate function which is very similar to the huggingface one.
43+
pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
44+
45+
46+
# %%
47+
# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs
48+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
49+
50+
# Export the GPT2 model into an ExportedProgram which is input of TRT compilation
51+
gpt2_ep = export_llm(model, input_ids, max_seq_len=1024)
52+
trt_model = torch_tensorrt.dynamo.compile(
53+
gpt2_ep,
54+
inputs=[input_ids],
55+
enabled_precisions={torch.float32},
56+
truncate_double=True,
57+
device=DEVICE,
58+
disable_tf32=True,
59+
)
60+
61+
# Auto-regressive generation loop for greedy decoding using TensorRT model
62+
# We use a custom generate function which is very similar to the huggingface one.
63+
# Move inputs to GPU
64+
input_ids = input_ids.to(DEVICE)
65+
trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
66+
67+
# %%
68+
# Decode the output sentences of PyTorch and TensorRT
69+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
70+
print("=============================")
71+
print(
72+
"Pytorch model generated text: ",
73+
tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
74+
)
75+
print("=============================")
76+
print(
77+
"TensorRT model generated text: ",
78+
tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
79+
)
80+
81+
# %%
82+
# The output sentences should look like
83+
# =============================
84+
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
85+
# =============================
86+
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
.. _torch_export_llama2:
3+
4+
Compiling Llama2 using the Torch-TensorRT with dynamo backend
5+
==========================================================
6+
7+
This interactive script is intended as a sample of the Torch-TensorRT workflow with dynamo backend on a Llama2 model."""
8+
9+
# %%
10+
# Imports and Model Definition
11+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
12+
import torch
13+
import torch_tensorrt
14+
from transformers import AutoModelForCausalLM, AutoTokenizer
15+
from utils import export_llm, generate
16+
17+
# %%
18+
# Define the parameters and initialize the model
19+
MAX_TOKENS = 32
20+
DEVICE = torch.device("cuda:0")
21+
22+
# Define the Llama2 model from hugging face
23+
# kv_cache is not supported in Torch-TRT currently.
24+
# CPU is used here so that GPU memory is reserved for TRT compilation.
25+
llama_path = "meta-llama/Llama-2-7b-chat-hf"
26+
with torch.no_grad():
27+
model = AutoModelForCausalLM.from_pretrained(
28+
llama_path, use_cache=False, attn_implementation="eager"
29+
).eval()
30+
31+
tokenizer = AutoTokenizer.from_pretrained(llama_path)
32+
33+
# %%
34+
# Tokenize a sample input prompt and get pytorch model outputs
35+
prompt = "What is dynamic programming?"
36+
model_inputs = tokenizer(prompt, return_tensors="pt")
37+
input_ids = model_inputs.input_ids
38+
39+
# Auto-regressive generation loop for greedy decoding using PyTorch model
40+
# We use a custom generate function which is very similar to the huggingface one.
41+
pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
42+
43+
# %%
44+
# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs
45+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
46+
47+
# Export the llama2 model into an ExportedProgram which is input of TRT compilation
48+
llama2_ep = export_llm(model, input_ids, max_seq_len=64)
49+
trt_model = torch_tensorrt.dynamo.compile(
50+
llama2_ep,
51+
inputs=[input_ids],
52+
enabled_precisions={torch.float32},
53+
min_block_size=1,
54+
truncate_double=True,
55+
device=DEVICE,
56+
disable_tf32=True,
57+
)
58+
59+
# Auto-regressive generation loop for greedy decoding using TensorRT model
60+
# We use a custom generate function which is very similar to the huggingface one.
61+
# Move inputs to GPU
62+
input_ids = input_ids.to(DEVICE)
63+
trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
64+
65+
# %%
66+
# Decode the output sentences of PyTorch and TensorRT
67+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
68+
print("=============================")
69+
print(
70+
"Pytorch model generated text: ",
71+
tokenizer.batch_decode(
72+
pyt_gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False
73+
)[0],
74+
)
75+
print("=============================")
76+
print(
77+
"TensorRT model generated text: ",
78+
tokenizer.batch_decode(
79+
trt_gen_tokens,
80+
skip_special_tokens=True,
81+
clean_up_tokenization_spaces=False,
82+
)[0],
83+
)
84+
85+
# %%
86+
# The output sentences should look like
87+
# =============================
88+
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
89+
# =============================
90+
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my

examples/dynamo/utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
from transformers import StoppingCriteriaList
3+
from transformers.generation.stopping_criteria import (
4+
EosTokenCriteria,
5+
MaxLengthCriteria,
6+
)
7+
8+
9+
def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
10+
"""
11+
Exports the LLM model into an ExportedProgram with dynamic shapes.
12+
In the case of guard failures due to some PyTorch kernel implements, we also
13+
try to re-export the graph by expressing them as runtime assert nodes
14+
"""
15+
with torch.no_grad():
16+
# max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
17+
seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
18+
try:
19+
print("Trying to export the model using torch.export.export()..")
20+
# strict=False only enables aotautograd tracing and excludes dynamo.
21+
ep = torch.export.export(
22+
model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False
23+
)
24+
except:
25+
print(
26+
"Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
27+
)
28+
# This API is used to express the constraint violation guards as asserts in the graph.
29+
ep = torch.export._trace._export(
30+
model,
31+
(inputs,),
32+
dynamic_shapes=({1: seq_len},),
33+
strict=False,
34+
allow_complex_guards_as_runtime_asserts=True,
35+
)
36+
37+
return ep
38+
39+
40+
def generate(model, input_seq, max_tokens, eos_token_id):
41+
"""
42+
Greedy decoding of the model. This generates up to max_tokens.
43+
"""
44+
# Max length of output seq = current input_seq length + max_tokens allowed to generate
45+
max_output_seq_length = input_seq.shape[1] + max_tokens
46+
stopping_criteria = StoppingCriteriaList(
47+
[
48+
MaxLengthCriteria(max_length=max_output_seq_length),
49+
EosTokenCriteria(eos_token_id=eos_token_id),
50+
]
51+
)
52+
53+
while True:
54+
outputs = model(input_seq)
55+
logits = outputs.logits
56+
next_token_logits = logits[:, -1, :]
57+
next_tokens = torch.argmax(next_token_logits, dim=-1)
58+
input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1)
59+
# TODO: Handle batch in this check
60+
if stopping_criteria(input_seq, logits).item():
61+
break
62+
63+
return input_seq

py/torch_tensorrt/dynamo/_DryRunTracker.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@ class PerSubgraphData:
2020
Args:
2121
subgraph_name (str): Name of the subgraph in the GraphModule
2222
subgraph_op_count (int): Number of operations in the subgraph
23-
subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph
24-
subgraph_input_dtypes (Any): Input data types of the subgraph
25-
subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph
26-
subgraph_output_dtypes (Any): Output data types of the subgraph
23+
input_shapes (Any): Shapes of input Tensors of the subgraph
24+
input_dtypes (Any): Input data types of the subgraph
25+
output_shapes (Any): Shapes of output Tensors of the subgraph
26+
output_dtypes (Any): Output data types of the subgraph
2727
"""
2828

2929
subgraph_name: str = ""
3030
subgraph_op_count: int = 0
31-
subgraph_input_shapes: Any = field(default_factory=list)
32-
subgraph_input_dtypes: Any = field(default_factory=list)
33-
subgraph_output_shapes: Any = field(default_factory=list)
34-
subgraph_output_dtypes: Any = field(default_factory=list)
31+
input_shapes: Any = field(default_factory=list)
32+
input_dtypes: Any = field(default_factory=list)
33+
output_shapes: Any = field(default_factory=list)
34+
output_dtypes: Any = field(default_factory=list)
3535

3636

3737
@dataclass
@@ -41,10 +41,10 @@ class DryRunTracker:
4141
Args:
4242
total_ops_in_graph (int): Total number of operators in graph
4343
supported_ops_in_graph (int): Number of supported operators in graph
44-
graph_input_shapes (Any): Shapes of input Tensors of the graph
45-
graph_input_dtypes (Any): Input data types of the graph
46-
graph_output_shapes (Any): Shapes of output Tensors of the graph
47-
graph_output_dtypes (Any): Output data types of the graph
44+
input_shapes (Any): Shapes of input Tensors of the graph
45+
input_dtypes (Any): Input data types of the graph
46+
output_shapes (Any): Shapes of output Tensors of the graph
47+
output_dtypes (Any): Output data types of the graph
4848
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
4949
tensorrt_graph_count (int): Number of TensorRT engines to be generated
5050
compilation_settings (CompilationSettings): User Compilation Settings
@@ -54,10 +54,10 @@ class DryRunTracker:
5454

5555
total_ops_in_graph: int = 0
5656
supported_ops_in_graph: int = 0
57-
graph_input_shapes: Any = field(default_factory=list)
58-
graph_input_dtypes: Any = field(default_factory=list)
59-
graph_output_shapes: Any = field(default_factory=list)
60-
graph_output_dtypes: Any = field(default_factory=list)
57+
input_shapes: Any = field(default_factory=list)
58+
input_dtypes: Any = field(default_factory=list)
59+
output_shapes: Any = field(default_factory=list)
60+
output_dtypes: Any = field(default_factory=list)
6161
per_subgraph_data: List[PerSubgraphData] = field(default_factory=list)
6262
tensorrt_graph_count: int = 0
6363
compilation_settings: CompilationSettings = field(
@@ -111,7 +111,7 @@ def dryrun_stats_display(
111111
formatted_stats += " " * 2 + "Graph Structure:\n\n"
112112
formatted_stats += (
113113
" " * 3
114-
+ f"Inputs: {input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}\n"
114+
+ f"Inputs: {input_formatter(dryrun_tracker.input_shapes, dryrun_tracker.input_dtypes)}\n"
115115
)
116116

117117
for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data):
@@ -122,21 +122,21 @@ def dryrun_stats_display(
122122
)
123123
formatted_stats += (
124124
" " * 5
125-
+ f"Engine Inputs: {input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}\n"
125+
+ f"Engine Inputs: {input_formatter(trt_subgraph_data.input_shapes, trt_subgraph_data.input_dtypes)}\n"
126126
)
127127
formatted_stats += (
128128
" " * 5
129129
+ f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n"
130130
)
131131
formatted_stats += (
132132
" " * 5
133-
+ f"Engine Outputs: {input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}\n"
133+
+ f"Engine Outputs: {input_formatter(trt_subgraph_data.output_shapes, trt_subgraph_data.output_dtypes)}\n"
134134
)
135135

136136
formatted_stats += " " * 4 + "...\n"
137137
formatted_stats += (
138138
" " * 3
139-
+ f"Outputs: {input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}\n"
139+
+ f"Outputs: {input_formatter(dryrun_tracker.output_shapes, dryrun_tracker.output_dtypes)}\n"
140140
)
141141

142142
# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
@@ -225,11 +225,20 @@ def input_formatter(shapes: Any, dtypes: Any) -> str:
225225

226226
def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
227227
"""Helper for input formatter"""
228-
# Base case - single shape, single dtype
229-
if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes):
230-
return f"Tensor: {shapes}@{str(dtypes)[6:]}, "
231-
232-
# Base case - dynamic shape, single dtype
228+
# Base case 1 - single static/dynamic shape, single dtype
229+
if isinstance(shapes, tuple) and all(
230+
isinstance(elt, (int, tuple)) for elt in shapes
231+
):
232+
input_shape_string = "Tensor: ("
233+
for elt in shapes:
234+
if isinstance(elt, tuple):
235+
input_shape_string += f"(min={elt[0]}, max={elt[1]}), "
236+
else:
237+
input_shape_string += f"{elt}, "
238+
input_shape_string = input_shape_string[:-2] + ")" + f"@{str(dtypes)[6:]}, "
239+
return input_shape_string
240+
241+
# Base case 2 - dynamic shape, single dtype
233242
elif (
234243
isinstance(shapes, dict)
235244
and len(shapes) == 3

0 commit comments

Comments
 (0)