Skip to content

Commit c5b6202

Browse files
committed
feat(aten::matmul|aten::addmm): Adds support for aten::matmul and
aten::admm Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent d945eb9 commit c5b6202

File tree

17 files changed

+197
-95
lines changed

17 files changed

+197
-95
lines changed

core/conversion/conversion.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
7373
LOG_DEBUG(ctx->logger, "Node input is a value that needs to be evaluated");
7474
auto eval = EvaluateNode(ctx, input_node);
7575
if (eval) {
76-
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
76+
if (!eval.value().isTensor()) {
77+
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
78+
} else {
79+
LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
80+
}
7781
ctx->evaluated_value_map[input] = std::move(eval.value());
7882
node_args.push_back(&(ctx->evaluated_value_map[input]));
7983
} else {

core/conversion/converters/Arg.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ std::string Arg::type_name() const {
8989
}
9090

9191
const torch::jit::IValue* Arg::IValue() const {
92+
TRTORCH_CHECK(isIValue(), "Requested IValue from Arg, however arg type is " << type_name());
9293
if (type_ == Type::kIValue) {
9394
return ptr_.ivalue;
9495
} else {
@@ -97,6 +98,7 @@ const torch::jit::IValue* Arg::IValue() const {
9798
}
9899

99100
nvinfer1::ITensor* Arg::ITensor() const {
101+
TRTORCH_CHECK(isITensor(), "Requested ITensor from Arg, however arg type is " << type_name());
100102
if (type_ == Type::kITensor) {
101103
return ptr_.tensor;
102104
} else {

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
"impl/conv_deconv.cpp",
1616
"impl/element_wise.cpp",
1717
"impl/linear.cpp",
18+
"impl/matrix_multiply.cpp",
1819
"impl/pooling.cpp",
1920
"impl/reduce.cpp",
2021
"impl/shuffle.cpp",

core/conversion/converters/impl/element_wise.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera
1414

1515
TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims), "Found inputs to elementwise operation do not have the same number of elements:\n Found: self " << self_dims << " other " << other_dims);
1616

17+
if (self_dims != other_dims) {
18+
LOG_DEBUG("Input shape dont match inserting shuffle layers to reshape to " << self_dims);
19+
auto other_shuffle = ctx->net->addShuffle(*other);
20+
other_shuffle->setReshapeDimensions(self_dims);
21+
other_shuffle->setName(std::string("[Reshape other to " + util::toStr(self_dims) + ']').c_str());
22+
other = other_shuffle->getOutput(0);
23+
}
24+
25+
1726
nvinfer1::ILayer* ele;
1827
if (scalar != 1) {
1928
LOG_WARNING("Please verify scalar handling in add converter, channel axis set to 3 but scaling is uniform");

core/conversion/converters/impl/linear.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@ namespace impl {
99
namespace {
1010

1111
auto linear_registrations = RegisterNodeConversionPatterns()
12-
// .pattern({
13-
// "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)",
14-
// [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> {
15-
// auto in = args[0].ITensor();
16-
17-
// }
18-
// })
1912
.pattern({
2013
"aten::linear(Tensor input, Tensor weight, Tensor? bias = None) -> (Tensor)",
2114
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -71,4 +64,4 @@ auto linear_registrations = RegisterNodeConversionPatterns()
7164
} // namespace converters
7265
} // namespace conversion
7366
} // namespace core
74-
} // trtorch
67+
} // namespace trtorch
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "core/util/prelude.h"
2+
#include "core/conversion/converters/converters.h"
3+
4+
namespace trtorch {
5+
namespace core {
6+
namespace conversion {
7+
namespace converters {
8+
namespace impl {
9+
namespace {
10+
11+
auto mm_registrations = RegisterNodeConversionPatterns()
12+
.pattern({
13+
"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
14+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15+
nvinfer1::ITensor* self;
16+
if (args[0].isIValue()) {
17+
auto t = args[0].unwrapToTensor();
18+
auto t_weights = Weights(ctx, t);
19+
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
20+
TRTORCH_CHECK(const_layer, "Unable to freeze tensor self for node: " << *n);
21+
const_layer->setName((util::node_info(n) + " [Freeze Tensor(self)]").c_str());
22+
self = const_layer->getOutput(0);
23+
} else {
24+
self = args[0].ITensor();
25+
}
26+
LOG_DEBUG("self tensor shape: " << self->getDimensions());
27+
28+
nvinfer1::ITensor* other;
29+
if (args[1].isIValue()) {
30+
auto t = args[1].unwrapToTensor();
31+
auto t_weights = Weights(ctx, t);
32+
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
33+
TRTORCH_CHECK(const_layer, "Unable to freeze tensor other for node: " << *n);
34+
const_layer->setName((util::node_info(n) + " [Freeze Tensor(other)]").c_str());
35+
other = const_layer->getOutput(0);
36+
} else {
37+
other = args[1].ITensor();
38+
}
39+
LOG_DEBUG("other tensor shape: " << other->getDimensions());
40+
41+
auto mm_layer = ctx->net->addMatrixMultiply(*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
42+
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
43+
mm_layer->setName(util::node_info(n).c_str());
44+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
45+
46+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
47+
return true;
48+
}
49+
});
50+
} // namespace
51+
} // namespace impl
52+
} // namespace converters
53+
} // namespace conversion
54+
} // namespace core
55+
} // namespace trtorch

core/lowering/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ cc_library(
88
srcs = [
99
"lowering.cpp",
1010
"drop_unused_nodes.cpp",
11+
"register_const_op.cpp"
1112
],
1213
deps = [
1314
"@libtorch//:libtorch",
1415
"//core/lowering/passes",
1516
"//core/util:prelude"
16-
]
17+
],
18+
alwayslink = True
1719
)
1820

1921
load("@rules_pkg//:pkg.bzl", "pkg_tar")

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
2525
torch::jit::FuseLinear(g);
2626
passes::RemoveDropout(g);
2727
passes::FuseFlattenLinear(g);
28+
passes::UnpackAddMM(g);
2829
passes::ExpandLogSoftmax(g);
2930
//passes::RemoveDimExeception(g);
3031
//irfusers::UnpackBatchNorm(g);

core/lowering/passes/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ cc_library(
1010
"expand_log_softmax.cpp",
1111
"remove_dropout.cpp",
1212
"unpack_batch_norm.cpp",
13-
"exception_elimination.cpp"
13+
"exception_elimination.cpp",
14+
"unpack_addmm.cpp"
1415
],
1516
deps = [
1617
"//core/util:prelude",

core/lowering/passes/fuse_flatten_linear.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -40,39 +40,6 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
4040
flatten_linear_bias_none_to_linear.runOnGraph(graph);
4141
}
4242

43-
void FuseFlattenAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
44-
//TensorRT implicitly adds a flatten layer infront of FC layers if necessary
45-
std::string flatten_linear_pattern = R"IR(
46-
graph(%input, %6, %7, %weight, %bias):
47-
%flat = aten::flatten(%input, %6, %7)
48-
%res = aten::linear(%flat, %weight, %bias)
49-
return (%res))IR";
50-
std::string flatten_linear_bias_none_pattern = R"IR(
51-
graph(%input, %6, %7, %weight):
52-
%flat = aten::flatten(%input, %6, %7)
53-
%bias: Tensor? = prim::Constant()
54-
%res = aten::linear(%flat, %weight, %bias)
55-
return (%res))IR";
56-
std::string fused_linear = R"IR(
57-
graph(%input, %6, %7, %weight, %bias):
58-
%res = aten::linear(%input, %weight, %bias)
59-
return (%res))IR";
60-
61-
std::string fused_linear_bias_none = R"IR(
62-
graph(%input, %6, %7, %weight):
63-
%bias: Tensor? = prim::Constant()
64-
%res = aten::linear(%input, %weight, %bias)
65-
return (%res))IR";
66-
67-
torch::jit::SubgraphRewriter flatten_linear_to_linear;
68-
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
69-
flatten_linear_to_linear.runOnGraph(graph);
70-
71-
torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
72-
flatten_linear_bias_none_to_linear.RegisterRewritePattern(
73-
flatten_linear_bias_none_pattern, fused_linear_bias_none);
74-
flatten_linear_bias_none_to_linear.runOnGraph(graph);
75-
}
7643
} // namespace passes
7744
} // namespace lowering
7845
} // namespace core

0 commit comments

Comments
 (0)