@@ -3753,6 +3753,165 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
37533753};
37543754} // namespace
37553755
3756+ namespace {
3757+ class DecomposeAtenGroupNormOp : public OpRewritePattern <AtenGroupNormOp> {
3758+ using OpRewritePattern<AtenGroupNormOp>::OpRewritePattern;
3759+ LogicalResult matchAndRewrite (AtenGroupNormOp op,
3760+ PatternRewriter &rewriter) const override {
3761+ Location loc = op.getLoc ();
3762+ MLIRContext *context = op.getContext ();
3763+
3764+ Value input = op.getInput ();
3765+ Value weight = op.getWeight ();
3766+ Value bias = op.getBias ();
3767+ Value numGroups = op.getNumGroups ();
3768+ Value eps = op.getEps ();
3769+
3770+ Value cstZero =
3771+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (0 ));
3772+ Value cstOne =
3773+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (1 ));
3774+ auto baseType = ValueTensorType::getWithLeastStaticInformation (context);
3775+
3776+ Value N = rewriter.create <AtenSizeIntOp>(loc, input, cstZero);
3777+ Value C = rewriter.create <AtenSizeIntOp>(loc, input, cstOne);
3778+ Value numElements = rewriter.create <AtenNumelOp>(loc, input);
3779+ Value numElementsDivN =
3780+ rewriter.create <AtenFloordivIntOp>(loc, numElements, N);
3781+ Value HxW = rewriter.create <AtenFloordivIntOp>(loc, numElementsDivN, C);
3782+
3783+ AtenNativeGroupNormOp newOp = rewriter.create <AtenNativeGroupNormOp>(
3784+ loc, ArrayRef<Type>{op.getResult ().getType (), baseType, baseType},
3785+ input, weight, bias, N, C, HxW, numGroups, eps);
3786+
3787+ rewriter.replaceOp (op, newOp.getResult0 ());
3788+ return success ();
3789+ }
3790+ };
3791+ } // namespace
3792+
3793+ namespace {
3794+ class DecomposeAtenNativeGroupNormOp
3795+ : public OpRewritePattern<AtenNativeGroupNormOp> {
3796+ using OpRewritePattern<AtenNativeGroupNormOp>::OpRewritePattern;
3797+ LogicalResult matchAndRewrite (AtenNativeGroupNormOp op,
3798+ PatternRewriter &rewriter) const override {
3799+ Location loc = op.getLoc ();
3800+ MLIRContext *context = op.getContext ();
3801+
3802+ Value input = op.getInput ();
3803+ Value weight = op.getWeight ();
3804+ Value bias = op.getBias ();
3805+ Value numGroups = op.getGroup ();
3806+ Value eps = op.getEps ();
3807+
3808+ // Check the rank of the input/outputs tensor.
3809+ auto inputType = input.getType ().cast <BaseTensorType>();
3810+ auto outputType = op.getResult0 ().getType ().cast <BaseTensorType>();
3811+ auto meanType = op.getResult1 ().getType ().cast <BaseTensorType>();
3812+ auto rsqrtVarType = op.getResult2 ().getType ().cast <BaseTensorType>();
3813+ if (!inputType.hasSizes () || !outputType.hasSizes () ||
3814+ !meanType.hasSizes () || !rsqrtVarType.hasSizes ()) {
3815+ return rewriter.notifyMatchFailure (
3816+ op, " input/outputs tensor should have known sizes." );
3817+ }
3818+
3819+ Value none = rewriter.create <ConstantNoneOp>(loc);
3820+ Value cstZero =
3821+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (0 ));
3822+ Value cstOne =
3823+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (1 ));
3824+ Value cstNegtiveOne =
3825+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (-1 ));
3826+ Value cstTrue = rewriter.create <Torch::ConstantBoolOp>(loc, true );
3827+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
3828+ auto baseType = ValueTensorType::getWithLeastStaticInformation (context);
3829+
3830+ // GroupNorm requires the channel dimension (C) to be exactly divisible by
3831+ // the number of groups.
3832+ Value channel = rewriter.create <AtenSizeIntOp>(loc, input, cstOne);
3833+ Value remainder =
3834+ rewriter.create <AtenRemainderIntOp>(loc, channel, numGroups);
3835+ Value eqOrNot = rewriter.create <AtenEqIntOp>(loc, remainder, cstZero);
3836+ rewriter.create <RuntimeAssertOp>(
3837+ loc, eqOrNot,
3838+ rewriter.getStringAttr (" the number of channels must be divisible by "
3839+ " the number of groups" ));
3840+
3841+ // Reshape the input tensor to (N, numGroups, -1) to apply normalization.
3842+ SmallVector<Value> newShape;
3843+ newShape.push_back (rewriter.create <AtenSizeIntOp>(loc, input, cstZero));
3844+ newShape.push_back (numGroups);
3845+ newShape.push_back (cstNegtiveOne);
3846+ Value reshapedInput = rewriter.create <AtenViewOp>(
3847+ loc, baseType, input,
3848+ rewriter.create <PrimListConstructOp>(
3849+ loc, Torch::ListType::get (IntType::get (context)), newShape));
3850+
3851+ // Now we proceed with the normalization steps across the 'groupSize'
3852+ // Compute the mean and variance for each group
3853+ Value dimList = rewriter.create <PrimListConstructOp>(
3854+ loc, Torch::ListType::get (Torch::IntType::get (op.getContext ())),
3855+ ArrayRef<Value>{cstNegtiveOne});
3856+ auto mean = rewriter.create <AtenMeanDimOp>(
3857+ loc, baseType, reshapedInput, /* dims=*/ dimList, /* keepdim=*/ cstTrue,
3858+ /* dtype=*/ none);
3859+ auto var = rewriter.create <AtenVarDimOp>(
3860+ loc, baseType, reshapedInput, /* dims=*/ dimList, /* unbiased=*/ cstFalse,
3861+ /* keepdim=*/ cstTrue);
3862+
3863+ // Compute the normalized output: (input - mean) * rsqrt(var + eps)
3864+ auto varPlusEps = rewriter.create <AtenAddScalarOp>(loc, baseType, var, eps,
3865+ /* alpha=*/ cstOne);
3866+ auto invStd = rewriter.create <AtenRsqrtOp>(loc, baseType, varPlusEps);
3867+ auto inputSubMean = rewriter.create <AtenSubTensorOp>(
3868+ loc, baseType, reshapedInput, mean, /* alpha=*/ cstOne);
3869+ auto normalizedOutput =
3870+ rewriter.create <AtenMulTensorOp>(loc, baseType, inputSubMean, invStd);
3871+
3872+ // Reshape normalized output back to the original input shape
3873+ auto inputShape = rewriter.create <AtenSizeOp>(
3874+ loc, Torch::ListType::get (IntType::get (context)), input);
3875+ auto reshapedOutput = rewriter.create <AtenViewOp>(
3876+ loc, inputType, normalizedOutput, /* shape=*/ inputShape);
3877+
3878+ // Apply weight and bias if they are not None
3879+ // Reshape weight and bias to C,1,1,...
3880+ SmallVector<Value> viewShape = {channel};
3881+ for (unsigned i = 2 ; i < inputType.getSizes ().size (); i++) {
3882+ viewShape.push_back (cstOne);
3883+ }
3884+ Value viewShapeSizeList = rewriter.create <PrimListConstructOp>(
3885+ loc, ListType::get (IntType::get (context)), viewShape);
3886+
3887+ Value groupNormOutput = reshapedOutput;
3888+ if (!weight.getType ().isa <Torch::NoneType>()) {
3889+ auto weightReshaped = rewriter.create <AtenViewOp>(
3890+ loc, baseType, weight, /* shape=*/ viewShapeSizeList);
3891+ groupNormOutput = rewriter.create <AtenMulTensorOp>(
3892+ loc, inputType, groupNormOutput, weightReshaped);
3893+ }
3894+ if (!bias.getType ().isa <Torch::NoneType>()) {
3895+ auto biasReshaped = rewriter.create <AtenViewOp>(
3896+ loc, baseType, bias, /* shape=*/ viewShapeSizeList);
3897+ groupNormOutput = rewriter.create <AtenAddTensorOp>(
3898+ loc, inputType, groupNormOutput, biasReshaped,
3899+ /* alpha=*/ cstOne);
3900+ }
3901+
3902+ Value squeezedMean =
3903+ rewriter.create <AtenSqueezeDimOp>(loc, meanType, mean, cstNegtiveOne);
3904+ Value squeezedRsqrtVar = rewriter.create <AtenSqueezeDimOp>(
3905+ loc, rsqrtVarType, invStd, cstNegtiveOne);
3906+
3907+ rewriter.replaceOp (
3908+ op, ArrayRef<Value>{groupNormOutput, squeezedMean, squeezedRsqrtVar});
3909+
3910+ return success ();
3911+ }
3912+ };
3913+ } // namespace
3914+
37563915namespace {
37573916class DecomposeAtenNativeBatchNormOp
37583917 : public OpRewritePattern<AtenNativeBatchNormOp> {
@@ -6204,6 +6363,8 @@ class DecomposeComplexOpsPass
62046363 DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
62056364 addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
62066365 addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
6366+ addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
6367+ addPatternIfTargetOpIsIllegal<DecomposeAtenNativeGroupNormOp>(patterns);
62076368 addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
62086369 addPatternIfTargetOpIsIllegal<
62096370 DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
0 commit comments