Skip to content

Commit 732648d

Browse files
committed
refactor: move broadcast derivative rule to tablegen
1 parent b3d6fde commit 732648d

File tree

2 files changed

+124
-113
lines changed

2 files changed

+124
-113
lines changed

src/enzyme_ad/jax/Implementations/HLODerivatives.td

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -977,8 +977,6 @@ def : HLODerivative<"Atan2Op", (Op $x, $y),
977977
(CheckedDiv (Sub (Mul $x, (Shadow $y)), (Mul $y, (Shadow $x))), (Add (Pow $x, (HLOConstantFP<"2">)), (Pow $y, (HLOConstantFP<"2"> $y))))
978978
>;
979979

980-
def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">;
981-
982980
// Gather Adjoint
983981
def ResultDimensionNumbers : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
984982
op.getDimensionNumbers();
@@ -1446,3 +1444,127 @@ def : HLODerivative<"FftOp",
14461444
Fft (Shadow $x), (FftType), (FftLength)
14471445
)
14481446
>;
1447+
1448+
def getBroadcastDimensionsWithBatch : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1449+
auto bcastDims = op.getBroadcastDimensions();
1450+
SmallVector<int64_t> newBcastDims;
1451+
for (auto dim : bcastDims) {
1452+
newBcastDims.push_back(to_i64(dim));
1453+
}
1454+
if (gutils->width > 1) {
1455+
newBcastDims.insert(newBcastDims.begin(), gutils->width);
1456+
}
1457+
getI64Attr(builder, newBcastDims);
1458+
}]>;
1459+
1460+
def BroadcastDimsToReductionDims : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1461+
SmallVector<int64_t> reduceDims;
1462+
auto outRank = cast<RankedTensorType>(op.getType()).getRank();
1463+
for (int64_t i = 0; i < outRank; i++) {
1464+
if (!llvm::is_contained(op.getBroadcastDimensions(), i)) {
1465+
reduceDims.push_back(i);
1466+
}
1467+
}
1468+
if (gutils->width > 1) {
1469+
for (int64_t i = 0; i < reduceDims.size(); i++) {
1470+
reduceDims[i] += 1;
1471+
}
1472+
}
1473+
getI64Attr(builder, reduceDims);
1474+
}]>;
1475+
1476+
def ReduceAddNullValue : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1477+
auto operand = op->getOperand(0);
1478+
auto operandElemType = cast<RankedTensorType>(operand.getType()).getElementType();
1479+
auto bodyTy = RankedTensorType::get({}, operandElemType);
1480+
cast<AutoDiffTypeInterface>(
1481+
gutils->getShadowType(bodyTy)).createNullValue(builder, op.getLoc());
1482+
}]>;
1483+
1484+
def BroadcastDimensionsToInversePermutation : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1485+
auto outTy = cast<RankedTensorType>(op.getType());
1486+
SmallVector<int64_t> perm(outTy.getRank());
1487+
SmallVector<int64_t> mapping(outTy.getRank(), -1);
1488+
for (auto [i, dim] : llvm::enumerate(op.getBroadcastDimensions())) {
1489+
mapping[to_i64(dim)] = i;
1490+
}
1491+
1492+
int next = op.getBroadcastDimensions().size();
1493+
for (int i = 0; i < outTy.getRank(); i++) {
1494+
if (mapping[i] == -1) {
1495+
mapping[i] = next++;
1496+
}
1497+
}
1498+
1499+
for (int i = 0; i < outTy.getRank(); i++) {
1500+
perm[mapping[i]] = i;
1501+
}
1502+
getI64Attr(builder, perm);
1503+
}]>;
1504+
1505+
def InsertDeletedReduceDimsType : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1506+
auto outTy = cast<RankedTensorType>(op.getType());
1507+
auto outShape = outTy.getShape();
1508+
SmallVector<int64_t> reshapeShape(outTy.getRank(), -1);
1509+
for (auto [i, sz] : llvm::enumerate(outShape)) {
1510+
if (!llvm::is_contained(op.getBroadcastDimensions(), i)) {
1511+
reshapeShape[i] = 1;
1512+
} else {
1513+
reshapeShape[i] = sz;
1514+
}
1515+
}
1516+
1517+
if (gutils->width > 1) {
1518+
reshapeShape.insert(reshapeShape.begin(), gutils->width);
1519+
}
1520+
RankedTensorType::get(reshapeShape, outTy.getElementType());
1521+
}]>;
1522+
1523+
def ReduceAdd : HLOInst<"ReduceOp", ")->getResult(0)", "createAddRegion(">;
1524+
1525+
def ResultTypeWithBatch : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1526+
auto outTy = cast<RankedTensorType>(op.getType());
1527+
auto outShape = outTy.getShape();
1528+
SmallVector<int64_t> newShape(outShape.begin(), outShape.end());
1529+
if (gutils->width > 1) {
1530+
newShape.insert(newShape.begin(), gutils->width);
1531+
}
1532+
RankedTensorType::get(newShape, outTy.getElementType());
1533+
}]>;
1534+
1535+
def InputTypeWithBatch : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1536+
auto inTy = cast<RankedTensorType>(op.getOperand().getType());
1537+
auto inShape = inTy.getShape();
1538+
SmallVector<int64_t> newShape(inShape.begin(), inShape.end());
1539+
if (gutils->width > 1) {
1540+
newShape.insert(newShape.begin(), gutils->width);
1541+
}
1542+
RankedTensorType::get(newShape, inTy.getElementType());
1543+
}]>;
1544+
1545+
def : HLODerivative<"BroadcastInDimOp", (Op $x),
1546+
[
1547+
(
1548+
Reshape
1549+
(InputTypeWithBatch),
1550+
(
1551+
Transpose
1552+
(
1553+
Reshape
1554+
(InsertDeletedReduceDimsType),
1555+
(ReduceAdd (DiffeRet), (ReduceAddNullValue), (BroadcastDimsToReductionDims))
1556+
),
1557+
(BroadcastDimensionsToInversePermutation)
1558+
)
1559+
)
1560+
],
1561+
(
1562+
SelectIfActive $x,
1563+
(
1564+
BroadcastInDim
1565+
(ResultTypeWithBatch),
1566+
(Shadow $x),
1567+
(getBroadcastDimensionsWithBatch)
1568+
),
1569+
(HLOConstantFP<"0">)
1570+
)>;

src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,116 +1133,6 @@ class AutoDiffWhileRev
11331133
}
11341134
};
11351135

1136-
class AutoDiffBroadcastInDimRev
1137-
: public ReverseAutoDiffOpInterface::ExternalModel<
1138-
AutoDiffBroadcastInDimRev, BroadcastInDimOp> {
1139-
public:
1140-
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
1141-
MGradientUtilsReverse *gutils,
1142-
SmallVector<Value> caches) const {
1143-
auto op = cast<BroadcastInDimOp>(orig);
1144-
auto inTy = op.getOperand().getType();
1145-
auto outTy = op.getType();
1146-
auto inDiffe = gutils->diffe(op, builder);
1147-
gutils->zeroDiffe(op, builder);
1148-
1149-
SmallVector<int64_t> bcastDims(op.getBroadcastDimensions().begin(),
1150-
op.getBroadcastDimensions().end());
1151-
1152-
SmallVector<int64_t> reducedDims;
1153-
SmallVector<int64_t> iterShape;
1154-
for (auto en : llvm::enumerate(outTy.getShape())) {
1155-
ssize_t bcastIdx = -1;
1156-
for (auto en2 : llvm::enumerate(bcastDims)) {
1157-
if (en2.value() == en.index()) {
1158-
bcastIdx = en2.index();
1159-
break;
1160-
}
1161-
}
1162-
if (bcastIdx != -1) {
1163-
if (en.value() != inTy.getShape()[bcastIdx]) {
1164-
reducedDims.push_back(en.index());
1165-
assert(inTy.getShape()[bcastIdx] == 1);
1166-
} else {
1167-
iterShape.push_back(inTy.getShape()[bcastIdx]);
1168-
}
1169-
continue;
1170-
}
1171-
reducedDims.push_back(en.index());
1172-
}
1173-
1174-
SmallVector<int64_t> reshapedShape(outTy.getRank(), -1);
1175-
for (auto [i, sz] : llvm::enumerate(outTy.getShape())) {
1176-
if (llvm::is_contained(reducedDims, i)) {
1177-
reshapedShape[i] = 1;
1178-
} else {
1179-
reshapedShape[i] = sz;
1180-
}
1181-
}
1182-
1183-
SmallVector<int64_t> perm(outTy.getRank(), -1);
1184-
SmallVector<int64_t> mapping(outTy.getRank(), -1);
1185-
for (auto [i, dim] : llvm::enumerate(bcastDims)) {
1186-
mapping[dim] = i;
1187-
}
1188-
1189-
int next = bcastDims.size();
1190-
for (int i = 0; i < outTy.getRank(); i++) {
1191-
if (mapping[i] == -1) {
1192-
mapping[i] = next++;
1193-
}
1194-
}
1195-
1196-
for (int i = 0; i < outTy.getRank(); i++) {
1197-
perm[mapping[i]] = i;
1198-
}
1199-
1200-
auto reduceTy = RankedTensorType::get(iterShape, inTy.getElementType());
1201-
auto bodyTy = RankedTensorType::get({}, inTy.getElementType());
1202-
1203-
Value zero = cast<AutoDiffTypeInterface>(gutils->getShadowType(bodyTy))
1204-
.createNullValue(builder, op.getLoc());
1205-
1206-
auto red = builder.create<ReduceOp>(
1207-
op.getLoc(), TypeRange(gutils->getShadowType(reduceTy)), inDiffe, zero,
1208-
reducedDims);
1209-
red.getBody().push_back(new Block());
1210-
Block &body = red.getBody().front();
1211-
OpBuilder bodyBuilder(orig->getContext());
1212-
bodyBuilder.setInsertionPointToEnd(&body);
1213-
1214-
body.addArgument(bodyTy, op.getLoc());
1215-
body.addArgument(bodyTy, op.getLoc());
1216-
auto add = bodyBuilder.create<AddOp>(op.getLoc(), body.getArgument(0),
1217-
body.getArgument(1));
1218-
bodyBuilder.create<ReturnOp>(op.getLoc(), ValueRange(add));
1219-
1220-
// for simplicity we do grad -> reduce -> reshape (restore 1 dims) ->
1221-
// transpose -> reshape
1222-
// The repeated reshapes are then eliminated via `enzyme-hlo-opt`.
1223-
auto reshapedRed = builder.create<ReshapeOp>(
1224-
op.getLoc(),
1225-
RankedTensorType::get(reshapedShape, inTy.getElementType()),
1226-
red->getResult(0));
1227-
auto transposedVal =
1228-
builder.create<TransposeOp>(op.getLoc(), reshapedRed, perm);
1229-
auto res = builder.create<ReshapeOp>(
1230-
op.getLoc(), gutils->getShadowType(op.getOperand().getType()),
1231-
transposedVal);
1232-
1233-
gutils->addToDiffe(op.getOperand(), res, builder);
1234-
return success();
1235-
}
1236-
1237-
SmallVector<Value> cacheValues(Operation *orig,
1238-
MGradientUtilsReverse *gutils) const {
1239-
return {};
1240-
}
1241-
1242-
void createShadowValues(Operation *op, OpBuilder &builder,
1243-
MGradientUtilsReverse *gutils) const {}
1244-
};
1245-
12461136
class AutoDiffSliceRev
12471137
: public ReverseAutoDiffOpInterface::ExternalModel<AutoDiffSliceRev,
12481138
SliceOp> {
@@ -3635,7 +3525,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
36353525
WhileOp::attachInterface<AutoDiffWhileRev>(*context);
36363526
ReduceOp::attachInterface<AutoDiffReduceCF<ReduceOp>>(*context);
36373527
WhileOp::attachInterface<AutoDiffReduceCF<WhileOp>>(*context);
3638-
BroadcastInDimOp::attachInterface<AutoDiffBroadcastInDimRev>(*context);
36393528
SliceOp::attachInterface<AutoDiffSliceRev>(*context);
36403529
ReduceOp::attachInterface<AutoDiffReduceRev>(*context);
36413530
ReduceWindowOp::attachInterface<AutoDiffReduceWindowRev>(*context);

0 commit comments

Comments
 (0)