Skip to content

Commit 3eb2475

Browse files
authored
[OnnxToTorch] Lower onnx.MeanVarianceNorm to torch dialect without expansion (#4219)
This PR takes care of #4218. - Lower onnx MeanVarianceNorm op to torch primitive ops. - Remove **function expansion** during import from onnx --------- Signed-off-by: Zahid Wakeel <[email protected]>
1 parent 46c3888 commit 3eb2475

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,74 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
16061606
/* cudnn enabled */ boolFalse);
16071607
return success();
16081608
});
1609+
patterns.onOp(
1610+
"MeanVarianceNormalization", 13,
1611+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
1612+
Torch::ValueTensorType resultType;
1613+
Value input;
1614+
SmallVector<int64_t> axes;
1615+
1616+
if (binder.tensorOperand(input) ||
1617+
binder.s64IntegerArrayAttr(axes, "axes",
1618+
llvm::SmallVector<int64_t>({0, 2, 3})) ||
1619+
binder.tensorResultType(resultType)) {
1620+
return failure();
1621+
}
1622+
if (!resultType.hasSizes() || !resultType.hasDtype()) {
1623+
return failure();
1624+
}
1625+
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
1626+
if (!inputTy || !inputTy.hasSizes()) {
1627+
return failure();
1628+
}
1629+
int64_t inputRank = inputTy.getSizes().size();
1630+
1631+
Location loc = binder.getLoc();
1632+
Value keepDim = rewriter.create<Torch::ConstantBoolOp>(loc, true);
1633+
Value unBiased = rewriter.create<Torch::ConstantBoolOp>(loc, false);
1634+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
1635+
1636+
ArrayRef<int64_t> output_shape = resultType.getSizes();
1637+
SmallVector<int64_t> reduced_shape(output_shape);
1638+
1639+
for (int64_t i : axes) {
1640+
int64_t dim = Torch::toPositiveDim(i, inputRank);
1641+
if (!Torch::isValidDim(dim, inputRank)) {
1642+
return failure();
1643+
}
1644+
reduced_shape[dim] = 1;
1645+
}
1646+
Torch::ValueTensorType reducedOutTy = Torch::ValueTensorType::get(
1647+
resultType.getContext(), reduced_shape, resultType.getDtype());
1648+
SmallVector<Value> cstAxes;
1649+
for (int64_t i : axes) {
1650+
cstAxes.push_back(rewriter.create<Torch::ConstantIntOp>(
1651+
loc, rewriter.getI64IntegerAttr(i)));
1652+
}
1653+
Value axes_list = rewriter.create<Torch::PrimListConstructOp>(
1654+
loc,
1655+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
1656+
cstAxes);
1657+
Value mean = rewriter.create<Torch::AtenMeanDimOp>(
1658+
loc, reducedOutTy, input, axes_list, keepDim, none);
1659+
Value variance = rewriter.create<Torch::AtenVarDimOp>(
1660+
loc, reducedOutTy, input, axes_list, unBiased, keepDim);
1661+
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
1662+
loc, rewriter.getI64IntegerAttr(1));
1663+
Value cstEps = rewriter.create<Torch::ConstantFloatOp>(
1664+
loc, rewriter.getF64FloatAttr(1e-9));
1665+
variance = rewriter.create<Torch::AtenAddScalarOp>(
1666+
loc, reducedOutTy, variance, cstEps, cstOne);
1667+
Value sqrtVar =
1668+
rewriter.create<Torch::AtenSqrtOp>(loc, reducedOutTy, variance);
1669+
Value inputMinusMean = rewriter.create<Torch::AtenSubTensorOp>(
1670+
loc, resultType, input, mean, cstOne);
1671+
Value meanVarNorm = rewriter.create<Torch::AtenDivTensorOp>(
1672+
loc, resultType, inputMinusMean, sqrtVar);
1673+
1674+
rewriter.replaceOp(binder.op, meanVarNorm);
1675+
return success();
1676+
});
16091677
patterns.onOp(
16101678
"Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
16111679
Torch::ValueTensorType resultType;

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,6 +1611,82 @@ func.func @test_mod_int64_no_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch
16111611

16121612
// -----
16131613

1614+
// CHECK-LABEL: func.func @test_meanvarnorm(
1615+
func.func @test_meanvarnorm(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1616+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1617+
// CHECK: %[[VAL_0:.*]] = torch.constant.bool true
1618+
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false
1619+
// CHECK: %[[VAL_2:.*]] = torch.constant.none
1620+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
1621+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
1622+
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
1623+
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
1624+
// CHECK: %[[VAL_7:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_6]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,5,1,1],f32>
1625+
// CHECK: %[[VAL_8:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_6]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,5,1,1],f32>
1626+
// CHECK: %[[VAL_9:.*]] = torch.constant.int 1
1627+
// CHECK: %[[VAL_10:.*]] = torch.constant.float 1.000000e-09
1628+
// CHECK: %[[VAL_11:.*]] = torch.aten.add.Scalar %[[VAL_8]], %[[VAL_10]], %[[VAL_9]] : !torch.vtensor<[1,5,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,5,1,1],f32>
1629+
// CHECK: %[[VAL_12:.*]] = torch.aten.sqrt %[[VAL_11]] : !torch.vtensor<[1,5,1,1],f32> -> !torch.vtensor<[1,5,1,1],f32>
1630+
// CHECK: %[[VAL_13:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[1,5,1,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32>
1631+
// CHECK: %[[VAL_14:.*]] = torch.aten.div.Tensor %[[VAL_13]], %[[VAL_12]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[1,5,1,1],f32> -> !torch.vtensor<[3,5,2,2],f32>
1632+
// CHECK: return %[[VAL_14]] : !torch.vtensor<[3,5,2,2],f32>
1633+
// CHECK: }
1634+
%0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32>
1635+
return %0 : !torch.vtensor<[3,5,2,2],f32>
1636+
}
1637+
1638+
// -----
1639+
1640+
// CHECK-LABEL: func.func @test_meanvarnorm_axes(
1641+
func.func @test_meanvarnorm_axes(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1642+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1643+
// CHECK: %[[VAL_0:.*]] = torch.constant.bool true
1644+
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false
1645+
// CHECK: %[[VAL_2:.*]] = torch.constant.none
1646+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
1647+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
1648+
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
1649+
// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_5]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2,1],f32>
1650+
// CHECK: %[[VAL_7:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_5]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[3,1,2,1],f32>
1651+
// CHECK: %[[VAL_8:.*]] = torch.constant.int 1
1652+
// CHECK: %[[VAL_9:.*]] = torch.constant.float 1.000000e-09
1653+
// CHECK: %[[VAL_10:.*]] = torch.aten.add.Scalar %[[VAL_7]], %[[VAL_9]], %[[VAL_8]] : !torch.vtensor<[3,1,2,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[3,1,2,1],f32>
1654+
// CHECK: %[[VAL_11:.*]] = torch.aten.sqrt %[[VAL_10]] : !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,1,2,1],f32>
1655+
// CHECK: %[[VAL_12:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_6]], %[[VAL_8]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32>
1656+
// CHECK: %[[VAL_13:.*]] = torch.aten.div.Tensor %[[VAL_12]], %[[VAL_11]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,5,2,2],f32>
1657+
// CHECK: return %[[VAL_13]] : !torch.vtensor<[3,5,2,2],f32>
1658+
// CHECK: }
1659+
%0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) {torch.onnx.axes = [1 : si64, 3 : si64]} : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32>
1660+
return %0 : !torch.vtensor<[3,5,2,2],f32>
1661+
}
1662+
1663+
// -----
1664+
1665+
// CHECK-LABEL: func.func @test_meanvarnorm_neg_axes(
1666+
func.func @test_meanvarnorm_neg_axes(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1667+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1668+
// CHECK: %[[VAL_0:.*]] = torch.constant.bool true
1669+
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false
1670+
// CHECK: %[[VAL_2:.*]] = torch.constant.none
1671+
// CHECK: %[[VAL_3:.*]] = torch.constant.int -1
1672+
// CHECK: %[[VAL_4:.*]] = torch.constant.int -3
1673+
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
1674+
// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_5]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2,1],f32>
1675+
// CHECK: %[[VAL_7:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_5]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[3,1,2,1],f32>
1676+
// CHECK: %[[VAL_8:.*]] = torch.constant.int 1
1677+
// CHECK: %[[VAL_9:.*]] = torch.constant.float 1.000000e-09
1678+
// CHECK: %[[VAL_10:.*]] = torch.aten.add.Scalar %[[VAL_7]], %[[VAL_9]], %[[VAL_8]] : !torch.vtensor<[3,1,2,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[3,1,2,1],f32>
1679+
// CHECK: %[[VAL_11:.*]] = torch.aten.sqrt %[[VAL_10]] : !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,1,2,1],f32>
1680+
// CHECK: %[[VAL_12:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_6]], %[[VAL_8]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32>
1681+
// CHECK: %[[VAL_13:.*]] = torch.aten.div.Tensor %[[VAL_12]], %[[VAL_11]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,5,2,2],f32>
1682+
// CHECK: return %[[VAL_13]] : !torch.vtensor<[3,5,2,2],f32>
1683+
// CHECK: }
1684+
%0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) {torch.onnx.axes = [-1 : si64, -3 : si64]} : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32>
1685+
return %0 : !torch.vtensor<[3,5,2,2],f32>
1686+
}
1687+
1688+
// -----
1689+
16141690
// CHECK-LABEL: func.func @test_not_2d
16151691
func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
16161692
// CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>

0 commit comments

Comments
 (0)