-
Notifications
You must be signed in to change notification settings - Fork 9
265 fuse computeconstrain loops in product program when possible #277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
raghav198
merged 18 commits into
main
from
265-fuse-computeconstrain-loops-in-product-program-when-possible
Feb 6, 2026
Merged
Changes from 9 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
780e9e4
Pass skeleton
297ea97
Really basic fusion works
1d0599e
oops
2878743
ASJKDHLKJASHDA:SJHD LKAS UGH UGUH I HATE MACOS
d1ebb15
Finally
f9bbeaa
Fusing in the right order now
f21bccc
Added testcases
a35d004
Some cleanup + changelog
a692746
Formatting
8926458
Apply suggestions from code review
raghav198 1aa5ded
Code review comments
aee6b84
Ensuring no bad loops
62a3acf
Refactor: can perform transformation on-demand
4dbf36a
license header
8d118f2
Apply suggestions from code review
raghav198 3b0cdbe
Merge branch 'main' into 265-fuse-computeconstrain-loops-in-product-p…
f944feb
Merge branch 'main' into 265-fuse-computeconstrain-loops-in-product-p…
d46abc9
Updated testcases: veridise.lang -> llzk.lang, field -> member
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
2 changes: 2 additions & 0 deletions
2
changelogs/unreleased/265-fuse-computeconstrain-loops-in-product-program-when-possible.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| added: | ||
| - Added "llzk-fuse-product-loops" transformation pass that matches and fuses pairs of scf.for loops from @compute and @constrain |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| //===-- AlignmentHelper.h ------------------------------------------*- C++ -*-===// | ||
| // | ||
| // Part of the LLZK Project, under the Apache License v2.0. | ||
| // See LICENSE.txt for license information. | ||
| // Copyright 2025 Veridise Inc. | ||
raghav198 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <llvm/ADT/SetVector.h> | ||
| #include <llvm/ADT/SmallVectorExtras.h> | ||
| #include <llvm/Support/Debug.h> | ||
| #include <llvm/Support/LogicalResult.h> | ||
|
|
||
| #include <concepts> | ||
|
|
||
| namespace llzk::alignmentHelpers { | ||
|
|
||
| template <class ValueT, class FnT> | ||
| concept Matcher = requires(FnT fn, ValueT val) { | ||
| { fn(val, val) } -> std::convertible_to<bool>; | ||
| }; | ||
|
|
||
| template <class ValueT, class FnT> | ||
| requires Matcher<ValueT, FnT> | ||
| llvm::FailureOr<llvm::SetVector<std::pair<ValueT, ValueT>>> getMatchingPairs( | ||
| llvm::ArrayRef<ValueT> as, llvm::ArrayRef<ValueT> bs, FnT doesMatch, bool allowPartial = true | ||
| ) { | ||
|
|
||
| llvm::SetVector<ValueT> setA {as.begin(), as.end()}, setB {bs.begin(), bs.end()}; | ||
| llvm::DenseMap<size_t, llvm::SmallVector<size_t>> possibleMatchesA, possibleMatchesB; | ||
|
|
||
| for (size_t i = 0; i < as.size(); i++) { | ||
| for (size_t j = 0; j < bs.size(); j++) { | ||
raghav198 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if (doesMatch(as[i], bs[j])) { | ||
| possibleMatchesA[i].push_back(j); | ||
| possibleMatchesB[j].push_back(i); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| llvm::SetVector<std::pair<ValueT, ValueT>> matches; | ||
| for (auto [a, b] : possibleMatchesA) { | ||
| if (b.size() == 1 && possibleMatchesB[b[0]].size() == 1) { | ||
| setA.remove(as[a]); | ||
| setB.remove(bs[b[0]]); | ||
| matches.insert({as[a], bs[b[0]]}); | ||
| } | ||
| } | ||
|
|
||
| if ((!setA.empty() || !setB.empty()) && !allowPartial) { | ||
| return llvm::failure(); | ||
| } | ||
| return matches; | ||
| } | ||
| } // namespace llzk::alignmentHelpers | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| //===-- LLZKComputeConstrainToProductPass.cpp -------------------*- C++ -*-===// | ||
raghav198 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // | ||
| // Part of the LLZK Project, under the Apache License v2.0. | ||
| // See LICENSE.txt for license information. | ||
| // Copyright 2025 Veridise Inc. | ||
raghav198 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| /// | ||
| /// \file | ||
| /// This file implements the `-llzk-fuse-product-loops` pass. | ||
| /// | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "llzk/Dialect/Function/IR/Ops.h" | ||
| #include "llzk/Dialect/Polymorphic/IR/Ops.h" | ||
| #include "llzk/Dialect/Struct/IR/Ops.h" | ||
| #include "llzk/Transforms/LLZKTransformationPasses.h" | ||
| #include "llzk/Util/AlignmentHelper.h" | ||
| #include "llzk/Util/Constants.h" | ||
|
|
||
| #include <mlir/Dialect/SCF/Utils/Utils.h> | ||
|
|
||
| #include <llvm/Support/Debug.h> | ||
| #include <llvm/Support/SMTAPI.h> | ||
|
|
||
| #include <functional> | ||
| #include <memory> | ||
| namespace llzk { | ||
| #define GEN_PASS_DECL_FUSEPRODUCTLOOPSPASS | ||
raghav198 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| #define GEN_PASS_DEF_FUSEPRODUCTLOOPSPASS | ||
| #include "llzk/Transforms/LLZKTransformationPasses.h.inc" | ||
|
|
||
| using namespace llzk::function; | ||
|
|
||
| // Bitwidth of `index` for instantiating SMT variables | ||
| constexpr int INDEX_WIDTH = 64; | ||
|
|
||
| class FuseProductLoopsPass : public impl::FuseProductLoopsPassBase<FuseProductLoopsPass> { | ||
| /// Identify pairs of scf.for loops that can be fused, fuse them, and then recurse to fuse nested | ||
| /// loops | ||
| void fuseMatchingLoopPairs(mlir::Region &body); | ||
| bool canLoopsBeFused(mlir::scf::ForOp a, mlir::scf::ForOp b); | ||
|
|
||
| public: | ||
| void runOnOperation() override { | ||
| mlir::ModuleOp mod = getOperation(); | ||
| mod.walk([this](FuncDefOp funcDef) { | ||
| if (funcDef.isStructProduct()) { | ||
| fuseMatchingLoopPairs(funcDef.getFunctionBody()); | ||
| } | ||
| }); | ||
| } | ||
| }; | ||
|
|
||
| bool isConstOrStructParam(mlir::Value val) { | ||
| // TODO: doing arithmetic over constants should also be fine? | ||
| return val.getDefiningOp<mlir::arith::ConstantIndexOp>() || | ||
| val.getDefiningOp<llzk::polymorphic::ConstReadOp>(); | ||
| } | ||
|
|
||
| llvm::SMTExprRef mkExpr(mlir::Value value, llvm::SMTSolver *solver) { | ||
| if (auto constOp = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) { | ||
| return solver->mkBitvector(llvm::APSInt::get(constOp.value()), INDEX_WIDTH); | ||
| } else if (auto polyReadOp = value.getDefiningOp<llzk::polymorphic::ConstReadOp>()) { | ||
|
|
||
iangneal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return solver->mkSymbol( | ||
| std::string {polyReadOp.getConstName()}.c_str(), solver->getBitvectorSort(INDEX_WIDTH) | ||
| ); | ||
| } | ||
| assert(false && "unsupported: checking non-constant trip counts"); | ||
| return nullptr; // Unreachable | ||
| } | ||
|
|
||
| llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver) { | ||
| const auto *one = solver->mkBitvector(llvm::APSInt::get(1), INDEX_WIDTH); | ||
| return solver->mkBVSDiv( | ||
| solver->mkBVAdd( | ||
| one, | ||
| solver->mkBVSub(mkExpr(op.getUpperBound(), solver), mkExpr(op.getLowerBound(), solver)) | ||
| ), | ||
| mkExpr(op.getStep(), solver) | ||
| ); | ||
| } | ||
|
|
||
| bool FuseProductLoopsPass::canLoopsBeFused(mlir::scf::ForOp a, mlir::scf::ForOp b) { | ||
| // A priori, loops two loops can be fused if: | ||
raghav198 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // 1. They live in the same parent region, | ||
| // 2. One comes from witgen and the other comes from constraint gen, and | ||
| // 3. They have the same trip count | ||
|
|
||
| // Check 1. | ||
| if (a->getParentRegion() != b->getParentRegion()) { | ||
| return false; | ||
| } | ||
|
|
||
| // Check 2. | ||
| if (!a->hasAttrOfType<mlir::StringAttr>("product_source") || | ||
| !b->hasAttrOfType<mlir::StringAttr>("product_source")) { | ||
| // Ideally this should never happen, since the pass only runs on fused @product functions, but | ||
| // check anyway just to be safe | ||
| return false; | ||
| } | ||
| if (a->getAttrOfType<mlir::StringAttr>("product_source") == | ||
| b->getAttrOfType<mlir::StringAttr>("product_source")) { | ||
| return false; | ||
| } | ||
|
|
||
| // Check 3. | ||
| // Easy case: both have a constant trip-count | ||
| auto tripCountA = mlir::constantTripCount(a.getLowerBound(), a.getUpperBound(), a.getStep()); | ||
| auto tripCountB = mlir::constantTripCount(b.getLowerBound(), b.getUpperBound(), b.getStep()); | ||
| if (tripCountA.has_value() && tripCountB.has_value() && *tripCountA == *tripCountB) { | ||
| return true; | ||
| } | ||
|
|
||
| // If the trip counts are not "constant up to a struct param", we definitely can't tell if they're | ||
| // equal | ||
| if (!isConstOrStructParam(a.getLowerBound()) || !isConstOrStructParam(a.getUpperBound()) || | ||
| !isConstOrStructParam(a.getStep()) || !isConstOrStructParam(b.getLowerBound()) || | ||
| !isConstOrStructParam(b.getUpperBound()) || !isConstOrStructParam(b.getStep())) { | ||
| return false; | ||
| } | ||
|
|
||
| // If the trip counts are only "constant up to a struct param" but not actually constant, we can | ||
| // ask a solver if the equations are guaranteed to be the same | ||
| llvm::SMTSolverRef solver = llvm::CreateZ3Solver(); | ||
| solver->addConstraint( | ||
| /* (actually ask if they "can't be different") */ solver->mkNot( | ||
| solver->mkEqual(tripCount(a, solver.get()), tripCount(b, solver.get())) | ||
| ) | ||
| ); | ||
|
|
||
| // The loops are fusable if its impossible for the trip count expressions to be different | ||
| return !*solver->check(); | ||
raghav198 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| void FuseProductLoopsPass::fuseMatchingLoopPairs(mlir::Region &body) { | ||
|
|
||
iangneal marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // Start by collecting all possible loops | ||
| llvm::SmallVector<mlir::scf::ForOp> witnessLoops, constraintLoops; | ||
| body.walk<mlir::WalkOrder::PreOrder>([&witnessLoops, &constraintLoops](mlir::scf::ForOp forOp) { | ||
| if (!forOp->hasAttrOfType<mlir::StringAttr>("product_source")) { | ||
iangneal marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return mlir::WalkResult::skip(); | ||
| } | ||
| if (forOp->getAttrOfType<mlir::StringAttr>("product_source") == FUNC_NAME_COMPUTE) { | ||
| witnessLoops.push_back(forOp); | ||
| } else if (forOp->getAttrOfType<mlir::StringAttr>("product_source") == FUNC_NAME_CONSTRAIN) { | ||
| constraintLoops.push_back(forOp); | ||
| } | ||
raghav198 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // Skipping here, because any nested loops can't possibly be fused at this stage | ||
| return mlir::WalkResult::skip(); | ||
| }); | ||
|
|
||
| // A pair of loops will be fused iff (1) they can be fused according to the rules above, and (2) | ||
| // neither can be fused with anything else (so there's no ambiguity) | ||
| auto fusionCandidates = alignmentHelpers::getMatchingPairs<mlir::scf::ForOp>( | ||
| witnessLoops, constraintLoops, std::bind_front(&FuseProductLoopsPass::canLoopsBeFused, this) | ||
| ); | ||
|
|
||
| // This shouldn't happen, since we allow partial matches | ||
| if (mlir::failed(fusionCandidates)) { | ||
| signalPassFailure(); | ||
| } | ||
|
|
||
| // Finally, fuse all the marked loops... | ||
| mlir::IRRewriter rewriter {&getContext()}; | ||
| for (auto [w, c] : *fusionCandidates) { | ||
| auto fusedLoop = mlir::fuseIndependentSiblingForLoops(w, c, rewriter); | ||
| fusedLoop->setAttr("product_source", rewriter.getAttr<mlir::StringAttr>("fused")); | ||
| // ...and recurse to fuse nested loops | ||
| fuseMatchingLoopPairs(fusedLoop.getBodyRegion()); | ||
| } | ||
| } | ||
|
|
||
| std::unique_ptr<mlir::Pass> createFuseProductLoopsPass() { | ||
| return std::make_unique<FuseProductLoopsPass>(); | ||
| } | ||
| } // namespace llzk | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
iangneal marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
iangneal marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| // RUN: llzk-opt --llzk-compute-constrain-to-product="root-struct=A" --llzk-fuse-product-loops %s | FileCheck --enable-var-scope %s | ||
|
|
||
| module attributes {veridise.lang = "llzk"} { | ||
| struct.def @A<[@N]> { | ||
| struct.field @arr : !array.type<@N x !felt.type> | ||
|
|
||
| function.def @compute(%as: !array.type<@N x !felt.type>) -> !struct.type<@A<[@N]>> { | ||
| %self = struct.new : <@A<[@N]>> | ||
|
|
||
| %N = poly.read_const @N : index | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
|
|
||
| %arr = array.new : !array.type<@N x !felt.type> | ||
|
|
||
| scf.for %i = %c0 to %N step %c1 { | ||
| %sum0 = felt.const 0 | ||
| %sum = scf.for %j = %i to %N step %c1 iter_args(%iter_sum = %sum0) -> !felt.type { | ||
| %a_j = array.read %as[%j] : !array.type<@N x !felt.type>, !felt.type | ||
| %0 = felt.add %iter_sum, %a_j | ||
| scf.yield %0 : !felt.type | ||
| } | ||
| array.write %arr[%i] = %sum : !array.type<@N x !felt.type>, !felt.type | ||
| scf.yield | ||
| } | ||
|
|
||
| struct.writef %self[@arr] = %arr : !struct.type<@A<[@N]>>, !array.type<@N x !felt.type> | ||
| function.return %self : !struct.type<@A<[@N]>> | ||
| } | ||
| function.def @constrain(%self : !struct.type<@A<[@N]>>, %as: !array.type<@N x !felt.type>) { | ||
| %N = poly.read_const @N : index | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
|
|
||
| %arr = struct.readf %self[@arr] : !struct.type<@A<[@N]>>, !array.type<@N x !felt.type> | ||
|
|
||
| scf.for %i = %c0 to %N step %c1 { | ||
| %j = arith.addi %i, %c1 : index | ||
|
|
||
| %arr_i = array.read %arr[%i] : !array.type<@N x !felt.type>, !felt.type | ||
| %arr_j = array.read %arr[%j] : !array.type<@N x !felt.type>, !felt.type | ||
| %a = array.read %as[%i] : !array.type<@N x !felt.type>, !felt.type | ||
| %diff = felt.sub %arr_j, %arr_i | ||
| constrain.eq %a, %diff : !felt.type | ||
| scf.yield | ||
| } | ||
|
|
||
| function.return | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // CHECK-LABEL: scf.for | ||
iangneal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]] = | ||
| // CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]] to | ||
| // CHECK-SAME: %[[VAL_2:[0-9a-zA-Z_\.]+]] step | ||
| // CHECK-SAME: %[[VAL_3:[0-9a-zA-Z_\.]+]] { | ||
| // CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.const 0 {product_source = "compute"} | ||
| // CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = scf.for %[[VAL_6:[0-9a-zA-Z_\.]+]] = %[[VAL_0]] to %[[VAL_7:[0-9a-zA-Z_\.]+]] step %[[VAL_8:[0-9a-zA-Z_\.]+]] iter_args(%[[VAL_9:[0-9a-zA-Z_\.]+]] = %[[VAL_4]]) -> (!felt.type) { | ||
| // CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_11:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_6]]] : <@N x !felt.type>, !felt.type {product_source = "compute"} | ||
| // CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type {product_source = "compute"} | ||
| // CHECK-NEXT: scf.yield {product_source = "compute"} %[[VAL_12]] : !felt.type | ||
| // CHECK-NEXT: } {product_source = "compute"} | ||
| // CHECK-NEXT: array.write %[[VAL_13:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] = %[[VAL_5]] : <@N x !felt.type>, !felt.type {product_source = "compute"} | ||
| // CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = arith.addi %[[VAL_0]], %[[VAL_3]] {product_source = "constrain"} : index | ||
| // CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_16:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"} | ||
| // CHECK-NEXT: %[[VAL_17:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_16]]{{\[}}%[[VAL_14]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"} | ||
| // CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"} | ||
| // CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = felt.sub %[[VAL_17]], %[[VAL_15]] : !felt.type, !felt.type {product_source = "constrain"} | ||
| // CHECK-NEXT: constrain.eq %[[VAL_18]], %[[VAL_20]] : !felt.type, !felt.type {product_source = "constrain"} | ||
| // CHECK-NEXT: } {product_source = "fused"} | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.