Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
added:
- Added "llzk-fuse-product-loops" transformation pass that matches and fuses pairs of scf.for loops from @compute and @constrain
19 changes: 19 additions & 0 deletions include/llzk/Transforms/LLZKFuseProductLoopsPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- LLZKFuseProductLoopsPass.h-------------------------------*- C++ -*-===//
//
// Part of the LLZK Project, under the Apache License v2.0.
// See LICENSE.txt for license information.
// Copyright 2026 Project LLZK
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

#pragma once

#include <mlir/IR/Region.h>
#include <mlir/Support/LogicalResult.h>

namespace llzk {
/// Identify pairs of `scf.for` loops that can be fused, fuse them, and then
/// recurse to fuse nested loops.
mlir::LogicalResult fuseMatchingLoopPairs(mlir::Region &body, mlir::MLIRContext *context);
} // namespace llzk
2 changes: 2 additions & 0 deletions include/llzk/Transforms/LLZKTransformationPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

namespace llzk {

std::unique_ptr<mlir::Pass> createFuseProductLoopsPass();

std::unique_ptr<mlir::Pass> createComputeConstrainToProductPass();

std::unique_ptr<mlir::Pass> createFlatteningPass();
Expand Down
9 changes: 9 additions & 0 deletions include/llzk/Transforms/LLZKTransformationPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ def ComputeConstrainToProductPass
"Root struct at which to start alignment "
"(default to `@Main`)">];
}

def FuseProductLoopsPass : LLZKPass<"llzk-fuse-product-loops"> {
let summary =
"Fuses matching witness/constraint loops in a @product function";
let description = summary;
let constructor = "llzk::createFuseProductLoopsPass()";
let options = [];
}

#ifdef WITH_PCL
def PCLLoweringPass : LLZKPass<"llzk-to-pcl"> {
let summary = "Rewrites constraints to be compatible with PCL constraints "
Expand Down
58 changes: 58 additions & 0 deletions include/llzk/Util/AlignmentHelper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
//===-- AlignmentHelper.h --------------------------------------*- C++ -*-===//
//
// Part of the LLZK Project, under the Apache License v2.0.
// See LICENSE.txt for license information.
// Copyright 2026 Project LLZK
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

#pragma once

#include <llvm/ADT/SetVector.h>
#include <llvm/ADT/SmallVectorExtras.h>
#include <llvm/Support/Debug.h>
#include <llvm/Support/LogicalResult.h>

#include <concepts>

namespace llzk::alignmentHelpers {

template <class ValueT, class FnT>
concept Matcher = requires(FnT fn, ValueT val) {
{ fn(val, val) } -> std::convertible_to<bool>;
};

template <class ValueT, class FnT>
requires Matcher<ValueT, FnT>
llvm::FailureOr<llvm::SetVector<std::pair<ValueT, ValueT>>> getMatchingPairs(
llvm::ArrayRef<ValueT> as, llvm::ArrayRef<ValueT> bs, FnT doesMatch, bool allowPartial = true
) {

llvm::SetVector<ValueT> setA {as.begin(), as.end()}, setB {bs.begin(), bs.end()};
llvm::DenseMap<size_t, llvm::SmallVector<size_t>> possibleMatchesA, possibleMatchesB;

for (size_t i = 0, ea = as.size(), eb = bs.size(); i < ea; i++) {
for (size_t j = 0; j < eb; j++) {
if (doesMatch(as[i], bs[j])) {
possibleMatchesA[i].push_back(j);
possibleMatchesB[j].push_back(i);
}
}
}

llvm::SetVector<std::pair<ValueT, ValueT>> matches;
for (auto [a, b] : possibleMatchesA) {
if (b.size() == 1 && possibleMatchesB[b[0]].size() == 1) {
setA.remove(as[a]);
setB.remove(bs[b[0]]);
matches.insert({as[a], bs[b[0]]});
}
}

if ((!setA.empty() || !setB.empty()) && !allowPartial) {
return llvm::failure();
}
return matches;
}
} // namespace llzk::alignmentHelpers
3 changes: 3 additions & 0 deletions include/llzk/Util/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,7 @@ constexpr char LANG_ATTR_NAME[] = "llzk.lang";
/// a `TypeAttr` specifying the `StructType` of the main struct.
constexpr char MAIN_ATTR_NAME[] = "llzk.main";

/// Name of the attribute on aligned product program ops that specifies where they came from.
constexpr char PRODUCT_SOURCE[] = "product_source";

} // namespace llzk
4 changes: 2 additions & 2 deletions lib/Transforms/LLZKComputeConstrainToProductPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ FuncDefOp ProductAligner::alignFuncs(StructDefOp root, FuncDefOp compute, FuncDe

// Add compute/constrain attributes
compute.walk([&funcBuilder](Operation *op) {
op->setAttr("product_source", funcBuilder.getStringAttr(FUNC_NAME_COMPUTE));
op->setAttr(PRODUCT_SOURCE, funcBuilder.getStringAttr(FUNC_NAME_COMPUTE));
});

constrain.walk([&funcBuilder](Operation *op) {
op->setAttr("product_source", funcBuilder.getStringAttr(FUNC_NAME_CONSTRAIN));
op->setAttr(PRODUCT_SOURCE, funcBuilder.getStringAttr(FUNC_NAME_CONSTRAIN));
});

// Create an empty @product func...
Expand Down
178 changes: 178 additions & 0 deletions lib/Transforms/LLZKFuseProductLoopsPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
//===-- LLZKFuseProductLoopsPass.cpp -----------------------------*- C++ -*-===//
//
// Part of the LLZK Project, under the Apache License v2.0.
// See LICENSE.txt for license information.
// Copyright 2026 Project LLZK
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements the `-llzk-fuse-product-loops` pass.
///
//===----------------------------------------------------------------------===//

#include "llzk/Dialect/Function/IR/Ops.h"
#include "llzk/Dialect/Polymorphic/IR/Ops.h"
#include "llzk/Dialect/Struct/IR/Ops.h"
#include "llzk/Transforms/LLZKFuseProductLoopsPass.h"
#include "llzk/Transforms/LLZKTransformationPasses.h"
#include "llzk/Util/AlignmentHelper.h"
#include "llzk/Util/Constants.h"

#include <mlir/Dialect/SCF/Utils/Utils.h>

#include <llvm/Support/Debug.h>
#include <llvm/Support/SMTAPI.h>

#include <memory>

namespace llzk {

#define GEN_PASS_DECL_FUSEPRODUCTLOOPSPASS
#define GEN_PASS_DEF_FUSEPRODUCTLOOPSPASS
#include "llzk/Transforms/LLZKTransformationPasses.h.inc"

using namespace llzk::function;

// Bitwidth of `index` for instantiating SMT variables
constexpr int INDEX_WIDTH = 64;

class FuseProductLoopsPass : public impl::FuseProductLoopsPassBase<FuseProductLoopsPass> {

public:
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
mod.walk([this](FuncDefOp funcDef) {
if (funcDef.isStructProduct()) {
if (mlir::failed(fuseMatchingLoopPairs(funcDef.getFunctionBody(), &getContext()))) {
signalPassFailure();
}
}
});
}
};

static inline bool isConstOrStructParam(mlir::Value val) {
// TODO: doing arithmetic over constants should also be fine?
return val.getDefiningOp<mlir::arith::ConstantIndexOp>() ||
val.getDefiningOp<llzk::polymorphic::ConstReadOp>();
}

llvm::SMTExprRef mkExpr(mlir::Value value, llvm::SMTSolver *solver) {
if (auto constOp = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
return solver->mkBitvector(llvm::APSInt::get(constOp.value()), INDEX_WIDTH);
} else if (auto polyReadOp = value.getDefiningOp<llzk::polymorphic::ConstReadOp>()) {

return solver->mkSymbol(
std::string {polyReadOp.getConstName()}.c_str(), solver->getBitvectorSort(INDEX_WIDTH)
);
}
assert(false && "unsupported: checking non-constant trip counts");
return nullptr; // Unreachable
}

llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver) {
const auto *one = solver->mkBitvector(llvm::APSInt::get(1), INDEX_WIDTH);
return solver->mkBVSDiv(
solver->mkBVAdd(
one,
solver->mkBVSub(mkExpr(op.getUpperBound(), solver), mkExpr(op.getLowerBound(), solver))
),
mkExpr(op.getStep(), solver)
);
}

static inline bool canLoopsBeFused(mlir::scf::ForOp a, mlir::scf::ForOp b) {
// A priori, two loops can be fused if:
// 1. They live in the same parent region,
// 2. One comes from witgen and the other comes from constraint gen, and
// 3. They have the same trip count

// Check 1.
if (a->getParentRegion() != b->getParentRegion()) {
return false;
}

// Check 2.
if (!a->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE) ||
!b->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) {
// Ideally this should never happen, since the pass only runs on fused @product functions, but
// check anyway just to be safe
return false;
}
if (a->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE) ==
b->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) {
return false;
}

// Check 3.
// Easy case: both have a constant trip-count. If the trip counts are not "constant up to a struct
// param", we definitely can't tell if they're equal. If the trip counts are only "constant up to
// a struct param" but not actually constant, we can ask a solver if the equations are guaranteed
// to be the same
auto tripCountA = mlir::constantTripCount(a.getLowerBound(), a.getUpperBound(), a.getStep());
auto tripCountB = mlir::constantTripCount(b.getLowerBound(), b.getUpperBound(), b.getStep());
if (tripCountA.has_value() && tripCountB.has_value() && *tripCountA == *tripCountB) {
return true;
}

if (!isConstOrStructParam(a.getLowerBound()) || !isConstOrStructParam(a.getUpperBound()) ||
!isConstOrStructParam(a.getStep()) || !isConstOrStructParam(b.getLowerBound()) ||
!isConstOrStructParam(b.getUpperBound()) || !isConstOrStructParam(b.getStep())) {
return false;
}

llvm::SMTSolverRef solver = llvm::CreateZ3Solver();
solver->addConstraint(/* (actually ask if they "can't be different") */ solver->mkNot(
solver->mkEqual(tripCount(a, solver.get()), tripCount(b, solver.get()))
));

return !*solver->check();
}

mlir::LogicalResult fuseMatchingLoopPairs(mlir::Region &body, mlir::MLIRContext *context) {
// Start by collecting all possible loops
llvm::SmallVector<mlir::scf::ForOp> witnessLoops, constraintLoops;
body.walk<mlir::WalkOrder::PreOrder>([&witnessLoops, &constraintLoops](mlir::scf::ForOp forOp) {
if (!forOp->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) {
return mlir::WalkResult::skip();
}
auto productSource = forOp->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE);
if (productSource == FUNC_NAME_COMPUTE) {
witnessLoops.push_back(forOp);
} else if (productSource == FUNC_NAME_CONSTRAIN) {
constraintLoops.push_back(forOp);
}
// Skipping here, because any nested loops can't possibly be fused at this stage
return mlir::WalkResult::skip();
});

// A pair of loops will be fused iff (1) they can be fused according to the rules above, and (2)
// neither can be fused with anything else (so there's no ambiguity)
auto fusionCandidates = alignmentHelpers::getMatchingPairs<mlir::scf::ForOp>(
witnessLoops, constraintLoops, canLoopsBeFused
);

// This shouldn't happen, since we allow partial matches
if (mlir::failed(fusionCandidates)) {
return mlir::failure();
}

// Finally, fuse all the marked loops...
mlir::IRRewriter rewriter {context};
for (auto [w, c] : *fusionCandidates) {
auto fusedLoop = mlir::fuseIndependentSiblingForLoops(w, c, rewriter);
fusedLoop->setAttr(PRODUCT_SOURCE, rewriter.getAttr<mlir::StringAttr>("fused"));
// ...and recurse to fuse nested loops
if (mlir::failed(fuseMatchingLoopPairs(fusedLoop.getBodyRegion(), context))) {
return mlir::failure();
}
}
return mlir::success();
}

std::unique_ptr<mlir::Pass> createFuseProductLoopsPass() {
return std::make_unique<FuseProductLoopsPass>();
}
} // namespace llzk
5 changes: 4 additions & 1 deletion lib/Transforms/TransformationPassPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ void registerTransformationPassPipelines() {
PassPipelineRegistration<>(
"llzk-product-program",
"Convert @compute/@constrain functions to @product function and perform alignment",
[](OpPassManager &pm) { pm.addPass(llzk::createComputeConstrainToProductPass()); }
[](OpPassManager &pm) {
pm.addPass(llzk::createComputeConstrainToProductPass());
pm.addPass(llzk::createFuseProductLoopsPass());
}
);
}

Expand Down
72 changes: 72 additions & 0 deletions test/Transforms/FuseProductLoops/fuse_outer.llzk
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: llzk-opt --llzk-compute-constrain-to-product="root-struct=A" --llzk-fuse-product-loops %s | FileCheck --enable-var-scope %s

module attributes {llzk.lang = "llzk"} {
struct.def @A<[@N]> {
struct.member @arr : !array.type<@N x !felt.type>

function.def @compute(%as: !array.type<@N x !felt.type>) -> !struct.type<@A<[@N]>> {
%self = struct.new : <@A<[@N]>>

%N = poly.read_const @N : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

%arr = array.new : !array.type<@N x !felt.type>

scf.for %i = %c0 to %N step %c1 {
%sum0 = felt.const 0
%sum = scf.for %j = %i to %N step %c1 iter_args(%iter_sum = %sum0) -> !felt.type {
%a_j = array.read %as[%j] : !array.type<@N x !felt.type>, !felt.type
%0 = felt.add %iter_sum, %a_j
scf.yield %0 : !felt.type
}
array.write %arr[%i] = %sum : !array.type<@N x !felt.type>, !felt.type
scf.yield
}

struct.writem %self[@arr] = %arr : !struct.type<@A<[@N]>>, !array.type<@N x !felt.type>
function.return %self : !struct.type<@A<[@N]>>
}
function.def @constrain(%self : !struct.type<@A<[@N]>>, %as: !array.type<@N x !felt.type>) {
%N = poly.read_const @N : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

%arr = struct.readm %self[@arr] : !struct.type<@A<[@N]>>, !array.type<@N x !felt.type>

scf.for %i = %c0 to %N step %c1 {
%j = arith.addi %i, %c1 : index

%arr_i = array.read %arr[%i] : !array.type<@N x !felt.type>, !felt.type
%arr_j = array.read %arr[%j] : !array.type<@N x !felt.type>, !felt.type
%a = array.read %as[%i] : !array.type<@N x !felt.type>, !felt.type
%diff = felt.sub %arr_j, %arr_i
constrain.eq %a, %diff : !felt.type
scf.yield
}

function.return
}
}
}

// CHECK-LABEL: scf.for
// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]] =
// CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]] to
// CHECK-SAME: %[[VAL_2:[0-9a-zA-Z_\.]+]] step
// CHECK-SAME: %[[VAL_3:[0-9a-zA-Z_\.]+]] {
// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.const 0 {product_source = "compute"}
// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = scf.for %[[VAL_6:[0-9a-zA-Z_\.]+]] = %[[VAL_0]] to %[[VAL_7:[0-9a-zA-Z_\.]+]] step %[[VAL_8:[0-9a-zA-Z_\.]+]] iter_args(%[[VAL_9:[0-9a-zA-Z_\.]+]] = %[[VAL_4]]) -> (!felt.type) {
// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_11:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_6]]] : <@N x !felt.type>, !felt.type {product_source = "compute"}
// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type {product_source = "compute"}
// CHECK-NEXT: scf.yield {product_source = "compute"} %[[VAL_12]] : !felt.type
// CHECK-NEXT: } {product_source = "compute"}
// CHECK-NEXT: array.write %[[VAL_13:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] = %[[VAL_5]] : <@N x !felt.type>, !felt.type {product_source = "compute"}
// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = arith.addi %[[VAL_0]], %[[VAL_3]] {product_source = "constrain"} : index
// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_16:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"}
// CHECK-NEXT: %[[VAL_17:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_16]]{{\[}}%[[VAL_14]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"}
// CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"}
// CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = felt.sub %[[VAL_17]], %[[VAL_15]] : !felt.type, !felt.type {product_source = "constrain"}
// CHECK-NEXT: constrain.eq %[[VAL_18]], %[[VAL_20]] : !felt.type, !felt.type {product_source = "constrain"}
// CHECK-NEXT: } {product_source = "fused"}
// CHECK-NOT: scf.for
Loading