Skip to content

Commit 22c0e17

Browse files
authored
support multiple indices for aten::index.Tensor (#1309)
Signed-off-by: Ruoqian Guo <[email protected]> Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 4d32d47 commit 22c0e17

File tree

3 files changed

+340
-24
lines changed

3 files changed

+340
-24
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 214 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -271,37 +271,229 @@ auto select_registrations TORCHTRT_UNUSED =
271271
auto ts = args[1].IValue()->toListRef();
272272

273273
std::vector<nvinfer1::ITensor*> tensors;
274-
for (auto t : ts) {
274+
std::vector<int32_t> adv_idx_indices;
275+
for (auto i = 0; i < ts.size(); i++) {
276+
auto t = ts[i];
275277
if (t.isTensor()) {
276-
auto torch_tensor = t.toTensor();
278+
auto torch_tensor = t.toTensor().to(torch::kInt32);
277279
tensors.push_back(tensor_to_const(ctx, torch_tensor));
280+
adv_idx_indices.push_back(i);
278281
} else {
279-
auto cont = t.toCustomClass<TensorContainer>();
280-
tensors.push_back(cont->tensor());
282+
// IValue
283+
if (!t.isNone()) {
284+
adv_idx_indices.push_back(i);
285+
auto cont = t.toCustomClass<TensorContainer>();
286+
// Set datatype for indices tensor to INT32
287+
auto identity = ctx->net->addIdentity(*cont->tensor());
288+
identity->setOutputType(0, nvinfer1::DataType::kINT32);
289+
tensors.push_back(identity->getOutput(0));
290+
}
281291
}
282292
}
283293

284-
// In TorchScript, aten::index.Tensor indexes the self tensor along its each dimension by several
285-
// indexes. In this version of Torch-TensorRT, it can only receive one index tensor which means it only
286-
// indexes the self tensor along dimension 0.
287-
TORCHTRT_CHECK(
288-
tensors.size() == 1,
289-
"In this version of Torch-TensorRT, aten::index.Tensor can only receive one index tensor which means it only indexes the self tensor along dimension 0.");
290-
auto indicesTensor = tensors[0];
291-
// Set datatype for indices tensor to INT32
292-
auto identity = ctx->net->addIdentity(*indicesTensor);
293-
identity->setOutputType(0, nvinfer1::DataType::kINT32);
294-
indicesTensor = identity->getOutput(0);
294+
if (tensors.size() == 0) {
295+
auto identity_out = ctx->net->addIdentity(*in)->getOutput(0);
296+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity_out);
297+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
298+
} else if (tensors.size() == 1) {
299+
auto indicesTensor = tensors[0];
300+
// Set datatype for indices tensor to INT32
301+
auto identity = ctx->net->addIdentity(*indicesTensor);
302+
identity->setOutputType(0, nvinfer1::DataType::kINT32);
303+
indicesTensor = identity->getOutput(0);
304+
305+
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
306+
// from
307+
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
308+
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
309+
auto gather_out = gather_layer->getOutput(0);
310+
311+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
312+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
313+
} else {
314+
auto inDims = in->getDimensions();
315+
int rank = inDims.nbDims;
316+
LOG_WARNING("If indices include negative values, the exported graph will produce incorrect results.");
317+
int adv_idx_count = adv_idx_indices.size();
318+
auto in_shape_itensor = ctx->net->addShape(*in)->getOutput(0);
319+
320+
std::vector<nvinfer1::ITensor*> dim_tensor_list;
321+
for (int i = 0; i < rank; i++) {
322+
auto dim_tensor =
323+
ctx->net
324+
->addGather(*in_shape_itensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0)
325+
->getOutput(0);
326+
dim_tensor_list.push_back(dim_tensor);
327+
}
295328

296-
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
297-
// from
298-
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
299-
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
300-
auto gather_out = gather_layer->getOutput(0);
329+
// t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
330+
// where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
331+
// for ":".
332+
auto in_transpose_layer = ctx->net->addShuffle(*in);
333+
TORCHTRT_CHECK(in_transpose_layer, "Unable to create shuffle layer from node: " << *n);
334+
nvinfer1::Permutation permute;
335+
std::vector<int32_t> new_order;
336+
for (int i = 0; i < adv_idx_count; i++) {
337+
new_order.push_back(adv_idx_indices[i]);
338+
}
339+
for (int i = 0; i < rank; i++) {
340+
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
341+
new_order.push_back(i);
342+
}
343+
}
344+
std::copy(new_order.begin(), new_order.end(), permute.order);
345+
in_transpose_layer->setSecondTranspose(permute);
346+
auto shuffle_out = in_transpose_layer->getOutput(0);
347+
348+
// t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] -> t: [x_1*x_2* ...*x_m, y_1*y_2* ...*y_n]
349+
nvinfer1::ITensor* flatten_tensor = NULL;
350+
{
351+
auto shuffle_shape_tensor = ctx->net->addShape(*shuffle_out)->getOutput(0);
352+
auto d0 = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32));
353+
for (int i = 0; i < adv_idx_count; i++) {
354+
auto dim_tensor =
355+
ctx->net
356+
->addGather(
357+
*shuffle_shape_tensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0)
358+
->getOutput(0);
359+
d0 = add_elementwise(
360+
ctx,
361+
nvinfer1::ElementWiseOperation::kPROD,
362+
d0,
363+
dim_tensor,
364+
std::string("compute_dim0_") + std::to_string(i))
365+
->getOutput(0);
366+
}
301367

302-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
368+
auto d1 = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32));
369+
for (int i = adv_idx_count; i < rank; i++) {
370+
auto dim_tensor =
371+
ctx->net
372+
->addGather(
373+
*shuffle_shape_tensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0)
374+
->getOutput(0);
375+
d1 = add_elementwise(
376+
ctx,
377+
nvinfer1::ElementWiseOperation::kPROD,
378+
d1,
379+
dim_tensor,
380+
std::string("compute_dim1_") + std::to_string(i))
381+
->getOutput(0);
382+
}
303383

304-
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
384+
std::vector<nvinfer1::ITensor*> concat_tensors;
385+
concat_tensors.push_back(d0);
386+
concat_tensors.push_back(d1);
387+
auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
388+
389+
auto shuffle = ctx->net->addShuffle(*shuffle_out);
390+
shuffle->setInput(1, *concat_layer->getOutput(0));
391+
flatten_tensor = shuffle->getOutput(0);
392+
LOG_DEBUG(flatten_tensor->getDimensions());
393+
}
394+
395+
// tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
396+
// j dimension of input x.
397+
nvinfer1::ITensor* multiplier = dim_tensor_list[adv_idx_indices[adv_idx_count - 1]];
398+
nvinfer1::ITensor* cum_adv_index = tensors[adv_idx_count - 1];
399+
for (int i = adv_idx_count - 2; i >= 0; i--) {
400+
nvinfer1::ITensor* adv_index = add_elementwise(
401+
ctx,
402+
nvinfer1::ElementWiseOperation::kPROD,
403+
tensors[i],
404+
multiplier,
405+
std::string("adv_index_") + std::to_string(i))
406+
->getOutput(0);
407+
cum_adv_index = add_elementwise(
408+
ctx,
409+
nvinfer1::ElementWiseOperation::kSUM,
410+
cum_adv_index,
411+
adv_index,
412+
std::string("cum_adv_index_") + std::to_string(i))
413+
->getOutput(0);
414+
multiplier = add_elementwise(
415+
ctx,
416+
nvinfer1::ElementWiseOperation::kPROD,
417+
multiplier,
418+
dim_tensor_list[adv_idx_indices[i]],
419+
std::string("multiplier_") + std::to_string(i))
420+
->getOutput(0);
421+
}
422+
423+
// perform gather
424+
auto gather_out = ctx->net->addGather(*flatten_tensor, *cum_adv_index, 0)->getOutput(0);
425+
426+
nvinfer1::ITensor* reshape_output = NULL;
427+
{
428+
auto cum_adv_index_shape_tensor = ctx->net->addShape(*cum_adv_index)->getOutput(0);
429+
// check if all advanced indices are consecutive.
430+
if (adv_idx_count == (adv_idx_indices[adv_idx_count - 1] - adv_idx_indices[0] + 1)) {
431+
// unfold regular index axes
432+
std::vector<nvinfer1::ITensor*> concat_tensors;
433+
concat_tensors.push_back(tensor_to_const(ctx, torch::tensor({-1}, torch::kInt32)));
434+
for (int i = 0; i < rank; i++) {
435+
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
436+
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
437+
concat_tensors.push_back(current_dim);
438+
}
439+
}
440+
auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
441+
auto regular_index_shuffle_layer = ctx->net->addShuffle(*gather_out);
442+
regular_index_shuffle_layer->setInput(1, *concat_layer->getOutput(0));
443+
auto unfold_tensor = regular_index_shuffle_layer->getOutput(0);
444+
445+
// Transpose folded advanced indexed axis to its original location.
446+
auto transpose_advanced_shuffle_layer = ctx->net->addShuffle(*unfold_tensor);
447+
nvinfer1::Permutation permute;
448+
std::vector<int32_t> new_order;
449+
for (int i = 1; i < adv_idx_indices[0] + 1; i++) {
450+
new_order.push_back(i);
451+
}
452+
new_order.push_back(0);
453+
for (int i = adv_idx_indices[0] + 1; i < rank - adv_idx_count + 1; i++) {
454+
new_order.push_back(i);
455+
}
456+
std::copy(new_order.begin(), new_order.end(), permute.order);
457+
transpose_advanced_shuffle_layer->setSecondTranspose(permute);
458+
auto shuffle_out = transpose_advanced_shuffle_layer->getOutput(0);
459+
460+
// unfold advanced index axes
461+
std::vector<nvinfer1::ITensor*> concat_final_tensors;
462+
for (int i = 0; i < adv_idx_indices[0]; i++) {
463+
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
464+
concat_final_tensors.push_back(current_dim);
465+
}
466+
concat_final_tensors.push_back(cum_adv_index_shape_tensor);
467+
for (int i = adv_idx_indices[0]; i < rank; i++) {
468+
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
469+
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
470+
concat_final_tensors.push_back(current_dim);
471+
}
472+
}
473+
auto concat_final_shape_layer =
474+
ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
475+
auto unfold_advanced_shuffle_layer = ctx->net->addShuffle(*shuffle_out);
476+
unfold_advanced_shuffle_layer->setInput(1, *concat_final_shape_layer->getOutput(0));
477+
reshape_output = unfold_advanced_shuffle_layer->getOutput(0);
478+
} else {
479+
std::vector<nvinfer1::ITensor*> concat_tensors;
480+
concat_tensors.push_back(cum_adv_index_shape_tensor);
481+
for (int i = 0; i < rank; i++) {
482+
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
483+
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
484+
concat_tensors.push_back(current_dim);
485+
}
486+
}
487+
auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
488+
auto shuffle_layer = ctx->net->addShuffle(*gather_out);
489+
shuffle_layer->setInput(1, *concat_layer->getOutput(0));
490+
reshape_output = shuffle_layer->getOutput(0);
491+
}
492+
}
493+
494+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], reshape_output);
495+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
496+
}
305497
return true;
306498
}})
307499
.pattern(

core/conversion/evaluators/prim.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,12 @@ auto prim_registrations =
100100
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
101101
list.emplace_back(std::move(ival));
102102
} else {
103-
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
103+
if (args.at(in).IValue()->isNone()) {
104+
auto ival = torch::jit::IValue();
105+
list.emplace_back(std::move(ival));
106+
} else {
107+
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
108+
}
104109
}
105110
}
106111
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));

0 commit comments

Comments
 (0)