-
Notifications
You must be signed in to change notification settings - Fork 9
265 fuse computeconstrain loops in product program when possible #277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
raghav198
merged 18 commits into
main
from
265-fuse-computeconstrain-loops-in-product-program-when-possible
Feb 6, 2026
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
780e9e4
Pass skeleton
297ea97
Really basic fusion works
1d0599e
oops
2878743
ASJKDHLKJASHDA:SJHD LKAS UGH UGUH I HATE MACOS
d1ebb15
Finally
f9bbeaa
Fusing in the right order now
f21bccc
Added testcases
a35d004
Some cleanup + changelog
a692746
Formatting
8926458
Apply suggestions from code review
raghav198 1aa5ded
Code review comments
aee6b84
Ensuring no bad loops
62a3acf
Refactor: can perform transformation on-demand
4dbf36a
license header
8d118f2
Apply suggestions from code review
raghav198 3b0cdbe
Merge branch 'main' into 265-fuse-computeconstrain-loops-in-product-p…
f944feb
Merge branch 'main' into 265-fuse-computeconstrain-loops-in-product-p…
d46abc9
Updated testcases: veridise.lang -> llzk.lang, field -> member
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
2 changes: 2 additions & 0 deletions
2
changelogs/unreleased/265-fuse-computeconstrain-loops-in-product-program-when-possible.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| added: | ||
| - Added "llzk-fuse-product-loops" transformation pass that matches and fuses pairs of scf.for loops from @compute and @constrain |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| //===-- LLZKFuseProductLoopsPass.h-------------------------------*- C++ -*-===// | ||
| // | ||
| // Part of the LLZK Project, under the Apache License v2.0. | ||
| // See LICENSE.txt for license information. | ||
| // Copyright 2026 Project LLZK | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <mlir/IR/Region.h> | ||
| #include <mlir/Support/LogicalResult.h> | ||
|
|
||
| namespace llzk { | ||
| /// Identify pairs of `scf.for` loops that can be fused, fuse them, and then | ||
| /// recurse to fuse nested loops. | ||
| mlir::LogicalResult fuseMatchingLoopPairs(mlir::Region &body, mlir::MLIRContext *context); | ||
| } // namespace llzk |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| //===-- AlignmentHelper.h --------------------------------------*- C++ -*-===// | ||
| // | ||
| // Part of the LLZK Project, under the Apache License v2.0. | ||
| // See LICENSE.txt for license information. | ||
| // Copyright 2026 Project LLZK | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <llvm/ADT/SetVector.h> | ||
| #include <llvm/ADT/SmallVectorExtras.h> | ||
| #include <llvm/Support/Debug.h> | ||
| #include <llvm/Support/LogicalResult.h> | ||
|
|
||
| #include <concepts> | ||
|
|
||
| namespace llzk::alignmentHelpers { | ||
|
|
||
| template <class ValueT, class FnT> | ||
| concept Matcher = requires(FnT fn, ValueT val) { | ||
| { fn(val, val) } -> std::convertible_to<bool>; | ||
| }; | ||
|
|
||
| template <class ValueT, class FnT> | ||
| requires Matcher<ValueT, FnT> | ||
| llvm::FailureOr<llvm::SetVector<std::pair<ValueT, ValueT>>> getMatchingPairs( | ||
| llvm::ArrayRef<ValueT> as, llvm::ArrayRef<ValueT> bs, FnT doesMatch, bool allowPartial = true | ||
| ) { | ||
|
|
||
| llvm::SetVector<ValueT> setA {as.begin(), as.end()}, setB {bs.begin(), bs.end()}; | ||
| llvm::DenseMap<size_t, llvm::SmallVector<size_t>> possibleMatchesA, possibleMatchesB; | ||
|
|
||
| for (size_t i = 0, ea = as.size(), eb = bs.size(); i < ea; i++) { | ||
| for (size_t j = 0; j < eb; j++) { | ||
| if (doesMatch(as[i], bs[j])) { | ||
| possibleMatchesA[i].push_back(j); | ||
| possibleMatchesB[j].push_back(i); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| llvm::SetVector<std::pair<ValueT, ValueT>> matches; | ||
| for (auto [a, b] : possibleMatchesA) { | ||
| if (b.size() == 1 && possibleMatchesB[b[0]].size() == 1) { | ||
| setA.remove(as[a]); | ||
| setB.remove(bs[b[0]]); | ||
| matches.insert({as[a], bs[b[0]]}); | ||
| } | ||
| } | ||
|
|
||
| if ((!setA.empty() || !setB.empty()) && !allowPartial) { | ||
| return llvm::failure(); | ||
| } | ||
| return matches; | ||
| } | ||
| } // namespace llzk::alignmentHelpers |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| //===-- LLZKFuseProductLoopsPass.cpp -----------------------------*- C++ -*-===// | ||
| // | ||
| // Part of the LLZK Project, under the Apache License v2.0. | ||
| // See LICENSE.txt for license information. | ||
| // Copyright 2026 Project LLZK | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| /// | ||
| /// \file | ||
| /// This file implements the `-llzk-fuse-product-loops` pass. | ||
| /// | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "llzk/Dialect/Function/IR/Ops.h" | ||
| #include "llzk/Dialect/Polymorphic/IR/Ops.h" | ||
| #include "llzk/Dialect/Struct/IR/Ops.h" | ||
| #include "llzk/Transforms/LLZKFuseProductLoopsPass.h" | ||
| #include "llzk/Transforms/LLZKTransformationPasses.h" | ||
| #include "llzk/Util/AlignmentHelper.h" | ||
| #include "llzk/Util/Constants.h" | ||
|
|
||
| #include <mlir/Dialect/SCF/Utils/Utils.h> | ||
|
|
||
| #include <llvm/Support/Debug.h> | ||
| #include <llvm/Support/SMTAPI.h> | ||
|
|
||
| #include <memory> | ||
|
|
||
| namespace llzk { | ||
|
|
||
| #define GEN_PASS_DECL_FUSEPRODUCTLOOPSPASS | ||
| #define GEN_PASS_DEF_FUSEPRODUCTLOOPSPASS | ||
| #include "llzk/Transforms/LLZKTransformationPasses.h.inc" | ||
|
|
||
| using namespace llzk::function; | ||
|
|
||
| // Bitwidth of `index` for instantiating SMT variables | ||
| constexpr int INDEX_WIDTH = 64; | ||
|
|
||
| class FuseProductLoopsPass : public impl::FuseProductLoopsPassBase<FuseProductLoopsPass> { | ||
|
|
||
| public: | ||
| void runOnOperation() override { | ||
| mlir::ModuleOp mod = getOperation(); | ||
| mod.walk([this](FuncDefOp funcDef) { | ||
| if (funcDef.isStructProduct()) { | ||
| if (mlir::failed(fuseMatchingLoopPairs(funcDef.getFunctionBody(), &getContext()))) { | ||
| signalPassFailure(); | ||
| } | ||
| } | ||
| }); | ||
| } | ||
| }; | ||
|
|
||
| static inline bool isConstOrStructParam(mlir::Value val) { | ||
| // TODO: doing arithmetic over constants should also be fine? | ||
| return val.getDefiningOp<mlir::arith::ConstantIndexOp>() || | ||
| val.getDefiningOp<llzk::polymorphic::ConstReadOp>(); | ||
| } | ||
|
|
||
| llvm::SMTExprRef mkExpr(mlir::Value value, llvm::SMTSolver *solver) { | ||
| if (auto constOp = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) { | ||
| return solver->mkBitvector(llvm::APSInt::get(constOp.value()), INDEX_WIDTH); | ||
| } else if (auto polyReadOp = value.getDefiningOp<llzk::polymorphic::ConstReadOp>()) { | ||
|
|
||
iangneal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return solver->mkSymbol( | ||
| std::string {polyReadOp.getConstName()}.c_str(), solver->getBitvectorSort(INDEX_WIDTH) | ||
| ); | ||
| } | ||
| assert(false && "unsupported: checking non-constant trip counts"); | ||
| return nullptr; // Unreachable | ||
| } | ||
|
|
||
| llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver) { | ||
| const auto *one = solver->mkBitvector(llvm::APSInt::get(1), INDEX_WIDTH); | ||
| return solver->mkBVSDiv( | ||
| solver->mkBVAdd( | ||
| one, | ||
| solver->mkBVSub(mkExpr(op.getUpperBound(), solver), mkExpr(op.getLowerBound(), solver)) | ||
| ), | ||
| mkExpr(op.getStep(), solver) | ||
| ); | ||
| } | ||
|
|
||
| static inline bool canLoopsBeFused(mlir::scf::ForOp a, mlir::scf::ForOp b) { | ||
| // A priori, two loops can be fused if: | ||
| // 1. They live in the same parent region, | ||
| // 2. One comes from witgen and the other comes from constraint gen, and | ||
| // 3. They have the same trip count | ||
|
|
||
| // Check 1. | ||
| if (a->getParentRegion() != b->getParentRegion()) { | ||
| return false; | ||
| } | ||
|
|
||
| // Check 2. | ||
| if (!a->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE) || | ||
| !b->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) { | ||
| // Ideally this should never happen, since the pass only runs on fused @product functions, but | ||
| // check anyway just to be safe | ||
| return false; | ||
| } | ||
| if (a->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE) == | ||
| b->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) { | ||
| return false; | ||
| } | ||
|
|
||
| // Check 3. | ||
| // Easy case: both have a constant trip-count. If the trip counts are not "constant up to a struct | ||
| // param", we definitely can't tell if they're equal. If the trip counts are only "constant up to | ||
| // a struct param" but not actually constant, we can ask a solver if the equations are guaranteed | ||
| // to be the same | ||
| auto tripCountA = mlir::constantTripCount(a.getLowerBound(), a.getUpperBound(), a.getStep()); | ||
| auto tripCountB = mlir::constantTripCount(b.getLowerBound(), b.getUpperBound(), b.getStep()); | ||
| if (tripCountA.has_value() && tripCountB.has_value() && *tripCountA == *tripCountB) { | ||
| return true; | ||
| } | ||
|
|
||
| if (!isConstOrStructParam(a.getLowerBound()) || !isConstOrStructParam(a.getUpperBound()) || | ||
| !isConstOrStructParam(a.getStep()) || !isConstOrStructParam(b.getLowerBound()) || | ||
| !isConstOrStructParam(b.getUpperBound()) || !isConstOrStructParam(b.getStep())) { | ||
| return false; | ||
| } | ||
|
|
||
| llvm::SMTSolverRef solver = llvm::CreateZ3Solver(); | ||
| solver->addConstraint(/* (actually ask if they "can't be different") */ solver->mkNot( | ||
| solver->mkEqual(tripCount(a, solver.get()), tripCount(b, solver.get())) | ||
| )); | ||
|
|
||
| return !*solver->check(); | ||
| } | ||
|
|
||
| mlir::LogicalResult fuseMatchingLoopPairs(mlir::Region &body, mlir::MLIRContext *context) { | ||
| // Start by collecting all possible loops | ||
| llvm::SmallVector<mlir::scf::ForOp> witnessLoops, constraintLoops; | ||
| body.walk<mlir::WalkOrder::PreOrder>([&witnessLoops, &constraintLoops](mlir::scf::ForOp forOp) { | ||
| if (!forOp->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) { | ||
| return mlir::WalkResult::skip(); | ||
| } | ||
| auto productSource = forOp->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE); | ||
| if (productSource == FUNC_NAME_COMPUTE) { | ||
| witnessLoops.push_back(forOp); | ||
| } else if (productSource == FUNC_NAME_CONSTRAIN) { | ||
| constraintLoops.push_back(forOp); | ||
| } | ||
| // Skipping here, because any nested loops can't possibly be fused at this stage | ||
| return mlir::WalkResult::skip(); | ||
| }); | ||
|
|
||
| // A pair of loops will be fused iff (1) they can be fused according to the rules above, and (2) | ||
| // neither can be fused with anything else (so there's no ambiguity) | ||
| auto fusionCandidates = alignmentHelpers::getMatchingPairs<mlir::scf::ForOp>( | ||
| witnessLoops, constraintLoops, canLoopsBeFused | ||
| ); | ||
|
|
||
| // This shouldn't happen, since we allow partial matches | ||
| if (mlir::failed(fusionCandidates)) { | ||
| return mlir::failure(); | ||
| } | ||
|
|
||
| // Finally, fuse all the marked loops... | ||
| mlir::IRRewriter rewriter {context}; | ||
| for (auto [w, c] : *fusionCandidates) { | ||
| auto fusedLoop = mlir::fuseIndependentSiblingForLoops(w, c, rewriter); | ||
| fusedLoop->setAttr(PRODUCT_SOURCE, rewriter.getAttr<mlir::StringAttr>("fused")); | ||
| // ...and recurse to fuse nested loops | ||
| if (mlir::failed(fuseMatchingLoopPairs(fusedLoop.getBodyRegion(), context))) { | ||
| return mlir::failure(); | ||
| } | ||
| } | ||
| return mlir::success(); | ||
| } | ||
|
|
||
| std::unique_ptr<mlir::Pass> createFuseProductLoopsPass() { | ||
| return std::make_unique<FuseProductLoopsPass>(); | ||
| } | ||
| } // namespace llzk | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| // RUN: llzk-opt --llzk-compute-constrain-to-product="root-struct=A" --llzk-fuse-product-loops %s | FileCheck --enable-var-scope %s | ||
|
|
||
| module attributes {llzk.lang = "llzk"} { | ||
| struct.def @A<[@N]> { | ||
| struct.member @arr : !array.type<@N x !felt.type> | ||
|
|
||
| function.def @compute(%as: !array.type<@N x !felt.type>) -> !struct.type<@A<[@N]>> { | ||
| %self = struct.new : <@A<[@N]>> | ||
|
|
||
| %N = poly.read_const @N : index | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
|
|
||
| %arr = array.new : !array.type<@N x !felt.type> | ||
|
|
||
| scf.for %i = %c0 to %N step %c1 { | ||
| %sum0 = felt.const 0 | ||
| %sum = scf.for %j = %i to %N step %c1 iter_args(%iter_sum = %sum0) -> !felt.type { | ||
| %a_j = array.read %as[%j] : !array.type<@N x !felt.type>, !felt.type | ||
| %0 = felt.add %iter_sum, %a_j | ||
| scf.yield %0 : !felt.type | ||
| } | ||
| array.write %arr[%i] = %sum : !array.type<@N x !felt.type>, !felt.type | ||
| scf.yield | ||
| } | ||
|
|
||
| struct.writem %self[@arr] = %arr : !struct.type<@A<[@N]>>, !array.type<@N x !felt.type> | ||
| function.return %self : !struct.type<@A<[@N]>> | ||
| } | ||
| function.def @constrain(%self : !struct.type<@A<[@N]>>, %as: !array.type<@N x !felt.type>) { | ||
| %N = poly.read_const @N : index | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
|
|
||
| %arr = struct.readm %self[@arr] : !struct.type<@A<[@N]>>, !array.type<@N x !felt.type> | ||
|
|
||
| scf.for %i = %c0 to %N step %c1 { | ||
| %j = arith.addi %i, %c1 : index | ||
|
|
||
| %arr_i = array.read %arr[%i] : !array.type<@N x !felt.type>, !felt.type | ||
| %arr_j = array.read %arr[%j] : !array.type<@N x !felt.type>, !felt.type | ||
| %a = array.read %as[%i] : !array.type<@N x !felt.type>, !felt.type | ||
| %diff = felt.sub %arr_j, %arr_i | ||
| constrain.eq %a, %diff : !felt.type | ||
| scf.yield | ||
| } | ||
|
|
||
| function.return | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // CHECK-LABEL: scf.for | ||
iangneal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]] = | ||
| // CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]] to | ||
| // CHECK-SAME: %[[VAL_2:[0-9a-zA-Z_\.]+]] step | ||
| // CHECK-SAME: %[[VAL_3:[0-9a-zA-Z_\.]+]] { | ||
| // CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.const 0 {product_source = "compute"} | ||
| // CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = scf.for %[[VAL_6:[0-9a-zA-Z_\.]+]] = %[[VAL_0]] to %[[VAL_7:[0-9a-zA-Z_\.]+]] step %[[VAL_8:[0-9a-zA-Z_\.]+]] iter_args(%[[VAL_9:[0-9a-zA-Z_\.]+]] = %[[VAL_4]]) -> (!felt.type) { | ||
| // CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_11:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_6]]] : <@N x !felt.type>, !felt.type {product_source = "compute"} | ||
| // CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type {product_source = "compute"} | ||
| // CHECK-NEXT: scf.yield {product_source = "compute"} %[[VAL_12]] : !felt.type | ||
| // CHECK-NEXT: } {product_source = "compute"} | ||
| // CHECK-NEXT: array.write %[[VAL_13:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] = %[[VAL_5]] : <@N x !felt.type>, !felt.type {product_source = "compute"} | ||
| // CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = arith.addi %[[VAL_0]], %[[VAL_3]] {product_source = "constrain"} : index | ||
| // CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_16:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"} | ||
| // CHECK-NEXT: %[[VAL_17:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_16]]{{\[}}%[[VAL_14]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"} | ||
| // CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"} | ||
| // CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = felt.sub %[[VAL_17]], %[[VAL_15]] : !felt.type, !felt.type {product_source = "constrain"} | ||
| // CHECK-NEXT: constrain.eq %[[VAL_18]], %[[VAL_20]] : !felt.type, !felt.type {product_source = "constrain"} | ||
| // CHECK-NEXT: } {product_source = "fused"} | ||
| // CHECK-NOT: scf.for | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.