Skip to content

Commit e68560d

Browse files
authored
Add attributes support for onnx.nms (llvm#3920)
- Set default attribute values - Support `max_output_boxes_per_class` attribute - e2e test `test_nonmaxsuppression_limit_output_size` passed
1 parent 71cb942 commit e68560d

File tree

2 files changed

+123
-85
lines changed

2 files changed

+123
-85
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 75 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3688,6 +3688,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
36883688
patterns.onOp(
36893689
"NonMaxSuppression", 10,
36903690
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
3691+
Location loc = binder.getLoc();
36913692
Torch::ValueTensorType resultType;
36923693
SmallVector<Value> operands;
36933694
int64_t centerPointBox;
@@ -3702,96 +3703,132 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
37023703
binder.op, "unimplemented: expected center_point_box "
37033704
"attribute value to be 0");
37043705

3705-
// TODO: Add support for optional arguments to be absent.
3706-
if (operands.size() < 4)
3707-
return rewriter.notifyMatchFailure(
3708-
binder.op, "unimplemented: expected at least 4 arguments");
3709-
3706+
// TODO: Support multiple batches and classes
37103707
// Squeeze the boxes and scores tensor.
37113708
// In Onnx, the shape of boxes is [BxNx4] while the
37123709
// torchvision expects it to be of shape [Nx4]. Similarly, for
37133710
// the scores tensor shape in Onnx is [BxCxN] while the
37143711
// torchvision expects it to be of shape [N].
37153712
Value boxes = operands[0], scores = operands[1];
3716-
FailureOr<Value> squeezedBoxes = Torch::squeezeTensor(
3717-
rewriter, binder.op, binder.getLoc(), 0, boxes);
3713+
FailureOr<Value> squeezedBoxes =
3714+
Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes);
37183715
if (failed(squeezedBoxes))
37193716
return rewriter.notifyMatchFailure(binder.op,
37203717
"failed to squeeze boxes tensor");
3721-
3722-
FailureOr<Value> squeezedScores = Torch::squeezeTensor(
3723-
rewriter, binder.op, binder.getLoc(), 0, scores);
3718+
FailureOr<Value> squeezedScores =
3719+
Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores);
37243720
if (failed(squeezedScores))
37253721
return rewriter.notifyMatchFailure(binder.op,
37263722
"failed to squeeze scores tensor");
3727-
squeezedScores = Torch::squeezeTensor(
3728-
rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value());
3723+
squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0,
3724+
squeezedScores.value());
37293725
if (failed(squeezedScores))
37303726
return rewriter.notifyMatchFailure(binder.op,
37313727
"failed to squeeze scores tensor");
3732-
37333728
boxes = squeezedBoxes.value();
37343729
scores = squeezedScores.value();
37353730

37363731
// TODO: Support score_threshold input
37373732
// Filter out the boxes if the score < score_threshold
37383733
if (operands.size() == 5) {
37393734
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
3740-
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
3741-
operands[4]);
3735+
loc, rewriter.getType<Torch::FloatType>(), operands[4]);
37423736
Value minScores = rewriter.create<Torch::AtenMinOp>(
3743-
binder.getLoc(),
3737+
loc,
37443738
Torch::ValueTensorType::get(binder.op->getContext(),
37453739
SmallVector<int64_t>{},
37463740
rewriter.getF32Type()),
37473741
scores);
37483742
minScores = rewriter.create<Torch::AtenItemOp>(
3749-
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);
3743+
loc, rewriter.getType<Torch::FloatType>(), minScores);
37503744

37513745
Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
3752-
binder.getLoc(), minScores, scoreThreshold);
3746+
loc, minScores, scoreThreshold);
37533747
rewriter.create<Torch::RuntimeAssertOp>(
3754-
binder.getLoc(), scoresCond,
3748+
loc, scoresCond,
37553749
rewriter.getStringAttr(
37563750
"unimplemented: score_threshold should be <= min(scores)"));
37573751
}
37583752

3759-
// TODO: Support default iou_threshold
3760-
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
3761-
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
3753+
// Get max_output_boxes_per_class and iou_threshold
3754+
Value cst0 = rewriter.create<Torch::ConstantIntOp>(
3755+
loc, rewriter.getI64IntegerAttr(0));
3756+
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
3757+
loc, rewriter.getI64IntegerAttr(1));
3758+
Value maxOutputBoxesPerClass = cst0;
3759+
Value iouThreshold = rewriter.create<Torch::ConstantFloatOp>(
3760+
loc, rewriter.getF64FloatAttr(0.0));
3761+
if (operands.size() > 3 &&
3762+
!isa<Torch::NoneType>(operands[3].getType())) {
3763+
iouThreshold = rewriter.create<Torch::AtenItemOp>(
3764+
loc, rewriter.getType<Torch::FloatType>(), operands[3]);
3765+
}
3766+
if (operands.size() > 2 &&
3767+
!isa<Torch::NoneType>(operands[2].getType())) {
3768+
maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
3769+
loc, rewriter.getType<Torch::IntType>(), operands[2]);
3770+
}
3771+
37623772
auto nmsTy = Torch::ValueTensorType::get(
3773+
binder.op->getContext(), SmallVector<int64_t>{-1},
3774+
rewriter.getIntegerType(64, /*signed=*/true));
3775+
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
3776+
loc, nmsTy, boxes, scores, iouThreshold);
3777+
3778+
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
3779+
Value numOutputBoxes =
3780+
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
3781+
Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
3782+
loc, numOutputBoxes, maxOutputBoxesPerClass);
3783+
3784+
auto nmsResultTy = Torch::ValueTensorType::get(
37633785
binder.op->getContext(),
37643786
SmallVector<int64_t>{resultType.getSizes()[0]},
37653787
rewriter.getIntegerType(64, /*signed=*/true));
3766-
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
3767-
binder.getLoc(), nmsTy, boxes, scores, iouThreshold);
3788+
auto ifSlice = rewriter.create<Torch::PrimIfOp>(
3789+
loc, TypeRange({nmsResultTy}), boxesCond);
3790+
{
3791+
PatternRewriter::InsertionGuard guard(rewriter);
3792+
rewriter.createBlock(&ifSlice.getThenRegion(),
3793+
ifSlice.getThenRegion().begin());
3794+
3795+
Value curResult = rewriter.create<Torch::AtenSliceTensorOp>(
3796+
loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0,
3797+
/*end=*/maxOutputBoxesPerClass, /*step=*/cst1);
3798+
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
3799+
}
3800+
{
3801+
PatternRewriter::InsertionGuard guard(rewriter);
3802+
rewriter.createBlock(&ifSlice.getElseRegion(),
3803+
ifSlice.getElseRegion().begin());
3804+
3805+
Value curResult = rewriter.create<Torch::TensorStaticInfoCastOp>(
3806+
loc, nmsResultTy, result);
3807+
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
3808+
}
3809+
result = ifSlice.getResult(0);
37683810

37693811
// The result generated by torchvision.nms op is of shape [n], while the
37703812
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
37713813
// and make it of shape [n, 1] and then concatenate it with a zero
37723814
// tensor of shape [n, 2] to make it of shape [n, 3].
3773-
Value dim = rewriter.create<Torch::ConstantIntOp>(
3774-
binder.getLoc(), rewriter.getI64IntegerAttr(1));
37753815
FailureOr<Value> unsqueezedResult =
3776-
Torch::unsqueezeTensor(rewriter, binder.op, result, dim);
3816+
Torch::unsqueezeTensor(rewriter, binder.op, result, cst1);
37773817
if (failed(unsqueezedResult))
37783818
return rewriter.notifyMatchFailure(
37793819
binder.op, "failed to unsqueeze result tensor");
37803820
result = unsqueezedResult.value();
37813821

3782-
Value numOutputBoxes = rewriter.create<Torch::AtenSizeIntOp>(
3783-
binder.getLoc(), result,
3784-
rewriter.create<Torch::ConstantIntOp>(
3785-
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
3822+
numOutputBoxes =
3823+
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
37863824
SmallVector<Value> zerosShapeValues{numOutputBoxes};
37873825
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
3788-
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
3826+
loc, rewriter.getI64IntegerAttr(2)));
37893827
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
3790-
binder.getLoc(),
3828+
loc,
37913829
rewriter.getType<Torch::ListType>(
37923830
rewriter.getType<Torch::IntType>()),
37933831
zerosShapeValues);
3794-
37953832
std::optional<ArrayRef<int64_t>> resultShape =
37963833
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
37973834
if (!resultShape.has_value())
@@ -3800,33 +3837,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
38003837
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
38013838
auto zerosTy = Torch::ValueTensorType::get(
38023839
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
3803-
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
3840+
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
38043841
Value zeros = rewriter.create<Torch::AtenZerosOp>(
3805-
binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone,
3806-
cstNone);
3842+
loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);
38073843

38083844
Type listElemType =
38093845
cast<Torch::BaseTensorType>(resultType)
38103846
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
38113847
/*optionalDtype=*/nullptr);
38123848
Type listType = Torch::ListType::get(listElemType);
38133849
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
3814-
binder.getLoc(), listType, SmallVector<Value>{zeros, result});
3815-
3816-
// TODO: Support max_output_boxes_per_class input
3817-
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
3818-
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
3819-
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
3820-
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
3821-
binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass);
3822-
rewriter.create<Torch::RuntimeAssertOp>(
3823-
binder.getLoc(), boxesCond,
3824-
rewriter.getStringAttr(
3825-
"unimplemented: number of output boxes per class should be "
3826-
"<= max_output_boxes_per_class"));
3827-
3850+
loc, listType, SmallVector<Value>{zeros, result});
38283851
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
3829-
tensorList, dim);
3852+
tensorList, cst1);
38303853
return success();
38313854
});
38323855
}

0 commit comments

Comments
 (0)