@@ -46,6 +46,77 @@ auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
46
46
47
47
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
48
48
49
+ return true ;
50
+ }
51
+ }).pattern({
52
+ " aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)" ,
53
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
54
+ auto in = args[0 ].ITensor ();
55
+ auto axis = args[1 ].unwrapToInt ();
56
+ auto start = (int32_t ) args[2 ].unwrapToInt ();
57
+ auto length = (int32_t ) args[3 ].unwrapToInt ();
58
+
59
+ // index to access needs to be an at::Tensor
60
+ at::Tensor indices = torch::arange (start, start + length, 1 ).to (torch::kI32 );
61
+ auto weights = Weights (ctx, indices);
62
+
63
+ // IConstantLayer to convert indices from Weights to ITensor
64
+ auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
65
+ TRTORCH_CHECK (const_layer, " Unable to create constant layer from node: " << *n);
66
+ auto const_out = const_layer->getOutput (0 );
67
+
68
+ // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
69
+ auto gather_layer = ctx->net ->addGather (*in, *const_out, axis);
70
+ TRTORCH_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
71
+ auto gather_out = gather_layer->getOutput (0 );
72
+
73
+ // IShuffleLayer removes redundant dimensions
74
+ auto shuffle_layer = ctx->net ->addShuffle (*gather_out);
75
+ TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
76
+ shuffle_layer->setReshapeDimensions (util::unpadDims (gather_out->getDimensions ()));
77
+ shuffle_layer->setName (util::node_info (n).c_str ());
78
+ auto shuffle_out = shuffle_layer->getOutput (0 );
79
+
80
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_out);
81
+
82
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
83
+
84
+ return true ;
85
+ }
86
+ }).pattern({
87
+ " aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)" ,
88
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
89
+ auto in = args[0 ].ITensor ();
90
+ auto axis = args[1 ].unwrapToInt ();
91
+ torch::Tensor start = args[2 ].IValue ()->toTensor ().to (torch::kI32 );
92
+ int32_t startIdx = start.item ().to <int32_t >();
93
+ auto length = (int32_t ) args[3 ].unwrapToInt ();
94
+
95
+ // index to access needs to be an at::Tensor
96
+ at::Tensor indices = torch::arange (startIdx, startIdx + length, 1 ).to (torch::kI32 );
97
+ auto weights = Weights (ctx, indices);
98
+
99
+ // IConstantLayer to convert indices from Weights to ITensor
100
+ auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
101
+ TRTORCH_CHECK (const_layer, " Unable to create constant layer from node: " << *n);
102
+ auto const_out = const_layer->getOutput (0 );
103
+
104
+ // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
105
+ auto gather_layer = ctx->net ->addGather (*in, *const_out, axis);
106
+ TRTORCH_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
107
+ auto gather_out = gather_layer->getOutput (0 );
108
+
109
+ // IShuffleLayer removes redundant dimensions
110
+ auto shuffle_layer = ctx->net ->addShuffle (*gather_out);
111
+ TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
112
+ shuffle_layer->setReshapeDimensions (util::unpadDims (gather_out->getDimensions ()));
113
+ shuffle_layer->setName (util::node_info (n).c_str ());
114
+ auto shuffle_out = shuffle_layer->getOutput (0 );
115
+
116
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_out);
117
+
118
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
119
+
49
120
return true ;
50
121
}
51
122
});
0 commit comments