Skip to content

Commit f71b018

Browse files
raghav198Raghav Maliktim-hoffmaniangneal
authored
Fuse computeconstrain loops in product program when possible
* Pass skeleton * Really basic fusion works * oops * ASJKDHLKJASHDA:SJHD LKAS UGH UGUH I HATE MACOS * Finally * Fusing in the right order now * Added testcases * Some cleanup + changelog * Formatting * Apply suggestions from code review Co-authored-by: Timothy Hoffman <4001421+tim-hoffman@users.noreply.github.com> * Code review comments * Ensuring no bad loops * Refactor: can perform transformation on-demand * license header * Apply suggestions from code review Co-authored-by: Ian Neal <ian@veridise.com> * Updated testcases: veridise.lang -> llzk.lang, field -> member --------- Co-authored-by: Raghav Malik <raghav@veridise.com> Co-authored-by: Timothy Hoffman <4001421+tim-hoffman@users.noreply.github.com> Co-authored-by: Ian Neal <ian@veridise.com>
1 parent dbc391d commit f71b018

15 files changed

+690
-3
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
added:
2+
- Added "llzk-fuse-product-loops" transformation pass that matches and fuses pairs of scf.for loops from @compute and @constrain
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===-- LLZKFuseProductLoopsPass.h-------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLZK Project, under the Apache License v2.0.
4+
// See LICENSE.txt for license information.
5+
// Copyright 2026 Project LLZK
6+
// SPDX-License-Identifier: Apache-2.0
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#pragma once
11+
12+
#include <mlir/IR/Region.h>
13+
#include <mlir/Support/LogicalResult.h>
14+
15+
namespace llzk {
16+
/// Identify pairs of `scf.for` loops that can be fused, fuse them, and then
17+
/// recurse to fuse nested loops.
18+
mlir::LogicalResult fuseMatchingLoopPairs(mlir::Region &body, mlir::MLIRContext *context);
19+
} // namespace llzk

include/llzk/Transforms/LLZKTransformationPasses.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
namespace llzk {
2020

21+
std::unique_ptr<mlir::Pass> createFuseProductLoopsPass();
22+
2123
std::unique_ptr<mlir::Pass> createComputeConstrainToProductPass();
2224

2325
std::unique_ptr<mlir::Pass> createFlatteningPass();

include/llzk/Transforms/LLZKTransformationPasses.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ def ComputeConstrainToProductPass
119119
"Root struct at which to start alignment "
120120
"(default to `@Main`)">];
121121
}
122+
123+
def FuseProductLoopsPass : LLZKPass<"llzk-fuse-product-loops"> {
124+
let summary =
125+
"Fuses matching witness/constraint loops in a @product function";
126+
let description = summary;
127+
let constructor = "llzk::createFuseProductLoopsPass()";
128+
let options = [];
129+
}
130+
122131
#ifdef WITH_PCL
123132
def PCLLoweringPass : LLZKPass<"llzk-to-pcl"> {
124133
let summary = "Rewrites constraints to be compatible with PCL constraints "
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//===-- AlignmentHelper.h --------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLZK Project, under the Apache License v2.0.
4+
// See LICENSE.txt for license information.
5+
// Copyright 2026 Project LLZK
6+
// SPDX-License-Identifier: Apache-2.0
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#pragma once
11+
12+
#include <llvm/ADT/SetVector.h>
13+
#include <llvm/ADT/SmallVectorExtras.h>
14+
#include <llvm/Support/Debug.h>
15+
#include <llvm/Support/LogicalResult.h>
16+
17+
#include <concepts>
18+
19+
namespace llzk::alignmentHelpers {
20+
21+
template <class ValueT, class FnT>
22+
concept Matcher = requires(FnT fn, ValueT val) {
23+
{ fn(val, val) } -> std::convertible_to<bool>;
24+
};
25+
26+
template <class ValueT, class FnT>
27+
requires Matcher<ValueT, FnT>
28+
llvm::FailureOr<llvm::SetVector<std::pair<ValueT, ValueT>>> getMatchingPairs(
29+
llvm::ArrayRef<ValueT> as, llvm::ArrayRef<ValueT> bs, FnT doesMatch, bool allowPartial = true
30+
) {
31+
32+
llvm::SetVector<ValueT> setA {as.begin(), as.end()}, setB {bs.begin(), bs.end()};
33+
llvm::DenseMap<size_t, llvm::SmallVector<size_t>> possibleMatchesA, possibleMatchesB;
34+
35+
for (size_t i = 0, ea = as.size(), eb = bs.size(); i < ea; i++) {
36+
for (size_t j = 0; j < eb; j++) {
37+
if (doesMatch(as[i], bs[j])) {
38+
possibleMatchesA[i].push_back(j);
39+
possibleMatchesB[j].push_back(i);
40+
}
41+
}
42+
}
43+
44+
llvm::SetVector<std::pair<ValueT, ValueT>> matches;
45+
for (auto [a, b] : possibleMatchesA) {
46+
if (b.size() == 1 && possibleMatchesB[b[0]].size() == 1) {
47+
setA.remove(as[a]);
48+
setB.remove(bs[b[0]]);
49+
matches.insert({as[a], bs[b[0]]});
50+
}
51+
}
52+
53+
if ((!setA.empty() || !setB.empty()) && !allowPartial) {
54+
return llvm::failure();
55+
}
56+
return matches;
57+
}
58+
} // namespace llzk::alignmentHelpers

include/llzk/Util/Constants.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,7 @@ constexpr char LANG_ATTR_NAME[] = "llzk.lang";
2727
/// a `TypeAttr` specifying the `StructType` of the main struct.
2828
constexpr char MAIN_ATTR_NAME[] = "llzk.main";
2929

30+
/// Name of the attribute on aligned product program ops that specifies where they came from.
31+
constexpr char PRODUCT_SOURCE[] = "product_source";
32+
3033
} // namespace llzk

lib/Transforms/LLZKComputeConstrainToProductPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ FuncDefOp ProductAligner::alignFuncs(StructDefOp root, FuncDefOp compute, FuncDe
111111

112112
// Add compute/constrain attributes
113113
compute.walk([&funcBuilder](Operation *op) {
114-
op->setAttr("product_source", funcBuilder.getStringAttr(FUNC_NAME_COMPUTE));
114+
op->setAttr(PRODUCT_SOURCE, funcBuilder.getStringAttr(FUNC_NAME_COMPUTE));
115115
});
116116

117117
constrain.walk([&funcBuilder](Operation *op) {
118-
op->setAttr("product_source", funcBuilder.getStringAttr(FUNC_NAME_CONSTRAIN));
118+
op->setAttr(PRODUCT_SOURCE, funcBuilder.getStringAttr(FUNC_NAME_CONSTRAIN));
119119
});
120120

121121
// Create an empty @product func...
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
//===-- LLZKFuseProductLoopsPass.cpp -----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLZK Project, under the Apache License v2.0.
4+
// See LICENSE.txt for license information.
5+
// Copyright 2026 Project LLZK
6+
// SPDX-License-Identifier: Apache-2.0
7+
//
8+
//===----------------------------------------------------------------------===//
9+
///
10+
/// \file
11+
/// This file implements the `-llzk-fuse-product-loops` pass.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "llzk/Dialect/Function/IR/Ops.h"
16+
#include "llzk/Dialect/Polymorphic/IR/Ops.h"
17+
#include "llzk/Dialect/Struct/IR/Ops.h"
18+
#include "llzk/Transforms/LLZKFuseProductLoopsPass.h"
19+
#include "llzk/Transforms/LLZKTransformationPasses.h"
20+
#include "llzk/Util/AlignmentHelper.h"
21+
#include "llzk/Util/Constants.h"
22+
23+
#include <mlir/Dialect/SCF/Utils/Utils.h>
24+
25+
#include <llvm/Support/Debug.h>
26+
#include <llvm/Support/SMTAPI.h>
27+
28+
#include <memory>
29+
30+
namespace llzk {
31+
32+
#define GEN_PASS_DECL_FUSEPRODUCTLOOPSPASS
33+
#define GEN_PASS_DEF_FUSEPRODUCTLOOPSPASS
34+
#include "llzk/Transforms/LLZKTransformationPasses.h.inc"
35+
36+
using namespace llzk::function;
37+
38+
// Bitwidth of `index` for instantiating SMT variables
39+
constexpr int INDEX_WIDTH = 64;
40+
41+
class FuseProductLoopsPass : public impl::FuseProductLoopsPassBase<FuseProductLoopsPass> {
42+
43+
public:
44+
void runOnOperation() override {
45+
mlir::ModuleOp mod = getOperation();
46+
mod.walk([this](FuncDefOp funcDef) {
47+
if (funcDef.isStructProduct()) {
48+
if (mlir::failed(fuseMatchingLoopPairs(funcDef.getFunctionBody(), &getContext()))) {
49+
signalPassFailure();
50+
}
51+
}
52+
});
53+
}
54+
};
55+
56+
static inline bool isConstOrStructParam(mlir::Value val) {
57+
// TODO: doing arithmetic over constants should also be fine?
58+
return val.getDefiningOp<mlir::arith::ConstantIndexOp>() ||
59+
val.getDefiningOp<llzk::polymorphic::ConstReadOp>();
60+
}
61+
62+
llvm::SMTExprRef mkExpr(mlir::Value value, llvm::SMTSolver *solver) {
63+
if (auto constOp = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
64+
return solver->mkBitvector(llvm::APSInt::get(constOp.value()), INDEX_WIDTH);
65+
} else if (auto polyReadOp = value.getDefiningOp<llzk::polymorphic::ConstReadOp>()) {
66+
67+
return solver->mkSymbol(
68+
std::string {polyReadOp.getConstName()}.c_str(), solver->getBitvectorSort(INDEX_WIDTH)
69+
);
70+
}
71+
assert(false && "unsupported: checking non-constant trip counts");
72+
return nullptr; // Unreachable
73+
}
74+
75+
llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver) {
76+
const auto *one = solver->mkBitvector(llvm::APSInt::get(1), INDEX_WIDTH);
77+
return solver->mkBVSDiv(
78+
solver->mkBVAdd(
79+
one,
80+
solver->mkBVSub(mkExpr(op.getUpperBound(), solver), mkExpr(op.getLowerBound(), solver))
81+
),
82+
mkExpr(op.getStep(), solver)
83+
);
84+
}
85+
86+
static inline bool canLoopsBeFused(mlir::scf::ForOp a, mlir::scf::ForOp b) {
87+
// A priori, two loops can be fused if:
88+
// 1. They live in the same parent region,
89+
// 2. One comes from witgen and the other comes from constraint gen, and
90+
// 3. They have the same trip count
91+
92+
// Check 1.
93+
if (a->getParentRegion() != b->getParentRegion()) {
94+
return false;
95+
}
96+
97+
// Check 2.
98+
if (!a->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE) ||
99+
!b->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) {
100+
// Ideally this should never happen, since the pass only runs on fused @product functions, but
101+
// check anyway just to be safe
102+
return false;
103+
}
104+
if (a->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE) ==
105+
b->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) {
106+
return false;
107+
}
108+
109+
// Check 3.
110+
// Easy case: both have a constant trip-count. If the trip counts are not "constant up to a struct
111+
// param", we definitely can't tell if they're equal. If the trip counts are only "constant up to
112+
// a struct param" but not actually constant, we can ask a solver if the equations are guaranteed
113+
// to be the same
114+
auto tripCountA = mlir::constantTripCount(a.getLowerBound(), a.getUpperBound(), a.getStep());
115+
auto tripCountB = mlir::constantTripCount(b.getLowerBound(), b.getUpperBound(), b.getStep());
116+
if (tripCountA.has_value() && tripCountB.has_value() && *tripCountA == *tripCountB) {
117+
return true;
118+
}
119+
120+
if (!isConstOrStructParam(a.getLowerBound()) || !isConstOrStructParam(a.getUpperBound()) ||
121+
!isConstOrStructParam(a.getStep()) || !isConstOrStructParam(b.getLowerBound()) ||
122+
!isConstOrStructParam(b.getUpperBound()) || !isConstOrStructParam(b.getStep())) {
123+
return false;
124+
}
125+
126+
llvm::SMTSolverRef solver = llvm::CreateZ3Solver();
127+
solver->addConstraint(/* (actually ask if they "can't be different") */ solver->mkNot(
128+
solver->mkEqual(tripCount(a, solver.get()), tripCount(b, solver.get()))
129+
));
130+
131+
return !*solver->check();
132+
}
133+
134+
mlir::LogicalResult fuseMatchingLoopPairs(mlir::Region &body, mlir::MLIRContext *context) {
135+
// Start by collecting all possible loops
136+
llvm::SmallVector<mlir::scf::ForOp> witnessLoops, constraintLoops;
137+
body.walk<mlir::WalkOrder::PreOrder>([&witnessLoops, &constraintLoops](mlir::scf::ForOp forOp) {
138+
if (!forOp->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) {
139+
return mlir::WalkResult::skip();
140+
}
141+
auto productSource = forOp->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE);
142+
if (productSource == FUNC_NAME_COMPUTE) {
143+
witnessLoops.push_back(forOp);
144+
} else if (productSource == FUNC_NAME_CONSTRAIN) {
145+
constraintLoops.push_back(forOp);
146+
}
147+
// Skipping here, because any nested loops can't possibly be fused at this stage
148+
return mlir::WalkResult::skip();
149+
});
150+
151+
// A pair of loops will be fused iff (1) they can be fused according to the rules above, and (2)
152+
// neither can be fused with anything else (so there's no ambiguity)
153+
auto fusionCandidates = alignmentHelpers::getMatchingPairs<mlir::scf::ForOp>(
154+
witnessLoops, constraintLoops, canLoopsBeFused
155+
);
156+
157+
// This shouldn't happen, since we allow partial matches
158+
if (mlir::failed(fusionCandidates)) {
159+
return mlir::failure();
160+
}
161+
162+
// Finally, fuse all the marked loops...
163+
mlir::IRRewriter rewriter {context};
164+
for (auto [w, c] : *fusionCandidates) {
165+
auto fusedLoop = mlir::fuseIndependentSiblingForLoops(w, c, rewriter);
166+
fusedLoop->setAttr(PRODUCT_SOURCE, rewriter.getAttr<mlir::StringAttr>("fused"));
167+
// ...and recurse to fuse nested loops
168+
if (mlir::failed(fuseMatchingLoopPairs(fusedLoop.getBodyRegion(), context))) {
169+
return mlir::failure();
170+
}
171+
}
172+
return mlir::success();
173+
}
174+
175+
std::unique_ptr<mlir::Pass> createFuseProductLoopsPass() {
176+
return std::make_unique<FuseProductLoopsPass>();
177+
}
178+
} // namespace llzk

lib/Transforms/TransformationPassPipelines.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ void registerTransformationPassPipelines() {
8484
PassPipelineRegistration<>(
8585
"llzk-product-program",
8686
"Convert @compute/@constrain functions to @product function and perform alignment",
87-
[](OpPassManager &pm) { pm.addPass(llzk::createComputeConstrainToProductPass()); }
87+
[](OpPassManager &pm) {
88+
pm.addPass(llzk::createComputeConstrainToProductPass());
89+
pm.addPass(llzk::createFuseProductLoopsPass());
90+
}
8891
);
8992
}
9093

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: llzk-opt --llzk-compute-constrain-to-product="root-struct=A" --llzk-fuse-product-loops %s | FileCheck --enable-var-scope %s
2+
3+
module attributes {llzk.lang = "llzk"} {
4+
struct.def @A<[@N]> {
5+
struct.member @arr : !array.type<@N x !felt.type>
6+
7+
function.def @compute(%as: !array.type<@N x !felt.type>) -> !struct.type<@A<[@N]>> {
8+
%self = struct.new : <@A<[@N]>>
9+
10+
%N = poly.read_const @N : index
11+
%c0 = arith.constant 0 : index
12+
%c1 = arith.constant 1 : index
13+
14+
%arr = array.new : !array.type<@N x !felt.type>
15+
16+
scf.for %i = %c0 to %N step %c1 {
17+
%sum0 = felt.const 0
18+
%sum = scf.for %j = %i to %N step %c1 iter_args(%iter_sum = %sum0) -> !felt.type {
19+
%a_j = array.read %as[%j] : !array.type<@N x !felt.type>, !felt.type
20+
%0 = felt.add %iter_sum, %a_j
21+
scf.yield %0 : !felt.type
22+
}
23+
array.write %arr[%i] = %sum : !array.type<@N x !felt.type>, !felt.type
24+
scf.yield
25+
}
26+
27+
struct.writem %self[@arr] = %arr : !struct.type<@A<[@N]>>, !array.type<@N x !felt.type>
28+
function.return %self : !struct.type<@A<[@N]>>
29+
}
30+
function.def @constrain(%self : !struct.type<@A<[@N]>>, %as: !array.type<@N x !felt.type>) {
31+
%N = poly.read_const @N : index
32+
%c0 = arith.constant 0 : index
33+
%c1 = arith.constant 1 : index
34+
35+
%arr = struct.readm %self[@arr] : !struct.type<@A<[@N]>>, !array.type<@N x !felt.type>
36+
37+
scf.for %i = %c0 to %N step %c1 {
38+
%j = arith.addi %i, %c1 : index
39+
40+
%arr_i = array.read %arr[%i] : !array.type<@N x !felt.type>, !felt.type
41+
%arr_j = array.read %arr[%j] : !array.type<@N x !felt.type>, !felt.type
42+
%a = array.read %as[%i] : !array.type<@N x !felt.type>, !felt.type
43+
%diff = felt.sub %arr_j, %arr_i
44+
constrain.eq %a, %diff : !felt.type
45+
scf.yield
46+
}
47+
48+
function.return
49+
}
50+
}
51+
}
52+
53+
// CHECK-LABEL: scf.for
54+
// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]] =
55+
// CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]] to
56+
// CHECK-SAME: %[[VAL_2:[0-9a-zA-Z_\.]+]] step
57+
// CHECK-SAME: %[[VAL_3:[0-9a-zA-Z_\.]+]] {
58+
// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.const 0 {product_source = "compute"}
59+
// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = scf.for %[[VAL_6:[0-9a-zA-Z_\.]+]] = %[[VAL_0]] to %[[VAL_7:[0-9a-zA-Z_\.]+]] step %[[VAL_8:[0-9a-zA-Z_\.]+]] iter_args(%[[VAL_9:[0-9a-zA-Z_\.]+]] = %[[VAL_4]]) -> (!felt.type) {
60+
// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_11:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_6]]] : <@N x !felt.type>, !felt.type {product_source = "compute"}
61+
// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type {product_source = "compute"}
62+
// CHECK-NEXT: scf.yield {product_source = "compute"} %[[VAL_12]] : !felt.type
63+
// CHECK-NEXT: } {product_source = "compute"}
64+
// CHECK-NEXT: array.write %[[VAL_13:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] = %[[VAL_5]] : <@N x !felt.type>, !felt.type {product_source = "compute"}
65+
// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = arith.addi %[[VAL_0]], %[[VAL_3]] {product_source = "constrain"} : index
66+
// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_16:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"}
67+
// CHECK-NEXT: %[[VAL_17:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_16]]{{\[}}%[[VAL_14]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"}
68+
// CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"}
69+
// CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = felt.sub %[[VAL_17]], %[[VAL_15]] : !felt.type, !felt.type {product_source = "constrain"}
70+
// CHECK-NEXT: constrain.eq %[[VAL_18]], %[[VAL_20]] : !felt.type, !felt.type {product_source = "constrain"}
71+
// CHECK-NEXT: } {product_source = "fused"}
72+
// CHECK-NOT: scf.for

0 commit comments

Comments
 (0)