Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
added:
- Added "llzk-fuse-product-loops" transformation pass that matches and fuses pairs of scf.for loops from @compute and @constrain
2 changes: 2 additions & 0 deletions include/llzk/Transforms/LLZKTransformationPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

namespace llzk {

std::unique_ptr<mlir::Pass> createFuseProductLoopsPass();

std::unique_ptr<mlir::Pass> createComputeConstrainToProductPass();

std::unique_ptr<mlir::Pass> createFlatteningPass();
Expand Down
9 changes: 9 additions & 0 deletions include/llzk/Transforms/LLZKTransformationPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ def ComputeConstrainToProductPass
"Root struct at which to start alignment "
"(default to `@Main`)">];
}

def FuseProductLoopsPass : LLZKPass<"llzk-fuse-product-loops"> {
let summary =
"Fuses matching witness/constraint loops in a @product function";
let description = summary;
let constructor = "llzk::createFuseProductLoopsPass()";
let options = [];
}

#ifdef WITH_PCL
def PCLLoweringPass : LLZKPass<"llzk-to-pcl"> {
let summary = "Rewrites constraints to be compatible with PCL constraints "
Expand Down
58 changes: 58 additions & 0 deletions include/llzk/Util/AlignmentHelper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
//===-- AlignmentHelper.h ------------------------------------------*- C++ -*-===//
//
// Part of the LLZK Project, under the Apache License v2.0.
// See LICENSE.txt for license information.
// Copyright 2025 Veridise Inc.
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

#pragma once

#include <llvm/ADT/SetVector.h>
#include <llvm/ADT/SmallVectorExtras.h>
#include <llvm/Support/Debug.h>
#include <llvm/Support/LogicalResult.h>

#include <concepts>

namespace llzk::alignmentHelpers {

template <class ValueT, class FnT>
concept Matcher = requires(FnT fn, ValueT val) {
{ fn(val, val) } -> std::convertible_to<bool>;
};

template <class ValueT, class FnT>
requires Matcher<ValueT, FnT>
llvm::FailureOr<llvm::SetVector<std::pair<ValueT, ValueT>>> getMatchingPairs(
llvm::ArrayRef<ValueT> as, llvm::ArrayRef<ValueT> bs, FnT doesMatch, bool allowPartial = true
) {

llvm::SetVector<ValueT> setA {as.begin(), as.end()}, setB {bs.begin(), bs.end()};
llvm::DenseMap<size_t, llvm::SmallVector<size_t>> possibleMatchesA, possibleMatchesB;

for (size_t i = 0; i < as.size(); i++) {
for (size_t j = 0; j < bs.size(); j++) {
if (doesMatch(as[i], bs[j])) {
possibleMatchesA[i].push_back(j);
possibleMatchesB[j].push_back(i);
}
}
}

llvm::SetVector<std::pair<ValueT, ValueT>> matches;
for (auto [a, b] : possibleMatchesA) {
if (b.size() == 1 && possibleMatchesB[b[0]].size() == 1) {
setA.remove(as[a]);
setB.remove(bs[b[0]]);
matches.insert({as[a], bs[b[0]]});
}
}

if ((!setA.empty() || !setB.empty()) && !allowPartial) {
return llvm::failure();
}
return matches;
}
} // namespace llzk::alignmentHelpers
179 changes: 179 additions & 0 deletions lib/Transforms/LLZKFuseProductLoopsPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
//===-- LLZKComputeConstrainToProductPass.cpp -------------------*- C++ -*-===//
//
// Part of the LLZK Project, under the Apache License v2.0.
// See LICENSE.txt for license information.
// Copyright 2025 Veridise Inc.
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements the `-llzk-fuse-product-loops` pass.
///
//===----------------------------------------------------------------------===//

#include "llzk/Dialect/Function/IR/Ops.h"
#include "llzk/Dialect/Polymorphic/IR/Ops.h"
#include "llzk/Dialect/Struct/IR/Ops.h"
#include "llzk/Transforms/LLZKTransformationPasses.h"
#include "llzk/Util/AlignmentHelper.h"
#include "llzk/Util/Constants.h"

#include <mlir/Dialect/SCF/Utils/Utils.h>

#include <llvm/Support/Debug.h>
#include <llvm/Support/SMTAPI.h>

#include <functional>
#include <memory>
namespace llzk {
#define GEN_PASS_DECL_FUSEPRODUCTLOOPSPASS
#define GEN_PASS_DEF_FUSEPRODUCTLOOPSPASS
#include "llzk/Transforms/LLZKTransformationPasses.h.inc"

using namespace llzk::function;

// Bitwidth of `index` for instantiating SMT variables
constexpr int INDEX_WIDTH = 64;

class FuseProductLoopsPass : public impl::FuseProductLoopsPassBase<FuseProductLoopsPass> {
/// Identify pairs of scf.for loops that can be fused, fuse them, and then recurse to fuse nested
/// loops
void fuseMatchingLoopPairs(mlir::Region &body);
bool canLoopsBeFused(mlir::scf::ForOp a, mlir::scf::ForOp b);

public:
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
mod.walk([this](FuncDefOp funcDef) {
if (funcDef.isStructProduct()) {
fuseMatchingLoopPairs(funcDef.getFunctionBody());
}
});
}
};

bool isConstOrStructParam(mlir::Value val) {
// TODO: doing arithmetic over constants should also be fine?
return val.getDefiningOp<mlir::arith::ConstantIndexOp>() ||
val.getDefiningOp<llzk::polymorphic::ConstReadOp>();
}

llvm::SMTExprRef mkExpr(mlir::Value value, llvm::SMTSolver *solver) {
if (auto constOp = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
return solver->mkBitvector(llvm::APSInt::get(constOp.value()), INDEX_WIDTH);
} else if (auto polyReadOp = value.getDefiningOp<llzk::polymorphic::ConstReadOp>()) {

return solver->mkSymbol(
std::string {polyReadOp.getConstName()}.c_str(), solver->getBitvectorSort(INDEX_WIDTH)
);
}
assert(false && "unsupported: checking non-constant trip counts");
return nullptr; // Unreachable
}

llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver) {
const auto *one = solver->mkBitvector(llvm::APSInt::get(1), INDEX_WIDTH);
return solver->mkBVSDiv(
solver->mkBVAdd(
one,
solver->mkBVSub(mkExpr(op.getUpperBound(), solver), mkExpr(op.getLowerBound(), solver))
),
mkExpr(op.getStep(), solver)
);
}

bool FuseProductLoopsPass::canLoopsBeFused(mlir::scf::ForOp a, mlir::scf::ForOp b) {
// A priori, loops two loops can be fused if:
// 1. They live in the same parent region,
// 2. One comes from witgen and the other comes from constraint gen, and
// 3. They have the same trip count

// Check 1.
if (a->getParentRegion() != b->getParentRegion()) {
return false;
}

// Check 2.
if (!a->hasAttrOfType<mlir::StringAttr>("product_source") ||
!b->hasAttrOfType<mlir::StringAttr>("product_source")) {
// Ideally this should never happen, since the pass only runs on fused @product functions, but
// check anyway just to be safe
return false;
}
if (a->getAttrOfType<mlir::StringAttr>("product_source") ==
b->getAttrOfType<mlir::StringAttr>("product_source")) {
return false;
}

// Check 3.
// Easy case: both have a constant trip-count
auto tripCountA = mlir::constantTripCount(a.getLowerBound(), a.getUpperBound(), a.getStep());
auto tripCountB = mlir::constantTripCount(b.getLowerBound(), b.getUpperBound(), b.getStep());
if (tripCountA.has_value() && tripCountB.has_value() && *tripCountA == *tripCountB) {
return true;
}

// If the trip counts are not "constant up to a struct param", we definitely can't tell if they're
// equal
if (!isConstOrStructParam(a.getLowerBound()) || !isConstOrStructParam(a.getUpperBound()) ||
!isConstOrStructParam(a.getStep()) || !isConstOrStructParam(b.getLowerBound()) ||
!isConstOrStructParam(b.getUpperBound()) || !isConstOrStructParam(b.getStep())) {
return false;
}

// If the trip counts are only "constant up to a struct param" but not actually constant, we can
// ask a solver if the equations are guaranteed to be the same
llvm::SMTSolverRef solver = llvm::CreateZ3Solver();
solver->addConstraint(
/* (actually ask if they "can't be different") */ solver->mkNot(
solver->mkEqual(tripCount(a, solver.get()), tripCount(b, solver.get()))
)
);

// The loops are fusable if its impossible for the trip count expressions to be different
return !*solver->check();
}

void FuseProductLoopsPass::fuseMatchingLoopPairs(mlir::Region &body) {

// Start by collecting all possible loops
llvm::SmallVector<mlir::scf::ForOp> witnessLoops, constraintLoops;
body.walk<mlir::WalkOrder::PreOrder>([&witnessLoops, &constraintLoops](mlir::scf::ForOp forOp) {
if (!forOp->hasAttrOfType<mlir::StringAttr>("product_source")) {
return mlir::WalkResult::skip();
}
if (forOp->getAttrOfType<mlir::StringAttr>("product_source") == FUNC_NAME_COMPUTE) {
witnessLoops.push_back(forOp);
} else if (forOp->getAttrOfType<mlir::StringAttr>("product_source") == FUNC_NAME_CONSTRAIN) {
constraintLoops.push_back(forOp);
}
// Skipping here, because any nested loops can't possibly be fused at this stage
return mlir::WalkResult::skip();
});

// A pair of loops will be fused iff (1) they can be fused according to the rules above, and (2)
// neither can be fused with anything else (so there's no ambiguity)
auto fusionCandidates = alignmentHelpers::getMatchingPairs<mlir::scf::ForOp>(
witnessLoops, constraintLoops, std::bind_front(&FuseProductLoopsPass::canLoopsBeFused, this)
);

// This shouldn't happen, since we allow partial matches
if (mlir::failed(fusionCandidates)) {
signalPassFailure();
}

// Finally, fuse all the marked loops...
mlir::IRRewriter rewriter {&getContext()};
for (auto [w, c] : *fusionCandidates) {
auto fusedLoop = mlir::fuseIndependentSiblingForLoops(w, c, rewriter);
fusedLoop->setAttr("product_source", rewriter.getAttr<mlir::StringAttr>("fused"));
// ...and recurse to fuse nested loops
fuseMatchingLoopPairs(fusedLoop.getBodyRegion());
}
}

std::unique_ptr<mlir::Pass> createFuseProductLoopsPass() {
return std::make_unique<FuseProductLoopsPass>();
}
} // namespace llzk
5 changes: 4 additions & 1 deletion lib/Transforms/TransformationPassPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ void registerTransformationPassPipelines() {
PassPipelineRegistration<>(
"llzk-product-program",
"Convert @compute/@constrain functions to @product function and perform alignment",
[](OpPassManager &pm) { pm.addPass(llzk::createComputeConstrainToProductPass()); }
[](OpPassManager &pm) {
pm.addPass(llzk::createComputeConstrainToProductPass());
pm.addPass(llzk::createFuseProductLoopsPass());
}
);
}

Expand Down
Empty file added lib/Util/AlignmentHelper.cpp
Empty file.
71 changes: 71 additions & 0 deletions test/Transforms/FuseProductLoops/fuse_outer.llzk
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: llzk-opt --llzk-compute-constrain-to-product="root-struct=A" --llzk-fuse-product-loops %s | FileCheck --enable-var-scope %s

module attributes {veridise.lang = "llzk"} {
struct.def @A<[@N]> {
struct.field @arr : !array.type<@N x !felt.type>

function.def @compute(%as: !array.type<@N x !felt.type>) -> !struct.type<@A<[@N]>> {
%self = struct.new : <@A<[@N]>>

%N = poly.read_const @N : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

%arr = array.new : !array.type<@N x !felt.type>

scf.for %i = %c0 to %N step %c1 {
%sum0 = felt.const 0
%sum = scf.for %j = %i to %N step %c1 iter_args(%iter_sum = %sum0) -> !felt.type {
%a_j = array.read %as[%j] : !array.type<@N x !felt.type>, !felt.type
%0 = felt.add %iter_sum, %a_j
scf.yield %0 : !felt.type
}
array.write %arr[%i] = %sum : !array.type<@N x !felt.type>, !felt.type
scf.yield
}

struct.writef %self[@arr] = %arr : !struct.type<@A<[@N]>>, !array.type<@N x !felt.type>
function.return %self : !struct.type<@A<[@N]>>
}
function.def @constrain(%self : !struct.type<@A<[@N]>>, %as: !array.type<@N x !felt.type>) {
%N = poly.read_const @N : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

%arr = struct.readf %self[@arr] : !struct.type<@A<[@N]>>, !array.type<@N x !felt.type>

scf.for %i = %c0 to %N step %c1 {
%j = arith.addi %i, %c1 : index

%arr_i = array.read %arr[%i] : !array.type<@N x !felt.type>, !felt.type
%arr_j = array.read %arr[%j] : !array.type<@N x !felt.type>, !felt.type
%a = array.read %as[%i] : !array.type<@N x !felt.type>, !felt.type
%diff = felt.sub %arr_j, %arr_i
constrain.eq %a, %diff : !felt.type
scf.yield
}

function.return
}
}
}

// CHECK-LABEL: scf.for
// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]] =
// CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]] to
// CHECK-SAME: %[[VAL_2:[0-9a-zA-Z_\.]+]] step
// CHECK-SAME: %[[VAL_3:[0-9a-zA-Z_\.]+]] {
// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.const 0 {product_source = "compute"}
// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = scf.for %[[VAL_6:[0-9a-zA-Z_\.]+]] = %[[VAL_0]] to %[[VAL_7:[0-9a-zA-Z_\.]+]] step %[[VAL_8:[0-9a-zA-Z_\.]+]] iter_args(%[[VAL_9:[0-9a-zA-Z_\.]+]] = %[[VAL_4]]) -> (!felt.type) {
// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_11:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_6]]] : <@N x !felt.type>, !felt.type {product_source = "compute"}
// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type {product_source = "compute"}
// CHECK-NEXT: scf.yield {product_source = "compute"} %[[VAL_12]] : !felt.type
// CHECK-NEXT: } {product_source = "compute"}
// CHECK-NEXT: array.write %[[VAL_13:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] = %[[VAL_5]] : <@N x !felt.type>, !felt.type {product_source = "compute"}
// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = arith.addi %[[VAL_0]], %[[VAL_3]] {product_source = "constrain"} : index
// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_16:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"}
// CHECK-NEXT: %[[VAL_17:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_16]]{{\[}}%[[VAL_14]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"}
// CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19:[0-9a-zA-Z_\.]+]]{{\[}}%[[VAL_0]]] : <@N x !felt.type>, !felt.type {product_source = "constrain"}
// CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = felt.sub %[[VAL_17]], %[[VAL_15]] : !felt.type, !felt.type {product_source = "constrain"}
// CHECK-NEXT: constrain.eq %[[VAL_18]], %[[VAL_20]] : !felt.type, !felt.type {product_source = "constrain"}
// CHECK-NEXT: } {product_source = "fused"}
Loading