Skip to content

Commit 894601b

Browse files
authored
Comms: Concat to dus (#1847)
* Comms: Concat to dus * Add reactant_commit entry to test-gb-25.yml * fix
1 parent cd10cc0 commit 894601b

File tree

4 files changed

+155
-1
lines changed

4 files changed

+155
-1
lines changed

.github/workflows/test-gb-25.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
- 'main'
5454
# - '0123456789abcdef0123456789abcdef01234567'
5555
reactant_commit:
56-
- 'main'
56+
- 'c2d'
5757

5858
steps:
5959
- name: Check GPUs

src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3693,6 +3693,119 @@ struct ConcatToPadCommOptimize
36933693
}
36943694
};
36953695

3696+
struct ConcatToDUSOptimize : public OpRewritePattern<stablehlo::ConcatenateOp> {
3697+
using OpRewritePattern::OpRewritePattern;
3698+
3699+
LogicalResult matchAndRewrite(stablehlo::ConcatenateOp concat,
3700+
PatternRewriter &rewriter) const override {
3701+
if (concat->getParentOfType<sdy::ManualComputationOp>())
3702+
return failure();
3703+
auto ndims = concat.getType().getShape().size();
3704+
auto concatShape = concat.getType().getShape();
3705+
auto concatDimension = concat.getDimension();
3706+
auto concatDimSize = concatShape[concatDimension];
3707+
auto elemType = concat.getType().getElementType();
3708+
3709+
auto concatSharding = mlir::sdy::getSharding(concat);
3710+
if (!concatSharding)
3711+
return failure();
3712+
3713+
auto numDevicesAlongDimension =
3714+
getNumDevicesAlongDimension(concatSharding, concatDimension, concat);
3715+
if (numDevicesAlongDimension == 1) {
3716+
return rewriter.notifyMatchFailure(
3717+
concat,
3718+
"numDevicesAlongDimension == 1. Communication is already optimized.");
3719+
}
3720+
3721+
if (concat.getNumOperands() == 2 &&
3722+
isRotateLike(concat.getDimension(), concat.getOperands()[0],
3723+
concat.getOperands()[1])) {
3724+
return rewriter.notifyMatchFailure(concat, "Explicit rotate like comm");
3725+
}
3726+
3727+
SmallVector<int64_t> padLow(ndims, 0);
3728+
SmallVector<int64_t> padHigh(ndims, 0);
3729+
SmallVector<int64_t> padInner(ndims, 0);
3730+
3731+
SmallVector<Value> addOperands;
3732+
3733+
size_t largest_idx = 0;
3734+
for (auto &&[idx, operand] : llvm::enumerate(concat.getOperands())) {
3735+
auto operandSharding = mlir::sdy::getSharding(operand);
3736+
if (!operandSharding || (operandSharding != concatSharding))
3737+
return failure();
3738+
if (cast<RankedTensorType>(operand.getType())
3739+
.getShape()[concatDimension] >
3740+
cast<RankedTensorType>(concat.getOperands()[largest_idx].getType())
3741+
.getShape()[concatDimension]) {
3742+
largest_idx = idx;
3743+
}
3744+
}
3745+
3746+
auto zero = stablehlo::ConstantOp::create(rewriter, concat.getLoc(),
3747+
rewriter.getZeroAttr(elemType));
3748+
3749+
int64_t leftPadding = 0;
3750+
for (auto [i, operand] : llvm::enumerate(concat.getOperands())) {
3751+
auto operandConcatDimSize =
3752+
cast<RankedTensorType>(operand.getType()).getShape()[concatDimension];
3753+
if (i == largest_idx)
3754+
break;
3755+
leftPadding += operandConcatDimSize;
3756+
}
3757+
3758+
padLow[concatDimension] = leftPadding;
3759+
padHigh[concatDimension] =
3760+
concatDimSize - leftPadding -
3761+
cast<RankedTensorType>(concat.getOperands()[largest_idx].getType())
3762+
.getShape()[concatDimension];
3763+
3764+
auto padStart = stablehlo::PadOp::create(rewriter, concat.getLoc(),
3765+
concat.getOperands()[largest_idx],
3766+
zero, padLow, padHigh, padInner);
3767+
assert(concat.getType() == padStart.getType());
3768+
sdy::setSharding(padStart, concatSharding);
3769+
3770+
Value current = padStart;
3771+
3772+
leftPadding = 0;
3773+
3774+
auto i32 = RankedTensorType::get({}, concatDimSize < (1ULL << 32)
3775+
? rewriter.getI32Type()
3776+
: rewriter.getI64Type());
3777+
auto zeroI32 = stablehlo::ConstantOp::create(rewriter, concat.getLoc(),
3778+
rewriter.getZeroAttr(i32));
3779+
3780+
for (auto [i, operand] : llvm::enumerate(concat.getOperands())) {
3781+
auto operandConcatDimSize =
3782+
cast<RankedTensorType>(operand.getType()).getShape()[concatDimension];
3783+
3784+
if (isZero(operand) || i == largest_idx) {
3785+
leftPadding += operandConcatDimSize;
3786+
continue;
3787+
}
3788+
3789+
SmallVector<Value> idxs(ndims, zeroI32);
3790+
idxs[concatDimension] = stablehlo::ConstantOp::create(
3791+
rewriter, concat.getLoc(), i32,
3792+
cast<ElementsAttr>(makeAttr(i32, leftPadding)));
3793+
3794+
auto paddedOperand = stablehlo::DynamicUpdateSliceOp::create(
3795+
3796+
rewriter, concat.getLoc(), current, operand, idxs);
3797+
3798+
assert(concat.getType() == paddedOperand.getType());
3799+
sdy::setSharding(paddedOperand, concatSharding);
3800+
leftPadding += operandConcatDimSize;
3801+
current = paddedOperand;
3802+
}
3803+
3804+
rewriter.replaceOp(concat, current);
3805+
return success();
3806+
}
3807+
};
3808+
36963809
// See https://github.com/EnzymeAD/Enzyme-JAX/issues/854 for the motivation
36973810
// TODO: At some point if we can come up with a cost model for this, we can do a
36983811
// greedy search for the best ordering
@@ -3881,6 +3994,9 @@ struct OptimizeCommunicationPass
38813994
patterns.add<ConcatToPadCommOptimize>(context,
38823995
PatternBenefit(concat_to_pad_comm));
38833996

3997+
if (concat_to_dus > 0)
3998+
patterns.add<ConcatToDUSOptimize>(context, PatternBenefit(concat_to_dus));
3999+
38844000
if (concat_two_operands_comm > 0)
38854001
patterns.add<ConcatTwoOperandsCommOptimize>(
38864002
channel_id, context, PatternBenefit(concat_two_operands_comm));

src/enzyme_ad/jax/Passes/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,12 @@ def OptimizeCommunication : Pass<"optimize-communication"> {
889889
/*default=*/"0",
890890
/*description=*/"Convert Concatenate two operands to Manual Computation with CollectivePermute">,
891891
Option<
892+
/*C++ variable name=*/"concat_to_dus",
893+
/*CLI argument=*/"concat_to_dus",
894+
/*type=*/"int",
895+
/*default=*/"0",
896+
/*description=*/"Perform a Concatenate with Padding to optimize the communication">,
897+
Option<
892898
/*C++ variable name=*/"concat_to_pad_comm",
893899
/*CLI argument=*/"concat_to_pad_comm",
894900
/*type=*/"int",
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(optimize-communication{periodic_concat=0 concat_to_pad_comm=0 concat_to_dus=1 dus_to_pad_comm=0})" %s | FileCheck %s
2+
3+
sdy.mesh @mesh1 = <["z"=1, "x"=4, "y"=4]>
4+
func.func @main1(%arg0: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x120xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}) {
5+
%0 = stablehlo.slice %arg1 [0:20, 0:24, 0:40] {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>) -> tensor<20x24x40xf64>
6+
%1 = stablehlo.concatenate %arg0, %0, dim = 2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor<20x24x40xf64>) -> tensor<20x24x120xf64>
7+
return %1 : tensor<20x24x120xf64>
8+
}
9+
10+
func.func @main2(%arg0: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x120xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}) {
11+
%0 = stablehlo.slice %arg1 [0:20, 0:24, 0:40] {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>) -> tensor<20x24x40xf64>
12+
%1 = stablehlo.concatenate %0, %arg0, dim = 2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x40xf64>, tensor<20x24x80xf64>) -> tensor<20x24x120xf64>
13+
return %1 : tensor<20x24x120xf64>
14+
}
15+
16+
// CHECK: func.func @main1(%arg0: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x120xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}) {
17+
// CHECK-NEXT: %c = stablehlo.constant dense<80> : tensor<i32>
18+
// CHECK-NEXT: %c_0 = stablehlo.constant dense<0> : tensor<i32>
19+
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
20+
// CHECK-NEXT: %0 = stablehlo.slice %arg1 [0:20, 0:24, 0:40] {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>) -> tensor<20x24x40xf64>
21+
// CHECK-NEXT: %1 = stablehlo.pad %arg0, %cst, low = [0, 0, 0], high = [0, 0, 40], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor<f64>) -> tensor<20x24x120xf64>
22+
// CHECK-NEXT: %2 = stablehlo.dynamic_update_slice %1, %0, %c_0, %c_0, %c {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x120xf64>, tensor<20x24x40xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<20x24x120xf64>
23+
// CHECK-NEXT: return %2 : tensor<20x24x120xf64>
24+
// CHECK-NEXT: }
25+
// CHECK: func.func @main2(%arg0: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x120xf64> {sdy.sharding = #sdy.sharding<@mesh1, [{"z"}, {"y"}, {"x"}]>}) {
26+
// CHECK-NEXT: %c = stablehlo.constant dense<0> : tensor<i32>
27+
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
28+
// CHECK-NEXT: %0 = stablehlo.slice %arg1 [0:20, 0:24, 0:40] {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>) -> tensor<20x24x40xf64>
29+
// CHECK-NEXT: %1 = stablehlo.pad %arg0, %cst, low = [0, 0, 40], high = [0, 0, 0], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor<f64>) -> tensor<20x24x120xf64>
30+
// CHECK-NEXT: %2 = stablehlo.dynamic_update_slice %1, %0, %c, %c, %c {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x120xf64>, tensor<20x24x40xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<20x24x120xf64>
31+
// CHECK-NEXT: return %2 : tensor<20x24x120xf64>
32+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)