Skip to content

Commit 0c41119

Browse files
[TorchOnnxToTorch] Handle default value for Resize op attribute (#4044)
Adds check for default value of the resize op attribute keep_aspect_ratio_policy = "stretch"
1 parent 3db6aea commit 0c41119

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,7 +2702,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27022702
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
27032703
Torch::ValueTensorType outputTensorType;
27042704
llvm::SmallVector<Value> operands;
2705-
std::string mode, nearest_mode, coordTfMode;
2705+
std::string mode, nearest_mode, coordTfMode, keepAspectRatioPolicy;
27062706
int64_t antialias, exclude_outside;
27072707
float extrapolation_value, cubic_coeff_a;
27082708

@@ -2711,12 +2711,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27112711
binder.op,
27122712
"unimplemented: support not present for axes attribute");
27132713
}
2714-
if (auto attr =
2715-
binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) {
2716-
return rewriter.notifyMatchFailure(
2717-
binder.op, "unimplemented: support not present for "
2718-
"keep_aspect_ratio_policy attribute");
2719-
}
27202714

27212715
if (binder.tensorOperandsList(operands) ||
27222716
binder.tensorResultType(outputTensorType) ||
@@ -2729,6 +2723,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27292723
0.0) ||
27302724
binder.customOpNameStringAttr(nearest_mode, "nearest_mode",
27312725
"round_prefer_floor") ||
2726+
binder.customOpNameStringAttr(keepAspectRatioPolicy,
2727+
"torch.onnx.keep_aspect_ratio_policy",
2728+
"stretch") ||
27322729
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
27332730
return failure();
27342731

@@ -2799,6 +2796,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27992796
binder.op, "unimplemented: cubic coeff must be -0.75");
28002797
}
28012798

2799+
if (keepAspectRatioPolicy != "stretch") {
2800+
return rewriter.notifyMatchFailure(
2801+
binder.op, "unimplemented: non-default keep_aspect_ratio_policy "
2802+
"attribute for resize");
2803+
}
2804+
28022805
auto loc = binder.getLoc();
28032806

28042807
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,6 +2300,25 @@ f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_ve
23002300

23012301
// -----
23022302

2303+
// CHECK-LABEL: func.func @test_resize_defaults
2304+
func.func @test_resize_defaults(%arg0: !torch.vtensor<[1,3,224,224],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,3,800,800],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
2305+
%none = torch.constant.none
2306+
// CHECK: torch.aten.__interpolate.size_list_scale_list
2307+
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
2308+
torch.onnx.antialias = 0 : si64,
2309+
torch.onnx.coordinate_transformation_mode = "half_pixel",
2310+
torch.onnx.cubic_coeff_a = -7.500000e-01 : f32,
2311+
torch.onnx.exclude_outside = 0 : si64,
2312+
torch.onnx.extrapolation_value = 0.000000e+00 : f32,
2313+
torch.onnx.keep_aspect_ratio_policy = "stretch",
2314+
torch.onnx.mode = "nearest",
2315+
torch.onnx.nearest_mode = "round_prefer_floor"
2316+
} : (!torch.vtensor<[1,3,224,224],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,3,800,800],f32>
2317+
return %0 : !torch.vtensor<[1,3,800,800],f32>
2318+
}
2319+
2320+
// -----
2321+
23032322
// CHECK-LABEL: @test_roialign_avg
23042323
func.func @test_roialign_avg(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
23052324
// CHECK: %[[Dim:.*]] = torch.constant.int 1

0 commit comments

Comments
 (0)