Skip to content

Commit 4c28933

Browse files
authored
feat: simplify diagonal accesses for a dot_general op (#1614)
* feat: simplify diagonal accesses for a dot_general op * test: update
1 parent e67cf9f commit 4c28933

File tree

6 files changed

+348
-1
lines changed

6 files changed

+348
-1
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25483,6 +25483,145 @@ struct DynamicSliceSimplify
2548325483
}
2548425484
};
2548525485

25486+
// TODO: generalize to higher ranked tensors
25487+
// TODO: if we determine that all accesses are on some offset diagonal,
25488+
// we can still replace it will a multiply combined with pad/slice
25489+
// If we prove that only the diagonal elements of a dot_general are accessed,
25490+
// we replace the dot_general with a cheaper multiply op. Note that
25491+
// this implies `diag(new_op(A, B)) = diag(A x B)` however
25492+
// `new_op(A, B) != A x B`
25493+
struct DotGeneralOnlyDiagonalAccess
25494+
: public CheckedOpRewritePattern<stablehlo::DotGeneralOp,
25495+
DotGeneralOnlyDiagonalAccess> {
25496+
using CheckedOpRewritePattern<
25497+
stablehlo::DotGeneralOp,
25498+
DotGeneralOnlyDiagonalAccess>::CheckedOpRewritePattern;
25499+
25500+
LogicalResult matchAndRewriteImpl(stablehlo::DotGeneralOp op,
25501+
PatternRewriter &rewriter) const {
25502+
auto resTy = cast<RankedTensorType>(op.getType());
25503+
if (resTy.getRank() != 2)
25504+
return failure();
25505+
25506+
auto M = resTy.getDimSize(0);
25507+
auto N = resTy.getDimSize(1);
25508+
auto diagLen = std::min(M, N);
25509+
25510+
auto lhs = op.getLhs();
25511+
auto rhs = op.getRhs();
25512+
auto dotDimNumbers = op.getDotDimensionNumbers();
25513+
auto lhsContractingDims = dotDimNumbers.getLhsContractingDimensions();
25514+
auto rhsContractingDims = dotDimNumbers.getRhsContractingDimensions();
25515+
auto lhsBatchingDims = dotDimNumbers.getLhsBatchingDimensions();
25516+
auto rhsBatchingDims = dotDimNumbers.getRhsBatchingDimensions();
25517+
25518+
if (lhsContractingDims.size() != 1 || rhsContractingDims.size() != 1 ||
25519+
lhsBatchingDims.size() != 0 || rhsBatchingDims.size() != 0)
25520+
return failure();
25521+
25522+
llvm::SetVector<Operation *> opsToReplace;
25523+
llvm::SmallPtrSet<Operation *, 4> seenOps;
25524+
for (auto user : op->getUsers()) {
25525+
if (seenOps.count(user))
25526+
continue;
25527+
if (!enzyme::allAccessesAreOnMainDiagonal(user, opsToReplace))
25528+
return failure();
25529+
seenOps.insert(user);
25530+
}
25531+
25532+
if (opsToReplace.empty())
25533+
return failure();
25534+
25535+
// rewrite the dot_general to a multiply.
25536+
// we insert transpose ops here, but those will get removed later
25537+
auto lhsContractDim = lhsContractingDims[0];
25538+
auto rhsContractDim = rhsContractingDims[0];
25539+
// result[i, i] = sum_k (lhs[i, k] * rhs[k, i])
25540+
// = reduce_sum(lhs[i, :] * rhs[:, i])
25541+
auto lhsNonContractDim = 1 - lhsContractDim;
25542+
auto rhsNonContractDim = 1 - rhsContractDim;
25543+
25544+
if (lhsContractDim == 0) {
25545+
// move to dim = 1
25546+
lhs = stablehlo::TransposeOp::create(
25547+
rewriter, op.getLoc(), lhs, rewriter.getDenseI64ArrayAttr({1, 0}));
25548+
}
25549+
lhs = stablehlo::SliceOp::create(
25550+
rewriter, op.getLoc(), lhs, rewriter.getDenseI64ArrayAttr({0, 0}),
25551+
rewriter.getDenseI64ArrayAttr(
25552+
{diagLen, cast<ShapedType>(lhs.getType()).getDimSize(1)}),
25553+
rewriter.getDenseI64ArrayAttr({1, 1})); // [DiagSize, C]
25554+
25555+
if (rhsContractDim == 0) {
25556+
// move to dim = 1
25557+
rhs = stablehlo::TransposeOp::create(
25558+
rewriter, op.getLoc(), rhs, rewriter.getDenseI64ArrayAttr({1, 0}));
25559+
}
25560+
rhs = stablehlo::SliceOp::create(
25561+
rewriter, op.getLoc(), rhs, rewriter.getDenseI64ArrayAttr({0, 0}),
25562+
rewriter.getDenseI64ArrayAttr(
25563+
{diagLen, cast<ShapedType>(rhs.getType()).getDimSize(1)}),
25564+
rewriter.getDenseI64ArrayAttr({1, 1})); // [DiagSize, C]
25565+
25566+
auto newMul = stablehlo::MulOp::create(rewriter, op.getLoc(), lhs,
25567+
rhs); // [DiagSize, C]
25568+
25569+
auto elemTy = cast<RankedTensorType>(newMul.getType()).getElementType();
25570+
auto tenElemTy = RankedTensorType::get({}, elemTy);
25571+
auto reduceOp = stablehlo::ReduceOp::create(
25572+
rewriter, op.getLoc(), ValueRange(newMul.getResult()),
25573+
ValueRange(stablehlo::ConstantOp::create(
25574+
rewriter, op.getLoc(), tenElemTy,
25575+
cast<ElementsAttr>(makeAttr(tenElemTy, 0)))
25576+
.getResult()),
25577+
{1});
25578+
25579+
{
25580+
Region &region = reduceOp.getBody();
25581+
Block *block = rewriter.createBlock(&region);
25582+
block->addArgument(tenElemTy, op.getLoc());
25583+
block->addArgument(tenElemTy, op.getLoc());
25584+
25585+
OpBuilder::InsertionGuard guard(rewriter);
25586+
rewriter.setInsertionPointToStart(block);
25587+
auto addOp = stablehlo::AddOp::create(
25588+
rewriter, op.getLoc(), block->getArgument(0), block->getArgument(1));
25589+
stablehlo::ReturnOp::create(rewriter, op.getLoc(), addOp.getResult());
25590+
}
25591+
25592+
for (auto &opToReplace : opsToReplace) {
25593+
if (auto sliceOp = dyn_cast<stablehlo::SliceOp>(opToReplace)) {
25594+
replaceSliceOp(rewriter, sliceOp, reduceOp, M, N, diagLen);
25595+
} else {
25596+
assert(false && "Unknown op to replace. open an issue on github");
25597+
}
25598+
}
25599+
25600+
return success();
25601+
}
25602+
25603+
private:
25604+
void replaceSliceOp(PatternRewriter &rewriter, stablehlo::SliceOp sliceOp,
25605+
stablehlo::ReduceOp reduceOp, int64_t M, int64_t N,
25606+
int64_t diagLen) const {
25607+
int64_t start = sliceOp.getStartIndices()[0];
25608+
int64_t limit = sliceOp.getLimitIndices()[0];
25609+
int64_t stride = sliceOp.getStrides()[0];
25610+
int64_t diagStride = N + 1;
25611+
25612+
int64_t newStart = start / diagStride;
25613+
int64_t newLimit = (limit - 1) / diagStride + 1;
25614+
int64_t newStride = stride / diagStride;
25615+
25616+
rewriter.setInsertionPoint(sliceOp);
25617+
rewriter.replaceOpWithNewOp<stablehlo::SliceOp>(
25618+
sliceOp, reduceOp.getResult(0),
25619+
rewriter.getDenseI64ArrayAttr({newStart}),
25620+
rewriter.getDenseI64ArrayAttr({newLimit}),
25621+
rewriter.getDenseI64ArrayAttr({newStride}));
25622+
}
25623+
};
25624+
2548625625
/////////////// End Imported from stablehlo
2548725626

2548825627
// clang-format off
@@ -26117,7 +26256,8 @@ struct EnzymeHLOOptPass
2611726256
RemoveNoOpsFromWhileLoop,
2611826257
WhileIsCopySimplify,
2611926258
SplitVariadicScatterOp,
26120-
DynamicSliceSimplify
26259+
DynamicSliceSimplify,
26260+
DotGeneralOnlyDiagonalAccess
2612126261
>(context);
2612226262

2612326263
patterns.add<

src/enzyme_ad/jax/Passes/StructuredTensors.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "src/enzyme_ad/jax/Utils.h"
66
#include "stablehlo/dialect/StablehloOps.h"
77

8+
#include "llvm/ADT/SetVector.h"
9+
810
namespace mlir {
911
namespace enzyme {
1012

@@ -256,5 +258,89 @@ std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor) {
256258
return result;
257259
}
258260

261+
bool allAccessesAreOnMainDiagonalPostReshape(stablehlo::ReshapeOp op,
262+
stablehlo::SliceOp sliceOp) {
263+
auto reshapeInTy = cast<RankedTensorType>(op.getOperand().getType());
264+
auto reshapeOutTy = cast<RankedTensorType>(op.getType());
265+
266+
if (reshapeOutTy.getRank() != 1 ||
267+
reshapeInTy.getRank() != 2) // [M, N] -> [M * N] vector
268+
return false;
269+
270+
auto M = reshapeInTy.getDimSize(0);
271+
auto N = reshapeInTy.getDimSize(1);
272+
auto diagLen = std::min(M, N);
273+
auto diagStride = N + 1;
274+
275+
int64_t start = sliceOp.getStartIndices()[0];
276+
int64_t limit = sliceOp.getLimitIndices()[0];
277+
int64_t stride = sliceOp.getStrides()[0];
278+
279+
if (stride % diagStride != 0)
280+
return false;
281+
282+
// start can be on any of the diagonal elements
283+
if (start % diagStride != 0)
284+
return false;
285+
286+
if (limit > M * N)
287+
return false; // technically this is illegal
288+
289+
// sanity check
290+
int64_t count = (limit - start + stride - 1) / stride;
291+
if (count <= 0 || count > diagLen)
292+
return false;
293+
294+
return true;
295+
}
296+
297+
bool allAccessesAreOnMainDiagonalPostReshape(
298+
stablehlo::ReshapeOp op, Operation *user,
299+
llvm::SetVector<Operation *> &opsToReplace) {
300+
if (auto sliceOp = dyn_cast<stablehlo::SliceOp>(user)) {
301+
if (allAccessesAreOnMainDiagonalPostReshape(op, sliceOp)) {
302+
opsToReplace.insert(sliceOp);
303+
return true;
304+
}
305+
return false;
306+
}
307+
return false;
308+
}
309+
310+
bool allAccessesAreOnMainDiagonal(Operation *op,
311+
llvm::SetVector<Operation *> &opsToReplace) {
312+
if (auto reshapeOp = dyn_cast<stablehlo::ReshapeOp>(op)) {
313+
return allAccessesAreOnMainDiagonal(reshapeOp, opsToReplace);
314+
} else if (auto gatherOp = dyn_cast<stablehlo::GatherOp>(op)) {
315+
return allAccessesAreOnMainDiagonal(gatherOp, opsToReplace);
316+
}
317+
return false;
318+
}
319+
320+
bool allAccessesAreOnMainDiagonal(stablehlo::ReshapeOp op,
321+
llvm::SetVector<Operation *> &opsToReplace) {
322+
auto reshapeInTy = cast<RankedTensorType>(op.getOperand().getType());
323+
if (reshapeInTy.getRank() != 2) // [M, N] matrix
324+
return false; // quick exit
325+
326+
llvm::SmallPtrSet<Operation *, 4> seenOps;
327+
for (auto user : op->getUsers()) {
328+
if (seenOps.count(user))
329+
continue;
330+
331+
if (!allAccessesAreOnMainDiagonalPostReshape(op, user, opsToReplace))
332+
return false;
333+
334+
seenOps.insert(user);
335+
}
336+
337+
return true;
338+
}
339+
340+
bool allAccessesAreOnMainDiagonal(stablehlo::GatherOp op,
341+
llvm::SetVector<Operation *> &opsToReplace) {
342+
return false; // TODO: implement this where we are doing gather with iota
343+
}
344+
259345
} // namespace enzyme
260346
} // namespace mlir

src/enzyme_ad/jax/Passes/StructuredTensors.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "src/enzyme_ad/jax/Utils.h"
66
#include "stablehlo/dialect/StablehloOps.h"
77

8+
#include "llvm/ADT/SetVector.h"
9+
810
#include <optional>
911

1012
namespace mlir {
@@ -27,5 +29,15 @@ struct IotaLikeTensor {
2729

2830
std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor);
2931

32+
// TODO: we can do a full analysis and return if the access is on a specific set
33+
// of diagonals. Checks that all accesses for this Op and its users thereoff are
34+
// along the diagonal.
35+
bool allAccessesAreOnMainDiagonal(
36+
mlir::Operation *op, llvm::SetVector<mlir::Operation *> &opsToReplace);
37+
bool allAccessesAreOnMainDiagonal(
38+
stablehlo::ReshapeOp op, llvm::SetVector<mlir::Operation *> &opsToReplace);
39+
bool allAccessesAreOnMainDiagonal(
40+
stablehlo::GatherOp op, llvm::SetVector<mlir::Operation *> &opsToReplace);
41+
3042
} // namespace enzyme
3143
} // namespace mlir

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2472,3 +2472,8 @@ def EnzymeHLOUnroll : EnzymeHLOParameterizedPatternOp<
24722472
}
24732473
}];
24742474
}
2475+
2476+
def ApplyDotGeneralOnlyDiagonalAccessPatterns : EnzymeHLOPatternOp<
2477+
"dot_general_only_diagonal_access"> {
2478+
let patterns = ["DotGeneralOnlyDiagonalAccess"];
2479+
}

src/enzyme_ad/jax/primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def optimization_passes(
346346
"dynamic_pad_to_pad",
347347
"remove_no_ops_from_while_loop",
348348
"while_is_copy_simplify",
349+
"dot_general_only_diagonal_access",
349350
]
350351

351352
# constant propagation patterns

0 commit comments

Comments
 (0)