Skip to content

Commit 85a9c86

Browse files
cats-marinLallapallooza
authored andcommitted
[OnnxToTorch] Casting float to integer should round to nearest for pow with int result type (#4228)
Fixes #4091. ~~I assume this will also need to be fixed for AtenPowScalarOp and AtenPowTensorScalarOp as well. I'm putting up a PR to ensure the initial approach is correct (new contributor :D ) before I put up another fix for AtenPowScalarOp and AtenPowTensorScalarOp.~~
1 parent e1ac1c4 commit 85a9c86

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3076,6 +3076,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
30763076
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
30773077
static_cast<int64_t>(outDtype)));
30783078

3079+
pow = rewriter.create<Torch::AtenRoundOp>(loc, powType, pow);
30793080
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
30803081
binder.op, resultType, pow, outTyConst, cstFalse, cstFalse, none);
30813082

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1253,14 +1253,30 @@ func.func @test_pow_i32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtens
12531253
// CHECK: %[[NONE:.+]] = torch.constant.none
12541254
// CHECK: %[[POW:.+]] = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],f64>
12551255
// CHECK: %[[DTY:.+]] = torch.constant.int 3
1256-
// CHECK: %[[RES:.+]] = torch.aten.to.dtype %[[POW]], %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]]
1256+
// CHECK: %[[ROUND:.+]] = torch.aten.round %[[POW]] : !torch.vtensor<[3,4,5],f64> -> !torch.vtensor<[3,4,5],f64>
1257+
// CHECK: %[[RES:.+]] = torch.aten.to.dtype %[[ROUND]], %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]]
12571258
// CHECK: return %[[RES]]
12581259
%0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32>
12591260
return %0 : !torch.vtensor<[3,4,5],si32>
12601261
}
12611262

12621263
// -----
12631264

1265+
// CHECK-LABEL: func.func @test_pow_i32_f32_to_i32
1266+
func.func @test_pow_i32_f32_to_i32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1267+
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
1268+
// CHECK: %[[NONE:.+]] = torch.constant.none
1269+
// CHECK: %[[POW:.+]] = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f64>
1270+
// CHECK: %[[DTY:.+]] = torch.constant.int 3
1271+
// CHECK: %[[ROUND:.+]] = torch.aten.round %[[POW]] : !torch.vtensor<[3,4,5],f64> -> !torch.vtensor<[3,4,5],f64>
1272+
// CHECK: %[[RES:.+]] = torch.aten.to.dtype %[[ROUND]], %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]]
1273+
// CHECK: return %[[RES]]
1274+
%0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si32>
1275+
return %0 : !torch.vtensor<[3,4,5],si32>
1276+
}
1277+
1278+
// -----
1279+
12641280
// CHECK-LABEL: @test_hardsigmoid_example
12651281
func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
12661282
// CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01

0 commit comments

Comments
 (0)