Skip to content

Commit 2c72a82

Browse files
authored
[ONNX] Fix nonzero output type difference between onnx and torch (llvm#3916)
The onnx output tensor has a shape of ((n, z)), where (n) is the number of dimensions in the input tensor and (z) is the number of non-zero elements2. This is different from PyTorch's default behavior, where the dimensions are reversed.
1 parent f03a576 commit 2c72a82

File tree

2 files changed

+37
-18
lines changed

2 files changed

+37
-18
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,18 +1093,35 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
10931093
rewriter.replaceOp(binder.op, nllLoss);
10941094
return success();
10951095
});
1096-
patterns.onOp("NonZero", 13,
1097-
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
1098-
Torch::ValueTensorType resultType;
1099-
Value operand;
1100-
if (binder.tensorOperand(operand) ||
1101-
binder.tensorResultType(resultType)) {
1102-
return failure();
1103-
}
1104-
rewriter.replaceOpWithNewOp<Torch::AtenNonzeroOp>(
1105-
binder.op, resultType, operand);
1106-
return success();
1107-
});
1096+
patterns.onOp(
1097+
"NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1098+
Torch::ValueTensorType resultType;
1099+
Value operand;
1100+
if (binder.tensorOperand(operand) ||
1101+
binder.tensorResultType(resultType)) {
1102+
return failure();
1103+
}
1104+
Value zero = rewriter.create<Torch::ConstantIntOp>(
1105+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
1106+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
1107+
Value one = rewriter.create<Torch::ConstantIntOp>(
1108+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
1109+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
1110+
auto rawSize = resultType.getSizes();
1111+
SmallVector<int64_t> torchResultSize(rawSize.rbegin(), rawSize.rend());
1112+
auto torchResultType = rewriter.getType<Torch::ValueTensorType>(
1113+
torchResultSize, resultType.getDtype());
1114+
auto nonZero = rewriter.create<Torch::AtenNonzeroOp>(
1115+
binder.getLoc(), torchResultType, operand);
1116+
// The output tensor has a shape of ((n, z)), where (n) is the
1117+
// number of dimensions in the input tensor and (z) is the
1118+
// number of non-zero elements2. This is different from
1119+
// PyTorch's default behavior, where the dimensions are
1120+
// reversed.
1121+
rewriter.replaceOpWithNewOp<Torch::AtenTransposeIntOp>(
1122+
binder.op, resultType, nonZero, zero, one);
1123+
return success();
1124+
});
11081125
patterns.onOp(
11091126
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
11101127
std::string autoPad;

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,12 +1580,14 @@ func.func @test_nllloss_iii_reduction_none_ignore_negative(%arg0: !torch.vtensor
15801580

15811581
// -----
15821582

1583-
// CHECK-LABEL: func.func @test_nonzero
1584-
func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1585-
// CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>
1586-
%0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64>
1587-
return %0 : !torch.vtensor<[3,4,5],si64>
1588-
}
1583+
func.func @test_nonzero(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1584+
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
1585+
// CHECK: %[[ONE:.*]] = torch.constant.int 1
1586+
// CHECK: %[[NONZERO:.*]] = torch.aten.nonzero %arg0 : !torch.vtensor<[?],f32> -> !torch.vtensor<[?,1],si64>
1587+
// CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[NONZERO]], %[[ZERO]], %[[ONE]] : !torch.vtensor<[?,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64>
1588+
%0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64>
1589+
return %0 : !torch.vtensor<[1,?],si64>
1590+
}
15891591

15901592
// -----
15911593

0 commit comments

Comments
 (0)