|
| 1 | +""" |
| 2 | +.. _torch_compile_gpt2: |
| 3 | +
|
| 4 | +Compiling GPT2 using the Torch-TensorRT ``torch.compile`` frontend |
| 5 | +========================================================== |
| 6 | +
|
| 7 | +This example illustrates the state of the art model `GPT2 <https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf>`_ optimized using |
| 8 | +``torch.compile`` frontend of Torch-TensorRT. Install the following dependencies before compilation |
| 9 | +
|
| 10 | +.. code-block:: python |
| 11 | +
|
| 12 | + pip install -r requirements.txt |
| 13 | +
|
| 14 | +GPT2 is a causal (unidirectional) transformer pretrained using language modeling on a very large corpus of text data. In this example, we use the GPT2 model available at `HuggingFace <https://huggingface.co/docs/transformers/en/model_doc/gpt2>`_ and apply torch.compile on it to |
| 15 | +get the graph module representation of the graph. Torch-TensorRT converts this graph into an optimized TensorRT engine. |
| 16 | +""" |
| 17 | + |
| 18 | +# %% |
| 19 | +# Import necessary libraries |
| 20 | +# ----------------------------- |
| 21 | +import torch |
| 22 | +import torch_tensorrt |
| 23 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 24 | + |
| 25 | +# %% |
| 26 | +# Define the necessary parameters |
| 27 | +# ----------------------------- |
| 28 | +# Torch-TensorRT requires a GPU for successful compilation of the model. |
| 29 | +# ``MAX_LENGTH`` is the maximum length the generated tokens can have. This corresponds to the length of the input prompt + |
| 30 | +# number of new tokens generated |
| 31 | +MAX_LENGTH = 32 |
| 32 | +DEVICE = torch.device("cuda:0") |
| 33 | + |
| 34 | +# %% |
| 35 | +# Model definition |
| 36 | +# ----------------------------- |
| 37 | +# We use ``AutoModelForCausalLM`` class to load the pretrained GPT2 model from hugging face. ``kv_cache`` is not supported in Torch-TRT currently so ``use_cache=False`` |
| 38 | +with torch.no_grad(): |
| 39 | + tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| 40 | + model = ( |
| 41 | + AutoModelForCausalLM.from_pretrained( |
| 42 | + "gpt2", |
| 43 | + pad_token_id=tokenizer.eos_token_id, |
| 44 | + use_cache=False, |
| 45 | + attn_implementation="eager", |
| 46 | + ) |
| 47 | + .eval() |
| 48 | + .cuda() |
| 49 | + ) |
| 50 | + |
| 51 | +# %% |
| 52 | +# PyTorch inference |
| 53 | +# ----------------------------- |
| 54 | +# Tokenize a sample input prompt and get pytorch model outputs |
| 55 | +prompt = "I enjoy walking with my cute dog" |
| 56 | +model_inputs = tokenizer(prompt, return_tensors="pt") |
| 57 | +input_ids = model_inputs["input_ids"].cuda() |
| 58 | + |
| 59 | +# %% |
| 60 | +# The ``generate()`` API of the ``AutoModelForCausalLM`` class is used for auto-regressive generation with greedy decoding. |
| 61 | +pyt_gen_tokens = model.generate( |
| 62 | + input_ids, |
| 63 | + max_length=MAX_LENGTH, |
| 64 | + use_cache=False, |
| 65 | + pad_token_id=tokenizer.eos_token_id, |
| 66 | +) |
| 67 | + |
| 68 | +# %% |
| 69 | +# Torch-TensorRT compilation and inference |
| 70 | +# ----------------------------- |
| 71 | +# The input sequence length is dynamic, so we mark it using ``torch._dynamo.mark_dynamic`` API. |
| 72 | +# We provide a (min, max) range of this value so that TensorRT knows in advance what values to optimize for. |
| 73 | +# Usually, this would be the context length for the model. We start with ``min=2`` due to the `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_ |
| 74 | +torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023) |
| 75 | +model.forward = torch.compile( |
| 76 | + model.forward, |
| 77 | + backend="tensorrt", |
| 78 | + dynamic=None, |
| 79 | + options={ |
| 80 | + "enabled_precisions": {torch.float32}, |
| 81 | + "disable_tf32": True, |
| 82 | + "min_block_size": 1, |
| 83 | + }, |
| 84 | +) |
| 85 | + |
| 86 | +# %% |
| 87 | +# Auto-regressive generation loop for greedy decoding using TensorRT model |
| 88 | +# The first token generation compiles the model using TensorRT and the second token |
| 89 | +# encounters recompilation (which is an issue currently that would be resolved in the future) |
| 90 | +trt_gen_tokens = model.generate( |
| 91 | + inputs=input_ids, |
| 92 | + max_length=MAX_LENGTH, |
| 93 | + use_cache=False, |
| 94 | + pad_token_id=tokenizer.eos_token_id, |
| 95 | +) |
| 96 | + |
| 97 | +# %% |
| 98 | +# Decode the output sentences of PyTorch and TensorRT |
| 99 | +# ----------------------------- |
| 100 | +print( |
| 101 | + "Pytorch model generated text: ", |
| 102 | + tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True), |
| 103 | +) |
| 104 | +print("=============================") |
| 105 | +print( |
| 106 | + "TensorRT model generated text: ", |
| 107 | + tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True), |
| 108 | +) |
| 109 | + |
| 110 | +# %% |
| 111 | +# The output sentences should look like |
| 112 | + |
| 113 | +""" |
| 114 | +Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll |
| 115 | +============================= |
| 116 | +TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll |
| 117 | +""" |
0 commit comments