@@ -352,6 +352,187 @@ auto element_wise_registrations TRTORCH_UNUSED =
352
352
pow->setName (util::node_info (n).c_str ());
353
353
auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], pow->getOutput (0 ));
354
354
355
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
356
+ return true ;
357
+ }})
358
+ .pattern({" aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
359
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
360
+ // TODO: Remove with functionalization
361
+ auto self = args[0 ].ITensorOrFreeze (ctx);
362
+ auto other = args[1 ].ITensorOrFreeze (ctx);
363
+ auto gt =
364
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kGREATER , self, other, util::node_info (n));
365
+ TRTORCH_CHECK (gt, " Unable to create greater layer from node: " << *n);
366
+
367
+ gt->setName (util::node_info (n).c_str ());
368
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gt->getOutput (0 ));
369
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
370
+ return true ;
371
+ }})
372
+ .pattern({" aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
373
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
374
+ // TODO: Remove with functionalization
375
+ auto self = args[0 ].ITensorOrFreeze (ctx);
376
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
377
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
378
+ auto gt =
379
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kGREATER , self, other, util::node_info (n));
380
+ TRTORCH_CHECK (gt, " Unable to create greater layer from node: " << *n);
381
+
382
+ gt->setName (util::node_info (n).c_str ());
383
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gt->getOutput (0 ));
384
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
385
+ return true ;
386
+ }})
387
+ .pattern({" aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
388
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
389
+ // TODO: Remove with functionalization
390
+ auto self = args[0 ].ITensorOrFreeze (ctx);
391
+ auto other = args[1 ].ITensorOrFreeze (ctx);
392
+ auto lt =
393
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kLESS , self, other, util::node_info (n));
394
+ TRTORCH_CHECK (lt, " Unable to create less layer from node: " << *n);
395
+
396
+ lt->setName (util::node_info (n).c_str ());
397
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], lt->getOutput (0 ));
398
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
399
+ return true ;
400
+ }})
401
+ .pattern({" aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
402
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
403
+ // TODO: Remove with functionalization
404
+ auto self = args[0 ].ITensorOrFreeze (ctx);
405
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
406
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
407
+ auto lt =
408
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kLESS , self, other, util::node_info (n));
409
+ TRTORCH_CHECK (lt, " Unable to create less layer from node: " << *n);
410
+
411
+ lt->setName (util::node_info (n).c_str ());
412
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], lt->getOutput (0 ));
413
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
414
+ return true ;
415
+ }})
416
+ .pattern({" aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
417
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
418
+ // TODO: Remove with functionalization
419
+ auto self = args[0 ].ITensorOrFreeze (ctx);
420
+ auto other = args[1 ].ITensorOrFreeze (ctx);
421
+ auto eq =
422
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n));
423
+ TRTORCH_CHECK (eq, " Unable to create equal layer from node: " << *n);
424
+
425
+ eq->setName (util::node_info (n).c_str ());
426
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], eq->getOutput (0 ));
427
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
428
+ return true ;
429
+ }})
430
+ .pattern({" aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
431
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
432
+ // TODO: Remove with functionalization
433
+ auto self = args[0 ].ITensorOrFreeze (ctx);
434
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
435
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
436
+ auto eq =
437
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n));
438
+ TRTORCH_CHECK (eq, " Unable to create equal layer from node: " << *n);
439
+
440
+ eq->setName (util::node_info (n).c_str ());
441
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], eq->getOutput (0 ));
442
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
443
+ return true ;
444
+ }})
445
+ .pattern({" aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
446
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
447
+ // TODO: Remove with functionalization
448
+ auto self = args[0 ].ITensorOrFreeze (ctx);
449
+ auto other = args[1 ].ITensorOrFreeze (ctx);
450
+
451
+ auto greater =
452
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kGREATER , self, other, util::node_info (n)+" _greater" );
453
+ TRTORCH_CHECK (greater, " Unable to create Greater layer from node: " << *n);
454
+
455
+ auto equal =
456
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n)+" _equal" );
457
+ TRTORCH_CHECK (equal, " Unable to create Equal layer from node: " << *n);
458
+
459
+ auto or_op = ctx->net ->addElementWise (*greater->getOutput (0 ), *equal->getOutput (0 ), nvinfer1::ElementWiseOperation::kOR );
460
+
461
+ TRTORCH_CHECK (or_op, " Unable to create Or layer from node: " << *n);
462
+ or_op->setName (util::node_info (n).c_str ());
463
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], or_op->getOutput (0 ));
464
+
465
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
466
+ return true ;
467
+ }})
468
+ .pattern({" aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
469
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
470
+ // TODO: Remove with functionalization
471
+ auto self = args[0 ].ITensorOrFreeze (ctx);
472
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
473
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
474
+
475
+ auto greater =
476
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kGREATER , self, other, util::node_info (n)+" _greater" );
477
+ TRTORCH_CHECK (greater, " Unable to create Greater layer from node: " << *n);
478
+
479
+ auto equal =
480
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n)+" _equal" );
481
+ TRTORCH_CHECK (equal, " Unable to create Equal layer from node: " << *n);
482
+
483
+ auto or_op = ctx->net ->addElementWise (*greater->getOutput (0 ), *equal->getOutput (0 ), nvinfer1::ElementWiseOperation::kOR );
484
+
485
+ TRTORCH_CHECK (or_op, " Unable to create Or layer from node: " << *n);
486
+ or_op->setName (util::node_info (n).c_str ());
487
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], or_op->getOutput (0 ));
488
+
489
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
490
+ return true ;
491
+ }})
492
+ .pattern({" aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
493
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
494
+ // TODO: Remove with functionalization
495
+ auto self = args[0 ].ITensorOrFreeze (ctx);
496
+ auto other = args[1 ].ITensorOrFreeze (ctx);
497
+
498
+ auto less =
499
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kLESS , self, other, util::node_info (n)+" _less" );
500
+ TRTORCH_CHECK (less, " Unable to create Less layer from node: " << *n);
501
+
502
+ auto equal =
503
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n)+" _equal" );
504
+ TRTORCH_CHECK (equal, " Unable to create Equal layer from node: " << *n);
505
+
506
+ auto or_op = ctx->net ->addElementWise (*less->getOutput (0 ), *equal->getOutput (0 ), nvinfer1::ElementWiseOperation::kOR );
507
+
508
+ TRTORCH_CHECK (or_op, " Unable to create Or layer from node: " << *n);
509
+ or_op->setName (util::node_info (n).c_str ());
510
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], or_op->getOutput (0 ));
511
+
512
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
513
+ return true ;
514
+ }})
515
+ .pattern({" aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
516
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
517
+ // TODO: Remove with functionalization
518
+ auto self = args[0 ].ITensorOrFreeze (ctx);
519
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
520
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
521
+
522
+ auto less =
523
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kLESS , self, other, util::node_info (n)+" _less" );
524
+ TRTORCH_CHECK (less, " Unable to create Less layer from node: " << *n);
525
+
526
+ auto equal =
527
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n)+" _equal" );
528
+ TRTORCH_CHECK (equal, " Unable to create Equal layer from node: " << *n);
529
+
530
+ auto or_op = ctx->net ->addElementWise (*less->getOutput (0 ), *equal->getOutput (0 ), nvinfer1::ElementWiseOperation::kOR );
531
+
532
+ TRTORCH_CHECK (or_op, " Unable to create Or layer from node: " << *n);
533
+ or_op->setName (util::node_info (n).c_str ());
534
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], or_op->getOutput (0 ));
535
+
355
536
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
356
537
return true ;
357
538
}});
0 commit comments