Skip to content

Commit cfc8de3

Browse files
authored
[MLIR][TORCH] Add E2E support for aten.native_layer_norm. (#470)
This commit adds support for aten.native_layer_norm operation. Here the previous code for aten.layer_norm is tweaked a little bit to accomodate both mean and variance values alongwith the layer norm value. This commit also adds decomposition of aten.layer_norm into aten.native_layer_norm, which was previously getting lowered directly to linalg. Signed-Off-By: Prateek Gupta<[email protected]>
1 parent 5a47f92 commit cfc8de3

File tree

6 files changed

+137
-7
lines changed

6 files changed

+137
-7
lines changed

e2e_testing/torchscript/batchnorm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,33 @@ def forward(self, x):
8787
def BatchNorm3DModule_basic(module, tu: TestUtils):
8888
module.forward(tu.rand(2, 5, 3, 6, 4))
8989

90+
# ==============================================================================
91+
92+
93+
class NativeLayerNormModule(torch.nn.Module):
94+
def __init__(self):
95+
super().__init__()
96+
97+
@export
98+
@annotate_args([
99+
None,
100+
([2, 5, 2, 2, 3], torch.float32, True),
101+
([2, 2, 3], torch.float32, True),
102+
([2, 2, 3], torch.float32, True),
103+
])
104+
def forward(self, x, weight, bias):
105+
list = [2, 2, 3]
106+
return torch.ops.aten.native_layer_norm(
107+
x, list, weight, bias, eps=0.5)[0]
108+
109+
110+
@register_test_case(module_factory=lambda: NativeLayerNormModule())
111+
def NativeLayerNormModule_basic(module, tu: TestUtils):
112+
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
90113

91114
# ==============================================================================
115+
116+
92117
class LayerNormModule(torch.nn.Module):
93118
def __init__(self):
94119
super().__init__()
@@ -138,6 +163,8 @@ def LayerNormLastDimModule_basic(module, tu: TestUtils):
138163
module.forward(tu.rand(2, 5, 2, 2, 3))
139164

140165
# ==============================================================================
166+
167+
141168
class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
142169
def __init__(self):
143170
super().__init__()

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,26 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
13361336
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps `,` $cudnn_enable attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `,` type($cudnn_enable) `->` type($result)";
13371337
}
13381338

1339+
def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
1340+
AllowsTypeRefinement,
1341+
HasValueSemantics
1342+
]> {
1343+
let summary = "Generated op for `aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)`";
1344+
let arguments = (ins
1345+
AnyTorchTensorType:$input,
1346+
TorchIntListType:$normalized_shape,
1347+
AnyTorchOptionalTensorType:$weight,
1348+
AnyTorchOptionalTensorType:$bias,
1349+
Torch_FloatType:$eps
1350+
);
1351+
let results = (outs
1352+
AnyTorchTensorType:$layer_norm,
1353+
AnyTorchTensorType:$mean,
1354+
AnyTorchTensorType:$variance
1355+
);
1356+
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `->` type($layer_norm) `,` type($mean) `,` type($variance)";
1357+
}
1358+
13391359
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
13401360
AllowsTypeRefinement,
13411361
HasValueSemantics

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -693,11 +693,12 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
693693
// Step 4. Get var.
694694
// Step 5. Get layernorm.
695695
namespace {
696-
class ConvertAtenLayerNormOp : public OpConversionPattern<AtenLayerNormOp> {
696+
class ConvertAtenNativeLayerNormOp
697+
: public OpConversionPattern<AtenNativeLayerNormOp> {
697698
public:
698699
using OpConversionPattern::OpConversionPattern;
699700
LogicalResult
700-
matchAndRewrite(AtenLayerNormOp op, OpAdaptor adaptor,
701+
matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor,
701702
ConversionPatternRewriter &rewriter) const override {
702703
MLIRContext *context = op->getContext();
703704
Location loc = op->getLoc();
@@ -889,9 +890,14 @@ class ConvertAtenLayerNormOp : public OpConversionPattern<AtenLayerNormOp> {
889890
b.create<linalg::YieldOp>(loc, result);
890891
})
891892
.getResult(0);
892-
893-
Type newResultType = getTypeConverter()->convertType(op.getType());
894-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, layerNorm);
893+
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
894+
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
895+
Type varResultType = getTypeConverter()->convertType(op.getType(2));
896+
Value layerNorm_ =
897+
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
898+
Value mean_ = rewriter.create<tensor::CastOp>(loc, meanResultType, mean);
899+
Value var_ = rewriter.create<tensor::CastOp>(loc, varResultType, var);
900+
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
895901
return success();
896902
}
897903
};
@@ -3659,8 +3665,8 @@ class ConvertTorchToLinalg
36593665
patterns.add<ConvertAtenCatOp>(typeConverter, context);
36603666
target.addIllegalOp<AtenGatherOp>();
36613667
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
3662-
target.addIllegalOp<AtenLayerNormOp>();
3663-
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
3668+
target.addIllegalOp<AtenNativeLayerNormOp>();
3669+
patterns.add<ConvertAtenNativeLayerNormOp>(typeConverter, context);
36643670
target.addIllegalOp<AtenBroadcastToOp>();
36653671
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
36663672
target.addIllegalOp<AtenArgmaxOp>();

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,33 @@ class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
477477
return success();
478478
}
479479
};
480+
481+
class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
482+
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
483+
LogicalResult matchAndRewrite(AtenLayerNormOp op,
484+
PatternRewriter &rewriter) const override {
485+
Location loc = op.getLoc();
486+
487+
auto input = op.input().getType().cast<BaseTensorType>();
488+
if (!input.hasSizes())
489+
return rewriter.notifyMatchFailure(
490+
op, "input tensor should have known sizes.");
491+
int64_t inputRank = input.getSizes().size();
492+
Value normalizedShape = op.normalized_shape();
493+
SmallVector<Value> normalizedShapeSizesTorchInt;
494+
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
495+
std::vector<int64_t> meanVarSizes;
496+
for (int i = normalizedShapeSizesTorchInt.size(); i < inputRank; i++)
497+
meanVarSizes.push_back(input.getSizes()[i]);
498+
auto meanVarType = input.getWithSizesAndDtype(
499+
llvm::makeArrayRef(meanVarSizes), input.getDtype());
500+
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
501+
loc, op.getType(), meanVarType, meanVarType, op.input(),
502+
op.normalized_shape(), op.weight(), op.bias(), op.eps());
503+
rewriter.replaceOp(op, nativeLayerNorm.getResult(0));
504+
return success();
505+
}
506+
};
480507
} // namespace
481508

482509
namespace {
@@ -522,6 +549,9 @@ class DecomposeComplexOpsPass
522549
target.addIllegalOp<AtenAddcmulOp>();
523550
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(context);
524551
target.addIllegalOp<AtenAddcdivOp>();
552+
target.addIllegalOp<AtenLayerNormOp>();
553+
patterns.add<DecomposeAtenLayerNormOp>(context);
554+
525555
if (failed(applyPartialConversion(getOperation(), target,
526556
std::move(patterns)))) {
527557
return signalPassFailure();

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
473473
return visitBinaryScalarOp(scalarOp);
474474
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
475475
return visitAtenNllLossForwardOp(nllForwardOp, operands);
476+
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
477+
return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands);
476478
}
477479

478480
// Otherwise, this is an unknown operation. Just mark all results as
@@ -609,6 +611,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
609611
ChangeResult
610612
visitAtenNllLossForwardOp(AtenNllLossForwardOp op,
611613
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
614+
ChangeResult visitAtenNativeLayerNormOp(
615+
AtenNativeLayerNormOp op,
616+
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
612617
};
613618
} // namespace
614619

@@ -1605,6 +1610,45 @@ ChangeResult TypeAnalyzer::visitAtenAddCLikeOp(
16051610
return getLatticeElement(op->getResult(0)).join(knowledge);
16061611
}
16071612

1613+
ChangeResult TypeAnalyzer::visitAtenNativeLayerNormOp(
1614+
AtenNativeLayerNormOp op,
1615+
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
1616+
auto input = operands[0]->getValue();
1617+
1618+
auto layerNormKnowledge =
1619+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
1620+
auto meanKnowledge =
1621+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
1622+
auto varKnowledge =
1623+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
1624+
1625+
layerNormKnowledge.hasSizes = input.hasSizes;
1626+
layerNormKnowledge.sizes = input.sizes;
1627+
layerNormKnowledge.dtype = input.dtype;
1628+
1629+
int64_t layerNormSize = input.sizes.size();
1630+
Value normalizedShape = op.normalized_shape();
1631+
SmallVector<Value> normalizedShapeSizesTorchInt;
1632+
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
1633+
std::vector<int64_t> meanVarSizes;
1634+
if (input.hasSizes) {
1635+
for (int i = normalizedShapeSizesTorchInt.size(); i < layerNormSize; i++)
1636+
meanVarSizes.push_back(input.sizes[i]);
1637+
}
1638+
meanKnowledge.hasSizes = input.hasSizes;
1639+
meanKnowledge.sizes = meanVarSizes;
1640+
meanKnowledge.dtype = input.dtype;
1641+
varKnowledge.hasSizes = input.hasSizes;
1642+
varKnowledge.sizes = meanVarSizes;
1643+
varKnowledge.dtype = input.dtype;
1644+
1645+
auto resultLattice =
1646+
getLatticeElement(op.getResult(0)).join(layerNormKnowledge);
1647+
resultLattice |= getLatticeElement(op.getResult(1)).join(meanKnowledge);
1648+
resultLattice |= getLatticeElement(op.getResult(2)).join(varKnowledge);
1649+
1650+
return resultLattice;
1651+
}
16081652
// -----------------------------------------------------------------------------
16091653
// Transforms.
16101654
// -----------------------------------------------------------------------------

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,9 @@ def emit_with_mutating_variants(key, **kwargs):
502502
emit(
503503
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
504504
)
505+
emit (
506+
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
507+
)
505508
emit(
506509
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
507510
)

0 commit comments

Comments
 (0)