@@ -25483,6 +25483,145 @@ struct DynamicSliceSimplify
2548325483 }
2548425484};
2548525485
25486+ // TODO: generalize to higher ranked tensors
25487+ // TODO: if we determine that all accesses are on some offset diagonal,
25488+ // we can still replace it will a multiply combined with pad/slice
25489+ // If we prove that only the diagonal elements of a dot_general are accessed,
25490+ // we replace the dot_general with a cheaper multiply op. Note that
25491+ // this implies `diag(new_op(A, B)) = diag(A x B)` however
25492+ // `new_op(A, B) != A x B`
25493+ struct DotGeneralOnlyDiagonalAccess
25494+ : public CheckedOpRewritePattern<stablehlo::DotGeneralOp,
25495+ DotGeneralOnlyDiagonalAccess> {
25496+ using CheckedOpRewritePattern<
25497+ stablehlo::DotGeneralOp,
25498+ DotGeneralOnlyDiagonalAccess>::CheckedOpRewritePattern;
25499+
25500+ LogicalResult matchAndRewriteImpl(stablehlo::DotGeneralOp op,
25501+ PatternRewriter &rewriter) const {
25502+ auto resTy = cast<RankedTensorType>(op.getType());
25503+ if (resTy.getRank() != 2)
25504+ return failure();
25505+
25506+ auto M = resTy.getDimSize(0);
25507+ auto N = resTy.getDimSize(1);
25508+ auto diagLen = std::min(M, N);
25509+
25510+ auto lhs = op.getLhs();
25511+ auto rhs = op.getRhs();
25512+ auto dotDimNumbers = op.getDotDimensionNumbers();
25513+ auto lhsContractingDims = dotDimNumbers.getLhsContractingDimensions();
25514+ auto rhsContractingDims = dotDimNumbers.getRhsContractingDimensions();
25515+ auto lhsBatchingDims = dotDimNumbers.getLhsBatchingDimensions();
25516+ auto rhsBatchingDims = dotDimNumbers.getRhsBatchingDimensions();
25517+
25518+ if (lhsContractingDims.size() != 1 || rhsContractingDims.size() != 1 ||
25519+ lhsBatchingDims.size() != 0 || rhsBatchingDims.size() != 0)
25520+ return failure();
25521+
25522+ llvm::SetVector<Operation *> opsToReplace;
25523+ llvm::SmallPtrSet<Operation *, 4> seenOps;
25524+ for (auto user : op->getUsers()) {
25525+ if (seenOps.count(user))
25526+ continue;
25527+ if (!enzyme::allAccessesAreOnMainDiagonal(user, opsToReplace))
25528+ return failure();
25529+ seenOps.insert(user);
25530+ }
25531+
25532+ if (opsToReplace.empty())
25533+ return failure();
25534+
25535+ // rewrite the dot_general to a multiply.
25536+ // we insert transpose ops here, but those will get removed later
25537+ auto lhsContractDim = lhsContractingDims[0];
25538+ auto rhsContractDim = rhsContractingDims[0];
25539+ // result[i, i] = sum_k (lhs[i, k] * rhs[k, i])
25540+ // = reduce_sum(lhs[i, :] * rhs[:, i])
25541+ auto lhsNonContractDim = 1 - lhsContractDim;
25542+ auto rhsNonContractDim = 1 - rhsContractDim;
25543+
25544+ if (lhsContractDim == 0) {
25545+ // move to dim = 1
25546+ lhs = stablehlo::TransposeOp::create(
25547+ rewriter, op.getLoc(), lhs, rewriter.getDenseI64ArrayAttr({1, 0}));
25548+ }
25549+ lhs = stablehlo::SliceOp::create(
25550+ rewriter, op.getLoc(), lhs, rewriter.getDenseI64ArrayAttr({0, 0}),
25551+ rewriter.getDenseI64ArrayAttr(
25552+ {diagLen, cast<ShapedType>(lhs.getType()).getDimSize(1)}),
25553+ rewriter.getDenseI64ArrayAttr({1, 1})); // [DiagSize, C]
25554+
25555+ if (rhsContractDim == 0) {
25556+ // move to dim = 1
25557+ rhs = stablehlo::TransposeOp::create(
25558+ rewriter, op.getLoc(), rhs, rewriter.getDenseI64ArrayAttr({1, 0}));
25559+ }
25560+ rhs = stablehlo::SliceOp::create(
25561+ rewriter, op.getLoc(), rhs, rewriter.getDenseI64ArrayAttr({0, 0}),
25562+ rewriter.getDenseI64ArrayAttr(
25563+ {diagLen, cast<ShapedType>(rhs.getType()).getDimSize(1)}),
25564+ rewriter.getDenseI64ArrayAttr({1, 1})); // [DiagSize, C]
25565+
25566+ auto newMul = stablehlo::MulOp::create(rewriter, op.getLoc(), lhs,
25567+ rhs); // [DiagSize, C]
25568+
25569+ auto elemTy = cast<RankedTensorType>(newMul.getType()).getElementType();
25570+ auto tenElemTy = RankedTensorType::get({}, elemTy);
25571+ auto reduceOp = stablehlo::ReduceOp::create(
25572+ rewriter, op.getLoc(), ValueRange(newMul.getResult()),
25573+ ValueRange(stablehlo::ConstantOp::create(
25574+ rewriter, op.getLoc(), tenElemTy,
25575+ cast<ElementsAttr>(makeAttr(tenElemTy, 0)))
25576+ .getResult()),
25577+ {1});
25578+
25579+ {
25580+ Region ®ion = reduceOp.getBody();
25581+ Block *block = rewriter.createBlock(®ion);
25582+ block->addArgument(tenElemTy, op.getLoc());
25583+ block->addArgument(tenElemTy, op.getLoc());
25584+
25585+ OpBuilder::InsertionGuard guard(rewriter);
25586+ rewriter.setInsertionPointToStart(block);
25587+ auto addOp = stablehlo::AddOp::create(
25588+ rewriter, op.getLoc(), block->getArgument(0), block->getArgument(1));
25589+ stablehlo::ReturnOp::create(rewriter, op.getLoc(), addOp.getResult());
25590+ }
25591+
25592+ for (auto &opToReplace : opsToReplace) {
25593+ if (auto sliceOp = dyn_cast<stablehlo::SliceOp>(opToReplace)) {
25594+ replaceSliceOp(rewriter, sliceOp, reduceOp, M, N, diagLen);
25595+ } else {
25596+ assert(false && "Unknown op to replace. open an issue on github");
25597+ }
25598+ }
25599+
25600+ return success();
25601+ }
25602+
25603+ private:
25604+ void replaceSliceOp(PatternRewriter &rewriter, stablehlo::SliceOp sliceOp,
25605+ stablehlo::ReduceOp reduceOp, int64_t M, int64_t N,
25606+ int64_t diagLen) const {
25607+ int64_t start = sliceOp.getStartIndices()[0];
25608+ int64_t limit = sliceOp.getLimitIndices()[0];
25609+ int64_t stride = sliceOp.getStrides()[0];
25610+ int64_t diagStride = N + 1;
25611+
25612+ int64_t newStart = start / diagStride;
25613+ int64_t newLimit = (limit - 1) / diagStride + 1;
25614+ int64_t newStride = stride / diagStride;
25615+
25616+ rewriter.setInsertionPoint(sliceOp);
25617+ rewriter.replaceOpWithNewOp<stablehlo::SliceOp>(
25618+ sliceOp, reduceOp.getResult(0),
25619+ rewriter.getDenseI64ArrayAttr({newStart}),
25620+ rewriter.getDenseI64ArrayAttr({newLimit}),
25621+ rewriter.getDenseI64ArrayAttr({newStride}));
25622+ }
25623+ };
25624+
2548625625/////////////// End Imported from stablehlo
2548725626
2548825627// clang-format off
@@ -26117,7 +26256,8 @@ struct EnzymeHLOOptPass
2611726256 RemoveNoOpsFromWhileLoop,
2611826257 WhileIsCopySimplify,
2611926258 SplitVariadicScatterOp,
26120- DynamicSliceSimplify
26259+ DynamicSliceSimplify,
26260+ DotGeneralOnlyDiagonalAccess
2612126261 >(context);
2612226262
2612326263 patterns.add<
0 commit comments