Skip to content

Commit 9ae0eb2

Browse files
committed
Support aten::bmm converter
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent aec4e1a commit 9ae0eb2

File tree

2 files changed

+81
-18
lines changed

2 files changed

+81
-18
lines changed

core/conversion/converters/impl/matrix_multiply.cpp

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,66 @@ namespace converters {
88
namespace impl {
99
namespace {
1010

11-
auto mm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
12-
{"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
13-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
14-
auto self = args[0].ITensorOrFreeze(ctx);
15-
LOG_DEBUG("self tensor shape: " << self->getDimensions());
16-
17-
auto other = args[1].ITensorOrFreeze(ctx);
18-
LOG_DEBUG("other tensor shape: " << other->getDimensions());
19-
20-
auto mm_layer = ctx->net->addMatrixMultiply(
21-
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
22-
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
23-
mm_layer->setName(util::node_info(n).c_str());
24-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
25-
26-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
27-
return true;
28-
}});
11+
auto mm_registrations TRTORCH_UNUSED =
12+
RegisterNodeConversionPatterns()
13+
.pattern({"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
14+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15+
auto self = args[0].ITensorOrFreeze(ctx);
16+
LOG_DEBUG("self tensor shape: " << self->getDimensions());
17+
18+
auto other = args[1].ITensorOrFreeze(ctx);
19+
LOG_DEBUG("other tensor shape: " << other->getDimensions());
20+
21+
auto mm_layer = ctx->net->addMatrixMultiply(
22+
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
23+
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
24+
mm_layer->setName(util::node_info(n).c_str());
25+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
26+
27+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
28+
return true;
29+
}})
30+
.pattern(
31+
{"aten::bmm(Tensor self, Tensor mat2) -> (Tensor)",
32+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
33+
auto self = args[0].ITensorOrFreeze(ctx);
34+
nvinfer1::Dims selfDims = self->getDimensions();
35+
auto mat2 = args[1].ITensorOrFreeze(ctx);
36+
nvinfer1::Dims mat2Dims = mat2->getDimensions();
37+
38+
// check dimensions
39+
TRTORCH_CHECK(
40+
selfDims.nbDims == 3,
41+
"Expected 3-dimensional tensor, but got "
42+
<< selfDims.nbDims
43+
<< "-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)");
44+
TRTORCH_CHECK(
45+
mat2Dims.nbDims == 3,
46+
"Expected 3-dimensional tensor, but got "
47+
<< mat2Dims.nbDims
48+
<< "-dimensional tensor for argument #2 'batch2' (while checking arguments for bmm)");
49+
50+
// Self and mat2 should have same size at dimension 0
51+
TRTORCH_CHECK(
52+
selfDims.d[0] == mat2Dims.d[0],
53+
"Expected tensor to have size " << selfDims.d[0] << " at dimension 0, but got size " << mat2Dims.d[0]
54+
<< " for argument #2 'batch2' (while checking arguments for bmm)");
55+
// The size of mat2 at dimension 1 should be the same as that of self at dimension 2.
56+
TRTORCH_CHECK(
57+
selfDims.d[2] == mat2Dims.d[1],
58+
"Expected tensor to have size " << selfDims.d[2] << " at dimension 1, but got size " << mat2Dims.d[1]
59+
<< " for argument #2 'batch2' (while checking arguments for bmm)");
60+
61+
auto mm_layer = ctx->net->addMatrixMultiply(
62+
*self, nvinfer1::MatrixOperation::kNONE, *mat2, nvinfer1::MatrixOperation::kNONE);
63+
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
64+
65+
mm_layer->setName(util::node_info(n).c_str());
66+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
67+
68+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
69+
return true;
70+
}});
2971
} // namespace
3072
} // namespace impl
3173
} // namespace converters

tests/core/conversion/converters/test_matrix_multiply.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,24 @@ TEST(Converters, ATenMMConvertsCorrectly) {
2525

2626
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
2727
}
28+
29+
TEST(Converters, ATenBMMConvertsCorrectly) {
30+
const auto graph = R"IR(
31+
graph(%0 : Tensor, %1 : Tensor):
32+
%2 : Tensor = aten::bmm(%0, %1)
33+
return (%2))IR";
34+
35+
auto g = std::make_shared<torch::jit::Graph>();
36+
torch::jit::parseIR(graph, g.get());
37+
38+
auto in1 = at::randint(0, 5, {4, 64, 128}, {at::kCUDA});
39+
auto in2 = at::randint(0, 5, {4, 128, 64}, {at::kCUDA});
40+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
41+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});
42+
43+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
44+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});
45+
auto trt = trt_results[0].reshape_as(jit_results[0]);
46+
47+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
48+
}

0 commit comments

Comments
 (0)