Skip to content

Commit 543e2a5

Browse files
committed
Merge branch 'master' into peri044/dtype_layout
2 parents bde8ee0 + b652045 commit 543e2a5

40 files changed

+2613
-1671
lines changed

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
77
Torch-TensorRT is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via NVIDIA's TensorRT Deep Learning Optimizer and Runtime. Unlike PyTorch's Just-In-Time (JIT) compiler, Torch-TensorRT is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript program into an module targeting a TensorRT engine. Torch-TensorRT operates as a PyTorch extention and compiles modules that integrate into the JIT runtime seamlessly. After compilation using the optimized graph should feel no different than running a TorchScript module. You also have access to TensorRT's suite of configurations at compile time, so you are able to specify operating precision (FP32/FP16/INT8) and other settings for your module.
88

9-
More Information / System Architecture:
10-
11-
- [GTC 2020 Talk](https://developer.nvidia.com/gtc/2020/video/s21671)
12-
9+
Resources:
10+
- [Documentation](https://nvidia.github.io/Torch-TensorRT/)
11+
- [Torch-TensorRT Explained in 2 minutes!](https://www.youtube.com/watch?v=TU5BMU6iYZ0&ab_channel=NVIDIADeveloper)
12+
- [Comprehensive Discusion (GTC Event)](https://www.nvidia.com/en-us/on-demand/session/gtcfall21-a31107/)
13+
- [Pre-built Docker Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). To use this container, make an NGC account and sign in to NVIDIA's registry with an API key. Refer to [this guide](https://docs.nvidia.com/ngc/ngc-catalog-user-guide/index.html#registering-activating-ngc-account) for the same.
1314

1415

1516
## Building a docker container for Torch-TensorRT

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4545
passes::ReduceToOperation(g);
4646
passes::ReduceGelu(g);
4747
passes::RemoveContiguous(g);
48+
passes::ViewToReshape(g);
4849
passes::RemoveDropout(g);
4950
passes::LinearToAddMM(g);
5051
passes::Conv1DToConvolution(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ cc_library(
2020
"reduce_gelu.cpp",
2121
"remove_bn_dim_check.cpp",
2222
"remove_contiguous.cpp",
23+
"view_to_reshape.cpp",
2324
"remove_dropout.cpp",
2425
"remove_nops.cpp",
2526
"silu_to_sigmoid_multiplication.cpp",

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);
2424
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims);
2525
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
2626
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
27+
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph);
2728
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
2829
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
2930
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
#include "core/util/prelude.h"
3+
4+
namespace torch_tensorrt {
5+
namespace core {
6+
namespace lowering {
7+
namespace passes {
8+
9+
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph) {
10+
std::string view_pattern = R"IR(
11+
graph(%x, %1):
12+
%out : Tensor = aten::view(%x, %1)
13+
return (%out))IR";
14+
15+
std::string reshape_pattern = R"IR(
16+
graph(%x, %1):
17+
%out : Tensor = aten::reshape(%x, %1)
18+
return (%out))IR";
19+
20+
// replace aten::view with aten::reshape
21+
torch::jit::SubgraphRewriter map_view_to_reshape;
22+
map_view_to_reshape.RegisterRewritePattern(view_pattern, reshape_pattern);
23+
map_view_to_reshape.runOnGraph(graph);
24+
25+
LOG_GRAPH("Post lowering of aten::view -> " << *graph);
26+
}
27+
28+
} // namespace passes
29+
} // namespace lowering
30+
} // namespace core
31+
} // namespace torch_tensorrt

core/partitioning/partitioning.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
176176
// if the segment that produce this nonTensor value is kTensorRT but consumed in kTorch, inject nodes in the first
177177
// kTorch segment.
178178
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
179-
auto first_torch_id = use_info.torch_use_id.front();
179+
auto first_torch_id = use_info.torch_use_id.back();
180180
if (!updated_segments.count(first_torch_id)) {
181181
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
182182
// Torch-TensorRT doesn't support non-tensor inputs for a module.
0 Bytes
Loading
97.5 KB
Loading
0 Bytes
Loading
0 Bytes
Loading

0 commit comments

Comments
 (0)