|
| 1 | +//===-- LLZKFuseProductLoopsPass.cpp -----------------------------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLZK Project, under the Apache License v2.0. |
| 4 | +// See LICENSE.txt for license information. |
| 5 | +// Copyright 2026 Project LLZK |
| 6 | +// SPDX-License-Identifier: Apache-2.0 |
| 7 | +// |
| 8 | +//===----------------------------------------------------------------------===// |
| 9 | +/// |
| 10 | +/// \file |
| 11 | +/// This file implements the `-llzk-fuse-product-loops` pass. |
| 12 | +/// |
| 13 | +//===----------------------------------------------------------------------===// |
| 14 | + |
| 15 | +#include "llzk/Dialect/Function/IR/Ops.h" |
| 16 | +#include "llzk/Dialect/Polymorphic/IR/Ops.h" |
| 17 | +#include "llzk/Dialect/Struct/IR/Ops.h" |
| 18 | +#include "llzk/Transforms/LLZKFuseProductLoopsPass.h" |
| 19 | +#include "llzk/Transforms/LLZKTransformationPasses.h" |
| 20 | +#include "llzk/Util/AlignmentHelper.h" |
| 21 | +#include "llzk/Util/Constants.h" |
| 22 | + |
| 23 | +#include <mlir/Dialect/SCF/Utils/Utils.h> |
| 24 | + |
| 25 | +#include <llvm/Support/Debug.h> |
| 26 | +#include <llvm/Support/SMTAPI.h> |
| 27 | + |
| 28 | +#include <memory> |
| 29 | + |
| 30 | +namespace llzk { |
| 31 | + |
| 32 | +#define GEN_PASS_DECL_FUSEPRODUCTLOOPSPASS |
| 33 | +#define GEN_PASS_DEF_FUSEPRODUCTLOOPSPASS |
| 34 | +#include "llzk/Transforms/LLZKTransformationPasses.h.inc" |
| 35 | + |
| 36 | +using namespace llzk::function; |
| 37 | + |
| 38 | +// Bitwidth of `index` for instantiating SMT variables |
| 39 | +constexpr int INDEX_WIDTH = 64; |
| 40 | + |
| 41 | +class FuseProductLoopsPass : public impl::FuseProductLoopsPassBase<FuseProductLoopsPass> { |
| 42 | + |
| 43 | +public: |
| 44 | + void runOnOperation() override { |
| 45 | + mlir::ModuleOp mod = getOperation(); |
| 46 | + mod.walk([this](FuncDefOp funcDef) { |
| 47 | + if (funcDef.isStructProduct()) { |
| 48 | + if (mlir::failed(fuseMatchingLoopPairs(funcDef.getFunctionBody(), &getContext()))) { |
| 49 | + signalPassFailure(); |
| 50 | + } |
| 51 | + } |
| 52 | + }); |
| 53 | + } |
| 54 | +}; |
| 55 | + |
| 56 | +static inline bool isConstOrStructParam(mlir::Value val) { |
| 57 | + // TODO: doing arithmetic over constants should also be fine? |
| 58 | + return val.getDefiningOp<mlir::arith::ConstantIndexOp>() || |
| 59 | + val.getDefiningOp<llzk::polymorphic::ConstReadOp>(); |
| 60 | +} |
| 61 | + |
| 62 | +llvm::SMTExprRef mkExpr(mlir::Value value, llvm::SMTSolver *solver) { |
| 63 | + if (auto constOp = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) { |
| 64 | + return solver->mkBitvector(llvm::APSInt::get(constOp.value()), INDEX_WIDTH); |
| 65 | + } else if (auto polyReadOp = value.getDefiningOp<llzk::polymorphic::ConstReadOp>()) { |
| 66 | + |
| 67 | + return solver->mkSymbol( |
| 68 | + std::string {polyReadOp.getConstName()}.c_str(), solver->getBitvectorSort(INDEX_WIDTH) |
| 69 | + ); |
| 70 | + } |
| 71 | + assert(false && "unsupported: checking non-constant trip counts"); |
| 72 | + return nullptr; // Unreachable |
| 73 | +} |
| 74 | + |
| 75 | +llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver) { |
| 76 | + const auto *one = solver->mkBitvector(llvm::APSInt::get(1), INDEX_WIDTH); |
| 77 | + return solver->mkBVSDiv( |
| 78 | + solver->mkBVAdd( |
| 79 | + one, |
| 80 | + solver->mkBVSub(mkExpr(op.getUpperBound(), solver), mkExpr(op.getLowerBound(), solver)) |
| 81 | + ), |
| 82 | + mkExpr(op.getStep(), solver) |
| 83 | + ); |
| 84 | +} |
| 85 | + |
| 86 | +static inline bool canLoopsBeFused(mlir::scf::ForOp a, mlir::scf::ForOp b) { |
| 87 | + // A priori, two loops can be fused if: |
| 88 | + // 1. They live in the same parent region, |
| 89 | + // 2. One comes from witgen and the other comes from constraint gen, and |
| 90 | + // 3. They have the same trip count |
| 91 | + |
| 92 | + // Check 1. |
| 93 | + if (a->getParentRegion() != b->getParentRegion()) { |
| 94 | + return false; |
| 95 | + } |
| 96 | + |
| 97 | + // Check 2. |
| 98 | + if (!a->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE) || |
| 99 | + !b->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) { |
| 100 | + // Ideally this should never happen, since the pass only runs on fused @product functions, but |
| 101 | + // check anyway just to be safe |
| 102 | + return false; |
| 103 | + } |
| 104 | + if (a->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE) == |
| 105 | + b->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) { |
| 106 | + return false; |
| 107 | + } |
| 108 | + |
| 109 | + // Check 3. |
| 110 | + // Easy case: both have a constant trip-count. If the trip counts are not "constant up to a struct |
| 111 | + // param", we definitely can't tell if they're equal. If the trip counts are only "constant up to |
| 112 | + // a struct param" but not actually constant, we can ask a solver if the equations are guaranteed |
| 113 | + // to be the same |
| 114 | + auto tripCountA = mlir::constantTripCount(a.getLowerBound(), a.getUpperBound(), a.getStep()); |
| 115 | + auto tripCountB = mlir::constantTripCount(b.getLowerBound(), b.getUpperBound(), b.getStep()); |
| 116 | + if (tripCountA.has_value() && tripCountB.has_value() && *tripCountA == *tripCountB) { |
| 117 | + return true; |
| 118 | + } |
| 119 | + |
| 120 | + if (!isConstOrStructParam(a.getLowerBound()) || !isConstOrStructParam(a.getUpperBound()) || |
| 121 | + !isConstOrStructParam(a.getStep()) || !isConstOrStructParam(b.getLowerBound()) || |
| 122 | + !isConstOrStructParam(b.getUpperBound()) || !isConstOrStructParam(b.getStep())) { |
| 123 | + return false; |
| 124 | + } |
| 125 | + |
| 126 | + llvm::SMTSolverRef solver = llvm::CreateZ3Solver(); |
| 127 | + solver->addConstraint(/* (actually ask if they "can't be different") */ solver->mkNot( |
| 128 | + solver->mkEqual(tripCount(a, solver.get()), tripCount(b, solver.get())) |
| 129 | + )); |
| 130 | + |
| 131 | + return !*solver->check(); |
| 132 | +} |
| 133 | + |
| 134 | +mlir::LogicalResult fuseMatchingLoopPairs(mlir::Region &body, mlir::MLIRContext *context) { |
| 135 | + // Start by collecting all possible loops |
| 136 | + llvm::SmallVector<mlir::scf::ForOp> witnessLoops, constraintLoops; |
| 137 | + body.walk<mlir::WalkOrder::PreOrder>([&witnessLoops, &constraintLoops](mlir::scf::ForOp forOp) { |
| 138 | + if (!forOp->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) { |
| 139 | + return mlir::WalkResult::skip(); |
| 140 | + } |
| 141 | + auto productSource = forOp->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE); |
| 142 | + if (productSource == FUNC_NAME_COMPUTE) { |
| 143 | + witnessLoops.push_back(forOp); |
| 144 | + } else if (productSource == FUNC_NAME_CONSTRAIN) { |
| 145 | + constraintLoops.push_back(forOp); |
| 146 | + } |
| 147 | + // Skipping here, because any nested loops can't possibly be fused at this stage |
| 148 | + return mlir::WalkResult::skip(); |
| 149 | + }); |
| 150 | + |
| 151 | + // A pair of loops will be fused iff (1) they can be fused according to the rules above, and (2) |
| 152 | + // neither can be fused with anything else (so there's no ambiguity) |
| 153 | + auto fusionCandidates = alignmentHelpers::getMatchingPairs<mlir::scf::ForOp>( |
| 154 | + witnessLoops, constraintLoops, canLoopsBeFused |
| 155 | + ); |
| 156 | + |
| 157 | + // This shouldn't happen, since we allow partial matches |
| 158 | + if (mlir::failed(fusionCandidates)) { |
| 159 | + return mlir::failure(); |
| 160 | + } |
| 161 | + |
| 162 | + // Finally, fuse all the marked loops... |
| 163 | + mlir::IRRewriter rewriter {context}; |
| 164 | + for (auto [w, c] : *fusionCandidates) { |
| 165 | + auto fusedLoop = mlir::fuseIndependentSiblingForLoops(w, c, rewriter); |
| 166 | + fusedLoop->setAttr(PRODUCT_SOURCE, rewriter.getAttr<mlir::StringAttr>("fused")); |
| 167 | + // ...and recurse to fuse nested loops |
| 168 | + if (mlir::failed(fuseMatchingLoopPairs(fusedLoop.getBodyRegion(), context))) { |
| 169 | + return mlir::failure(); |
| 170 | + } |
| 171 | + } |
| 172 | + return mlir::success(); |
| 173 | +} |
| 174 | + |
| 175 | +std::unique_ptr<mlir::Pass> createFuseProductLoopsPass() { |
| 176 | + return std::make_unique<FuseProductLoopsPass>(); |
| 177 | +} |
| 178 | +} // namespace llzk |
0 commit comments