Skip to content

Commit e499ed1

Browse files
committed
Merge branch 'main' into aten_size_fix
2 parents adc29b0 + e9da9b0 commit e499ed1

File tree

112 files changed

+350
-252
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+350
-252
lines changed

core/lowering/passes/convNd_to_convolution.cpp

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
#include "torch/csrc/jit/ir/irparser.h"
23

34
#include "core/util/prelude.h"
45

@@ -7,78 +8,91 @@ namespace core {
78
namespace lowering {
89
namespace passes {
910

10-
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
11-
std::string conv1d_pattern = R"IR(
12-
graph(%x, %w, %b, %s, %p, %d, %g):
13-
%4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
14-
return (%4))IR";
11+
void replaceConv(
12+
torch::jit::Block* block,
13+
const std::string& node_kind,
14+
const std::string& unwrapped_conv,
15+
const size_t num_input_args) {
16+
// Iterate through nodes in block, seaching for aten::conv*
17+
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
18+
auto n = *it;
19+
20+
// Recursively explore nested blocks, such as those arising from prim::If
21+
for (auto nested_block : n->blocks()) {
22+
replaceConv(nested_block, node_kind, unwrapped_conv, num_input_args);
23+
}
24+
25+
// If node matches desired kind and number of input arguments, replace it
26+
if ((n->kind().toQualString() == node_kind) && (n->inputs().size() == num_input_args)) {
27+
// Establish insert point within block
28+
torch::jit::WithInsertPoint guard(*it);
29+
30+
// Initialize new fused subgraph from IR code provided
31+
auto fused_g = std::make_shared<torch::jit::Graph>();
32+
torch::jit::parseIR(unwrapped_conv, fused_g.get());
33+
34+
// Insert subgraph in place of aten::conv*, replacing inputs and outputs accordingly
35+
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *fused_g, it->inputs()).at(0);
36+
new_output->setType(it->output()->type());
37+
it->output()->replaceAllUsesWith(new_output);
38+
it.destroyCurrent();
39+
}
40+
}
41+
}
1542

16-
std::string convolution_pattern = R"IR(
43+
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
44+
const std::string conv1d_node_kind = "aten::conv1d";
45+
const std::string convolution_pattern = R"IR(
1746
graph(%x, %w, %b, %s, %p, %d, %g):
1847
%1 : bool = prim::Constant[value=0]()
1948
%2 : int[] = prim::Constant[value=[0]]()
2049
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
2150
return (%4))IR";
2251

23-
torch::jit::SubgraphRewriter map_conv1d_to_convolution;
24-
map_conv1d_to_convolution.RegisterRewritePattern(conv1d_pattern, convolution_pattern);
25-
map_conv1d_to_convolution.runOnGraph(graph);
52+
// Schema is aten::conv1d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
53+
replaceConv(graph->block(), conv1d_node_kind, convolution_pattern, 7);
2654
LOG_GRAPH("Post map conv1d -> _convolution: " << *graph);
2755
}
2856

2957
void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
30-
std::string conv_transpose1d_pattern = R"IR(
31-
graph(%x, %w, %b, %s, %p, %o, %g, %d):
32-
%4 : Tensor = aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d)
33-
return (%4))IR";
34-
std::string convolution_pattern = R"IR(
58+
const std::string conv_transpose1d_node_kind = "aten::conv_transpose1d";
59+
const std::string convolution_pattern = R"IR(
3560
graph(%x, %w, %b, %s, %p, %o, %g, %d):
3661
%1 : bool = prim::Constant[value=1]()
3762
%2 : bool = prim::Constant[value=1]()
3863
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
3964
return (%4))IR";
4065

41-
torch::jit::SubgraphRewriter map_conv_transpose1d_to_convolution;
42-
map_conv_transpose1d_to_convolution.RegisterRewritePattern(conv_transpose1d_pattern, convolution_pattern);
43-
map_conv_transpose1d_to_convolution.runOnGraph(graph);
66+
// Schema is aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
67+
replaceConv(graph->block(), conv_transpose1d_node_kind, convolution_pattern, 8);
4468
LOG_GRAPH("Post map conv_transpose1d -> _convolution: " << *graph);
4569
}
4670

4771
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
48-
std::string conv2d_pattern = R"IR(
49-
graph(%x, %w, %b, %s, %p, %d, %g):
50-
%4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
51-
return (%4))IR";
52-
std::string convolution_pattern = R"IR(
72+
const std::string conv2d_node_kind = "aten::conv2d";
73+
const std::string convolution_pattern = R"IR(
5374
graph(%x, %w, %b, %s, %p, %d, %g):
5475
%1 : bool = prim::Constant[value=0]()
5576
%2 : int[] = prim::Constant[value=[0, 0]]()
5677
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
5778
return (%4))IR";
5879

59-
// replace matmul + add pattern to linear
60-
torch::jit::SubgraphRewriter map_conv2d_to_convolution;
61-
map_conv2d_to_convolution.RegisterRewritePattern(conv2d_pattern, convolution_pattern);
62-
map_conv2d_to_convolution.runOnGraph(graph);
80+
// Schema is aten::conv2d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
81+
replaceConv(graph->block(), conv2d_node_kind, convolution_pattern, 7);
6382
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
6483
}
6584

6685
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
67-
std::string conv3d_pattern = R"IR(
68-
graph(%x, %w, %b, %s, %p, %d, %g):
69-
%4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
70-
return (%4))IR";
71-
std::string convolution_pattern = R"IR(
86+
const std::string conv3d_node_kind = "aten::conv3d";
87+
const std::string convolution_pattern = R"IR(
7288
graph(%x, %w, %b, %s, %p, %d, %g):
7389
%1 : bool = prim::Constant[value=0]()
7490
%2 : int[] = prim::Constant[value=[0, 0, 0]]()
7591
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
7692
return (%4))IR";
7793

78-
// replace matmul + add pattern to linear
79-
torch::jit::SubgraphRewriter map_conv3d_to_convolution;
80-
map_conv3d_to_convolution.RegisterRewritePattern(conv3d_pattern, convolution_pattern);
81-
map_conv3d_to_convolution.runOnGraph(graph);
94+
// Schema is aten::conv3d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
95+
replaceConv(graph->block(), conv3d_node_kind, convolution_pattern, 7);
8296
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
8397
}
8498

docs/_cpp_api/classtorch__tensorrt_1_1DataType.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class DataType &mdash; Torch-TensorRT v1.4.0.dev0+b388010 documentation</title>
13+
<title>Class DataType &mdash; Torch-TensorRT v1.4.0.dev0+2ea9f00 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0.dev0+b388010
218+
v1.4.0.dev0+2ea9f00
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1Device_1_1DeviceType.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class Device::DeviceType &mdash; Torch-TensorRT v1.4.0.dev0+b388010 documentation</title>
13+
<title>Class Device::DeviceType &mdash; Torch-TensorRT v1.4.0.dev0+2ea9f00 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0.dev0+b388010
218+
v1.4.0.dev0+2ea9f00
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1TensorFormat.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class TensorFormat &mdash; Torch-TensorRT v1.4.0.dev0+b388010 documentation</title>
13+
<title>Class TensorFormat &mdash; Torch-TensorRT v1.4.0.dev0+2ea9f00 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0.dev0+b388010
218+
v1.4.0.dev0+2ea9f00
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8CacheCalibrator.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Template Class Int8CacheCalibrator &mdash; Torch-TensorRT v1.4.0.dev0+b388010 documentation</title>
13+
<title>Template Class Int8CacheCalibrator &mdash; Torch-TensorRT v1.4.0.dev0+2ea9f00 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0.dev0+b388010
218+
v1.4.0.dev0+2ea9f00
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8Calibrator.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Template Class Int8Calibrator &mdash; Torch-TensorRT v1.4.0.dev0+b388010 documentation</title>
13+
<title>Template Class Int8Calibrator &mdash; Torch-TensorRT v1.4.0.dev0+2ea9f00 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0.dev0+b388010
218+
v1.4.0.dev0+2ea9f00
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a18d295a837ac71add5578860b55e5502.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define STR &mdash; Torch-TensorRT v1.4.0.dev0+b388010 documentation</title>
13+
<title>Define STR &mdash; Torch-TensorRT v1.4.0.dev0+2ea9f00 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0.dev0+b388010
218+
v1.4.0.dev0+2ea9f00
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a282fd3c0b1c3a215148ae372070e1268.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCH_TENSORRT_PATCH_VERSION &mdash; Torch-TensorRT v1.4.0.dev0+b388010 documentation</title>
13+
<title>Define TORCH_TENSORRT_PATCH_VERSION &mdash; Torch-TensorRT v1.4.0.dev0+2ea9f00 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0.dev0+b388010
218+
v1.4.0.dev0+2ea9f00
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a31398a6d4d27e28817afb0f0139e909e.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCH_TENSORRT_MAJOR_VERSION &mdash; Torch-TensorRT v1.4.0.dev0+b388010 documentation</title>
13+
<title>Define TORCH_TENSORRT_MAJOR_VERSION &mdash; Torch-TensorRT v1.4.0.dev0+2ea9f00 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0.dev0+b388010
218+
v1.4.0.dev0+2ea9f00
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a35703561b26b1a9d2738ad7d58b27827.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCH_TENSORRT_MINOR_VERSION &mdash; Torch-TensorRT v1.4.0.dev0+b388010 documentation</title>
13+
<title>Define TORCH_TENSORRT_MINOR_VERSION &mdash; Torch-TensorRT v1.4.0.dev0+2ea9f00 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0.dev0+b388010
218+
v1.4.0.dev0+2ea9f00
219219
</div>
220220

221221

0 commit comments

Comments
 (0)