@@ -19,7 +19,11 @@ namespace conversion {
19
19
namespace evaluators {
20
20
namespace {
21
21
22
- nvinfer1::ITensor* index_layer (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_tensor, int64_t index){
22
+ nvinfer1::ITensor* index_layer (
23
+ ConversionCtx* ctx,
24
+ const torch::jit::Node* n,
25
+ nvinfer1::ITensor* input_tensor,
26
+ int64_t index) {
23
27
// index to access needs to be an at::Tensor
24
28
at::Tensor indices = torch::tensor ({index}).to (torch::kI32 );
25
29
auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const (ctx, indices);
@@ -30,15 +34,15 @@ nvinfer1::ITensor* index_layer(ConversionCtx* ctx, const torch::jit::Node* n, nv
30
34
return indexed_tensor;
31
35
}
32
36
33
- c10::IValue dynamic_size_layer (ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args){
37
+ c10::IValue dynamic_size_layer (ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
34
38
LOG_DEBUG (" Using dynamic version of aten::size evaluator" );
35
39
auto in = args.at (n->input (0 )).ITensorOrFreeze (ctx);
36
40
LOG_DEBUG (" Input dimensions: " << in->getDimensions ());
37
41
auto shape_layer = ctx->net ->addShape (*in);
38
42
TORCHTRT_CHECK (shape_layer, " Unable to create shape layer from node: " << *n);
39
43
auto shape_1d_tensor = shape_layer->getOutput (0 );
40
44
41
- if (n->inputs ().size () != 1 ){
45
+ if (n->inputs ().size () != 1 ) {
42
46
auto maxDim = static_cast <int64_t >(in->getDimensions ().nbDims );
43
47
auto dim = args.at (n->input (1 )).unwrapToInt ();
44
48
// Handle negative axis by refering to nbDims of input Tensor
@@ -306,7 +310,7 @@ auto aten_registrations TORCHTRT_UNUSED =
306
310
if (n->inputs ().size () == 1 ) {
307
311
if (tensor_var.isITensor ()) {
308
312
auto tensor = tensor_var.ITensor ();
309
- if (ctx->input_is_dynamic ){
313
+ if (ctx->input_is_dynamic ) {
310
314
return dynamic_size_layer (ctx, n, args);
311
315
}
312
316
return util::toVec (tensor->getDimensions ());
@@ -322,7 +326,7 @@ auto aten_registrations TORCHTRT_UNUSED =
322
326
} else {
323
327
auto dim = args.at (n->input (1 )).unwrapToInt ();
324
328
if (tensor_var.isITensor ()) {
325
- if (ctx->input_is_dynamic ){
329
+ if (ctx->input_is_dynamic ) {
326
330
return dynamic_size_layer (ctx, n, args);
327
331
}
328
332
auto tensor = tensor_var.ITensor ();
@@ -359,14 +363,14 @@ auto aten_registrations TORCHTRT_UNUSED =
359
363
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
360
364
auto list_input = args.at (n->input (0 ));
361
365
auto idx = args.at (n->input (1 )).unwrapToInt ();
362
- if (list_input.isIValue ()){
363
- auto list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
364
- const int64_t list_size = list.size ();
365
- const int64_t normalized_idx = normalizeIndex (idx, list_size);
366
- TORCHTRT_CHECK (
367
- normalized_idx >= 0 || normalized_idx < list_size, " List index out of range (aten::__getitem__)" );
368
- return list.get (normalized_idx);
369
- } else if (list_input.isITensor ()){
366
+ if (list_input.isIValue ()) {
367
+ auto list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
368
+ const int64_t list_size = list.size ();
369
+ const int64_t normalized_idx = normalizeIndex (idx, list_size);
370
+ TORCHTRT_CHECK (
371
+ normalized_idx >= 0 || normalized_idx < list_size, " List index out of range (aten::__getitem__)" );
372
+ return list.get (normalized_idx);
373
+ } else if (list_input.isITensor ()) {
370
374
auto indexed_tensor = index_layer (ctx, n, list_input.ITensorOrFreeze (ctx), idx);
371
375
auto tensor_holder = TensorContainer ();
372
376
tensor_holder.hold_tensor (indexed_tensor);
0 commit comments