Skip to content

Commit ade51b4

Browse files
authored
chore: example fixes (#3176)
1 parent a66684c commit ade51b4

File tree

5 files changed

+132
-7
lines changed

5 files changed

+132
-7
lines changed

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ Model Zoo
135135
* :ref:`torch_compile_resnet`
136136
* :ref:`torch_compile_transformer`
137137
* :ref:`torch_compile_stable_diffusion`
138+
* :ref:`torch_compile_gpt2`
138139
* :ref:`torch_export_gpt2`
139140
* :ref:`torch_export_llama2`
140141
* :ref:`torch_export_sam2`
@@ -150,6 +151,7 @@ Model Zoo
150151
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
151152
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
152153
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
154+
tutorials/_rendered_examples/dynamo/torch_compile_gpt2
153155
tutorials/_rendered_examples/dynamo/torch_export_gpt2
154156
tutorials/_rendered_examples/dynamo/torch_export_llama2
155157
tutorials/_rendered_examples/dynamo/torch_export_sam2

examples/dynamo/README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Model Zoo
1717
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
1818
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
1919
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
20+
* :ref:`_torch_compile_gpt2`: Compiling a GPT2 model using ``torch.compile``
2021
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
2122
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
2223
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)

examples/dynamo/torch_compile_gpt2.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
.. _torch_compile_gpt2:
3+
4+
Compiling GPT2 using the Torch-TensorRT ``torch.compile`` frontend
5+
==========================================================
6+
7+
This example illustrates the state of the art model `GPT2 <https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf>`_ optimized using
8+
``torch.compile`` frontend of Torch-TensorRT. Install the following dependencies before compilation
9+
10+
.. code-block:: python
11+
12+
pip install -r requirements.txt
13+
14+
GPT2 is a causal (unidirectional) transformer pretrained using language modeling on a very large corpus of text data. In this example, we use the GPT2 model available at `HuggingFace <https://huggingface.co/docs/transformers/en/model_doc/gpt2>`_ and apply torch.compile on it to
15+
get the graph module representation of the graph. Torch-TensorRT converts this graph into an optimized TensorRT engine.
16+
"""
17+
18+
# %%
19+
# Import necessary libraries
20+
# -----------------------------
21+
import torch
22+
import torch_tensorrt
23+
from transformers import AutoModelForCausalLM, AutoTokenizer
24+
25+
# %%
26+
# Define the necessary parameters
27+
# -----------------------------
28+
# Torch-TensorRT requires a GPU for successful compilation of the model.
29+
# ``MAX_LENGTH`` is the maximum length the generated tokens can have. This corresponds to the length of the input prompt +
30+
# number of new tokens generated
31+
MAX_LENGTH = 32
32+
DEVICE = torch.device("cuda:0")
33+
34+
# %%
35+
# Model definition
36+
# -----------------------------
37+
# We use ``AutoModelForCausalLM`` class to load the pretrained GPT2 model from hugging face. ``kv_cache`` is not supported in Torch-TRT currently so ``use_cache=False``
38+
with torch.no_grad():
39+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
40+
model = (
41+
AutoModelForCausalLM.from_pretrained(
42+
"gpt2",
43+
pad_token_id=tokenizer.eos_token_id,
44+
use_cache=False,
45+
attn_implementation="eager",
46+
)
47+
.eval()
48+
.cuda()
49+
)
50+
51+
# %%
52+
# PyTorch inference
53+
# -----------------------------
54+
# Tokenize a sample input prompt and get pytorch model outputs
55+
prompt = "I enjoy walking with my cute dog"
56+
model_inputs = tokenizer(prompt, return_tensors="pt")
57+
input_ids = model_inputs["input_ids"].cuda()
58+
59+
# %%
60+
# The ``generate()`` API of the ``AutoModelForCausalLM`` class is used for auto-regressive generation with greedy decoding.
61+
pyt_gen_tokens = model.generate(
62+
input_ids,
63+
max_length=MAX_LENGTH,
64+
use_cache=False,
65+
pad_token_id=tokenizer.eos_token_id,
66+
)
67+
68+
# %%
69+
# Torch-TensorRT compilation and inference
70+
# -----------------------------
71+
# The input sequence length is dynamic, so we mark it using ``torch._dynamo.mark_dynamic`` API.
72+
# We provide a (min, max) range of this value so that TensorRT knows in advance what values to optimize for.
73+
# Usually, this would be the context length for the model. We start with ``min=2`` due to the `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
74+
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
75+
model.forward = torch.compile(
76+
model.forward,
77+
backend="tensorrt",
78+
dynamic=None,
79+
options={
80+
"enabled_precisions": {torch.float32},
81+
"disable_tf32": True,
82+
"min_block_size": 1,
83+
},
84+
)
85+
86+
# %%
87+
# Auto-regressive generation loop for greedy decoding using TensorRT model
88+
# The first token generation compiles the model using TensorRT and the second token
89+
# encounters recompilation (which is an issue currently that would be resolved in the future)
90+
trt_gen_tokens = model.generate(
91+
inputs=input_ids,
92+
max_length=MAX_LENGTH,
93+
use_cache=False,
94+
pad_token_id=tokenizer.eos_token_id,
95+
)
96+
97+
# %%
98+
# Decode the output sentences of PyTorch and TensorRT
99+
# -----------------------------
100+
print(
101+
"Pytorch model generated text: ",
102+
tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
103+
)
104+
print("=============================")
105+
print(
106+
"TensorRT model generated text: ",
107+
tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
108+
)
109+
110+
# %%
111+
# The output sentences should look like
112+
113+
"""
114+
Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
115+
=============================
116+
TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
117+
"""

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def _pretraced_backend(
8080
repair_input_aliasing(gm, settings)
8181

8282
# Remove sym_int placeholders and inputs
83-
remove_sym_nodes(gm, settings)
83+
remove_sym_nodes(gm, sample_inputs, settings)
84+
8485
torch_inputs = [
8586
input for input in sample_inputs if isinstance(input, torch.Tensor)
8687
]
@@ -91,7 +92,7 @@ def _pretraced_backend(
9192
# Invoke AOTAutograd to translate operators to aten
9293
gm = aot_export_joint_simple(
9394
gm,
94-
torch_inputs,
95+
sample_inputs,
9596
trace_joint=False,
9697
decompositions=get_decompositions(
9798
settings.enable_experimental_decompositions

py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from typing import Any, Sequence
23

34
import torch
45
from torch_tensorrt.dynamo._settings import CompilationSettings
@@ -7,15 +8,17 @@
78

89

910
def remove_sym_nodes(
10-
gm: torch.fx.GraphModule, settings: CompilationSettings
11+
gm: torch.fx.GraphModule,
12+
sample_inputs: Sequence[Any],
13+
settings: CompilationSettings,
1114
) -> torch.fx.GraphModule:
1215
"""Remove sym_int placeholders which get inserted due to torch.compile's
1316
dynamic=True behavior
1417
"""
1518
# Extract SymInt placeholder Tensors
16-
placeholder_sym_ints = [
17-
node
18-
for node in gm.graph.nodes
19+
placeholder_idx_sym_ints = [
20+
(idx, node)
21+
for idx, node in enumerate(gm.graph.nodes)
1922
if (
2023
node.op == "placeholder"
2124
and isinstance(node.type, type)
@@ -24,8 +27,9 @@ def remove_sym_nodes(
2427
)
2528
]
2629

27-
for node in placeholder_sym_ints:
30+
for idx, node in placeholder_idx_sym_ints:
2831
gm.graph.erase_node(node)
32+
sample_inputs.pop(idx)
2933

3034
gm.graph.lint()
3135
gm.recompile()

0 commit comments

Comments
 (0)