Skip to content

Commit df9dd6a

Browse files
committed
Merge branch 'master' of https://github.com/NVIDIA/Torch-TensorRT into support_aten_format
2 parents 7ec5c73 + a10613e commit df9dd6a

File tree

123 files changed

+36283
-24829
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

123 files changed

+36283
-24829
lines changed

.bazelrc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ build:pre_cxx11_abi --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
3535
build:pre_cxx11_abi --linkopt="-D_GLIBCXX_USE_CXX11_ABI=0"
3636
build:pre_cxx11_abi --define=abi=pre_cxx11_abi
3737

38-
build:ci_testing --define=torchtrt_src=pre_built --cxxopt="-DDISABLE_TEST_IN_CI" --action_env "NVIDIA_TF32_OVERRIDE=0"
39-
build:use_precompiled_torchtrt --define=torchtrt_src=pre_built
38+
build:ci_testing --define=torchtrt_src=prebuilt --cxxopt="-DDISABLE_TEST_IN_CI" --action_env "NVIDIA_TF32_OVERRIDE=0"
39+
build:use_precompiled_torchtrt --define=torchtrt_src=prebuilt
4040

41-
test:ci_testing --define=torchtrt_src=pre_built --cxxopt="-DDISABLE_TEST_IN_CI" --action_env "NVIDIA_TF32_OVERRIDE=0"
42-
test:use_precompiled_torchtrt --define=torchtrt_src=pre_built
41+
test:ci_testing --define=torchtrt_src=prebuilt --cxxopt="-DDISABLE_TEST_IN_CI" --action_env "NVIDIA_TF32_OVERRIDE=0"
42+
test:use_precompiled_torchtrt --define=torchtrt_src=prebuilt

.github/workflows/docgen.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
- name: Generate New Docs
3737
run: |
3838
cd docsrc
39+
pip3 install -r requirements.txt
3940
python3 -c "import torch_tensorrt; print(torch_tensorrt.__version__)"
4041
make html
4142
- uses: stefanzweifel/git-auto-commit-action@v4

core/conversion/converters/impl/activation.cpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -166,39 +166,7 @@ auto acthardtanh TORCHTRT_UNUSED =
166166
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
167167
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
168168
return true;
169-
}})
170-
.pattern({"aten::gelu(Tensor self) -> (Tensor)",
171-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
172-
auto in = args[0].ITensorOrFreeze(ctx);
173-
nvinfer1::DataType type = in->getType();
174-
TORCHTRT_CHECK(
175-
type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF,
176-
"gelu only supports kFLOAT and kHALF");
177-
std::string pluginName = "CustomGeluPluginDynamic";
178-
nvinfer1::PluginFieldCollection fc;
179-
std::vector<nvinfer1::PluginField> f;
180-
// REVIEW is this right?
181-
int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) ==
182-
ctx->settings.enabled_precisions.end()
183-
? 0
184-
: 1; // Integer encoding the DataType (0: FP32, 1: FP16)
185-
f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));
186-
fc.nbFields = f.size();
187-
fc.fields = f.data();
188-
189-
auto creator = getPluginRegistry()->getPluginCreator("CustomGeluPluginDynamic", "1", "");
190-
auto gelu_plugin = creator->createPlugin("gelu", &fc);
191-
192-
TORCHTRT_CHECK(gelu_plugin, "Unable to create gelu plugin from TensorRT plugin registry" << *n);
193-
auto new_layer =
194-
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *gelu_plugin);
195-
new_layer->setName(util::node_info(n).c_str());
196-
auto out_tensor = new_layer->getOutput(0);
197-
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
198-
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
199-
return true;
200169
}});
201-
202170
} // namespace
203171
} // namespace impl
204172
} // namespace converters

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4343
passes::UnpackHardSwish(g);
4444
passes::EliminateExceptionOrPassPattern(g);
4545
passes::ReduceToOperation(g);
46+
passes::ReduceGelu(g);
4647
passes::RemoveContiguous(g);
4748
passes::RemoveDropout(g);
4849
passes::LinearToAddMM(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cc_library(
1717
"module_fallback.cpp",
1818
"op_aliasing.cpp",
1919
"reduce_to.cpp",
20+
"reduce_gelu.cpp",
2021
"remove_bn_dim_check.cpp",
2122
"remove_contiguous.cpp",
2223
"remove_dropout.cpp",

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
2020
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
2121
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
2222
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
23+
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);
2324
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims);
2425
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
2526
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);

core/lowering/passes/reduce_gelu.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
#include "core/util/prelude.h"
3+
4+
namespace torch_tensorrt {
5+
namespace core {
6+
namespace lowering {
7+
namespace passes {
8+
9+
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
10+
std::string gelu_pattern = R"IR(
11+
graph(%x):
12+
%out : Tensor = aten::gelu(%x)
13+
return (%out))IR";
14+
15+
std::string gelu_reduce_pattern = R"IR(
16+
graph(%x.1 : Tensor):
17+
%6 : float = prim::Constant[value=0.044714999999999998]()
18+
%5 : float = prim::Constant[value=0.79788456080000003]()
19+
%4 : float = prim::Constant[value=1.]()
20+
%3 : float = prim::Constant[value=0.5]()
21+
%2 : int = prim::Constant[value=1]()
22+
%7 : Tensor = aten::mul(%x.1, %3)
23+
%8 : Tensor = aten::mul(%x.1, %5)
24+
%9 : Tensor = aten::mul(%x.1, %6)
25+
%10 : Tensor = aten::mul(%9, %x.1)
26+
%11 : Tensor = aten::add(%10, %4, %2)
27+
%12 : Tensor = aten::mul(%8, %11)
28+
%13 : Tensor = aten::tanh(%12)
29+
%14 : Tensor = aten::add(%13, %4, %2)
30+
%15 : Tensor = aten::mul(%7, %14)
31+
return (%15))IR";
32+
33+
// replace aten::gelu with pointwise operations
34+
torch::jit::SubgraphRewriter map_gelu_to_pointwise_ops;
35+
map_gelu_to_pointwise_ops.RegisterRewritePattern(gelu_pattern, gelu_reduce_pattern);
36+
map_gelu_to_pointwise_ops.runOnGraph(graph);
37+
38+
LOG_GRAPH("Post lowering of [aten::gelu] -> " << *graph);
39+
}
40+
41+
} // namespace passes
42+
} // namespace lowering
43+
} // namespace core
44+
} // namespace torch_tensorrt

core/partitioning/README.md

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,9 @@ To enable automatic fallback feature, you can set following attributes in Python
3434
ts_model = torch.jit.script(model)
3535
trt_model = torchtrt.ts.compile(model, **{
3636
...
37-
"torch_fallback" : {
38-
"enabled" : True,
39-
"min_block_size" : 3,
40-
"forced_fallback_ops": ["aten::add"],
41-
}
37+
"min_block_size" : 3,
38+
"torch_executed_ops": ["aten::add"],
39+
"torch_executed_modules": [],
4240
})
4341
```
4442
- `enabled`: By default automatic fallback will be off. It is enabled by setting it to True.
@@ -59,9 +57,8 @@ auto in = torch::randn({1, 3, 224, 224}, {torch::kCUDA});
5957
auto mod = torch::jit::load("trt_ts_module.ts");
6058
auto input_sizes = std::vector<torchtrt::InputRange>{{in.sizes()}};
6159
torchtrt::ts::CompileSpec cfg(input_sizes);
62-
cfg.torch_fallback = torchtrt::CompileSpec::TorchFallback(true);
63-
cfg.torch_fallback.min_block_size = 2;
64-
cfg.torch_fallback.forced_fallback_ops.push_back("aten::relu");
60+
cfg.min_block_size = 2;
61+
cfg.torch_executed_ops.push_back("aten::relu");
6562
auto trt_mod = torchtrt::ts::compile(mod, cfg);
6663
auto out = trt_mod.forward({in});
6764
```

0 commit comments

Comments
 (0)