Skip to content

Commit 38771e6

Browse files
peri044chohk88zewenli98
authored
feat: Implement FP32 accumulation for matmul (#3110)
Co-authored-by: Hoonkyung Cho <[email protected]> Co-authored-by: Zewen (Evan) Li <[email protected]>
1 parent 451179b commit 38771e6

File tree

22 files changed

+314
-341
lines changed

22 files changed

+314
-341
lines changed

docsrc/index.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ User Guide
3737
* :ref:`saving_models`
3838
* :ref:`runtime`
3939
* :ref:`using_dla`
40+
* :ref:`mixed_precision`
4041

4142
.. toctree::
4243
:caption: User Guide
@@ -48,6 +49,7 @@ User Guide
4849
user_guide/saving_models
4950
user_guide/runtime
5051
user_guide/using_dla
52+
user_guide/mixed_precision
5153
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
5254
tutorials/_rendered_examples/dynamo/vgg16_ptq
5355
tutorials/_rendered_examples/dynamo/engine_caching_example
@@ -119,6 +121,8 @@ Tutorials
119121
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
120122
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
121123
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
124+
tutorials/_rendered_examples/dynamo/torch_export_gpt2
125+
tutorials/_rendered_examples/dynamo/torch_export_llama2
122126

123127
Python API Documentation
124128
------------------------
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
.. _mixed_precision:
2+
3+
Compile Mixed Precision models with Torch-TensorRT
4+
====================================
5+
.. currentmodule:: torch_tensorrt.dynamo
6+
7+
.. automodule:: torch_tensorrt.dynamo
8+
:members:
9+
:undoc-members:
10+
:show-inheritance:
11+
12+
Consider the following Pytorch model which explicitly casts intermediate layer to run in FP16.
13+
14+
.. code-block:: python
15+
16+
class MyModule(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
self.linear1 = torch.nn.Linear(10,10)
20+
self.linear2 = torch.nn.Linear(10,30).half()
21+
self.linear3 = torch.nn.Linear(30,40)
22+
23+
def forward(self, x):
24+
x = self.linear1(x)
25+
x = x.to(torch.float16)
26+
x = self.linear2(x)
27+
x = x.to(torch.float32)
28+
x = self.linear3(x)
29+
return x
30+
31+
32+
If we compile the above model using Torch-TensorRT, layer profiling logs indicate that all the layers are
33+
run in FP32. This is because TensorRT picks the kernels for layers which result in the best performance.
34+
35+
.. code-block:: python
36+
37+
inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()]
38+
mod = MyModule().eval().cuda()
39+
ep = torch.export.export(mod, tuple(inputs))
40+
with torch_tensorrt.logging.debug():
41+
trt_gm = torch_tensorrt.dynamo.compile(ep,
42+
inputs=inputs,
43+
debug=True)
44+
45+
# Debug log info
46+
# Layers:
47+
# Name: __myl_MulSum_myl0_0, LayerType: kgen, Inputs: [ { Name: __mye116_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Float }], TacticName: __myl_MulSum_0xfa6c1858aea1b13b03f90165d7149ec6, StreamId: 0, Metadata:
48+
# Name: __myl_AddResMulSum_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye131_dconst, Dimensions: [10,30], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Float }, { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_AddResMulSum_0xb3915d7ebfe48be45b6d49083479e12f, StreamId: 0, Metadata:
49+
# Name: __myl_AddResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye146_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_AddResMulSumAdd_0xcdd0085ad25f5f45ac5fafb72acbffd6, StreamId: 0, Metadata:
50+
51+
52+
In order to respect the types specified by the user in the model (eg: in this case, ``linear2`` layer to run in FP16), users can enable
53+
the compilation setting ``use_explicit_typing=True``. Compiling with this option results in the following TensorRT logs
54+
55+
.. note:: If you enable ``use_explicit_typing=True``, only torch.float32 is supported in the enabled_precisions.
56+
57+
.. code-block:: python
58+
59+
inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()]
60+
mod = MyModule().eval().cuda()
61+
ep = torch.export.export(mod, tuple(inputs))
62+
with torch_tensorrt.logging.debug():
63+
trt_gm = torch_tensorrt.dynamo.compile(ep,
64+
inputs=inputs,
65+
use_explicit_typing=True
66+
debug=True)
67+
68+
# Debug log info
69+
# Layers:
70+
# Name: __myl_MulSumAddCas_myl0_0, LayerType: kgen, Inputs: [ { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }, { Name: __mye112_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], TacticName: __myl_MulSumAddCas_0xacf8f5dd9be2f3e7bb09cdddeac6c936, StreamId: 0, Metadata:
71+
# Name: __myl_ResMulSumAddCas_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye127_dconst, Dimensions: [10,30], Format/Datatype: Half }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantHalf, Dimensions: [1,30], Format/Datatype: Half }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_ResMulSumAddCas_0x5a3b318b5a1c97b7d5110c0291481337, StreamId: 0, Metadata:
72+
# Name: __myl_ResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye142_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_ResMulSumAdd_0x3fad91127c640fd6db771aa9cde67db0, StreamId: 0, Metadata:
73+
74+
Now the ``linear2`` layer runs in FP16 as shown in the above logs.

examples/dynamo/README.rst

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
11
.. _torch_compile:
22

3-
Dynamo / ``torch.compile``
4-
----------------------------
3+
Torch-TensorRT Examples
4+
====================================
55

6-
Torch-TensorRT provides a backend for the new ``torch.compile`` API released in PyTorch 2.0. In the following examples we describe
7-
a number of ways you can leverage this backend to accelerate inference.
6+
Please refer to the following examples which demonstrate the usage of different features of Torch-TensorRT. We also provide
7+
examples of Torch-TensorRT compilation of select computer vision and language models.
88

9-
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
10-
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
9+
Dependencies
10+
------------------------------------
11+
12+
Please install the following external dependencies (assuming you already have correct `torch`, `torch_tensorrt` and `tensorrt` libraries installed (`dependencies <https://github.com/pytorch/TensorRT?tab=readme-ov-file#dependencies>`_))
13+
14+
.. code-block:: python
15+
16+
pip install -r requirements.txt
17+
18+
19+
Compiler Features
20+
------------------------------------
1121
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
12-
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
1322
* :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"`
1423
* :ref:`converter_overloading`: How to write custom converters and overload existing ones
1524
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
@@ -18,3 +27,11 @@ a number of ways you can leverage this backend to accelerate inference.
1827
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
1928
* :ref:`engine_caching_example`: Utilizing engine caching to speed up compilation times
2029
* :ref:`engine_caching_bert_example`: Demonstrating engine caching on BERT
30+
31+
Model Zoo
32+
------------------------------------
33+
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
34+
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
35+
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
36+
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
37+
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)

examples/dynamo/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
cupy==13.1.0
2-
torch>=2.4.0.dev20240503+cu121
3-
torch-tensorrt>=2.4.0.dev20240503+cu121
42
triton==2.3.0
3+
diffusers==0.30.3
4+
transformers==4.44.2

examples/dynamo/torch_export_gpt2.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,16 @@
2525
# CPU is used here so that GPU memory is reserved for TRT compilation.
2626
with torch.no_grad():
2727
tokenizer = AutoTokenizer.from_pretrained("gpt2")
28-
model = AutoModelForCausalLM.from_pretrained(
29-
"gpt2",
30-
pad_token_id=tokenizer.eos_token_id,
31-
use_cache=False,
32-
attn_implementation="eager",
33-
).eval()
28+
model = (
29+
AutoModelForCausalLM.from_pretrained(
30+
"gpt2",
31+
pad_token_id=tokenizer.eos_token_id,
32+
use_cache=False,
33+
attn_implementation="eager",
34+
)
35+
.eval()
36+
.half()
37+
)
3438

3539
# %%
3640
# Tokenize a sample input prompt and get pytorch model outputs
@@ -48,6 +52,10 @@
4852
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4953

5054
# Export the GPT2 model into an ExportedProgram which is input of TRT compilation
55+
# To compile the model in FP16, we do the following
56+
# 1) Cast the model to FP16 via model.half()
57+
# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation
58+
# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch)
5159
gpt2_ep = export_llm(model, input_ids, max_seq_len=1024)
5260
trt_model = torch_tensorrt.dynamo.compile(
5361
gpt2_ep,
@@ -56,6 +64,8 @@
5664
truncate_double=True,
5765
device=DEVICE,
5866
disable_tf32=True,
67+
use_explicit_typing=True,
68+
use_fp32_acc=True,
5969
)
6070

6171
# Auto-regressive generation loop for greedy decoding using TensorRT model
@@ -81,6 +91,10 @@
8191
# %%
8292
# The output sentences should look like
8393
# =============================
84-
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
94+
# Pytorch model generated text: What is parallel programming ?
95+
96+
# The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that
8597
# =============================
86-
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
98+
# TensorRT model generated text: What is parallel programming ?
99+
100+
# The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that

examples/dynamo/torch_export_llama2.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,13 @@
2424
# CPU is used here so that GPU memory is reserved for TRT compilation.
2525
llama_path = "meta-llama/Llama-2-7b-chat-hf"
2626
with torch.no_grad():
27-
model = AutoModelForCausalLM.from_pretrained(
28-
llama_path, use_cache=False, attn_implementation="eager"
29-
).eval()
27+
model = (
28+
AutoModelForCausalLM.from_pretrained(
29+
llama_path, use_cache=False, attn_implementation="eager"
30+
)
31+
.eval()
32+
.half()
33+
)
3034

3135
tokenizer = AutoTokenizer.from_pretrained(llama_path)
3236

@@ -45,15 +49,20 @@
4549
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4650

4751
# Export the llama2 model into an ExportedProgram which is input of TRT compilation
52+
# To compile the model in FP16, we do the following
53+
# 1) Cast the model to FP16 via model.half()
54+
# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation
55+
# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch)
4856
llama2_ep = export_llm(model, input_ids, max_seq_len=64)
4957
trt_model = torch_tensorrt.dynamo.compile(
5058
llama2_ep,
5159
inputs=[input_ids],
5260
enabled_precisions={torch.float32},
53-
min_block_size=1,
5461
truncate_double=True,
5562
device=DEVICE,
5663
disable_tf32=True,
64+
use_explicit_typing=True,
65+
use_fp32_acc=True,
5766
)
5867

5968
# Auto-regressive generation loop for greedy decoding using TensorRT model
@@ -85,6 +94,6 @@
8594
# %%
8695
# The output sentences should look like
8796
# =============================
88-
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
97+
# Pytorch model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and
8998
# =============================
90-
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
99+
# TensorRT model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def compile(
8888
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
8989
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
9090
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
91+
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
92+
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
9193
**kwargs: Any,
9294
) -> torch.fx.GraphModule:
9395
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -158,6 +160,8 @@ def compile(
158160
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
159161
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
160162
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
163+
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
164+
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
161165
**kwargs: Any,
162166
Returns:
163167
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -197,6 +201,20 @@ def compile(
197201
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
198202
)
199203

204+
if use_explicit_typing:
205+
if len(enabled_precisions) != 1 or not any(
206+
x in enabled_precisions for x in {torch.float32, dtype.f32}
207+
):
208+
raise AssertionError(
209+
f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
210+
)
211+
212+
if use_fp32_acc:
213+
logger.debug(
214+
"FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
215+
This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation."
216+
)
217+
200218
# Aliasing inputs to arg_inputs for better understanding
201219
if not arg_inputs and not inputs:
202220
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
@@ -232,7 +250,7 @@ def compile(
232250
logger.debug("Input graph: " + str(gm.graph))
233251

234252
# Apply lowering on the graph module
235-
gm = post_lowering(gm)
253+
gm = post_lowering(gm, use_fp32_acc=use_fp32_acc)
236254
logger.debug("Lowered Input graph: " + str(gm.graph))
237255

238256
engine_cache = None
@@ -281,6 +299,8 @@ def compile(
281299
"lazy_engine_init": lazy_engine_init,
282300
"cache_built_engines": cache_built_engines,
283301
"reuse_cached_engines": reuse_cached_engines,
302+
"use_explicit_typing": use_explicit_typing,
303+
"use_fp32_acc": use_fp32_acc,
284304
}
285305

286306
settings = CompilationSettings(**compilation_options)
@@ -522,6 +542,8 @@ def convert_exported_program_to_serialized_trt_engine(
522542
calibrator: object = None,
523543
allow_shape_tensors: bool = False,
524544
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
545+
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
546+
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
525547
**kwargs: Any,
526548
) -> bytes:
527549
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -580,6 +602,8 @@ def convert_exported_program_to_serialized_trt_engine(
580602
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
581603
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
582604
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
605+
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
606+
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
583607
Returns:
584608
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
585609
"""
@@ -653,6 +677,8 @@ def convert_exported_program_to_serialized_trt_engine(
653677
"dla_local_dram_size": dla_local_dram_size,
654678
"dla_global_dram_size": dla_global_dram_size,
655679
"timing_cache_path": timing_cache_path,
680+
"use_explicit_typing": use_explicit_typing,
681+
"use_fp32_acc": use_fp32_acc,
656682
}
657683

658684
exported_program = pre_export_lowering(exported_program)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
4141
ENGINE_CACHE_SIZE = 1073741824
4242
CUSTOM_ENGINE_CACHE = None
43+
USE_EXPLICIT_TYPING = False
44+
USE_FP32_ACC = False
4345

4446

4547
def default_device() -> Device:

0 commit comments

Comments
 (0)