@@ -8,24 +8,66 @@ namespace converters {
8
8
namespace impl {
9
9
namespace {
10
10
11
- auto mm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
12
- {" aten::matmul(Tensor self, Tensor other) -> (Tensor)" ,
13
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
14
- auto self = args[0 ].ITensorOrFreeze (ctx);
15
- LOG_DEBUG (" self tensor shape: " << self->getDimensions ());
16
-
17
- auto other = args[1 ].ITensorOrFreeze (ctx);
18
- LOG_DEBUG (" other tensor shape: " << other->getDimensions ());
19
-
20
- auto mm_layer = ctx->net ->addMatrixMultiply (
21
- *self, nvinfer1::MatrixOperation::kNONE , *other, nvinfer1::MatrixOperation::kNONE );
22
- TRTORCH_CHECK (mm_layer, " Unable to create matrix multiplication node: " << *n);
23
- mm_layer->setName (util::node_info (n).c_str ());
24
- auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mm_layer->getOutput (0 ));
25
-
26
- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
27
- return true ;
28
- }});
11
+ auto mm_registrations TRTORCH_UNUSED =
12
+ RegisterNodeConversionPatterns ()
13
+ .pattern({" aten::matmul(Tensor self, Tensor other) -> (Tensor)" ,
14
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15
+ auto self = args[0 ].ITensorOrFreeze (ctx);
16
+ LOG_DEBUG (" self tensor shape: " << self->getDimensions ());
17
+
18
+ auto other = args[1 ].ITensorOrFreeze (ctx);
19
+ LOG_DEBUG (" other tensor shape: " << other->getDimensions ());
20
+
21
+ auto mm_layer = ctx->net ->addMatrixMultiply (
22
+ *self, nvinfer1::MatrixOperation::kNONE , *other, nvinfer1::MatrixOperation::kNONE );
23
+ TRTORCH_CHECK (mm_layer, " Unable to create matrix multiplication node: " << *n);
24
+ mm_layer->setName (util::node_info (n).c_str ());
25
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mm_layer->getOutput (0 ));
26
+
27
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
28
+ return true ;
29
+ }})
30
+ .pattern(
31
+ {" aten::bmm(Tensor self, Tensor mat2) -> (Tensor)" ,
32
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
33
+ auto self = args[0 ].ITensorOrFreeze (ctx);
34
+ nvinfer1::Dims selfDims = self->getDimensions ();
35
+ auto mat2 = args[1 ].ITensorOrFreeze (ctx);
36
+ nvinfer1::Dims mat2Dims = mat2->getDimensions ();
37
+
38
+ // check dimensions
39
+ TRTORCH_CHECK (
40
+ selfDims.nbDims == 3 ,
41
+ " Expected 3-dimensional tensor, but got "
42
+ << selfDims.nbDims
43
+ << " -dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)" );
44
+ TRTORCH_CHECK (
45
+ mat2Dims.nbDims == 3 ,
46
+ " Expected 3-dimensional tensor, but got "
47
+ << mat2Dims.nbDims
48
+ << " -dimensional tensor for argument #2 'batch2' (while checking arguments for bmm)" );
49
+
50
+ // Self and mat2 should have same size at dimension 0
51
+ TRTORCH_CHECK (
52
+ selfDims.d [0 ] == mat2Dims.d [0 ],
53
+ " Expected tensor to have size " << selfDims.d [0 ] << " at dimension 0, but got size " << mat2Dims.d [0 ]
54
+ << " for argument #2 'batch2' (while checking arguments for bmm)" );
55
+ // The size of mat2 at dimension 1 should be the same as that of self at dimension 2.
56
+ TRTORCH_CHECK (
57
+ selfDims.d [2 ] == mat2Dims.d [1 ],
58
+ " Expected tensor to have size " << selfDims.d [2 ] << " at dimension 1, but got size " << mat2Dims.d [1 ]
59
+ << " for argument #2 'batch2' (while checking arguments for bmm)" );
60
+
61
+ auto mm_layer = ctx->net ->addMatrixMultiply (
62
+ *self, nvinfer1::MatrixOperation::kNONE , *mat2, nvinfer1::MatrixOperation::kNONE );
63
+ TRTORCH_CHECK (mm_layer, " Unable to create matrix multiplication node: " << *n);
64
+
65
+ mm_layer->setName (util::node_info (n).c_str ());
66
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mm_layer->getOutput (0 ));
67
+
68
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
69
+ return true ;
70
+ }});
29
71
} // namespace
30
72
} // namespace impl
31
73
} // namespace converters
0 commit comments