@@ -92,8 +92,40 @@ static auto shuffle_registrations TRTORCH_UNUSED =
92
92
auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle->getOutput (0 ));
93
93
LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
94
94
95
+ return true ;
96
+ }})
97
+ .pattern({" aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> (Tensor(a))" ,
98
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
99
+ auto in = args[0 ].ITensorOrFreeze (ctx);
100
+ auto in_shape = util::toVec (in->getDimensions ());
101
+ auto ndims = in_shape.size ();
102
+ auto dim0 = args[1 ].unwrapToInt ();
103
+ auto dim1 = args[2 ].unwrapToInt ();
104
+
105
+ std::vector<int64_t > new_order;
106
+ for (size_t i = 0 ; i < ndims; i++) {
107
+ new_order.push_back (i);
108
+ }
109
+ auto tmp = dim0;
110
+ new_order[dim0] = new_order[dim1];
111
+ new_order[dim1] = tmp;
112
+
113
+ LOG_DEBUG (" Shuffle to: " << util::toDims (new_order));
114
+
115
+ auto shuffle = ctx->net ->addShuffle (*in);
116
+ TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
117
+ nvinfer1::Permutation permute;
118
+ std::copy (new_order.begin (), new_order.end (), permute.order );
119
+
120
+ shuffle->setSecondTranspose (permute);
121
+ shuffle->setName (util::node_info (n).c_str ());
122
+
123
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle->getOutput (0 ));
124
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
125
+
95
126
return true ;
96
127
}});
128
+
97
129
} // namespace
98
130
} // namespace impl
99
131
} // namespace converters
0 commit comments