@@ -3512,6 +3512,67 @@ class ReplaceCastLikeByCastPattern : public OpRewritePattern<ONNXCastLikeOp> {
35123512 }
35133513};
35143514
3515+ // =============================================================================
3516+ // Decompose InstanceNormalization to LayerNormalization
3517+ // =============================================================================
3518+ struct DecomposeInstanceNormPattern
3519+ : public OpRewritePattern<ONNXInstanceNormalizationOp> {
3520+ using OpRewritePattern<ONNXInstanceNormalizationOp>::OpRewritePattern;
3521+
3522+ static bool isDecomposable (ONNXInstanceNormalizationOp instanceNormOp) {
3523+ return onnx_mlir::hasStaticShape (instanceNormOp.getInput ().getType ()) &&
3524+ onnx_mlir::hasStaticShape (instanceNormOp.getOutput ().getType ());
3525+ }
3526+
3527+ LogicalResult matchAndRewrite (ONNXInstanceNormalizationOp instanceNormOp,
3528+ PatternRewriter &rewriter) const final {
3529+ // Match.
3530+ if (!isDecomposable (instanceNormOp)) {
3531+ return failure ();
3532+ }
3533+
3534+ // Get info.
3535+ Value input = instanceNormOp.getInput ();
3536+ Value scale = instanceNormOp.getScale ();
3537+ Value bias = instanceNormOp.getB ();
3538+ ShapedType inputType = mlir::cast<ShapedType>(input.getType ());
3539+ Type elementType = inputType.getElementType ();
3540+ auto inputShape = inputType.getShape ();
3541+ int64_t C = inputShape[1 ];
3542+ int64_t inputRank = inputType.getRank ();
3543+ int64_t nonSpacialRank = 2 ; // Batch N and Channel C: 2 dimensions.
3544+ assert (inputRank > nonSpacialRank &&
3545+ " expected instance norm with input ranks > 2" );
3546+
3547+ // Rewrite.
3548+ onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create (
3549+ rewriter, instanceNormOp.getLoc ());
3550+ int64_t axis = nonSpacialRank;
3551+ int64_t numInNorm = inputRank - axis;
3552+ // Unsqueeze scale/bias from [C] to [C x 1 x 1 x ... x 1] with numInNorm
3553+ // 1s.
3554+ llvm::SmallVector<int64_t , 4 > axesList, biasScaleShape;
3555+ biasScaleShape.emplace_back (C);
3556+ for (int64_t i = 1 ; i <= numInNorm; ++i) {
3557+ biasScaleShape.emplace_back (1 );
3558+ axesList.emplace_back (i);
3559+ }
3560+ Value axes = create.onnx .constantInt64 (axesList);
3561+ Type biasScaleType = RankedTensorType::get (biasScaleShape, elementType);
3562+ Value newScale = create.onnx .unsqueeze (biasScaleType, scale, axes);
3563+ Value newBias = create.onnx .unsqueeze (biasScaleType, bias, axes);
3564+ // Create output using layer norm.
3565+ Value Y = create.onnx .layerNorm (inputType, input, newScale, newBias, axis,
3566+ instanceNormOp.getEpsilonAttr ());
3567+ // Set the type of the output to be the same as the output of the original
3568+ // operation we are trying to replace.
3569+ Y.setType (instanceNormOp.getResult ().getType ());
3570+ // Replace operation.
3571+ rewriter.replaceOp (instanceNormOp, Y);
3572+ return success ();
3573+ }
3574+ };
3575+
35153576// =============================================================================
35163577// Decompose Hardswish to simpler ONNX ops
35173578// =============================================================================
@@ -3577,13 +3638,15 @@ struct DecomposeONNXToONNXPass
35773638 DecomposeONNXToONNXPass (const std::string &target,
35783639 bool enableConvTransposeDecompose = false ,
35793640 bool enableConvTransposeDecomposeToPhasedConv = false ,
3580- bool enableConvTranspose1dDecomposeToPhasedConv = false ) {
3641+ bool enableConvTranspose1dDecomposeToPhasedConv = false ,
3642+ bool enableInstanceNormDecompose = true ) {
35813643 this ->target = target;
35823644 this ->enableConvTransposeDecompose = enableConvTransposeDecompose;
35833645 this ->enableConvTransposeDecomposeToPhasedConv =
35843646 enableConvTransposeDecomposeToPhasedConv;
35853647 this ->enableConvTranspose1dDecomposeToPhasedConv =
35863648 enableConvTranspose1dDecomposeToPhasedConv;
3649+ this ->enableInstanceNormDecompose = enableInstanceNormDecompose;
35873650 }
35883651
35893652 DecomposeONNXToONNXPass (const DecomposeONNXToONNXPass &pass)
@@ -3594,6 +3657,8 @@ struct DecomposeONNXToONNXPass
35943657 pass.enableConvTransposeDecompose .getValue ();
35953658 this ->enableConvTransposeDecomposeToPhasedConv =
35963659 pass.enableConvTransposeDecomposeToPhasedConv .getValue ();
3660+ this ->enableInstanceNormDecompose =
3661+ pass.enableInstanceNormDecompose .getValue ();
35973662 }
35983663
35993664 StringRef getArgument () const override { return " decompose-onnx" ; }
@@ -3623,6 +3688,12 @@ struct DecomposeONNXToONNXPass
36233688 " phased Conv" ),
36243689 ::llvm::cl::init (false )};
36253690
3691+ Option<bool > enableInstanceNormDecompose{*this ,
3692+ " enable-instancenorm-decompose" ,
3693+ llvm::cl::desc (" Enable decomposition of InstanceNormalization to "
3694+ " LayerNormalization" ),
3695+ ::llvm::cl::init (true )};
3696+
36263697 void runOnOperation () final ;
36273698
36283699 typedef PassWrapper<DecomposeONNXToONNXPass, OperationPass<func::FuncOp>>
@@ -3635,7 +3706,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
36353706 RewritePatternSet patterns (context);
36363707 onnx_mlir::getDecomposeONNXToONNXPatterns (patterns,
36373708 enableConvTransposeDecompose, enableConvTransposeDecomposeToPhasedConv,
3638- enableConvTranspose1dDecomposeToPhasedConv);
3709+ enableConvTranspose1dDecomposeToPhasedConv, enableInstanceNormDecompose );
36393710 patterns.insert <ReplaceCastLikeByCastPattern>(context);
36403711
36413712#ifdef ONNX_MLIR_ENABLE_STABLEHLO
@@ -3653,7 +3724,8 @@ void DecomposeONNXToONNXPass::runOnOperation() {
36533724void onnx_mlir::getDecomposeONNXToONNXPatterns (
36543725 mlir::RewritePatternSet &patterns, bool enableConvTransposeDecompose,
36553726 bool enableConvTransposeDecomposeToPhasedConv,
3656- bool enableConvTranspose1dDecomposeToPhasedConv) {
3727+ bool enableConvTranspose1dDecomposeToPhasedConv,
3728+ bool enableInstanceNormDecompose) {
36573729 MLIRContext *context = patterns.getContext ();
36583730 populateWithGenerated (patterns);
36593731 if (enableConvTransposeDecompose)
@@ -3662,6 +3734,8 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
36623734 convtranspose_phased::populateWithGenerated (patterns);
36633735 if (enableConvTranspose1dDecomposeToPhasedConv)
36643736 convtranspose_1d_phased::populateWithGenerated (patterns);
3737+ if (enableInstanceNormDecompose)
3738+ patterns.insert <DecomposeInstanceNormPattern>(context);
36653739 patterns.insert <onnx_mlir::DecomposeEinsumPattern>(context);
36663740 patterns.insert <ConcatFusePattern>(context);
36673741 patterns.insert <DecomposeHardSwishPattern>(context);
@@ -3699,8 +3773,9 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
36993773std::unique_ptr<mlir::Pass> onnx_mlir::createDecomposeONNXToONNXPass (
37003774 const std::string &target, bool enableConvTransposeDecompose,
37013775 bool enableConvTransposeDecomposeToPhasedConv,
3702- bool enableConvTranspose1dDecomposeToPhasedConv) {
3776+ bool enableConvTranspose1dDecomposeToPhasedConv,
3777+ bool enableInstanceNormDecompose) {
37033778 return std::make_unique<DecomposeONNXToONNXPass>(target,
37043779 enableConvTransposeDecompose, enableConvTransposeDecomposeToPhasedConv,
3705- enableConvTranspose1dDecomposeToPhasedConv);
3780+ enableConvTranspose1dDecomposeToPhasedConv, enableInstanceNormDecompose );
37063781}
0 commit comments