@@ -3688,6 +3688,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
36883688 patterns.onOp (
36893689 " NonMaxSuppression" , 10 ,
36903690 [](OpBinder binder, ConversionPatternRewriter &rewriter) {
3691+ Location loc = binder.getLoc ();
36913692 Torch::ValueTensorType resultType;
36923693 SmallVector<Value> operands;
36933694 int64_t centerPointBox;
@@ -3702,96 +3703,132 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
37023703 binder.op , " unimplemented: expected center_point_box "
37033704 " attribute value to be 0" );
37043705
3705- // TODO: Add support for optional arguments to be absent.
3706- if (operands.size () < 4 )
3707- return rewriter.notifyMatchFailure (
3708- binder.op , " unimplemented: expected at least 4 arguments" );
3709-
3706+ // TODO: Support multiple batches and classes
37103707 // Squeeze the boxes and scores tensor.
37113708 // In Onnx, the shape of boxes is [BxNx4] while the
37123709 // torchvision expects it to be of shape [Nx4]. Similarly, for
37133710 // the scores tensor shape in Onnx is [BxCxN] while the
37143711 // torchvision expects it to be of shape [N].
37153712 Value boxes = operands[0 ], scores = operands[1 ];
3716- FailureOr<Value> squeezedBoxes = Torch::squeezeTensor (
3717- rewriter, binder.op , binder. getLoc () , 0 , boxes);
3713+ FailureOr<Value> squeezedBoxes =
3714+ Torch::squeezeTensor ( rewriter, binder.op , loc , 0 , boxes);
37183715 if (failed (squeezedBoxes))
37193716 return rewriter.notifyMatchFailure (binder.op ,
37203717 " failed to squeeze boxes tensor" );
3721-
3722- FailureOr<Value> squeezedScores = Torch::squeezeTensor (
3723- rewriter, binder.op , binder.getLoc (), 0 , scores);
3718+ FailureOr<Value> squeezedScores =
3719+ Torch::squeezeTensor (rewriter, binder.op , loc, 0 , scores);
37243720 if (failed (squeezedScores))
37253721 return rewriter.notifyMatchFailure (binder.op ,
37263722 " failed to squeeze scores tensor" );
3727- squeezedScores = Torch::squeezeTensor (
3728- rewriter, binder. op , binder. getLoc (), 0 , squeezedScores.value ());
3723+ squeezedScores = Torch::squeezeTensor (rewriter, binder. op , loc, 0 ,
3724+ squeezedScores.value ());
37293725 if (failed (squeezedScores))
37303726 return rewriter.notifyMatchFailure (binder.op ,
37313727 " failed to squeeze scores tensor" );
3732-
37333728 boxes = squeezedBoxes.value ();
37343729 scores = squeezedScores.value ();
37353730
37363731 // TODO: Support score_threshold input
37373732 // Filter out the boxes if the score < score_threshold
37383733 if (operands.size () == 5 ) {
37393734 Value scoreThreshold = rewriter.create <Torch::AtenItemOp>(
3740- binder.getLoc (), rewriter.getType <Torch::FloatType>(),
3741- operands[4 ]);
3735+ loc, rewriter.getType <Torch::FloatType>(), operands[4 ]);
37423736 Value minScores = rewriter.create <Torch::AtenMinOp>(
3743- binder. getLoc () ,
3737+ loc ,
37443738 Torch::ValueTensorType::get (binder.op ->getContext (),
37453739 SmallVector<int64_t >{},
37463740 rewriter.getF32Type ()),
37473741 scores);
37483742 minScores = rewriter.create <Torch::AtenItemOp>(
3749- binder. getLoc () , rewriter.getType <Torch::FloatType>(), minScores);
3743+ loc , rewriter.getType <Torch::FloatType>(), minScores);
37503744
37513745 Value scoresCond = rewriter.create <Torch::AtenGeFloatOp>(
3752- binder. getLoc () , minScores, scoreThreshold);
3746+ loc , minScores, scoreThreshold);
37533747 rewriter.create <Torch::RuntimeAssertOp>(
3754- binder. getLoc () , scoresCond,
3748+ loc , scoresCond,
37553749 rewriter.getStringAttr (
37563750 " unimplemented: score_threshold should be <= min(scores)" ));
37573751 }
37583752
3759- // TODO: Support default iou_threshold
3760- Value iouThreshold = rewriter.create <Torch::AtenItemOp>(
3761- binder.getLoc (), rewriter.getType <Torch::FloatType>(), operands[3 ]);
3753+ // Get max_output_boxes_per_class and iou_threshold
3754+ Value cst0 = rewriter.create <Torch::ConstantIntOp>(
3755+ loc, rewriter.getI64IntegerAttr (0 ));
3756+ Value cst1 = rewriter.create <Torch::ConstantIntOp>(
3757+ loc, rewriter.getI64IntegerAttr (1 ));
3758+ Value maxOutputBoxesPerClass = cst0;
3759+ Value iouThreshold = rewriter.create <Torch::ConstantFloatOp>(
3760+ loc, rewriter.getF64FloatAttr (0.0 ));
3761+ if (operands.size () > 3 &&
3762+ !isa<Torch::NoneType>(operands[3 ].getType ())) {
3763+ iouThreshold = rewriter.create <Torch::AtenItemOp>(
3764+ loc, rewriter.getType <Torch::FloatType>(), operands[3 ]);
3765+ }
3766+ if (operands.size () > 2 &&
3767+ !isa<Torch::NoneType>(operands[2 ].getType ())) {
3768+ maxOutputBoxesPerClass = rewriter.create <Torch::AtenItemOp>(
3769+ loc, rewriter.getType <Torch::IntType>(), operands[2 ]);
3770+ }
3771+
37623772 auto nmsTy = Torch::ValueTensorType::get (
3773+ binder.op ->getContext (), SmallVector<int64_t >{-1 },
3774+ rewriter.getIntegerType (64 , /* signed=*/ true ));
3775+ Value result = rewriter.create <Torch::TorchvisionNmsOp>(
3776+ loc, nmsTy, boxes, scores, iouThreshold);
3777+
3778+ // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
3779+ Value numOutputBoxes =
3780+ rewriter.create <Torch::AtenSizeIntOp>(loc, result, cst0);
3781+ Value boxesCond = rewriter.create <Torch::AtenGtIntOp>(
3782+ loc, numOutputBoxes, maxOutputBoxesPerClass);
3783+
3784+ auto nmsResultTy = Torch::ValueTensorType::get (
37633785 binder.op ->getContext (),
37643786 SmallVector<int64_t >{resultType.getSizes ()[0 ]},
37653787 rewriter.getIntegerType (64 , /* signed=*/ true ));
3766- Value result = rewriter.create <Torch::TorchvisionNmsOp>(
3767- binder.getLoc (), nmsTy, boxes, scores, iouThreshold);
3788+ auto ifSlice = rewriter.create <Torch::PrimIfOp>(
3789+ loc, TypeRange ({nmsResultTy}), boxesCond);
3790+ {
3791+ PatternRewriter::InsertionGuard guard (rewriter);
3792+ rewriter.createBlock (&ifSlice.getThenRegion (),
3793+ ifSlice.getThenRegion ().begin ());
3794+
3795+ Value curResult = rewriter.create <Torch::AtenSliceTensorOp>(
3796+ loc, nmsResultTy, result, /* dim=*/ cst0, /* start=*/ cst0,
3797+ /* end=*/ maxOutputBoxesPerClass, /* step=*/ cst1);
3798+ rewriter.create <Torch::PrimIfYieldOp>(loc, curResult);
3799+ }
3800+ {
3801+ PatternRewriter::InsertionGuard guard (rewriter);
3802+ rewriter.createBlock (&ifSlice.getElseRegion (),
3803+ ifSlice.getElseRegion ().begin ());
3804+
3805+ Value curResult = rewriter.create <Torch::TensorStaticInfoCastOp>(
3806+ loc, nmsResultTy, result);
3807+ rewriter.create <Torch::PrimIfYieldOp>(loc, curResult);
3808+ }
3809+ result = ifSlice.getResult (0 );
37683810
37693811 // The result generated by torchvision.nms op is of shape [n], while the
37703812 // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
37713813 // and make it of shape [n, 1] and then concatenate it with a zero
37723814 // tensor of shape [n, 2] to make it of shape [n, 3].
3773- Value dim = rewriter.create <Torch::ConstantIntOp>(
3774- binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
37753815 FailureOr<Value> unsqueezedResult =
3776- Torch::unsqueezeTensor (rewriter, binder.op , result, dim );
3816+ Torch::unsqueezeTensor (rewriter, binder.op , result, cst1 );
37773817 if (failed (unsqueezedResult))
37783818 return rewriter.notifyMatchFailure (
37793819 binder.op , " failed to unsqueeze result tensor" );
37803820 result = unsqueezedResult.value ();
37813821
3782- Value numOutputBoxes = rewriter.create <Torch::AtenSizeIntOp>(
3783- binder.getLoc (), result,
3784- rewriter.create <Torch::ConstantIntOp>(
3785- binder.getLoc (), rewriter.getI64IntegerAttr (0 )));
3822+ numOutputBoxes =
3823+ rewriter.create <Torch::AtenSizeIntOp>(loc, result, cst0);
37863824 SmallVector<Value> zerosShapeValues{numOutputBoxes};
37873825 zerosShapeValues.push_back (rewriter.create <Torch::ConstantIntOp>(
3788- binder. getLoc () , rewriter.getI64IntegerAttr (2 )));
3826+ loc , rewriter.getI64IntegerAttr (2 )));
37893827 Value zerosShapeList = rewriter.create <Torch::PrimListConstructOp>(
3790- binder. getLoc () ,
3828+ loc ,
37913829 rewriter.getType <Torch::ListType>(
37923830 rewriter.getType <Torch::IntType>()),
37933831 zerosShapeValues);
3794-
37953832 std::optional<ArrayRef<int64_t >> resultShape =
37963833 cast<Torch::ValueTensorType>(result.getType ()).getOptionalSizes ();
37973834 if (!resultShape.has_value ())
@@ -3800,33 +3837,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
38003837 llvm::SmallVector<int64_t > zerosShape = {resultShape->front (), 2 };
38013838 auto zerosTy = Torch::ValueTensorType::get (
38023839 resultType.getContext (), zerosShape, resultType.getOptionalDtype ());
3803- Value cstNone = rewriter.create <Torch::ConstantNoneOp>(binder. getLoc () );
3840+ Value cstNone = rewriter.create <Torch::ConstantNoneOp>(loc );
38043841 Value zeros = rewriter.create <Torch::AtenZerosOp>(
3805- binder.getLoc (), zerosTy, zerosShapeList, cstNone, cstNone, cstNone,
3806- cstNone);
3842+ loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);
38073843
38083844 Type listElemType =
38093845 cast<Torch::BaseTensorType>(resultType)
38103846 .getWithSizesAndDtype (/* optionalSizes=*/ std::nullopt ,
38113847 /* optionalDtype=*/ nullptr );
38123848 Type listType = Torch::ListType::get (listElemType);
38133849 Value tensorList = rewriter.create <Torch::PrimListConstructOp>(
3814- binder.getLoc (), listType, SmallVector<Value>{zeros, result});
3815-
3816- // TODO: Support max_output_boxes_per_class input
3817- // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
3818- Value maxOutputBoxesPerClass = rewriter.create <Torch::AtenItemOp>(
3819- binder.getLoc (), rewriter.getType <Torch::IntType>(), operands[2 ]);
3820- Value boxesCond = rewriter.create <Torch::AtenLeIntOp>(
3821- binder.getLoc (), numOutputBoxes, maxOutputBoxesPerClass);
3822- rewriter.create <Torch::RuntimeAssertOp>(
3823- binder.getLoc (), boxesCond,
3824- rewriter.getStringAttr (
3825- " unimplemented: number of output boxes per class should be "
3826- " <= max_output_boxes_per_class" ));
3827-
3850+ loc, listType, SmallVector<Value>{zeros, result});
38283851 rewriter.replaceOpWithNewOp <Torch::AtenCatOp>(binder.op , resultType,
3829- tensorList, dim );
3852+ tensorList, cst1 );
38303853 return success ();
38313854 });
38323855}
0 commit comments