Skip to content

Commit 5151c34

Browse files
committed
feat(//core/conversion/converters/impl): select converter, which adds support for aten::select.int
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 164a1a6 commit 5151c34

File tree

1 file changed

+17
-23
lines changed

1 file changed

+17
-23
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,40 +20,34 @@ auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
2020
.pattern({
2121
"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))",
2222
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
23-
std::cout << "select.int converter recognized" << std::endl;
24-
2523
auto in = args[0].ITensor();
2624
auto axis = args[1].unwrapToInt();
2725
auto ind = (int32_t) args[2].unwrapToInt();
2826

29-
// tried: vector for input
30-
//std::vector<int32_t> indices_input = {ind};
31-
32-
auto options = torch::TensorOptions().device(torch::kCUDA, 1).dtype(torch::kInt32);
33-
at::Tensor indices = torch::tensor(torch::detail::TensorDataContainer(ind), options);
34-
27+
// index to access needs to be an at::Tensor
28+
at::Tensor indices = torch::tensor({ind}).to(torch::kI32);
3529
auto weights = Weights(ctx, indices);
36-
// manually setting weights
37-
// weights.data.type = nvinfer1::DataType::kINT32;
3830

31+
// IConstantLayer to convert indices from Weights to ITensor
3932
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
40-
const_layer->setName(util::node_info(n).c_str());
41-
// manually setting output type
42-
// const_layer->setOutputType(0, nvinfer1::DataType::kINT32);
43-
44-
auto const_out = ctx->AssociateValueAndTensor(n->outputs()[0], const_layer->getOutput(0));
33+
TRTORCH_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
34+
auto const_out = const_layer->getOutput(0);
4535

36+
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
4637
auto gather_layer = ctx->net->addGather(*in, *const_out, axis);
47-
gather_layer->setName(util::node_info(n).c_str());
48-
// manually setting output type
49-
// gather_layer->setOutputType(0, nvinfer1::DataType::kINT32);
38+
TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
39+
auto gather_out = gather_layer->getOutput(0);
5040

51-
auto gather_output = ctx->AssociateValueAndTensor(n->outputs()[0], gather_layer->getOutput(0));
41+
// IShuffleLayer removes redundant dimensions
42+
auto shuffle_layer = ctx->net->addShuffle(*gather_out);
43+
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
44+
shuffle_layer->setReshapeDimensions(util::unpadDims(gather_out->getDimensions()));
45+
shuffle_layer->setName(util::node_info(n).c_str());
46+
auto shuffle_out = shuffle_layer->getOutput(0);
5247

53-
LOG_DEBUG("Output tensor shape: " << gather_output->getDimensions());
54-
55-
// for debugging
56-
// std::raise(SIGTRAP);
48+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_out);
49+
50+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
5751

5852
return true;
5953
}

0 commit comments

Comments
 (0)