@@ -20,40 +20,34 @@ auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
20
20
.pattern({
21
21
" aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))" ,
22
22
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
23
- std::cout << " select.int converter recognized" << std::endl;
24
-
25
23
auto in = args[0 ].ITensor ();
26
24
auto axis = args[1 ].unwrapToInt ();
27
25
auto ind = (int32_t ) args[2 ].unwrapToInt ();
28
26
29
- // tried: vector for input
30
- // std::vector<int32_t> indices_input = {ind};
31
-
32
- auto options = torch::TensorOptions ().device (torch::kCUDA , 1 ).dtype (torch::kInt32 );
33
- at::Tensor indices = torch::tensor (torch::detail::TensorDataContainer (ind), options);
34
-
27
+ // index to access needs to be an at::Tensor
28
+ at::Tensor indices = torch::tensor ({ind}).to (torch::kI32 );
35
29
auto weights = Weights (ctx, indices);
36
- // manually setting weights
37
- // weights.data.type = nvinfer1::DataType::kINT32;
38
30
31
+ // IConstantLayer to convert indices from Weights to ITensor
39
32
auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
40
- const_layer->setName (util::node_info (n).c_str ());
41
- // manually setting output type
42
- // const_layer->setOutputType(0, nvinfer1::DataType::kINT32);
43
-
44
- auto const_out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], const_layer->getOutput (0 ));
33
+ TRTORCH_CHECK (const_layer, " Unable to create constant layer from node: " << *n);
34
+ auto const_out = const_layer->getOutput (0 );
45
35
36
+ // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
46
37
auto gather_layer = ctx->net ->addGather (*in, *const_out, axis);
47
- gather_layer->setName (util::node_info (n).c_str ());
48
- // manually setting output type
49
- // gather_layer->setOutputType(0, nvinfer1::DataType::kINT32);
38
+ TRTORCH_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
39
+ auto gather_out = gather_layer->getOutput (0 );
50
40
51
- auto gather_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gather_layer->getOutput (0 ));
41
+ // IShuffleLayer removes redundant dimensions
42
+ auto shuffle_layer = ctx->net ->addShuffle (*gather_out);
43
+ TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
44
+ shuffle_layer->setReshapeDimensions (util::unpadDims (gather_out->getDimensions ()));
45
+ shuffle_layer->setName (util::node_info (n).c_str ());
46
+ auto shuffle_out = shuffle_layer->getOutput (0 );
52
47
53
- LOG_DEBUG (" Output tensor shape: " << gather_output->getDimensions ());
54
-
55
- // for debugging
56
- // std::raise(SIGTRAP);
48
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_out);
49
+
50
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
57
51
58
52
return true ;
59
53
}
0 commit comments