@@ -46,6 +46,77 @@ auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
4646
4747 LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
4848
49+ return true ;
50+ }
51+ }).pattern({
52+ " aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)" ,
53+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
54+ auto in = args[0 ].ITensor ();
55+ auto axis = args[1 ].unwrapToInt ();
56+ auto start = (int32_t ) args[2 ].unwrapToInt ();
57+ auto length = (int32_t ) args[3 ].unwrapToInt ();
58+
59+ // index to access needs to be an at::Tensor
60+ at::Tensor indices = torch::arange (start, start + length, 1 ).to (torch::kI32 );
61+ auto weights = Weights (ctx, indices);
62+
63+ // IConstantLayer to convert indices from Weights to ITensor
64+ auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
65+ TRTORCH_CHECK (const_layer, " Unable to create constant layer from node: " << *n);
66+ auto const_out = const_layer->getOutput (0 );
67+
68+ // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
69+ auto gather_layer = ctx->net ->addGather (*in, *const_out, axis);
70+ TRTORCH_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
71+ auto gather_out = gather_layer->getOutput (0 );
72+
73+ // IShuffleLayer removes redundant dimensions
74+ auto shuffle_layer = ctx->net ->addShuffle (*gather_out);
75+ TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
76+ shuffle_layer->setReshapeDimensions (util::unpadDims (gather_out->getDimensions ()));
77+ shuffle_layer->setName (util::node_info (n).c_str ());
78+ auto shuffle_out = shuffle_layer->getOutput (0 );
79+
80+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_out);
81+
82+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
83+
84+ return true ;
85+ }
86+ }).pattern({
87+ " aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)" ,
88+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
89+ auto in = args[0 ].ITensor ();
90+ auto axis = args[1 ].unwrapToInt ();
91+ torch::Tensor start = args[2 ].IValue ()->toTensor ().to (torch::kI32 );
92+ int32_t startIdx = start.item ().to <int32_t >();
93+ auto length = (int32_t ) args[3 ].unwrapToInt ();
94+
95+ // index to access needs to be an at::Tensor
96+ at::Tensor indices = torch::arange (startIdx, startIdx + length, 1 ).to (torch::kI32 );
97+ auto weights = Weights (ctx, indices);
98+
99+ // IConstantLayer to convert indices from Weights to ITensor
100+ auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
101+ TRTORCH_CHECK (const_layer, " Unable to create constant layer from node: " << *n);
102+ auto const_out = const_layer->getOutput (0 );
103+
104+ // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
105+ auto gather_layer = ctx->net ->addGather (*in, *const_out, axis);
106+ TRTORCH_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
107+ auto gather_out = gather_layer->getOutput (0 );
108+
109+ // IShuffleLayer removes redundant dimensions
110+ auto shuffle_layer = ctx->net ->addShuffle (*gather_out);
111+ TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
112+ shuffle_layer->setReshapeDimensions (util::unpadDims (gather_out->getDimensions ()));
113+ shuffle_layer->setName (util::node_info (n).c_str ());
114+ auto shuffle_out = shuffle_layer->getOutput (0 );
115+
116+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_out);
117+
118+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
119+
49120 return true ;
50121 }
51122 });
0 commit comments