Skip to content

Commit 70c8393

Browse files
authored
feat: Saving modules using the AOTI format (#3567)
1 parent 78b56b4 commit 70c8393

File tree

12 files changed

+545
-23
lines changed

12 files changed

+545
-23
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,6 @@ MODULE.bazel.lock
7878
*.whl
7979
.coverage
8080
coverage.xml
81-
*.log
81+
*.log
82+
*.pt2
83+
examples/torchtrt_aoti_example/torchtrt_aoti_example

docsrc/user_guide/runtime.rst

Lines changed: 130 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ programs just as you would otherwise via PyTorch API.
2424

2525
.. note:: If you are linking ``libtorchtrt_runtime.so``, likely using the following flags will help ``-Wl,--no-as-needed -ltorchtrt -Wl,--as-needed`` as there's no direct symbol dependency to anything in the Torch-TensorRT runtime for most Torch-TensorRT runtime applications
2626

27-
An example of how to use ``libtorchtrt_runtime.so`` can be found here: https://github.com/pytorch/TensorRT/tree/master/examples/torchtrt_runtime_example
27+
An example of how to use ``libtorchtrt_runtime.so`` can be found here: https://github.com/pytorch/TensorRT/tree/master/examples/torchtrt_aoti_example
2828

2929
Plugin Library
3030
---------------
@@ -87,8 +87,8 @@ Cudagraphs can accelerate certain models by reducing kernel overheads, as docume
8787
with torch_tensorrt.runtime.enable_cudagraphs(trt_module):
8888
...
8989
90-
In the current implementation, use of a new input shape (for instance in dynamic shape
91-
cases), will cause the cudagraph to be re-recorded. Cudagraph recording is generally
90+
In the current implementation, use of a new input shape (for instance in dynamic shape
91+
cases), will cause the cudagraph to be re-recorded. Cudagraph recording is generally
9292
not latency intensive, and future improvements include caching cudagraphs for multiple input shapes.
9393

9494
Dynamic Output Allocation Mode
@@ -101,11 +101,11 @@ Without dynamic output allocation, the output buffer is allocated based on the i
101101

102102
There are two scenarios in which dynamic output allocation is enabled:
103103

104-
1. The model has been identified at compile time to require dynamic output allocation for at least one TensorRT subgraph.
105-
These models will engage the runtime mode automatically (with logging) and are incompatible with other runtime modes
104+
1. The model has been identified at compile time to require dynamic output allocation for at least one TensorRT subgraph.
105+
These models will engage the runtime mode automatically (with logging) and are incompatible with other runtime modes
106106
such as CUDA Graphs.
107107

108-
Converters can declare that subgraphs that they produce will require the output allocator using `requires_output_allocator=True`
108+
Converters can declare that subgraphs that they produce will require the output allocator using `requires_output_allocator=True`
109109
there by forcing any model which utilizes the converter to automatically use the output allocator runtime mode. e.g.,
110110

111111
.. code-block:: python
@@ -131,3 +131,127 @@ there by forcing any model which utilizes the converter to automatically use the
131131
# Enables Dynamic Output Allocation Mode, then resets the mode to its prior setting
132132
with torch_tensorrt.runtime.enable_output_allocator(trt_module):
133133
...
134+
135+
Deploying Torch-TensorRT Programs without Python
136+
--------------------------------------------------------
137+
138+
AOT-Inductor
139+
~~~~~~~~~~~~~~~~
140+
141+
AOTInductor is a specialized version of TorchInductor, designed to process exported PyTorch models, optimize them, and produce shared
142+
libraries as well as other relevant artifacts. These compiled artifacts are specifically crafted for deployment in non-Python environments,
143+
which are frequently employed for inference deployments on the server side.
144+
145+
Torch-TensorRT is able to accelerate subgraphs within AOTInductor exports in the same way it does in Python.
146+
147+
.. code-block:: py
148+
149+
dynamo_model = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=[...])
150+
torch_tensorrt.save(
151+
dynamo_model,
152+
file_path=os.path.join(os.getcwd(), "model.pt2"),
153+
output_format="aot_inductor",
154+
retrace=True,
155+
arg_inputs=[...],
156+
)
157+
158+
This artifact then can be loaded in a C++ application to be executed with out a Python dependency.
159+
160+
.. code-block:: c++
161+
162+
#include <iostream>
163+
#include <vector>
164+
165+
#include "torch/torch.h"
166+
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"
167+
168+
int main(int argc, const char* argv[]) {
169+
// Check for correct number of command-line arguments
170+
std::string trt_aoti_module_path = "model.pt2";
171+
172+
if (argc == 2) {
173+
trt_aoti_module_path = argv[1];
174+
}
175+
176+
std::cout << trt_aoti_module_path << std::endl;
177+
178+
// Get the path to the TRT AOTI model package from the command line
179+
c10::InferenceMode mode;
180+
181+
torch::inductor::AOTIModelPackageLoader loader(trt_aoti_module_path);
182+
// Assume running on CUDA
183+
std::vector<torch::Tensor> inputs = {torch::randn({8, 10}, at::kCUDA)};
184+
std::vector<torch::Tensor> outputs = loader.run(inputs);
185+
std::cout << "Result from the first inference:"<< std::endl;
186+
std::cout << outputs << std::endl;
187+
188+
// The second inference uses a different batch size and it works because we
189+
// specified that dimension as dynamic when compiling model.pt2.
190+
std::cout << "Result from the second inference:"<< std::endl;
191+
// Assume running on CUDA
192+
std::cout << loader.run({torch::randn({1, 10}, at::kCUDA)}) << std::endl;
193+
194+
return 0;
195+
}
196+
197+
Note: Similar to Python, at runtime, no Torch-TensorRT APIs are used to operate the model. Therefore typically additional
198+
flags are needed to make sure that ``libtorchtrt_runtime.so`` gets optimized out (see above).
199+
200+
See: ``//examples/torchtrt_aoti_example`` for a full end to end demo of this workflow
201+
202+
203+
TorchScript
204+
~~~~~~~~~~~~~~
205+
206+
TorchScript is a legacy compiler stack for PyTorch that includes a Python-less interpreter for TorchScript programs.
207+
It has historically been used by Torch-TensorRT to execute models without Python. Even after the transition to TorchDynamo,
208+
the TorchScript interpreter can continue to be used to run PyTorch models with TensorRT engines outside of Python.
209+
210+
.. code-block:: py
211+
212+
dynamo_model = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=[...])
213+
ts_model = torch.jit.trace(dynamo_model, inputs=[...])
214+
torch.jit.save(ts_model, os.path.join(os.getcwd(), "model.ts"),)
215+
216+
This artifact then can be loaded in a C++ application to be executed with out a Python dependency.
217+
218+
.. code-block:: c++
219+
220+
#include <fstream>
221+
#include <iostream>
222+
#include <memory>
223+
#include <sstream>
224+
#include <vector>
225+
#include "torch/script.h"
226+
227+
int main(int argc, const char* argv[]) {
228+
if (argc < 2) {
229+
std::cerr << "usage: samplertapp <path-to-pre-built-trt-ts module>\n";
230+
return -1;
231+
}
232+
233+
std::string trt_ts_module_path = argv[1];
234+
235+
torch::jit::Module trt_ts_mod;
236+
try {
237+
// Deserialize the ScriptModule from a file using torch::jit::load().
238+
trt_ts_mod = torch::jit::load(trt_ts_module_path);
239+
} catch (const c10::Error& e) {
240+
std::cerr << "error loading the model from : " << trt_ts_module_path << std::endl;
241+
return -1;
242+
}
243+
244+
std::cout << "Running TRT engine" << std::endl;
245+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
246+
trt_inputs_ivalues.push_back(at::randint(-5, 5, {1, 3, 5, 5}, {at::kCUDA}).to(torch::kFloat32));
247+
torch::jit::IValue trt_results_ivalues = trt_ts_mod.forward(trt_inputs_ivalues);
248+
std::cout << "==================TRT outputs================" << std::endl;
249+
std::cout << trt_results_ivalues << std::endl;
250+
std::cout << "=============================================" << std::endl;
251+
std::cout << "TRT engine execution completed. " << std::endl;
252+
}
253+
254+
Note: Similar to Python, at runtime, no Torch-TensorRT APIs are used to operate the model. Therefore typically additional
255+
flags are needed to make sure that ``libtorchtrt_runtime.so`` gets optimized out (see above).
256+
257+
See: ``//examples/torchtrt_runtime_example`` for a full end to end demo of this workflow

docsrc/user_guide/saving_models.rst

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.sav
1414
Dynamo IR
1515
-------------
1616

17-
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18-
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by
17+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18+
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`), `ExportedProgram` (`torch.export.ExportedProgram`) or `PT2` formats by
1919
specifying the `output_format` flag. Here are the options `output_format` will accept
2020

2121
* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module.
2222
* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`.
23+
* `PT2 Format` : This is a next generation runtime for PyTorch models, allowing them to run in Python and in C++
2324

2425
a) ExportedProgram
2526
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -52,8 +53,8 @@ b) Torchscript
5253
model = MyModel().eval().cuda()
5354
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
5455
# trt_gm is a torch.fx.GraphModule object
55-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
56-
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
56+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=inputs)
57+
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", arg_inputs=inputs)
5758
5859
# Later, you can load it and run inference
5960
model = torch.jit.load("trt.ts").cuda()
@@ -73,7 +74,7 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
7374
7475
model = MyModel().eval().cuda()
7576
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
76-
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs=inputs) # Output is a ScriptModule object
77+
trt_ts = torch_tensorrt.compile(model, ir="ts", arg_inputs=inputs) # Output is a ScriptModule object
7778
torch.jit.save(trt_ts, "trt_model.ts")
7879
7980
# Later, you can load it and run inference
@@ -98,3 +99,26 @@ Here's an example usage
9899
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
99100
model = torch_tensorrt.load(<file_path>).module()
100101
model(*inputs)
102+
103+
b) PT2 Format
104+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
105+
106+
PT2 is a new format that allows models to be run outside of Python in the future. It utilizes `AOTInductor <https://docs.pytorch.org/docs/main/torch.compiler_aot_inductor.html>`_
107+
to generate kernels for components that will not be run in TensorRT.
108+
109+
Here's an example on how to save and load Torch-TensorRT Module using AOTInductor in Python
110+
111+
.. code-block:: python
112+
113+
import torch
114+
import torch_tensorrt
115+
116+
model = MyModel().eval().cuda()
117+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
118+
# trt_ep is a torch.fx.GraphModule object
119+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
120+
torch_tensorrt.save(trt_gm, "trt.pt2", arg_inputs=inputs, output_format="aot_inductor", retrace=True)
121+
122+
# Later, you can load it and run inference
123+
model = torch._inductor.aoti_load_package("trt.pt2")
124+
model(*inputs)

examples/torchtrt_aoti_example/BUILD

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
load("@rules_cc//cc:defs.bzl", "cc_binary")
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
config_setting(
6+
name = "use_torch_whl",
7+
flag_values = {
8+
"//toolchains/dep_src:torch": "whl"
9+
},
10+
)
11+
12+
config_setting(
13+
name = "jetpack",
14+
constraint_values = [
15+
"@platforms//cpu:aarch64",
16+
],
17+
flag_values = {
18+
"//toolchains/dep_collection:compute_libs": "jetpack"
19+
},
20+
)
21+
config_setting(
22+
name = "windows",
23+
constraint_values = [
24+
"@platforms//os:windows",
25+
],
26+
)
27+
28+
cc_binary(
29+
name = "torchtrt_aoti_example",
30+
srcs = [
31+
"inference.cpp"
32+
],
33+
linkopts = [
34+
"-ldl",
35+
],
36+
deps = [
37+
"//cpp:torch_tensorrt",
38+
] + select({
39+
":windows": [
40+
"@libtorch_win//:caffe2",
41+
"@libtorch_win//:libtorch"
42+
],
43+
":use_torch_whl": [
44+
"@torch_whl//:caffe2",
45+
"@torch_whl//:libtorch"
46+
],
47+
":jetpack": [
48+
"@torch_l4t//:caffe2",
49+
"@torch_l4t//:libtorch"
50+
],
51+
"//conditions:default": [
52+
"@libtorch",
53+
"@libtorch//:caffe2",
54+
],
55+
}),
56+
)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
2+
project(torchtrt_aoti_example LANGUAGES CXX)
3+
4+
find_package(Torch REQUIRED)
5+
find_package(torchtrt REQUIRED)
6+
7+
add_executable(torchtrt_aoti_example inference.cpp model.pt2)
8+
9+
add_custom_command(
10+
OUTPUT model.pt2
11+
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/model.py
12+
DEPENDS model.py
13+
)
14+
15+
target_link_libraries(torchtrt_aoti_example "${TORCH_LIBRARIES}" "-Wl,--no-as-needed" torchtrt_runtime "-Wl,--as-needed")
16+
set_property(TARGET torchtrt_aoti_example PROPERTY CXX_STANDARD 17)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
CXX=g++
2+
SITE_PACKAGES=$(shell python -c 'import site; print(site.getsitepackages()[0])')
3+
CUDA_HOME=/usr/local/cuda-12.8
4+
5+
INCLUDE_DIRS=-I$(SITE_PACKAGES)/torch/include -I$(SITE_PACKAGES)/torch_tensorrt/include -I$(CUDA_HOME)/include -I$(SITE_PACKAGES)/torch/include/torch/csrc/api/include
6+
7+
LIB_DIRS=-L$(SITE_PACKAGES)/torch_tensorrt/lib -L$(SITE_PACKAGES)/torch/lib -Wl,-rpath $(SITE_PACKAGES)/tensorrt_libs -L/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/lib
8+
LIBS=-Wl,--no-as-needed -ltorchtrt_runtime -ltorchtrt_plugins -Wl,--as-needed -ltorch -ltorch_cuda -ltorch_cpu -ltorch_global_deps -ltorch_cuda_linalg -lc10 -lc10_cuda -lshm -ltorch_global_deps -ltorch_python
9+
10+
SRCS=inference.cpp
11+
12+
TARGET=torchtrt_aoti_example
13+
14+
$(TARGET): *cpp
15+
$(CXX) $(SRCS) $(INCLUDE_DIRS) $(LIB_DIRS) $(LIBS) -o $(TARGET)
16+
echo "\n\nAdd to LD_LIBRARY_PATH: $(SITE_PACKAGES)/torch_tensorrt/lib:$(SITE_PACKAGES)/torch/lib:$(SITE_PACKAGES)/tensorrt_libs:$(CUDA_HOME)/lib64"
17+
18+
generate_pt2:
19+
python model.py
20+
21+
clean:
22+
$(RM) $(TARGET)

0 commit comments

Comments
 (0)