Skip to content

Commit d6f551e

Browse files
committed
refactor: refactor aten::slice converter
Signed-off-by: inocsin <[email protected]>
1 parent 88717f2 commit d6f551e

File tree

3 files changed

+81
-50
lines changed

3 files changed

+81
-50
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -199,67 +199,108 @@ nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::
199199
return out;
200200
}
201201

202-
nvinfer1::ITensor* toITensor(ConversionCtx* ctx, const torch::jit::Node* n, at::Tensor* input) {
203-
204-
auto weights = Weights(ctx, *input);
205-
// IConstantLayer to convert indices from Weights to ITensor
206-
auto const_layer = ctx->net->addConstant(weights.shape, weights.data); // shouln't use constant
207-
TORCHTRT_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
208-
auto const_out = const_layer->getOutput(0);
209-
return const_out;
210-
}
211-
212202
// clamp x to [lower_bound, upper_bound]
213-
nvinfer1::ITensor* clamp(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* x,
203+
nvinfer1::ITensor* clamp(ConversionCtx* ctx, nvinfer1::ITensor* x,
214204
nvinfer1::ITensor* lower_bound, nvinfer1::ITensor* upper_bound) {
215205
auto max_layer = ctx->net->addElementWise(*x, *lower_bound, nvinfer1::ElementWiseOperation::kMAX);
206+
TORCHTRT_CHECK(max_layer, "Unable to create max layer for clamp");
207+
LOG_DEBUG(ctx->logger, "Create " << max_layer->getName() << " for clamp");
216208
auto max_itensor = max_layer->getOutput(0);
209+
217210
auto min_layer = ctx->net->addElementWise(*max_itensor, *upper_bound, nvinfer1::ElementWiseOperation::kMIN);
211+
TORCHTRT_CHECK(min_layer, "Unable to create min layer for clamp");
212+
LOG_DEBUG(ctx->logger, "Create " << min_layer->getName() << " for clamp");
218213
auto min_itensor = min_layer->getOutput(0);
219214
return min_itensor;
220215
}
221216

222217
// clamp x to [0, input_dim]
223-
nvinfer1::ITensor* clamp_to_input_dim(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* x,
218+
nvinfer1::ITensor* clamp_to_input_dim(ConversionCtx* ctx, nvinfer1::ITensor* x,
224219
nvinfer1::ITensor* input_dim) {
225220
auto nbdims = input_dim->getDimensions().d[0];
226221
auto zero = torch::zeros({nbdims}).to(torch::kI32);
227-
auto zero_itensor = toITensor(ctx, n, &zero);
222+
auto zero_itensor = tensor_to_const(ctx, zero);
228223
auto one = torch::ones({nbdims}).to(torch::kI32);
229-
auto one_itensor = toITensor(ctx, n, &one);
224+
auto one_itensor = tensor_to_const(ctx, one);
225+
230226
auto upper_bound_layer = ctx->net->addElementWise(*input_dim, *one_itensor, nvinfer1::ElementWiseOperation::kSUB);
227+
TORCHTRT_CHECK(upper_bound_layer, "Unable to create sub layer for clamp to inputDim");
228+
LOG_DEBUG(ctx->logger, "Create " << upper_bound_layer->getName() << " for clamp to inputDim");
231229
auto upper_bound = upper_bound_layer->getOutput(0);
230+
232231
auto max_layer = ctx->net->addElementWise(*x, *zero_itensor, nvinfer1::ElementWiseOperation::kMAX);
232+
TORCHTRT_CHECK(max_layer, "Unable to create max_layer for clamp to inputDim");
233+
LOG_DEBUG(ctx->logger, "Create " << max_layer->getName() << " for clamp to inputDim");
233234
auto max_itensor = max_layer->getOutput(0);
235+
234236
auto min_layer = ctx->net->addElementWise(*max_itensor, *upper_bound, nvinfer1::ElementWiseOperation::kMIN);
237+
TORCHTRT_CHECK(min_layer, "Unable to create min_layer for clamp to inputDim");
238+
LOG_DEBUG(ctx->logger, "Create " << min_layer->getName() << " for clamp to inputDim");
235239
auto min_itensor = min_layer->getOutput(0);
236240
return min_itensor;
237241
}
238242

239243

240244
// return indices < 0 ? inputDims + indices : indices
241-
nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_dim,
245+
nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, nvinfer1::ITensor* input_dim,
242246
nvinfer1::ITensor* indices) {
243247
auto nbdims = input_dim->getDimensions().d[0];
244248
auto zero = torch::zeros({nbdims}).to(torch::kI32);
245249
auto neg = - torch::ones({nbdims}).to(torch::kI32);
246-
auto zero_itensor = toITensor(ctx, n, &zero);
247-
auto neg_itensor = toITensor(ctx, n, &neg);
248-
auto signs = clamp(ctx, n, indices, neg_itensor, zero_itensor);
250+
auto zero_itensor = tensor_to_const(ctx, zero);
251+
auto neg_itensor = tensor_to_const(ctx, neg);
252+
// find the indices that = -1
253+
auto signs = clamp(ctx, indices, neg_itensor, zero_itensor);
254+
255+
// get the inputDim value where indices == -1, else 0
249256
auto mul = ctx->net->addElementWise(*signs, *input_dim, nvinfer1::ElementWiseOperation::kPROD);
257+
TORCHTRT_CHECK(mul, "Unable to create mul layer in bump_if_negtive");
258+
LOG_DEBUG(ctx->logger, "Create " << mul->getName() << " for bump_if_negtive");
250259
auto mul_itensor = mul->getOutput(0);
260+
261+
// add the inputDim value to indices where indices == -1
251262
auto sub = ctx->net->addElementWise(*indices, *mul_itensor, nvinfer1::ElementWiseOperation::kSUB);
263+
TORCHTRT_CHECK(sub, "Unable to create sub layer in bump_if_negtive");
264+
LOG_DEBUG(ctx->logger, "Create " << sub->getName() << " for bump_if_negtive");
252265
auto sub_itensor = sub->getOutput(0);
253266
return sub_itensor;
254267
}
255268

256-
void update_start_and_end(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in_shape,
257-
nvinfer1::ITensor* in_start, nvinfer1::ITensor* in_end,
258-
nvinfer1::ITensor** out_start, nvinfer1::ITensor** out_end) {
259-
auto start = bump_if_negtive(ctx, n, in_shape, in_start);
260-
*out_start = clamp_to_input_dim(ctx, n, start, in_shape);
261-
auto end = bump_if_negtive(ctx, n, in_shape, in_end);
262-
*out_end = clamp_to_input_dim(ctx, n, end, in_shape);
269+
std::vector<nvinfer1::ITensor*> update_start_and_end(ConversionCtx* ctx, nvinfer1::ITensor* in_shape,
270+
nvinfer1::ITensor* in_start, nvinfer1::ITensor* in_end) {
271+
auto start = bump_if_negtive(ctx, in_shape, in_start);
272+
auto out_start = clamp_to_input_dim(ctx, start, in_shape);
273+
auto end = bump_if_negtive(ctx, in_shape, in_end);
274+
auto out_end = clamp_to_input_dim(ctx, end, in_shape);
275+
std::vector<nvinfer1::ITensor*> outputs;
276+
outputs.push_back(out_start);
277+
outputs.push_back(out_end);
278+
return outputs;
279+
}
280+
281+
// size = (end - start) / stride + 1, where range is [start, end], end is included
282+
nvinfer1::ITensor* calculate_output_size(ConversionCtx* ctx, nvinfer1::ITensor* start, nvinfer1::ITensor* end,
283+
nvinfer1::ITensor* stride, int nbdims) {
284+
285+
at::Tensor one_tensor = torch::ones({nbdims}).to(torch::kI32);
286+
auto one_itensor = tensor_to_const(ctx, one_tensor);
287+
288+
auto sub_layer = ctx->net->addElementWise(*end, *start, nvinfer1::ElementWiseOperation::kSUB);
289+
TORCHTRT_CHECK(sub_layer, "Unable to create sub layer in calculate_output_size");
290+
LOG_DEBUG(ctx->logger, "Create " << sub_layer->getName() << " for calculate_output_size");
291+
auto sub_itensor = sub_layer->getOutput(0);
292+
293+
auto div_layer = ctx->net->addElementWise(*sub_itensor, *stride, nvinfer1::ElementWiseOperation::kDIV);
294+
TORCHTRT_CHECK(div_layer, "Unable to create div layer in calculate_output_size");
295+
LOG_DEBUG(ctx->logger, "Create " << div_layer->getName() << " for calculate_output_size");
296+
auto div_itensor = div_layer->getOutput(0);
297+
298+
auto add_layer = ctx->net->addElementWise(*div_itensor, *one_itensor, nvinfer1::ElementWiseOperation::kSUM);
299+
TORCHTRT_CHECK(add_layer, "Unable to create add layer in calculate_output_size");
300+
LOG_DEBUG(ctx->logger, "Create " << add_layer->getName() << " for calculate_output_size");
301+
auto size_itensor = add_layer->getOutput(0);
302+
303+
return size_itensor;
263304
}
264305

265306
bool is_dynamic_shape(nvinfer1::ITensor* tensor) {

core/conversion/converters/converter_util.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,17 @@ nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nv
5050
// Freeze an at::Tensor in a IConstant layer
5151
nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::string& name = std::string());
5252

53-
nvinfer1::ITensor* toITensor(ConversionCtx* ctx, const torch::jit::Node* n, at::Tensor* input);
54-
55-
nvinfer1::ITensor* clamp(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* x,
53+
nvinfer1::ITensor* clamp(ConversionCtx* ctx, nvinfer1::ITensor* x,
5654
nvinfer1::ITensor* lower_bound, nvinfer1::ITensor* upper_bound);
5755

58-
nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_dim,
56+
nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, nvinfer1::ITensor* input_dim,
5957
nvinfer1::ITensor* indices);
6058

61-
void update_start_and_end(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in_shape,
62-
nvinfer1::ITensor* in_start, nvinfer1::ITensor* in_end,
63-
nvinfer1::ITensor** out_start, nvinfer1::ITensor** out_end);
59+
std::vector<nvinfer1::ITensor*> update_start_and_end(ConversionCtx* ctx, nvinfer1::ITensor* in_shape,
60+
nvinfer1::ITensor* in_start, nvinfer1::ITensor* in_end);
61+
62+
nvinfer1::ITensor* calculate_output_size(ConversionCtx* ctx, nvinfer1::ITensor* start, nvinfer1::ITensor* end,
63+
nvinfer1::ITensor* stride, int nbdims);
6464

6565
bool is_dynamic_shape(nvinfer1::ITensor* tensor);
6666

core/conversion/converters/impl/select.cpp

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ auto select_registrations TORCHTRT_UNUSED =
332332
}
333333
auto step = args[4].unwrapToInt();
334334

335+
// update start, end, stride for static shape
335336
auto nbdims = in->getDimensions().nbDims;
336337
nvinfer1::Dims start_, size_, stride_;
337338
start_.nbDims = nbdims;
@@ -355,12 +356,12 @@ auto select_registrations TORCHTRT_UNUSED =
355356
// start tensor
356357
at::Tensor start_tensor = torch::zeros({nbdims}).to(torch::kI32);;
357358
start_tensor[axis] = startIdx;
358-
auto start_itensor = toITensor(ctx, n, &start_tensor);
359+
auto start_itensor = tensor_to_const(ctx, start_tensor);
359360

360361
// step tensor
361362
at::Tensor stride_tensor = torch::ones({nbdims}).to(torch::kI32);
362363
stride_tensor[axis] = step;
363-
auto stride_itensor = toITensor(ctx, n, &stride_tensor);
364+
auto stride_itensor = tensor_to_const(ctx, stride_tensor);
364365

365366
// end tensor
366367
at::Tensor end_tensor = torch::zeros({nbdims}).to(torch::kI32);
@@ -371,32 +372,21 @@ auto select_registrations TORCHTRT_UNUSED =
371372
end_tensor[i] = input_dim.d[i] == -1 ? -1 : input_dim.d[i]-1;
372373
}
373374
}
374-
auto end_itensor = toITensor(ctx, n, &end_tensor);
375-
376-
// one itensor
377-
at::Tensor one_tensor = torch::ones({nbdims}).to(torch::kI32);
378-
auto one_itensor = toITensor(ctx, n, &one_tensor);
375+
auto end_itensor = tensor_to_const(ctx, end_tensor);
379376

380377
// update start and end
381378
nvinfer1::ITensor* out_start;
382379
nvinfer1::ITensor* out_end;
383-
update_start_and_end(ctx, n, ishape_tensor,
384-
start_itensor, end_itensor,
385-
&out_start, &out_end);
380+
auto start_end = update_start_and_end(ctx, ishape_tensor, start_itensor, end_itensor);
381+
out_start = start_end[0];
382+
out_end = start_end[1];
386383

387384
// calculate size
388-
auto sub_layer = ctx->net->addElementWise(*out_end, *out_start, nvinfer1::ElementWiseOperation::kSUB);
389-
auto sub_itensor = sub_layer->getOutput(0);
390-
auto div_layer = ctx->net->addElementWise(*sub_itensor, *stride_itensor, nvinfer1::ElementWiseOperation::kDIV);
391-
auto div_itensor = div_layer->getOutput(0);
392-
auto add_layer = ctx->net->addElementWise(*div_itensor, *one_itensor, nvinfer1::ElementWiseOperation::kSUM);
393-
auto size_itensor = add_layer->getOutput(0);
394-
385+
auto size_itensor = calculate_output_size(ctx, out_start, out_end, stride_itensor, nbdims);
395386

396387
// update slice layer
397388
slice_layer->setInput(1, *out_start); // start
398389
slice_layer->setInput(2, *size_itensor); // size, must be set if input is dynamic
399-
400390
}
401391
auto slice_out = slice_layer->getOutput(0);
402392
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_out);

0 commit comments

Comments
 (0)