@@ -19,6 +19,41 @@ namespace conversion {
19
19
namespace evaluators {
20
20
namespace {
21
21
22
+ nvinfer1::ITensor* index_layer (){
23
+
24
+ }
25
+
26
+ c10::IValue dynamic_size_layer (ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args){
27
+ LOG_DEBUG (" Using dynamic version of aten::size evaluator" );
28
+ auto in = args.at (n->input (0 )).ITensorOrFreeze (ctx);
29
+ LOG_DEBUG (" Input dimensions: " << in->getDimensions ());
30
+ auto shape_layer = ctx->net ->addShape (*in);
31
+ auto shape_1d_tensor = shape_layer->getOutput (0 );
32
+
33
+ if (n->inputs ().size () != 1 ){
34
+ auto maxDim = static_cast <int64_t >(in->getDimensions ().nbDims );
35
+ auto dim = args.at (n->input (1 )).unwrapToInt ();
36
+ // Handle negative axis by refering to nbDims of input Tensor
37
+ dim = dim < 0 ? dim + maxDim : dim;
38
+ LOG_DEBUG (" Dimension to select: " << dim);
39
+
40
+ // index to access needs to be an at::Tensor
41
+ at::Tensor indices = torch::tensor ({dim}).to (torch::kI32 );
42
+ auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const (ctx, indices);
43
+
44
+ auto gather_layer = ctx->net ->addGather (*shape_1d_tensor, *indices_out, 0 );
45
+ shape_1d_tensor = gather_layer->getOutput (0 );
46
+ }
47
+
48
+ LOG_DEBUG (" Output tensor shape: " << shape_1d_tensor->getDimensions ());
49
+
50
+ auto tensor_holder = TensorContainer ();
51
+ tensor_holder.hold_tensor (shape_1d_tensor);
52
+ auto shape_1d_ivalue = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
53
+
54
+ return shape_1d_ivalue;
55
+ }
56
+
22
57
DEFINE_GENERIC_TWO_INPUT_EVALUATOR (
23
58
eq,
24
59
" aten::eq" ,
@@ -176,7 +211,7 @@ auto aten_registrations TORCHTRT_UNUSED =
176
211
{c10::Symbol::fromQualString (" aten::full_like" ),
177
212
// aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None,
178
213
// Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> (Tensor)
179
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
214
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
180
215
// Override options related to layout and device for TensorRT
181
216
auto options = torch::TensorOptions ().layout (torch::kStrided ).device (torch::kCUDA );
182
217
auto input_tensor_var = args.at (n->input (0 ));
@@ -262,67 +297,80 @@ auto aten_registrations TORCHTRT_UNUSED =
262
297
return static_cast <int64_t >(list.size ());
263
298
},
264
299
EvalOptions ().validSchemas ({" aten::len.t(t[] a) -> (int)" })})
265
- // .evaluator(
266
- // {c10::Symbol::fromQualString("aten::size"),
267
- // [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
268
- // LOG_WARNING("There may be undefined behavior using dynamic shape and aten::size");
269
- // auto tensor_var = args.at(n->input(0));
270
- // if (n->inputs().size() == 1) {
271
- // if (tensor_var.isITensor()) {
272
- // auto tensor = tensor_var.ITensor();
273
- // return util::toVec(tensor->getDimensions());
274
- // } else if (tensor_var.IValue()->isTensor()) {
275
- // auto tensor = tensor_var.unwrapToTensor();
276
- // return tensor.sizes();
277
- // } else if (tensor_var.IValue()->isCustomClass()) {
278
- // auto tensor = tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
279
- // return util::toVec(tensor->getDimensions());
280
- // } else {
281
- // TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type());
282
- // }
283
- // } else {
284
- // auto dim = args.at(n->input(1)).unwrapToInt();
285
- // if (tensor_var.isITensor()) {
286
- // auto tensor = tensor_var.ITensor();
287
- // auto dims = util::toVec(tensor->getDimensions());
288
- // auto nbDims = tensor->getDimensions().nbDims;
289
- // if (dim < 0) {
290
- // dim += nbDims;
291
- // }
292
- // return dims[dim];
293
- // } else if (tensor_var.IValue()->isTensor()) {
294
- // auto tensor = tensor_var.unwrapToTensor();
295
- // auto nbDims = tensor.sizes().size();
296
- // if (dim < 0) {
297
- // dim += nbDims;
298
- // }
299
- // return tensor.sizes()[dim];
300
- // } else if (tensor_var.IValue()->isCustomClass()) {
301
- // auto tensor = tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
302
- // auto dims = util::toVec(tensor->getDimensions());
303
- // auto nbDims = tensor->getDimensions().nbDims;
304
- // if (dim < 0) {
305
- // dim += nbDims;
306
- // }
307
- // return dims[dim];
308
- // } else {
309
- // TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type());
310
- // }
311
- // }
312
- // },
313
- // EvalOptions().validSchemas(
314
- // {"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})})
300
+ .evaluator(
301
+ {c10::Symbol::fromQualString (" aten::size" ),
302
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
303
+ auto tensor_var = args.at (n->input (0 ));
304
+ if (n->inputs ().size () == 1 ) {
305
+ if (tensor_var.isITensor ()) {
306
+ auto tensor = tensor_var.ITensor ();
307
+ if (ctx->input_is_dynamic ){
308
+ return dynamic_size_layer (ctx, n, args);
309
+ }
310
+ return util::toVec (tensor->getDimensions ());
311
+ } else if (tensor_var.IValue ()->isTensor ()) {
312
+ auto tensor = tensor_var.unwrapToTensor ();
313
+ return tensor.sizes ();
314
+ } else if (tensor_var.IValue ()->isCustomClass ()) {
315
+ auto tensor = tensor_var.IValue ()->toCustomClass <TensorContainer>()->tensor ();
316
+ return util::toVec (tensor->getDimensions ());
317
+ } else {
318
+ TORCHTRT_THROW_ERROR (" IValue is not some class of Tensor. Found: " << tensor_var.IValue ()->type ());
319
+ }
320
+ } else {
321
+ auto dim = args.at (n->input (1 )).unwrapToInt ();
322
+ if (tensor_var.isITensor ()) {
323
+ if (ctx->input_is_dynamic ){
324
+ return dynamic_size_layer (ctx, n, args);
325
+ }
326
+ auto tensor = tensor_var.ITensor ();
327
+ auto dims = util::toVec (tensor->getDimensions ());
328
+ auto nbDims = tensor->getDimensions ().nbDims ;
329
+ if (dim < 0 ) {
330
+ dim += nbDims;
331
+ }
332
+ return dims[dim];
333
+ } else if (tensor_var.IValue ()->isTensor ()) {
334
+ auto tensor = tensor_var.unwrapToTensor ();
335
+ auto nbDims = tensor.sizes ().size ();
336
+ if (dim < 0 ) {
337
+ dim += nbDims;
338
+ }
339
+ return tensor.sizes ()[dim];
340
+ } else if (tensor_var.IValue ()->isCustomClass ()) {
341
+ auto tensor = tensor_var.IValue ()->toCustomClass <TensorContainer>()->tensor ();
342
+ auto dims = util::toVec (tensor->getDimensions ());
343
+ auto nbDims = tensor->getDimensions ().nbDims ;
344
+ if (dim < 0 ) {
345
+ dim += nbDims;
346
+ }
347
+ return dims[dim];
348
+ } else {
349
+ TORCHTRT_THROW_ERROR (" IValue is not some class of Tensor. Found: " << tensor_var.IValue ()->type ());
350
+ }
351
+ }
352
+ },
353
+ EvalOptions ().validSchemas (
354
+ {" aten::size(Tensor self) -> (int[])" , " aten::size.int(Tensor self, int dim) -> (int)" })})
315
355
.evaluator(
316
356
{c10::Symbol::fromQualString (" aten::__getitem__" ),
317
357
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318
- auto list = args.at (n->input (0 )). IValue ()-> to <c10::List<c10::IValue>>( );
358
+ auto list_input = args.at (n->input (0 ));
319
359
auto idx = args.at (n->input (1 )).unwrapToInt ();
360
+ if (list_input.isIValue ()){
361
+ auto list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
362
+ const int64_t list_size = list.size ();
363
+ const int64_t normalized_idx = normalizeIndex (idx, list_size);
364
+ TORCHTRT_CHECK (
365
+ normalized_idx >= 0 || normalized_idx < list_size, " List index out of range (aten::__getitem__)" );
366
+ return list.get (normalized_idx);
367
+ } elif (list_input.isITensor ()){
368
+ return dynamic_size_layer (ctx, n, args);
369
+ }
370
+
371
+
320
372
321
- const int64_t list_size = list.size ();
322
- const int64_t normalized_idx = normalizeIndex (idx, list_size);
323
- TORCHTRT_CHECK (
324
- normalized_idx >= 0 || normalized_idx < list_size, " List index out of range (aten::__getitem__)" );
325
- return list.get (normalized_idx);
373
+
326
374
},
327
375
EvalOptions ().validSchemas ({
328
376
" aten::__getitem__.t(t[](a) list, int idx) -> (t(*))" ,
0 commit comments