@@ -3693,6 +3693,119 @@ struct ConcatToPadCommOptimize
36933693 }
36943694};
36953695
3696+ struct ConcatToDUSOptimize : public OpRewritePattern <stablehlo::ConcatenateOp> {
3697+ using OpRewritePattern::OpRewritePattern;
3698+
3699+ LogicalResult matchAndRewrite (stablehlo::ConcatenateOp concat,
3700+ PatternRewriter &rewriter) const override {
3701+ if (concat->getParentOfType <sdy::ManualComputationOp>())
3702+ return failure ();
3703+ auto ndims = concat.getType ().getShape ().size ();
3704+ auto concatShape = concat.getType ().getShape ();
3705+ auto concatDimension = concat.getDimension ();
3706+ auto concatDimSize = concatShape[concatDimension];
3707+ auto elemType = concat.getType ().getElementType ();
3708+
3709+ auto concatSharding = mlir::sdy::getSharding (concat);
3710+ if (!concatSharding)
3711+ return failure ();
3712+
3713+ auto numDevicesAlongDimension =
3714+ getNumDevicesAlongDimension (concatSharding, concatDimension, concat);
3715+ if (numDevicesAlongDimension == 1 ) {
3716+ return rewriter.notifyMatchFailure (
3717+ concat,
3718+ " numDevicesAlongDimension == 1. Communication is already optimized." );
3719+ }
3720+
3721+ if (concat.getNumOperands () == 2 &&
3722+ isRotateLike (concat.getDimension (), concat.getOperands ()[0 ],
3723+ concat.getOperands ()[1 ])) {
3724+ return rewriter.notifyMatchFailure (concat, " Explicit rotate like comm" );
3725+ }
3726+
3727+ SmallVector<int64_t > padLow (ndims, 0 );
3728+ SmallVector<int64_t > padHigh (ndims, 0 );
3729+ SmallVector<int64_t > padInner (ndims, 0 );
3730+
3731+ SmallVector<Value> addOperands;
3732+
3733+ size_t largest_idx = 0 ;
3734+ for (auto &&[idx, operand] : llvm::enumerate (concat.getOperands ())) {
3735+ auto operandSharding = mlir::sdy::getSharding (operand);
3736+ if (!operandSharding || (operandSharding != concatSharding))
3737+ return failure ();
3738+ if (cast<RankedTensorType>(operand.getType ())
3739+ .getShape ()[concatDimension] >
3740+ cast<RankedTensorType>(concat.getOperands ()[largest_idx].getType ())
3741+ .getShape ()[concatDimension]) {
3742+ largest_idx = idx;
3743+ }
3744+ }
3745+
3746+ auto zero = stablehlo::ConstantOp::create (rewriter, concat.getLoc (),
3747+ rewriter.getZeroAttr (elemType));
3748+
3749+ int64_t leftPadding = 0 ;
3750+ for (auto [i, operand] : llvm::enumerate (concat.getOperands ())) {
3751+ auto operandConcatDimSize =
3752+ cast<RankedTensorType>(operand.getType ()).getShape ()[concatDimension];
3753+ if (i == largest_idx)
3754+ break ;
3755+ leftPadding += operandConcatDimSize;
3756+ }
3757+
3758+ padLow[concatDimension] = leftPadding;
3759+ padHigh[concatDimension] =
3760+ concatDimSize - leftPadding -
3761+ cast<RankedTensorType>(concat.getOperands ()[largest_idx].getType ())
3762+ .getShape ()[concatDimension];
3763+
3764+ auto padStart = stablehlo::PadOp::create (rewriter, concat.getLoc (),
3765+ concat.getOperands ()[largest_idx],
3766+ zero, padLow, padHigh, padInner);
3767+ assert (concat.getType () == padStart.getType ());
3768+ sdy::setSharding (padStart, concatSharding);
3769+
3770+ Value current = padStart;
3771+
3772+ leftPadding = 0 ;
3773+
3774+ auto i32 = RankedTensorType::get ({}, concatDimSize < (1ULL << 32 )
3775+ ? rewriter.getI32Type ()
3776+ : rewriter.getI64Type ());
3777+ auto zeroI32 = stablehlo::ConstantOp::create (rewriter, concat.getLoc (),
3778+ rewriter.getZeroAttr (i32 ));
3779+
3780+ for (auto [i, operand] : llvm::enumerate (concat.getOperands ())) {
3781+ auto operandConcatDimSize =
3782+ cast<RankedTensorType>(operand.getType ()).getShape ()[concatDimension];
3783+
3784+ if (isZero (operand) || i == largest_idx) {
3785+ leftPadding += operandConcatDimSize;
3786+ continue ;
3787+ }
3788+
3789+ SmallVector<Value> idxs (ndims, zeroI32);
3790+ idxs[concatDimension] = stablehlo::ConstantOp::create (
3791+ rewriter, concat.getLoc (), i32 ,
3792+ cast<ElementsAttr>(makeAttr (i32 , leftPadding)));
3793+
3794+ auto paddedOperand = stablehlo::DynamicUpdateSliceOp::create (
3795+
3796+ rewriter, concat.getLoc (), current, operand, idxs);
3797+
3798+ assert (concat.getType () == paddedOperand.getType ());
3799+ sdy::setSharding (paddedOperand, concatSharding);
3800+ leftPadding += operandConcatDimSize;
3801+ current = paddedOperand;
3802+ }
3803+
3804+ rewriter.replaceOp (concat, current);
3805+ return success ();
3806+ }
3807+ };
3808+
36963809// See https://github.com/EnzymeAD/Enzyme-JAX/issues/854 for the motivation
36973810// TODO: At some point if we can come up with a cost model for this, we can do a
36983811// greedy search for the best ordering
@@ -3881,6 +3994,9 @@ struct OptimizeCommunicationPass
38813994 patterns.add <ConcatToPadCommOptimize>(context,
38823995 PatternBenefit (concat_to_pad_comm));
38833996
3997+ if (concat_to_dus > 0 )
3998+ patterns.add <ConcatToDUSOptimize>(context, PatternBenefit (concat_to_dus));
3999+
38844000 if (concat_two_operands_comm > 0 )
38854001 patterns.add <ConcatTwoOperandsCommOptimize>(
38864002 channel_id, context, PatternBenefit (concat_two_operands_comm));
0 commit comments