From 525608605058123ef06c2d6c3f61cdbfea9b5ab9 Mon Sep 17 00:00:00 2001 From: Ian Glen Neal Date: Tue, 11 Nov 2025 00:45:57 -0500 Subject: [PATCH 1/6] Add inverse logic --- include/llzk/Analysis/Field.h | 7 ++++++ include/llzk/Analysis/IntervalAnalysis.h | 32 ------------------------ include/llzk/Util/DynamicAPIntHelper.h | 7 ++++++ lib/Analysis/Field.cpp | 4 +++ lib/Analysis/IntervalAnalysis.cpp | 4 +-- lib/Util/DynamicAPIntHelper.cpp | 28 +++++++++++++++++++++ 6 files changed, 48 insertions(+), 34 deletions(-) diff --git a/include/llzk/Analysis/Field.h b/include/llzk/Analysis/Field.h index 941782742..fadc9139e 100644 --- a/include/llzk/Analysis/Field.h +++ b/include/llzk/Analysis/Field.h @@ -15,6 +15,8 @@ #include +#include "llzk/Util/DynamicAPIntHelper.h" + namespace llzk { /// @brief Information about the prime finite field used for the interval analysis. @@ -51,6 +53,11 @@ class Field { /// @brief Returns p - 1, which is the max value possible in a prime field described by p. inline llvm::DynamicAPInt maxVal() const { return prime() - one(); } + /// @brief Returns the multiplicative inverse of `i` in prime field `p`. + llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const; + + llvm::DynamicAPInt inv(const llvm::APInt &i) const; + /// @brief Returns i mod p and reduces the result into the appropriate bitwidth. /// Field elements are returned as signed integers so that negation functions /// as expected (i.e., reducing -1 will yield p-1). diff --git a/include/llzk/Analysis/IntervalAnalysis.h b/include/llzk/Analysis/IntervalAnalysis.h index 13311d5c7..78c190b20 100644 --- a/include/llzk/Analysis/IntervalAnalysis.h +++ b/include/llzk/Analysis/IntervalAnalysis.h @@ -358,48 +358,16 @@ class IntervalDataFlowAnalysis mlir::FailureOr, Interval>> getGeneralizedDecompInterval(mlir::Operation *baseOp, mlir::Value lhs, mlir::Value rhs); - bool isBoolOp(mlir::Operation *op) const { - return llvm::isa( - op - ); - } - - bool isConversionOp(mlir::Operation *op) const { - return llvm::isa(op); - } - - bool isApplyMapOp(mlir::Operation *op) const { return llvm::isa(op); } - - bool isAssertOp(mlir::Operation *op) const { return llvm::isa(op); } - bool isReadOp(mlir::Operation *op) const { return llvm::isa(op); } - bool isWriteOp(mlir::Operation *op) const { - return llvm::isa(op); - } - - bool isArrayLengthOp(mlir::Operation *op) const { return llvm::isa(op); } - - bool isEmitOp(mlir::Operation *op) const { - return llvm::isa(op); - } - - bool isCreateOp(mlir::Operation *op) const { - return llvm::isa(op); - } - - bool isExtractArrayOp(mlir::Operation *op) const { return llvm::isa(op); } - bool isDefinitionOp(mlir::Operation *op) const { return llvm::isa< component::StructDefOp, function::FuncDefOp, component::FieldDefOp, global::GlobalDefOp, mlir::ModuleOp>(op); } - bool isCallOp(mlir::Operation *op) const { return llvm::isa(op); } - bool isReturnOp(mlir::Operation *op) const { return llvm::isa(op); } /// @brief Get the SourceRefLattice that defines `val`, or the SourceRefLattice after `baseOp` diff --git a/include/llzk/Util/DynamicAPIntHelper.h b/include/llzk/Util/DynamicAPIntHelper.h index a83e2c548..bae476c1d 100644 --- a/include/llzk/Util/DynamicAPIntHelper.h +++ b/include/llzk/Util/DynamicAPIntHelper.h @@ -42,4 +42,11 @@ inline llvm::DynamicAPInt toDynamicAPInt(const llvm::APInt &i) { llvm::APSInt toAPSInt(const llvm::DynamicAPInt &i); +llvm::DynamicAPInt modExp(const llvm::DynamicAPInt &base, + const llvm::DynamicAPInt &exp, + const llvm::DynamicAPInt &mod); + +llvm::DynamicAPInt modInversePrime(const llvm::DynamicAPInt &f, + const llvm::DynamicAPInt &p); + } // namespace llzk diff --git a/lib/Analysis/Field.cpp b/lib/Analysis/Field.cpp index eee72f581..61ddb5cf1 100644 --- a/lib/Analysis/Field.cpp +++ b/lib/Analysis/Field.cpp @@ -64,4 +64,8 @@ DynamicAPInt Field::reduce(const DynamicAPInt &i) const { DynamicAPInt Field::reduce(const APInt &i) const { return reduce(toDynamicAPInt(i)); } +DynamicAPInt Field::inv(const DynamicAPInt &i) const { return modInversePrime(i, prime()); } + +DynamicAPInt Field::inv(const llvm::APInt &i) const { return modInversePrime(toDynamicAPInt(i), prime()); } + } // namespace llzk diff --git a/lib/Analysis/IntervalAnalysis.cpp b/lib/Analysis/IntervalAnalysis.cpp index 9120605a0..29806b40b 100644 --- a/lib/Analysis/IntervalAnalysis.cpp +++ b/lib/Analysis/IntervalAnalysis.cpp @@ -606,7 +606,7 @@ IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, L // to the operand. changed |= applyInterval( assertOp, after, after, assertOp.getCondition(), - Interval::Degenerate(field.get(), field.get().one()) + Interval::True(field.get()) ); // Also add the solver constraint that the expression must be true. auto assertExpr = operandVals[0].getScalarValue(); @@ -1116,7 +1116,7 @@ IntervalDataFlowAnalysis::getGeneralizedDecompInterval(Operation *baseOp, Value // Now, we aggregate the Interval. If we have sparse values (e.g., 0, 2, 4), // we will create a larger range of [0, 4], since we don't support multiple intervals. std::sort(consts.begin(), consts.end()); - Interval iv = Interval::TypeA(field.get(), consts.front(), consts.back()); + Interval iv = UnreducedInterval(consts.front(), consts.back()).reduce(field.get()); return std::make_pair(std::move(signalVals), iv); } diff --git a/lib/Util/DynamicAPIntHelper.cpp b/lib/Util/DynamicAPIntHelper.cpp index e4524025e..b7ab82f3e 100644 --- a/lib/Util/DynamicAPIntHelper.cpp +++ b/lib/Util/DynamicAPIntHelper.cpp @@ -134,4 +134,32 @@ APSInt toAPSInt(const DynamicAPInt &i) { return res; } +DynamicAPInt modExp(const DynamicAPInt &base, + const DynamicAPInt &exp, + const DynamicAPInt &mod) { + DynamicAPInt result(1); + DynamicAPInt b = base; + DynamicAPInt e = exp; + DynamicAPInt one(1); + + while (e != 0) { + if (e % 2 != 0) { + result = (result * b) % mod; + } + + b = (b * b) % mod; + e = e >> one; + } + assert((base * result) % mod == 1 && "inverse is incorrect"); + return result; +} + +llvm::DynamicAPInt modInversePrime(const DynamicAPInt &f, + const DynamicAPInt &p) { + assert(f != 0 && "0 has no inverse"); + // Fermat: f^(p-2) mod p + DynamicAPInt exp = p - 2; + return modExp(f, exp, p); +} + } // namespace llzk From 19a06d34042035eedcf7bb7cf0367879ac3d5e44 Mon Sep 17 00:00:00 2001 From: Ian Glen Neal Date: Tue, 11 Nov 2025 23:34:20 -0500 Subject: [PATCH 2/6] - Port SparseAnalysis from MLIR - Fix TypeA -> Degenerate interval conversion --- include/llzk/Analysis/AnalysisUtil.h | 25 + include/llzk/Analysis/AnalysisWrappers.h | 2 +- include/llzk/Analysis/DenseAnalysis.h | 13 - include/llzk/Analysis/IntervalAnalysis.h | 68 +- include/llzk/Analysis/SparseAnalysis.h | 248 +++++++ lib/Analysis/AnalysisUtil.cpp | 32 + lib/Analysis/DenseAnalysis.cpp | 16 - lib/Analysis/IntervalAnalysis.cpp | 606 ++++++++---------- lib/Analysis/Intervals.cpp | 29 +- lib/Analysis/SparseAnalysis.cpp | 360 +++++++++++ .../interval_analysis_pass.llzk | 2 +- 11 files changed, 992 insertions(+), 409 deletions(-) create mode 100644 include/llzk/Analysis/AnalysisUtil.h create mode 100644 include/llzk/Analysis/SparseAnalysis.h create mode 100644 lib/Analysis/AnalysisUtil.cpp create mode 100644 lib/Analysis/SparseAnalysis.cpp diff --git a/include/llzk/Analysis/AnalysisUtil.h b/include/llzk/Analysis/AnalysisUtil.h new file mode 100644 index 000000000..7dbc9d244 --- /dev/null +++ b/include/llzk/Analysis/AnalysisUtil.h @@ -0,0 +1,25 @@ +//===-- AnalysisUtil.h - Data-flow analysis utils ---------------*- C++ -*-===// +// +// Part of the LLZK Project, under the Apache License v2.0. +// See LICENSE.txt for license information. +// Copyright 2025 Veridise Inc. +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace llzk::dataflow { + +/// LLZK: Added this utility to ensure analysis is performed for all structs +/// in a given module. +/// +/// @brief Mark all operations from the top and included in the top operation +/// as live so the solver will perform dataflow analyses. +/// @param solver The solver. +/// @param top The top-level operation. +void markAllOpsAsLive(mlir::DataFlowSolver &solver, mlir::Operation *top); + +} // namespace llzk::dataflow diff --git a/include/llzk/Analysis/AnalysisWrappers.h b/include/llzk/Analysis/AnalysisWrappers.h index c65326469..527664197 100644 --- a/include/llzk/Analysis/AnalysisWrappers.h +++ b/include/llzk/Analysis/AnalysisWrappers.h @@ -22,7 +22,7 @@ #pragma once -#include "llzk/Analysis/DenseAnalysis.h" +#include "llzk/Analysis/AnalysisUtil.h" #include "llzk/Dialect/Struct/IR/Ops.h" #include "llzk/Util/Compare.h" #include "llzk/Util/ErrorHelper.h" diff --git a/include/llzk/Analysis/DenseAnalysis.h b/include/llzk/Analysis/DenseAnalysis.h index 4494d9baf..1d1bbc0d1 100644 --- a/include/llzk/Analysis/DenseAnalysis.h +++ b/include/llzk/Analysis/DenseAnalysis.h @@ -36,19 +36,6 @@ namespace llzk::dataflow { -//===----------------------------------------------------------------------===// -// Utilities -//===----------------------------------------------------------------------===// - -/// LLZK: Added this utility to ensure analysis is performed for all structs -/// in a given module. -/// -/// @brief Mark all operations from the top and included in the top operation -/// as live so the solver will perform dataflow analyses. -/// @param solver The solver. -/// @param top The top-level operation. -void markAllOpsAsLive(mlir::DataFlowSolver &solver, mlir::Operation *top); - //===----------------------------------------------------------------------===// // AbstractDenseForwardDataFlowAnalysis //===----------------------------------------------------------------------===// diff --git a/include/llzk/Analysis/IntervalAnalysis.h b/include/llzk/Analysis/IntervalAnalysis.h index 78c190b20..952a9fea3 100644 --- a/include/llzk/Analysis/IntervalAnalysis.h +++ b/include/llzk/Analysis/IntervalAnalysis.h @@ -15,6 +15,7 @@ #include "llzk/Analysis/DenseAnalysis.h" #include "llzk/Analysis/Field.h" #include "llzk/Analysis/Intervals.h" +#include "llzk/Analysis/SparseAnalysis.h" #include "llzk/Dialect/Array/IR/Ops.h" #include "llzk/Dialect/Bool/IR/Ops.h" #include "llzk/Dialect/Cast/IR/Ops.h" @@ -199,9 +200,7 @@ class IntervalAnalysisLatticeValue class IntervalDataFlowAnalysis; -/// @brief Maps mlir::Values to LatticeValues. -/// -class IntervalAnalysisLattice : public dataflow::AbstractDenseLattice { +class IntervalAnalysisLattice : public dataflow::AbstractSparseLattice { public: using LatticeValue = IntervalAnalysisLatticeValue; // Map mlir::Values to LatticeValues @@ -214,23 +213,18 @@ class IntervalAnalysisLattice : public dataflow::AbstractDenseLattice { // Tracks all constraints and assignments in insertion order using ConstraintSet = llvm::SetVector; - using AbstractDenseLattice::AbstractDenseLattice; + using AbstractSparseLattice::AbstractSparseLattice; - mlir::ChangeResult join(const AbstractDenseLattice &other) override; + mlir::ChangeResult join(const AbstractSparseLattice &other) override; - mlir::ChangeResult meet(const AbstractDenseLattice & /*rhs*/) override { - llvm::report_fatal_error("IntervalDataFlowAnalysis::meet : unsupported"); - return mlir::ChangeResult::NoChange; - } + mlir::ChangeResult meet(const AbstractSparseLattice &other) override; void print(mlir::raw_ostream &os) const override; - mlir::FailureOr getValue(mlir::Value v) const; - mlir::FailureOr getValue(mlir::Value v, mlir::StringAttr f) const; + const LatticeValue &getValue() const { return val; } - mlir::ChangeResult setValue(mlir::Value v, const LatticeValue &val); - mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e); - mlir::ChangeResult setValue(mlir::Value v, mlir::StringAttr f, ExpressionValue e); + mlir::ChangeResult setValue(const LatticeValue &val); + mlir::ChangeResult setValue(ExpressionValue e); mlir::ChangeResult addSolverConstraint(ExpressionValue e); @@ -244,27 +238,16 @@ class IntervalAnalysisLattice : public dataflow::AbstractDenseLattice { mlir::FailureOr findInterval(llvm::SMTExprRef expr) const; mlir::ChangeResult setInterval(llvm::SMTExprRef expr, const Interval &i); - size_t size() const { return valMap.size(); } - - const ValueMap &getMap() const { return valMap; } - - ValueMap::iterator begin() { return valMap.begin(); } - ValueMap::iterator end() { return valMap.end(); } - ValueMap::const_iterator begin() const { return valMap.begin(); } - ValueMap::const_iterator end() const { return valMap.end(); } - private: - ValueMap valMap; - FieldMap fieldMap; + LatticeValue val; ConstraintSet constraints; - ExpressionIntervals intervals; }; /* IntervalDataFlowAnalysis */ class IntervalDataFlowAnalysis - : public dataflow::DenseForwardDataFlowAnalysis { - using Base = dataflow::DenseForwardDataFlowAnalysis; + : public dataflow::SparseForwardDataFlowAnalysis { + using Base = dataflow::SparseForwardDataFlowAnalysis; using Lattice = IntervalAnalysisLattice; using LatticeValue = IntervalAnalysisLattice::LatticeValue; @@ -276,23 +259,28 @@ class IntervalDataFlowAnalysis mlir::DataFlowSolver &dataflowSolver, llvm::SMTSolverRef smt, const Field &f, bool propInputConstraints ) - : Base::DenseForwardDataFlowAnalysis(dataflowSolver), _dataflowSolver(dataflowSolver), + : Base::SparseForwardDataFlowAnalysis(dataflowSolver), _dataflowSolver(dataflowSolver), smtSolver(smt), field(f), propagateInputConstraints(propInputConstraints) {} - void visitCallControlFlowTransfer( - mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before, - Lattice *after + mlir::LogicalResult visitOperation( + mlir::Operation *op, mlir::ArrayRef operands, + mlir::ArrayRef results ) override; - mlir::LogicalResult - visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override; - /// @brief Either return the existing SMT expression that corresponds to the SourceRef, /// or create one. /// @param r /// @return llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r); + const llvm::DenseMap> &getFieldReadResults() const { + return fieldReadResults; + } + + const llvm::DenseMap &getFieldWriteResults() const { + return fieldWriteResults; + } + private: mlir::DataFlowSolver &_dataflowSolver; llvm::SMTSolverRef smtSolver; @@ -301,6 +289,11 @@ class IntervalDataFlowAnalysis bool propagateInputConstraints; mlir::SymbolTableCollection tables; + // Track field reads so that propagations to fields can be all updated efficiently. + llvm::DenseMap> fieldReadResults; + // Track field writes values. For now, we'll overapproximate this. + llvm::DenseMap fieldWriteResults; + void setToEntryState(Lattice *lattice) override { // initial state should be empty, so do nothing here } @@ -349,10 +342,7 @@ class IntervalDataFlowAnalysis /// @param after The current lattice state. Assumes that this has already been joined with the /// `before` lattice in `visitOperation`, so lookups and updates can be performed on the `after` /// lattice alone. - mlir::ChangeResult applyInterval( - mlir::Operation *originalOp, Lattice *originalLattice, Lattice *after, mlir::Value val, - Interval newInterval - ); + void applyInterval(mlir::Operation *originalOp, mlir::Value val, Interval newInterval); /// @brief Special handling for generalized (s - c0) * (s - c1) * ... * (s - cN) = 0 patterns. mlir::FailureOr, Interval>> diff --git a/include/llzk/Analysis/SparseAnalysis.h b/include/llzk/Analysis/SparseAnalysis.h new file mode 100644 index 000000000..44dd1f770 --- /dev/null +++ b/include/llzk/Analysis/SparseAnalysis.h @@ -0,0 +1,248 @@ +//===- SparseAnalysis.h - Sparse data-flow analysis -----------------------===// +// +// Part of the LLZK Project, under the Apache License v2.0. +// See LICENSE.txt for license information. +// Copyright 2025 Veridise Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Adapted from mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements sparse data-flow analysis using the data-flow analysis +/// framework. The analysis is forward and conditional and uses the results of +/// dead code analysis to prune dead code during the analysis. +/// +/// This file has been ported from the MLIR analysis so that it may be +/// tailored to work for LLZK modules, +/// as LLZK modules have different symbol lookup mechanisms that are currently +/// incompatible with the builtin MLIR dataflow analyses. +/// This file is mostly left as original in MLIR, with notes added where +/// changes have been made. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace llzk::dataflow { + +using AbstractSparseLattice = mlir::dataflow::AbstractSparseLattice; + +//===----------------------------------------------------------------------===// +// AbstractSparseForwardDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// LLZK: This class has been ported from the MLIR DenseAnalysis utilities to +/// allow for the use of custom LLZK symbol lookup logic. The class has been +/// left as unmodified as possible, with explicit comments added where modifications +/// have been made. +/// +/// Base class for sparse forward data-flow analyses. A sparse analysis +/// implements a transfer function on operations from the lattices of the +/// operands to the lattices of the results. This analysis will propagate +/// lattices across control-flow edges and the callgraph using liveness +/// information. +/// +/// Visit a program point in sparse forward data-flow analysis will invoke the +/// transfer function of the operation preceding the program point iterator. +/// Visit a program point at the begining of block will visit the block itself. +class AbstractSparseForwardDataFlowAnalysis : public mlir::DataFlowAnalysis { +public: + /// Initialize the analysis by visiting every owner of an SSA value: all + /// operations and blocks. + mlir::LogicalResult initialize(mlir::Operation *top) override; + + /// Visit a program point. If this is at beginning of block and all + /// control-flow predecessors or callsites are known, then the arguments + /// lattices are propagated from them. If this is after call operation or an + /// operation with region control-flow, then its result lattices are set + /// accordingly. Otherwise, the operation transfer function is invoked. + mlir::LogicalResult visit(mlir::ProgramPoint *point) override; + +protected: + explicit AbstractSparseForwardDataFlowAnalysis(mlir::DataFlowSolver &solver); + + /// The operation transfer function. Given the operand lattices, this + /// function is expected to set the result lattices. + virtual mlir::LogicalResult visitOperationImpl( + mlir::Operation *op, mlir::ArrayRef operandLattices, + mlir::ArrayRef resultLattices + ) = 0; + + /// The transfer function for calls to external functions. + virtual void visitExternalCallImpl( + mlir::CallOpInterface call, mlir::ArrayRef argumentLattices, + mlir::ArrayRef resultLattices + ) = 0; + + /// Given an operation with region control-flow, the lattices of the operands, + /// and a region successor, compute the lattice values for block arguments + /// that are not accounted for by the branching control flow (ex. the bounds + /// of loops). + virtual void visitNonControlFlowArgumentsImpl( + mlir::Operation *op, const mlir::RegionSuccessor &successor, + mlir::ArrayRef argLattices, unsigned firstIndex + ) = 0; + + /// Get the lattice element of a value. + virtual AbstractSparseLattice *getLatticeElement(mlir::Value value) = 0; + + /// Get a read-only lattice element for a value and add it as a dependency to + /// a program point. + const AbstractSparseLattice *getLatticeElementFor(mlir::ProgramPoint *point, mlir::Value value); + + /// Set the given lattice element(s) at control flow entry point(s). + virtual void setToEntryState(AbstractSparseLattice *lattice) = 0; + void setAllToEntryStates(mlir::ArrayRef lattices); + + /// Join the lattice element and propagate and update if it changed. + void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs); + + /// LLZK: Added for use of symbol helper caching. + mlir::SymbolTableCollection tables; + +private: + /// Recursively initialize the analysis on nested operations and blocks. + mlir::LogicalResult initializeRecursively(mlir::Operation *op); + + /// Visit an operation. If this is a call operation or an operation with + /// region control-flow, then its result lattices are set accordingly. + /// Otherwise, the operation transfer function is invoked. + mlir::LogicalResult visitOperation(mlir::Operation *op); + + /// Visit a block to compute the lattice values of its arguments. If this is + /// an entry block, then the argument values are determined from the block's + /// "predecessors" as set by `PredecessorState`. The predecessors can be + /// region terminators or callable callsites. Otherwise, the values are + /// determined from block predecessors. + void visitBlock(mlir::Block *block); + + /// Visit a program point `point` with predecessors within a region branch + /// operation `branch`, which can either be the entry block of one of the + /// regions or the parent operation itself, and set either the argument or + /// parent result lattices. + void visitRegionSuccessors( + mlir::ProgramPoint *point, mlir::RegionBranchOpInterface branch, + mlir::RegionBranchPoint successor, mlir::ArrayRef lattices + ); +}; + +//===----------------------------------------------------------------------===// +// SparseForwardDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// A sparse forward data-flow analysis for propagating SSA value lattices +/// across the IR by implementing transfer functions for operations. +/// +/// `StateT` is expected to be a subclass of `AbstractSparseLattice`. +template +class SparseForwardDataFlowAnalysis : public AbstractSparseForwardDataFlowAnalysis { + static_assert( + std::is_base_of::value, + "analysis state class expected to subclass AbstractSparseLattice" + ); + +public: + explicit SparseForwardDataFlowAnalysis(mlir::DataFlowSolver &solver) + : AbstractSparseForwardDataFlowAnalysis(solver) {} + + /// Visit an operation with the lattices of its operands. This function is + /// expected to set the lattices of the operation's results. + virtual mlir::LogicalResult visitOperation( + mlir::Operation *op, mlir::ArrayRef operands, mlir::ArrayRef results + ) = 0; + + /// Visit a call operation to an externally defined function given the + /// lattices of its arguments. + virtual void visitExternalCall( + mlir::CallOpInterface call, mlir::ArrayRef argumentLattices, + mlir::ArrayRef resultLattices + ) { + setAllToEntryStates(resultLattices); + } + + /// Given an operation with possible region control-flow, the lattices of the + /// operands, and a region successor, compute the lattice values for block + /// arguments that are not accounted for by the branching control flow (ex. + /// the bounds of loops). By default, this method marks all such lattice + /// elements as having reached a pessimistic fixpoint. `firstIndex` is the + /// index of the first element of `argLattices` that is set by control-flow. + virtual void visitNonControlFlowArguments( + mlir::Operation *op, const mlir::RegionSuccessor &successor, + mlir::ArrayRef argLattices, unsigned firstIndex + ) { + setAllToEntryStates(argLattices.take_front(firstIndex)); + setAllToEntryStates(argLattices.drop_front(firstIndex + successor.getSuccessorInputs().size())); + } + +protected: + /// Get the lattice element for a value. + StateT *getLatticeElement(mlir::Value value) override { return getOrCreate(value); } + + /// Get the lattice element for a value and create a dependency on the + /// provided program point. + const StateT *getLatticeElementFor(mlir::ProgramPoint *point, mlir::Value value) { + return static_cast( + AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(point, value) + ); + } + + /// Set the given lattice element(s) at control flow entry point(s). + virtual void setToEntryState(StateT *lattice) = 0; + void setAllToEntryStates(mlir::ArrayRef lattices) { + AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates( + {reinterpret_cast(lattices.begin()), lattices.size()} + ); + } + +private: + /// Type-erased wrappers that convert the abstract lattice operands to derived + /// lattices and invoke the virtual hooks operating on the derived lattices. + llvm::LogicalResult visitOperationImpl( + mlir::Operation *op, mlir::ArrayRef operandLattices, + mlir::ArrayRef resultLattices + ) override { + return visitOperation( + op, + {reinterpret_cast(operandLattices.begin()), operandLattices.size()}, + {reinterpret_cast(resultLattices.begin()), resultLattices.size()} + ); + } + void visitExternalCallImpl( + mlir::CallOpInterface call, mlir::ArrayRef argumentLattices, + mlir::ArrayRef resultLattices + ) override { + visitExternalCall( + call, + {reinterpret_cast(argumentLattices.begin()), + argumentLattices.size()}, + {reinterpret_cast(resultLattices.begin()), resultLattices.size()} + ); + } + void visitNonControlFlowArgumentsImpl( + mlir::Operation *op, const mlir::RegionSuccessor &successor, + mlir::ArrayRef argLattices, unsigned firstIndex + ) override { + visitNonControlFlowArguments( + op, successor, {reinterpret_cast(argLattices.begin()), argLattices.size()}, + firstIndex + ); + } + void setToEntryState(AbstractSparseLattice *lattice) override { + return setToEntryState(reinterpret_cast(lattice)); + } +}; + +} // namespace llzk::dataflow diff --git a/lib/Analysis/AnalysisUtil.cpp b/lib/Analysis/AnalysisUtil.cpp new file mode 100644 index 000000000..a86d0fa11 --- /dev/null +++ b/lib/Analysis/AnalysisUtil.cpp @@ -0,0 +1,32 @@ +//===-- AnalysisUtil.cpp - Data-flow analysis utils -------------*- C++ -*-===// +// +// Part of the LLZK Project, under the Apache License v2.0. +// See LICENSE.txt for license information. +// Copyright 2025 Veridise Inc. +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#include "llzk/Analysis/AnalysisUtil.h" + +#include + +using namespace mlir; + +using Executable = mlir::dataflow::Executable; + +namespace llzk::dataflow { + +void markAllOpsAsLive(DataFlowSolver &solver, Operation *top) { + for (Region ®ion : top->getRegions()) { + for (Block &block : region) { + ProgramPoint *point = solver.getProgramPointBefore(&block); + (void)solver.getOrCreateState(point)->setToLive(); + for (Operation &oper : block) { + markAllOpsAsLive(solver, &oper); + } + } + } +} + +} // namespace llzk::dataflow diff --git a/lib/Analysis/DenseAnalysis.cpp b/lib/Analysis/DenseAnalysis.cpp index bb2027f28..fc41f741c 100644 --- a/lib/Analysis/DenseAnalysis.cpp +++ b/lib/Analysis/DenseAnalysis.cpp @@ -42,22 +42,6 @@ using namespace function; namespace dataflow { -//===----------------------------------------------------------------------===// -// Utilities -//===----------------------------------------------------------------------===// - -void markAllOpsAsLive(DataFlowSolver &solver, Operation *top) { - for (Region ®ion : top->getRegions()) { - for (Block &block : region) { - ProgramPoint *point = solver.getProgramPointBefore(&block); - (void)solver.getOrCreateState(point)->setToLive(); - for (Operation &oper : block) { - markAllOpsAsLive(solver, &oper); - } - } - } -} - //===----------------------------------------------------------------------===// // AbstractDenseForwardDataFlowAnalysis //===----------------------------------------------------------------------===// diff --git a/lib/Analysis/IntervalAnalysis.cpp b/lib/Analysis/IntervalAnalysis.cpp index 29806b40b..748a58311 100644 --- a/lib/Analysis/IntervalAnalysis.cpp +++ b/lib/Analysis/IntervalAnalysis.cpp @@ -9,6 +9,7 @@ #include "llzk/Analysis/IntervalAnalysis.h" #include "llzk/Analysis/Matchers.h" +#include "llzk/Dialect/Array/IR/Ops.h" #include "llzk/Util/Debug.h" #include "llzk/Util/StreamHelper.h" @@ -277,104 +278,55 @@ void ExpressionValue::print(mlir::raw_ostream &os) const { /* IntervalAnalysisLattice */ -ChangeResult IntervalAnalysisLattice::join(const AbstractDenseLattice &other) { +ChangeResult IntervalAnalysisLattice::join(const AbstractSparseLattice &other) { const auto *rhs = dynamic_cast(&other); if (!rhs) { llvm::report_fatal_error("invalid join lattice type"); } - ChangeResult res = ChangeResult::NoChange; - for (auto &[k, v] : rhs->valMap) { - auto it = valMap.find(k); - if (it == valMap.end() || it->second != v) { - valMap[k] = v; - res |= ChangeResult::Change; - } - } + ChangeResult res = val.update(rhs->getValue()); for (auto &v : rhs->constraints) { if (!constraints.contains(v)) { constraints.insert(v); res |= ChangeResult::Change; } } - for (auto &[e, i] : rhs->intervals) { - auto it = intervals.find(e); - if (it == intervals.end() || it->second != i) { - intervals[e] = i; - res |= ChangeResult::Change; - } - } return res; } -void IntervalAnalysisLattice::print(mlir::raw_ostream &os) const { - os << "IntervalAnalysisLattice { "; - for (auto &[ref, val] : valMap) { - os << "\n (valMap) " << ref << " := " << val; +ChangeResult IntervalAnalysisLattice::meet(const AbstractSparseLattice &other) { + const auto *rhs = dynamic_cast(&other); + if (!rhs) { + llvm::report_fatal_error("invalid join lattice type"); } - for (auto &[expr, interval] : intervals) { - os << "\n (intervals) "; - if (!expr) { - os << ""; - } else { - expr->print(os); + // Intersect the intervals + ExpressionValue lhsExpr = val.getScalarValue(); + ExpressionValue rhsExpr = val.getScalarValue(); + Interval newInterval = lhsExpr.getInterval().intersect(rhsExpr.getInterval()); + ChangeResult res = setValue(lhsExpr.withInterval(newInterval)); + for (auto &v : rhs->constraints) { + if (!constraints.contains(v)) { + constraints.insert(v); + res |= ChangeResult::Change; } - os << " in " << interval; - } - if (!valMap.empty()) { - os << '\n'; - } - os << '}'; -} - -FailureOr IntervalAnalysisLattice::getValue(Value v) const { - auto it = valMap.find(v); - if (it == valMap.end()) { - return failure(); - } - return it->second; -} - -FailureOr -IntervalAnalysisLattice::getValue(Value v, StringAttr f) const { - auto it = fieldMap.find(v); - if (it == fieldMap.end()) { - return failure(); - } - auto fit = it->second.find(f); - if (fit == it->second.end()) { - return failure(); } - return fit->second; + return res; } -ChangeResult IntervalAnalysisLattice::setValue(Value v, const LatticeValue &val) { - if (valMap[v] == val) { - return ChangeResult::NoChange; - } - valMap[v] = val; - ExpressionValue e = val.foldToScalar(); - intervals[e.getExpr()] = e.getInterval(); - return ChangeResult::Change; +void IntervalAnalysisLattice::print(mlir::raw_ostream &os) const { + os << "IntervalAnalysisLattice { " << val << " }"; } -ChangeResult IntervalAnalysisLattice::setValue(Value v, ExpressionValue e) { - LatticeValue val(e); - if (valMap[v] == val) { +ChangeResult IntervalAnalysisLattice::setValue(const LatticeValue &newVal) { + if (val == newVal) { return ChangeResult::NoChange; } - valMap[v] = val; - intervals[e.getExpr()] = e.getInterval(); + val = newVal; return ChangeResult::Change; } -ChangeResult IntervalAnalysisLattice::setValue(Value v, StringAttr f, ExpressionValue e) { - LatticeValue val(e); - if (fieldMap[v][f] == val) { - return ChangeResult::NoChange; - } - fieldMap[v][f] = val; - intervals[e.getExpr()] = e.getInterval(); - return ChangeResult::Change; +ChangeResult IntervalAnalysisLattice::setValue(ExpressionValue e) { + LatticeValue newVal(e); + return setValue(newVal); } ChangeResult IntervalAnalysisLattice::addSolverConstraint(ExpressionValue e) { @@ -385,67 +337,8 @@ ChangeResult IntervalAnalysisLattice::addSolverConstraint(ExpressionValue e) { return ChangeResult::NoChange; } -FailureOr IntervalAnalysisLattice::findInterval(llvm::SMTExprRef expr) const { - auto it = intervals.find(expr); - if (it != intervals.end()) { - return it->second; - } - return failure(); -} - -ChangeResult IntervalAnalysisLattice::setInterval(llvm::SMTExprRef expr, const Interval &i) { - auto it = intervals.find(expr); - if (it != intervals.end() && it->second == i) { - return ChangeResult::NoChange; - } - intervals[expr] = i; - return ChangeResult::Change; -} - /* IntervalDataFlowAnalysis */ -/// @brief The interval analysis is intraprocedural only for now, so this control -/// flow transfer function passes no data to the callee and sets the post-call -/// state to that of the pre-call state (i.e., calls are ignored). -void IntervalDataFlowAnalysis::visitCallControlFlowTransfer( - CallOpInterface call, dataflow::CallControlFlowAction action, - const IntervalAnalysisLattice &before, IntervalAnalysisLattice *after -) { - /// `action == CallControlFlowAction::Enter` indicates that: - /// - `before` is the state before the call operation; - /// - `after` is the state at the beginning of the callee entry block; - if (action == dataflow::CallControlFlowAction::EnterCallee) { - // We skip updating the incoming lattice for function calls, - // as values are relative to the containing function/struct, so we don't need to pollute - // the callee with the callers values. - setToEntryState(after); - } - /// `action == CallControlFlowAction::Exit` indicates that: - /// - `before` is the state at the end of a callee exit block; - /// - `after` is the state after the call operation. - else if (action == dataflow::CallControlFlowAction::ExitCallee) { - // Get the argument values of the lattice by getting the state as it would - // have been for the callsite. - const dataflow::AbstractDenseLattice *beforeCall = getLattice(getProgramPointBefore(call)); - ensure(beforeCall, "could not get prior lattice"); - - // The lattice at the return is the lattice before the call - propagateIfChanged(after, after->join(*beforeCall)); - } - /// `action == CallControlFlowAction::External` indicates that: - /// - `before` is the state before the call operation. - /// - `after` is the state after the call operation, since there is no callee - /// body to enter into. - else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) { - // For external calls, we propagate what information we already have from - // before the call to after the call, since the external call won't invalidate - // any of that information. It also, conservatively, makes no assumptions about - // external calls and their computation, so CDG edges will not be computed over - // input arguments to external functions. - join(after, before); - } -} - const SourceRefLattice * IntervalDataFlowAnalysis::getSourceRefLattice(Operation *baseOp, Value val) { ProgramPoint *pp = _dataflowSolver.getProgramPointAfter(baseOp); @@ -460,39 +353,25 @@ IntervalDataFlowAnalysis::getSourceRefLattice(Operation *baseOp, Value val) { return defaultSourceRefLattice; } -mlir::LogicalResult -IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, Lattice *after) { +mlir::LogicalResult IntervalDataFlowAnalysis::visitOperation( + Operation *op, ArrayRef operands, ArrayRef results +) { // We only perform the visitation on operations within functions FuncDefOp fn = op->getParentOfType(); if (!fn) { return success(); } - ChangeResult changed = ChangeResult::NoChange; - // We always propagate the values of the function args from the function - // entry as the function context; if the input values are changed, this will - // force the recomputation of intervals throughout the function. - for (BlockArgument blockArg : fn.getArguments()) { - auto blockArgLookupRes = before.getValue(blockArg); - if (succeeded(blockArgLookupRes)) { - changed |= after->setValue(blockArg, *blockArgLookupRes); - } + // If there are no operands or results, skip. + if (operands.empty() && results.empty()) { + return success(); } - auto getAfter = [&](Value val) { - if (Operation *defOp = val.getDefiningOp()) { - return getLattice(getProgramPointAfter(defOp)); - } else if (auto blockArg = dyn_cast(val)) { - Operation *blockEntry = &blockArg.getOwner()->front(); - return getLattice(getProgramPointBefore(blockEntry)); - } - return getLattice(getProgramPointBefore(op)); - }; - + // Get the values or defaults from the operand lattices llvm::SmallVector operandVals; llvm::SmallVector> operandRefs; - for (OpOperand &operand : op->getOpOperands()) { - Value val = operand.get(); + for (unsigned opNum = 0; opNum < op->getNumOperands(); ++opNum) { + Value val = op->getOperand(opNum); SourceRefLatticeValue refSet = getSourceRefLattice(op, val)->getOrDefault(val); if (refSet.isSingleValue()) { operandRefs.push_back(refSet.getSingleValue()); @@ -500,11 +379,9 @@ IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, L operandRefs.push_back(std::nullopt); } // First, lookup the operand value after it is initialized - Lattice *valLattice = getAfter(val); - auto priorState = valLattice->getValue(val); - if (succeeded(priorState) && priorState->getScalarValue().getExpr() != nullptr) { - operandVals.push_back(*priorState); - changed |= after->setValue(val, *priorState); + auto priorState = operands[opNum]->getValue(); + if (priorState.getScalarValue().getExpr() != nullptr) { + operandVals.push_back(priorState); continue; } @@ -513,9 +390,8 @@ IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, L // are currently limited to non-Signal structs and arrays. Type valTy = val.getType(); if (llvm::isa(valTy) && !isSignalType(valTy)) { - LatticeValue empty; - operandVals.push_back(empty); - changed |= after->setValue(val, empty); + ExpressionValue anyVal(field.get(), createFeltSymbol(val)); + operandVals.emplace_back(anyVal); continue; } @@ -530,7 +406,6 @@ IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, L "state of ", val, " is empty; defining operation is unsupported by SourceRef analysis" ) .report(); - propagateIfChanged(after, changed); // We still return success so we can return overapproximated and partial // results to the user. return success(); @@ -542,24 +417,23 @@ IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, L // Here, we will override the prior lattice value with a new symbol, representing // "any" value, then use that value for the operands. ExpressionValue anyVal(field.get(), createFeltSymbol(val)); - changed |= after->setValue(val, anyVal); operandVals.emplace_back(anyVal); } else { const SourceRef &ref = refSet.getSingleValue(); - ExpressionValue exprVal(field.get(), getOrCreateSymbol(ref)); - if (succeeded(priorState)) { - exprVal = exprVal.withInterval(priorState->getScalarValue().getInterval()); + // See if we've written the value before. If so, use that. + if (auto it = fieldWriteResults.find(ref); it != fieldWriteResults.end()) { + operandVals.emplace_back(it->second); + } else { + ExpressionValue exprVal(field.get(), getOrCreateSymbol(ref)); + operandVals.emplace_back(exprVal); } - changed |= after->setValue(val, exprVal); - operandVals.emplace_back(exprVal); } // Since we initialized a value that was not found in the before lattice, // update that value in the lattice so we can find it later, but we don't // need to propagate the changes, since we already have what we need. - auto res = after->getValue(val); - ensure(succeeded(res), "expected precondition is that value is set"); - (void)valLattice->setValue(val, *res); + Lattice *operandLattice = getLatticeElement(val); + (void)operandLattice->setValue(operandVals[opNum]); } // Now, the way we update is dependent on the type of the operation. @@ -567,19 +441,16 @@ IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, L llvm::DynamicAPInt constVal = getConst(op); llvm::SMTExprRef expr = createConstBitvectorExpr(constVal); ExpressionValue latticeVal(field.get(), expr, constVal); - changed |= after->setValue(op->getResult(0), latticeVal); + propagateIfChanged(results[0], results[0]->setValue(latticeVal)); } else if (isArithmeticOp(op)) { - ensure(operandVals.size() <= 2, "arithmetic op with the wrong number of operands"); ExpressionValue result; - if (operandVals.size() == 2) { + if (operands.size() == 2) { result = performBinaryArithmetic(op, operandVals[0], operandVals[1]); } else { result = performUnaryArithmetic(op, operandVals[0]); } - - changed |= after->setValue(op->getResult(0), result); + propagateIfChanged(results[0], results[0]->setValue(result)); } else if (EmitEqualityOp emitEq = llvm::dyn_cast(op)) { - ensure(operandVals.size() == 2, "constraint op with the wrong number of operands"); Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs(); ExpressionValue lhsExpr = operandVals[0].getScalarValue(); ExpressionValue rhsExpr = operandVals[1].getScalarValue(); @@ -589,7 +460,7 @@ IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, L auto res = getGeneralizedDecompInterval(op, lhsVal, rhsVal); if (succeeded(res)) { for (Value signalVal : res->first) { - changed |= applyInterval(emitEq, after, getAfter(signalVal), signalVal, res->second); + applyInterval(emitEq, signalVal, res->second); } } @@ -597,46 +468,28 @@ IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, L // Update the LHS and RHS to the same value, but restricted intervals // based on the constraints. const Interval &constrainInterval = constraint.getInterval(); - changed |= applyInterval(emitEq, after, getAfter(lhsVal), lhsVal, constrainInterval); - changed |= applyInterval(emitEq, after, getAfter(rhsVal), rhsVal, constrainInterval); - changed |= after->addSolverConstraint(constraint); - } else if (AssertOp assertOp = llvm::dyn_cast(op)) { - ensure(operandVals.size() == 1, "assert op with the wrong number of operands"); + applyInterval(emitEq, lhsVal, constrainInterval); + applyInterval(emitEq, rhsVal, constrainInterval); + } else if (auto assertOp = llvm::dyn_cast(op)) { // assert enforces that the operand is true. So we apply an interval of [1, 1] // to the operand. - changed |= applyInterval( - assertOp, after, after, assertOp.getCondition(), - Interval::True(field.get()) - ); + Value cond = assertOp.getCondition(); + applyInterval(assertOp, cond, Interval::True(field.get())); // Also add the solver constraint that the expression must be true. auto assertExpr = operandVals[0].getScalarValue(); - changed |= after->addSolverConstraint(assertExpr); + // No need to propagate the constraint + (void)getLatticeElement(cond)->addSolverConstraint(assertExpr); } else if (auto readf = llvm::dyn_cast(op)) { Value cmp = readf.getComponent(); if (isSignalType(cmp.getType())) { // The reg value read from the signal type is equal to the value of the Signal // struct overall. - changed |= after->setValue(readf.getVal(), operandVals[0].getScalarValue()); - } else { - auto storedVal = getAfter(cmp)->getValue(cmp, readf.getFieldNameAttr().getAttr()); - if (succeeded(storedVal)) { - // The result value is the value previously written to this field. - changed |= after->setValue(readf.getVal(), storedVal->getScalarValue()); - } else if (operandRefs[0].has_value()) { - // Initialize the value - auto fieldDefRes = readf.getFieldDefOp(tables); - if (succeeded(fieldDefRes)) { - SourceRef ref = operandRefs[0]->createChild(SourceRefIndex(*fieldDefRes)); - ExpressionValue exprVal(field.get(), getOrCreateSymbol(ref)); - changed |= after->setValue(readf.getVal(), exprVal); - } - } + propagateIfChanged(results[0], results[0]->setValue(operandVals[0])); } } else if (auto writef = llvm::dyn_cast(op)) { // Update values stored in a field ExpressionValue writeVal = operandVals[1].getScalarValue(); auto cmp = writef.getComponent(); - changed |= after->setValue(cmp, writef.getFieldNameAttr().getAttr(), writeVal); // We also need to update the interval on the assigned symbol SourceRefLatticeValue refSet = getSourceRefLattice(op, cmp)->getOrDefault(cmp); if (refSet.isSingleValue()) { @@ -645,7 +498,23 @@ IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, L SourceRefIndex idx(fieldDefRes.value()); SourceRef fieldRef = refSet.getSingleValue().createChild(idx); llvm::SMTExprRef expr = getOrCreateSymbol(fieldRef); - changed |= after->setInterval(expr, writeVal.getInterval()); + ExpressionValue written(expr, writeVal.getInterval()); + + if (auto it = fieldWriteResults.find(fieldRef); it != fieldWriteResults.end()) { + const ExpressionValue &old = it->second; + Interval combinedWrite = old.getInterval().join(written.getInterval()); + fieldWriteResults[fieldRef] = old.withInterval(combinedWrite); + } else { + fieldWriteResults[fieldRef] = written; + } + + // Propagate to all field readers we've collected so far. + for (Lattice *readerLattice : fieldReadResults[fieldRef]) { + ExpressionValue prior = readerLattice->getValue().getScalarValue(); + Interval intersection = prior.getInterval().intersect(written.getInterval()); + ExpressionValue newVal = prior.withInterval(intersection); + propagateIfChanged(readerLattice, readerLattice->setValue(newVal)); + } } } } else if (isa(op)) { @@ -657,32 +526,27 @@ IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, L if (expr.isBoolSort(smtSolver)) { expr = boolToFelt(smtSolver, expr, field.get().bitWidth()); } - changed |= after->setValue(op->getResult(0), expr); + propagateIfChanged(results[0], results[0]->setValue(expr)); } else if (auto yieldOp = dyn_cast(op)) { // Fetch the lattice for after the parent operation so we can propagate // the yielded value to subsequent operations. Operation *parent = op->getParentOp(); ensure(parent, "yield operation must have parent operation"); - auto postYieldLattice = getLattice(getProgramPointAfter(parent)); - ensure(postYieldLattice, "could not fetch post-yield lattice"); // Bind the operand values to the result values of the parent for (unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) { Value parentRes = parent->getResult(idx); + Lattice *resLattice = getLatticeElement(parentRes); // Merge with the existing value, if present (e.g., another branch) // has possible value that must be merged. - auto exprValRes = postYieldLattice->getValue(parentRes); + ExpressionValue exprVal = resLattice->getValue().getScalarValue(); ExpressionValue newResVal = operandVals[idx].getScalarValue(); - if (succeeded(exprValRes)) { - ExpressionValue existingVal = exprValRes->getScalarValue(); - newResVal = - existingVal.withInterval(existingVal.getInterval().join(newResVal.getInterval())); + if (exprVal.getExpr() != nullptr) { + newResVal = exprVal.withInterval(exprVal.getInterval().join(newResVal.getInterval())); } else { newResVal = ExpressionValue(createFeltSymbol(parentRes), newResVal.getInterval()); } - changed |= after->setValue(parentRes, newResVal); + propagateIfChanged(resLattice, resLattice->setValue(newResVal)); } - - propagateIfChanged(postYieldLattice, postYieldLattice->join(*after)); } else if ( // We do not need to explicitly handle read ops since they are resolved at the operand value // step where `SourceRef`s are queries (with the exception of the Signal struct, see above). @@ -698,7 +562,6 @@ IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, L op->emitWarning("unhandled operation, analysis may be incomplete").report(); } - propagateIfChanged(after, changed); return success(); } @@ -810,23 +673,20 @@ IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op, const LatticeVal return res; } -ChangeResult IntervalDataFlowAnalysis::applyInterval( - Operation *originalOp, Lattice *originalLattice, Lattice *after, Value val, Interval newInterval -) { - auto latValRes = after->getValue(val); - if (failed(latValRes)) { - // visitOperation didn't add val to the lattice, so there's nothing to do - return ChangeResult::NoChange; +void IntervalDataFlowAnalysis::applyInterval(Operation *valUser, Value val, Interval newInterval) { + Lattice *valLattice = getLatticeElement(val); + ExpressionValue oldLatticeVal = valLattice->getValue().getScalarValue(); + // Intersect with the current value to accumulate restrictions across constraints. + Interval intersection = oldLatticeVal.getInterval().intersect(newInterval); + ExpressionValue newLatticeVal = oldLatticeVal.withInterval(intersection); + ChangeResult changed = valLattice->setValue(newLatticeVal); + + if (changed == ChangeResult::NoChange) { + // We don't need to keep recursing since `val`s interval hasn't changed + return; } - ExpressionValue newLatticeVal = latValRes->getScalarValue().withInterval(newInterval); - propagateIfChanged(after, after->setValue(val, newLatticeVal)); - ChangeResult res = originalLattice->setValue(val, newLatticeVal); - // To allow the dataflow analysis to do its fixed-point iteration, we need to - // add the new expression to val's lattice as well. - Lattice *valLattice = nullptr; - if (Operation *valOp = val.getDefiningOp()) { - valLattice = getLattice(getProgramPointAfter(valOp)); - } else if (auto blockArg = llvm::dyn_cast(val)) { + + if (auto blockArg = llvm::dyn_cast(val)) { auto fnOp = dyn_cast(blockArg.getOwner()->getParentOp()); Operation *blockEntry = &blockArg.getOwner()->front(); @@ -835,45 +695,21 @@ ChangeResult IntervalDataFlowAnalysis::applyInterval( blockArg.getArgNumber() > 0 && !newInterval.isEntire()) { auto structOp = fnOp->getParentOfType(); FuncDefOp computeFn = structOp.getComputeFuncOp(); - Operation *computeEntry = &computeFn.getRegion().front().front(); BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1); - Lattice *computeEntryLattice = getLattice(getProgramPointBefore(computeEntry)); + Lattice *computeEntryLattice = getLatticeElement(computeArg); SourceRef ref(computeArg); ExpressionValue newArgVal(getOrCreateSymbol(ref), newInterval); - ChangeResult computeRes = computeEntryLattice->setValue(computeArg, newArgVal); - propagateIfChanged(computeEntryLattice, computeRes); + propagateIfChanged(computeEntryLattice, computeEntryLattice->setValue(newArgVal)); } - - valLattice = getLattice(getProgramPointBefore(blockEntry)); - } else { - valLattice = getLattice(val); } - ensure(valLattice, "val should have a lattice"); - auto setNewVal = [&valLattice, &val, &newLatticeVal, &res, this]() { - propagateIfChanged(valLattice, valLattice->setValue(val, newLatticeVal)); - return res; - }; - // Now we descend into val's operands, if it has any. Operation *definingOp = val.getDefiningOp(); if (!definingOp) { - return setNewVal(); + propagateIfChanged(valLattice, changed); + return; } - Lattice *definingOpLattice = getLattice(getProgramPointAfter(definingOp)); - auto getOperandLattice = [&](Value operand) { - if (Operation *defOp = operand.getDefiningOp()) { - return getLattice(getProgramPointAfter(defOp)); - } else if (auto blockArg = dyn_cast(operand)) { - Operation *blockEntry = &blockArg.getOwner()->front(); - return getLattice(getProgramPointBefore(blockEntry)); - } - return definingOpLattice; - }; - auto getOperandLatticeVal = [&](Value operand) { - return getOperandLattice(operand)->getValue(operand); - }; const Field &f = field.get(); @@ -892,18 +728,15 @@ ChangeResult IntervalDataFlowAnalysis::applyInterval( ); if (!newInterval.isDegenerate()) { // The comparison result is unknown, so we can't update the operand ranges - return ChangeResult::NoChange; + return; } bool cmpTrue = newInterval.rhs() == f.one(); Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs(); - auto lhsLatValRes = getOperandLatticeVal(lhs), rhsLatValRes = getOperandLatticeVal(rhs); - if (failed(lhsLatValRes) || failed(rhsLatValRes)) { - return ChangeResult::NoChange; - } - ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(), - rhsExpr = rhsLatValRes->getScalarValue(); + auto lhsLat = getLatticeElement(lhs), rhsLat = getLatticeElement(rhs); + ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(), + rhsExpr = rhsLat->getValue().getScalarValue(); Interval newLhsInterval, newRhsInterval; const Interval &lhsInterval = lhsExpr.getInterval(); @@ -971,69 +804,170 @@ ChangeResult IntervalDataFlowAnalysis::applyInterval( newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f); } else { cmpOp->emitWarning("unhandled cmp predicate").report(); - return ChangeResult::NoChange; + return; } // Now we recurse to each operand - return applyInterval(originalOp, originalLattice, getOperandLattice(lhs), lhs, newLhsInterval) | - applyInterval(originalOp, originalLattice, getOperandLattice(rhs), rhs, newRhsInterval); + applyInterval(cmpOp, lhs, newLhsInterval); + applyInterval(cmpOp, rhs, newRhsInterval); }; - // If the result of a multiplication is non-zero, then both operands must be + // Multiplication cases: + // - If the result of a multiplication is non-zero, then both operands must be // non-zero. + // - If one operand is a constant, we can propagate the new interval when multiplied + // by the multiplicative inverse of the constant. auto mulCase = [&](MulFeltOp mulOp) { + // We check for the constant case first. + auto constCase = [&](FeltConstantOp constOperand, Value multiplicand) { + auto latVal = getLatticeElement(multiplicand)->getValue().getScalarValue(); + APInt constVal = constOperand.getValue(); + if (constVal.isZero()) { + // There's no inverse for zero, so we do nothing. + return; + } + Interval updatedInterval = newInterval * Interval::Degenerate(f, f.inv(constVal)); + applyInterval(mulOp, multiplicand, updatedInterval); + }; + + Value lhs = mulOp.getLhs(), rhs = mulOp.getRhs(); + + auto lhsConstOp = dyn_cast_if_present(lhs.getDefiningOp()); + auto rhsConstOp = dyn_cast_if_present(rhs.getDefiningOp()); + // If both are consts, we don't need to do anything + if (lhsConstOp && rhsConstOp) { + return; + } else if (lhsConstOp) { + constCase(lhsConstOp, rhs); + return; + } else if (rhsConstOp) { + constCase(rhsConstOp, lhs); + return; + } + + // Otherwise, try to propagate non-zero information. auto zeroInt = Interval::Degenerate(f, f.zero()); if (newInterval.intersect(zeroInt).isNotEmpty()) { // The multiplication may be zero, so we can't reduce the operands to be non-zero - return ChangeResult::NoChange; + return; } - Value lhs = mulOp.getLhs(), rhs = mulOp.getRhs(); - auto lhsLatValRes = getOperandLatticeVal(lhs), rhsLatValRes = getOperandLatticeVal(rhs); - if (failed(lhsLatValRes) || failed(rhsLatValRes)) { - return ChangeResult::NoChange; - } - ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(), - rhsExpr = rhsLatValRes->getScalarValue(); + auto lhsLat = getLatticeElement(lhs), rhsLat = getLatticeElement(rhs); + ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(), + rhsExpr = rhsLat->getValue().getScalarValue(); Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt); Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt); - return applyInterval(originalOp, originalLattice, getOperandLattice(lhs), lhs, newLhsInterval) | - applyInterval(originalOp, originalLattice, getOperandLattice(rhs), rhs, newRhsInterval); + applyInterval(mulOp, lhs, newLhsInterval); + applyInterval(mulOp, rhs, newRhsInterval); + }; + + // Addition case: + // - If one operand is a constant, we can propagate the new interval when subtracting + // the constant + auto addCase = [&](AddFeltOp addOp) { + // We check for the constant case first. + auto constCase = [&](FeltConstantOp constOperand, Value operand) { + auto latVal = getLatticeElement(operand)->getValue().getScalarValue(); + auto constVal = toDynamicAPInt(constOperand.getValue()); + Interval updatedInterval = newInterval - Interval::Degenerate(f, constVal); + applyInterval(addOp, operand, updatedInterval); + }; + + Value lhs = addOp.getLhs(), rhs = addOp.getRhs(); + + auto lhsConstOp = dyn_cast_if_present(lhs.getDefiningOp()); + auto rhsConstOp = dyn_cast_if_present(rhs.getDefiningOp()); + // If both are consts, we don't need to do anything + if (lhsConstOp && !rhsConstOp) { + constCase(lhsConstOp, rhs); + } else if (rhsConstOp && !lhsConstOp) { + constCase(rhsConstOp, lhs); + } + }; + + // Subtraction case: + // - If one operand is a constant, we can propagate the new interval when adding + // the constant. + auto subCase = [&](SubFeltOp subOp) { + // We check for the constant case first. + Value lhs = subOp.getLhs(), rhs = subOp.getRhs(); + + auto lhsConstOp = dyn_cast_if_present(lhs.getDefiningOp()); + auto rhsConstOp = dyn_cast_if_present(rhs.getDefiningOp()); + // If both are consts, we don't need to do anything + if (lhsConstOp && !rhsConstOp) { + auto constVal = toDynamicAPInt(lhsConstOp.getValue()); + Interval updatedInterval = Interval::Degenerate(f, constVal) - newInterval; + applyInterval(subOp, rhs, updatedInterval); + } else if (rhsConstOp && !lhsConstOp) { + auto constVal = toDynamicAPInt(rhsConstOp.getValue()); + Interval updatedInterval = newInterval + Interval::Degenerate(f, constVal); + applyInterval(subOp, lhs, updatedInterval); + } }; - // We have a special case for the Signal struct: if this value is created - // from reading a Signal struct's reg field, we also apply the interval to - // the struct itself. auto readfCase = [&](FieldReadOp readfOp) { + const SourceRefLattice *sourceRefLattice = getSourceRefLattice(valUser, val); + SourceRefLatticeValue sourceRefVal = sourceRefLattice->getOrDefault(val); + + if (sourceRefVal.isSingleValue()) { + const SourceRef &ref = sourceRefVal.getSingleValue(); + fieldReadResults[ref].insert(valLattice); + + // Also propagate to all other field read results for this field + for (Lattice *l : fieldReadResults[ref]) { + if (l != valLattice) { + propagateIfChanged(l, l->setValue(newLatticeVal)); + } + } + } + + // We have a special case for the Signal struct: if this value is created + // from reading a Signal struct's reg field, we also apply the interval to + // the struct itself. Value comp = readfOp.getComponent(); if (isSignalType(comp.getType())) { - return applyInterval(originalOp, originalLattice, getOperandLattice(comp), comp, newInterval); + applyInterval(readfOp, comp, newInterval); } - return ChangeResult::NoChange; }; - // For casts, just pass the interval along to the cast's operand. - auto castCase = [&](Operation *op) { - Value operand = op->getOperand(0); - return applyInterval( - originalOp, originalLattice, getOperandLattice(operand), operand, newInterval - ); + auto readArrCase = [&](ReadArrayOp _) { + const SourceRefLattice *sourceRefLattice = getSourceRefLattice(valUser, val); + SourceRefLatticeValue sourceRefVal = sourceRefLattice->getOrDefault(val); + + if (sourceRefVal.isSingleValue()) { + const SourceRef &ref = sourceRefVal.getSingleValue(); + fieldReadResults[ref].insert(valLattice); + + // Also propagate to all other field read results for this field + for (Lattice *l : fieldReadResults[ref]) { + if (l != valLattice) { + propagateIfChanged(l, l->setValue(newLatticeVal)); + } + } + } }; + // For casts, just pass the interval along to the cast's operand. + auto castCase = [&](Operation *op) { applyInterval(op, op->getOperand(0), newInterval); }; + // - Apply the rules given the op. // NOTE: disabling clang-format for this because it makes the last case statement // look ugly. // clang-format off - res |= TypeSwitch(definingOp) - .Case([&](auto op) { return cmpCase(op); }) - .Case([&](auto op) { return mulCase(op); }) - .Case([&](auto op){ return readfCase(op); }) - .Case([&](auto op) { return castCase(op); }) - .Default([&](Operation *) { return ChangeResult::NoChange; }); + TypeSwitch(definingOp) + .Case([&](auto op) { cmpCase(op); }) + .Case([&](auto op) { return addCase(op); }) + .Case([&](auto op) { return subCase(op); }) + .Case([&](auto op) { mulCase(op); }) + .Case([&](auto op){ readfCase(op); }) + .Case([&](auto op){ readArrCase(op); }) + .Case([&](auto op) { castCase(op); }) + .Default([&](Operation *) { }); // clang-format on - // Set the new val after recursion to avoid having recursive calls unset the value. - return setNewVal(); + // Propagate after recursion to avoid having recursive calls unset the value. + propagateIfChanged(valLattice, changed); } FailureOr, Interval>> @@ -1137,7 +1071,20 @@ LogicalResult StructIntervals::computeIntervals( mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx ) { - auto computeIntervalsImpl = [&solver, &ctx, this]( + auto validSourceRefType = [](const SourceRef &ref) { + // We only want to compute intervals for field elements and not composite types, + // with the exception of the Signal struct. + if (!ref.isScalar() && !ref.isSignal()) { + return false; + } + // We also don't want to show the interval for a Signal and its internal reg. + if (auto parentOr = ref.getParentPrefix(); succeeded(parentOr) && parentOr->isSignal()) { + return false; + } + return true; + }; + + auto computeIntervalsImpl = [&solver, &ctx, &validSourceRefType, this]( FuncDefOp fn, llvm::MapVector &fieldRanges, llvm::SetVector &solverConstraints ) { @@ -1150,40 +1097,33 @@ LogicalResult StructIntervals::computeIntervals( for (const auto &ref : SourceRef::getAllSourceRefs(structDef, fn)) { // We only want to compute intervals for field elements and not composite types, // with the exception of the Signal struct. - if (!ref.isScalar() && !ref.isSignal()) { - continue; + if (validSourceRefType(ref)) { + searchSet.insert(ref); } - // We also don't want to show the interval for a Signal and its internal reg. - if (auto parentOr = ref.getParentPrefix(); succeeded(parentOr) && parentOr->isSignal()) { - continue; + } + + // Iterate over arguments + for (BlockArgument arg : fn.getArguments()) { + SourceRef ref {arg}; + if (searchSet.erase(ref)) { + const IntervalAnalysisLattice *lattice = solver.lookupState(arg); + fieldRanges[ref] = lattice->getValue().getScalarValue().getInterval(); } - searchSet.insert(ref); } - // Get all ops in reverse order, including nested ops. - llvm::SmallVector opList; - getReversedOps(&fn.getBody(), opList); - - // Also traverse the function op itself - opList.push_back(fn); - - for (Operation *op : opList) { - ProgramPoint *pp = solver.getProgramPointAfter(op); - const IntervalAnalysisLattice *lattice = solver.lookupState(pp); - const auto &c = lattice->getConstraints(); - solverConstraints.insert(c.begin(), c.end()); - - SourceRefSet newSearchSet; - for (const auto &ref : searchSet) { - auto symbol = ctx.getSymbol(ref); - auto intervalRes = lattice->findInterval(symbol); - if (succeeded(intervalRes)) { - fieldRanges[ref] = *intervalRes; - } else { - newSearchSet.insert(ref); - } + // Iterate over fields that were touched by the analysis + for (const auto &[ref, lattices] : ctx.intervalDFA->getFieldReadResults()) { + // All lattices should have the same value, so we can get the front. + if (!lattices.empty() && searchSet.erase(ref)) { + const IntervalAnalysisLattice *lattice = *lattices.begin(); + fieldRanges[ref] = lattice->getValue().getScalarValue().getInterval(); + } + } + + for (const auto &[ref, val] : ctx.intervalDFA->getFieldWriteResults()) { + if (searchSet.erase(ref)) { + fieldRanges[ref] = val.getInterval(); } - searchSet = newSearchSet; } // For all unfound refs, default to the entire range. diff --git a/lib/Analysis/Intervals.cpp b/lib/Analysis/Intervals.cpp index 142b07aa6..6c45c15ac 100644 --- a/lib/Analysis/Intervals.cpp +++ b/lib/Analysis/Intervals.cpp @@ -197,7 +197,12 @@ Interval Interval::join(const Interval &rhs) const { if (areOneOf< {Type::TypeA, Type::TypeA}, {Type::TypeB, Type::TypeB}, {Type::TypeC, Type::TypeC}, {Type::TypeA, Type::TypeC}, {Type::TypeB, Type::TypeC}>(lhs, rhs)) { - return Interval(rhs.ty, f, std::min(lhs.a, rhs.a), std::max(lhs.b, rhs.b)); + auto newLhs = std::min(lhs.a, rhs.a); + auto newRhs = std::max(lhs.b, rhs.b); + if (newLhs == newRhs) { + return Interval::Degenerate(f, newLhs); + } + return Interval(rhs.ty, f, newLhs, newRhs); } if (areOneOf<{Type::TypeA, Type::TypeB}>(lhs, rhs)) { auto lhsUnred = lhs.firstUnreduced(); @@ -232,6 +237,9 @@ Interval Interval::intersect(const Interval &rhs) const { const auto &lhs = *this; const Field &f = checkFields(lhs, rhs); // Trivial cases + if (lhs == rhs) { + return lhs; + } if (lhs.isEmpty() || rhs.isEmpty()) { return Interval::Empty(f); } @@ -241,8 +249,15 @@ Interval Interval::intersect(const Interval &rhs) const { if (rhs.isEntire()) { return lhs; } - if (lhs.isDegenerate() || rhs.isDegenerate()) { - return lhs.toUnreduced().intersect(rhs.toUnreduced()).reduce(f); + if (lhs.isDegenerate() && rhs.isDegenerate()) { + // These must not be equal + return Interval::Empty(f); + } + if (lhs.isDegenerate()) { + return Interval::TypeA(f, lhs.a, lhs.a).intersect(rhs); + } + if (rhs.isDegenerate()) { + return Interval::TypeA(f, rhs.a, rhs.a).intersect(lhs); } // More complex cases @@ -251,8 +266,10 @@ Interval Interval::intersect(const Interval &rhs) const { {Type::TypeA, Type::TypeC}, {Type::TypeB, Type::TypeC}>(lhs, rhs)) { auto maxA = std::max(lhs.a, rhs.a); auto minB = std::min(lhs.b, rhs.b); - if (maxA <= minB) { + if (maxA < minB) { return Interval(lhs.ty, f, maxA, minB); + } else if (maxA == minB) { + return Interval::Degenerate(f, maxA); } else { return Interval::Empty(f); } @@ -359,10 +376,10 @@ Interval Interval::operator~() const { Interval operator+(const Interval &lhs, const Interval &rhs) { const Field &f = checkFields(lhs, rhs); - if (lhs.isEmpty()) { + if (lhs.isEmpty() || rhs.isEntire()) { return rhs; } - if (rhs.isEmpty()) { + if (rhs.isEmpty() || lhs.isEntire()) { return lhs; } return (lhs.firstUnreduced() + rhs.firstUnreduced()).reduce(f); diff --git a/lib/Analysis/SparseAnalysis.cpp b/lib/Analysis/SparseAnalysis.cpp new file mode 100644 index 000000000..417c4a329 --- /dev/null +++ b/lib/Analysis/SparseAnalysis.cpp @@ -0,0 +1,360 @@ +//===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===// +// +// Part of the LLZK Project, under the Apache License v2.0. +// See LICENSE.txt for license information. +// Copyright 2025 Veridise Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Adapted from mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// +//===----------------------------------------------------------------------===// + +#include "llzk/Analysis/SparseAnalysis.h" +#include "llzk/Dialect/Function/IR/Ops.h" +#include "llzk/Util/ErrorHelper.h" +#include "llzk/Util/SymbolHelper.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +using namespace mlir; +using namespace mlir::dataflow; +using namespace llzk::function; + +namespace llzk::dataflow { + +//===----------------------------------------------------------------------===// +// AbstractSparseForwardDataFlowAnalysis +//===----------------------------------------------------------------------===// + +AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver) + : DataFlowAnalysis(solver) { + registerAnchorKind(); +} + +LogicalResult AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) { + // Mark the entry block arguments as having reached their pessimistic + // fixpoints. + for (Region ®ion : top->getRegions()) { + if (region.empty()) { + continue; + } + for (Value argument : region.front().getArguments()) { + setToEntryState(getLatticeElement(argument)); + } + } + + return initializeRecursively(top); +} + +LogicalResult AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { + // Initialize the analysis by visiting every owner of an SSA value (all + // operations and blocks). + if (failed(visitOperation(op))) { + return failure(); + } + + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + getOrCreate(getProgramPointBefore(&block))->blockContentSubscribe(this); + visitBlock(&block); + // LLZK: Renamed "op" -> "containedOp" to avoid shadowing. + for (Operation &containedOp : block) { + if (failed(initializeRecursively(&containedOp))) { + return failure(); + } + } + } + } + + return success(); +} + +LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) { + if (!point->isBlockStart()) { + return visitOperation(point->getPrevOp()); + } + visitBlock(point->getBlock()); + return success(); +} + +LogicalResult AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { + /// LLZK: Visit operations with no results, as they may still affect values + /// (e.g., constraints and assertions). + /// The MLIR version doesn't visit result-less operations. + + // If the containing block is not executable, bail out. + if (op->getBlock() != nullptr && + !getOrCreate(getProgramPointBefore(op->getBlock()))->isLive()) { + return success(); + } + + // Get the result lattices. + SmallVector resultLattices; + resultLattices.reserve(op->getNumResults()); + for (Value result : op->getResults()) { + AbstractSparseLattice *resultLattice = getLatticeElement(result); + resultLattices.push_back(resultLattice); + } + + // The results of a region branch operation are determined by control-flow. + if (auto branch = dyn_cast(op)) { + visitRegionSuccessors( + getProgramPointAfter(branch), branch, + /*successor=*/RegionBranchPoint::parent(), resultLattices + ); + return success(); + } + + // Grab the lattice elements of the operands. + SmallVector operandLattices; + operandLattices.reserve(op->getNumOperands()); + for (Value operand : op->getOperands()) { + AbstractSparseLattice *operandLattice = getLatticeElement(operand); + operandLattice->useDefSubscribe(this); + operandLattices.push_back(operandLattice); + } + + if (auto call = dyn_cast(op)) { + /// LLZK: Use LLZK resolveCallable interface. + // If the call operation is to an external function, attempt to infer the + // results from the call arguments. + auto callable = resolveCallable(tables, call); + if (!getSolverConfig().isInterprocedural() || + (succeeded(callable) && !callable->get().getCallableRegion())) { + visitExternalCallImpl(call, operandLattices, resultLattices); + return success(); + } + + // Otherwise, the results of a call operation are determined by the + // callgraph. + /// LLZK: The PredecessorState Analysis state does not work for LLZK's custom calls. + /// We therefore accumulate predecessor operations (return ops) manually. + SmallVector predecessors; + callable->get().walk([&predecessors](ReturnOp ret) mutable { predecessors.push_back(ret); }); + + // If not all return sites are known, then conservatively assume we can't + // reason about the data-flow. + if (predecessors.empty()) { + setAllToEntryStates(resultLattices); + return success(); + } + for (Operation *predecessor : predecessors) { + for (auto &&[operand, resLattice] : llvm::zip(predecessor->getOperands(), resultLattices)) { + join(resLattice, *getLatticeElementFor(getProgramPointAfter(op), operand)); + } + } + return success(); + } + + // Invoke the operation transfer function. + return visitOperationImpl(op, operandLattices, resultLattices); +} + +/// LLZK: Removing use of PredecessorState because it does not work with LLZK's +/// CallOp and FuncDefOp definitions. +void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { + // Exit early on blocks with no arguments. + if (block->getNumArguments() == 0) { + return; + } + + // If the block is not executable, bail out. + if (!getOrCreate(getProgramPointBefore(block))->isLive()) { + return; + } + + // Get the argument lattices. + SmallVector argLattices; + argLattices.reserve(block->getNumArguments()); + for (BlockArgument argument : block->getArguments()) { + AbstractSparseLattice *argLattice = getLatticeElement(argument); + argLattices.push_back(argLattice); + } + + // The argument lattices of entry blocks are set by region control-flow or the + // callgraph. + if (block->isEntryBlock()) { + // Check if this block is the entry block of a callable region. + auto callable = dyn_cast(block->getParentOp()); + if (callable && callable.getCallableRegion() == block->getParent()) { + /// LLZK: Get callsites of the callable as the predecessors. + auto moduleOpRes = getTopRootModule(callable.getOperation()); + ensure(succeeded(moduleOpRes), "could not get root module from callable"); + SmallVector callsites; + moduleOpRes->walk([this, &callable, &callsites](CallOp call) mutable { + auto calledFnRes = resolveCallable(tables, call); + if (succeeded(calledFnRes) && + calledFnRes->get().getCallableRegion() == callable.getCallableRegion()) { + callsites.push_back(call); + } + }); + // If not all callsites are known, conservatively mark all lattices as + // having reached their pessimistic fixpoints. + if (callsites.empty() || !getSolverConfig().isInterprocedural()) { + return setAllToEntryStates(argLattices); + } + for (Operation *callsite : callsites) { + auto call = cast(callsite); + for (auto it : llvm::zip(call.getArgOperands(), argLattices)) { + join( + std::get<1>(it), *getLatticeElementFor(getProgramPointBefore(block), std::get<0>(it)) + ); + } + } + return; + } + + // Check if the lattices can be determined from region control flow. + if (auto branch = dyn_cast(block->getParentOp())) { + return visitRegionSuccessors( + getProgramPointBefore(block), branch, block->getParent(), argLattices + ); + } + + // Otherwise, we can't reason about the data-flow. + return visitNonControlFlowArgumentsImpl( + block->getParentOp(), RegionSuccessor(block->getParent()), argLattices, /*firstIndex=*/0 + ); + } + + // Iterate over the predecessors of the non-entry block. + for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { + Block *predecessor = *it; + + // If the edge from the predecessor block to the current block is not live, + // bail out. + auto *edgeExecutable = getOrCreate(getLatticeAnchor(predecessor, block)); + edgeExecutable->blockContentSubscribe(this); + if (!edgeExecutable->isLive()) { + continue; + } + + // Check if we can reason about the data-flow from the predecessor. + if (auto branch = dyn_cast(predecessor->getTerminator())) { + SuccessorOperands operands = branch.getSuccessorOperands(it.getSuccessorIndex()); + for (auto [idx, lattice] : llvm::enumerate(argLattices)) { + if (Value operand = operands[idx]) { + join(lattice, *getLatticeElementFor(getProgramPointBefore(block), operand)); + } else { + // Conservatively consider internally produced arguments as entry + // points. + setAllToEntryStates(lattice); + } + } + } else { + return setAllToEntryStates(argLattices); + } + } +} + +/// LLZK: Removing use of PredecessorState because it does not work with LLZK's lookup logic. +void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( + ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint successor, + ArrayRef lattices +) { + Operation *op = point->isBlockStart() ? point->getBlock()->getParentOp() : point->getPrevOp(); + + if (op) { + // Get the incoming successor operands. + std::optional operands; + + // Check if the predecessor is the parent op. + if (op == branch) { + operands = branch.getEntrySuccessorOperands(successor); + // Otherwise, try to deduce the operands from a region return-like op. + } else if (auto regionTerminator = dyn_cast(op)) { + operands = regionTerminator.getSuccessorOperands(successor); + } + + if (!operands) { + // We can't reason about the data-flow. + return setAllToEntryStates(lattices); + } + + ValueRange inputs; + + /// LLZK: We only handle these kinds of region ops with inputs for now. + if (auto forOp = dyn_cast(op)) { + inputs = forOp.getRegionIterArgs(); + } else if (auto whileOp = dyn_cast(op)) { + inputs = whileOp.getRegionIterArgs(); + } + + if (inputs.size() != operands->size()) { + // We can't reason about the data-flow. + return setAllToEntryStates(lattices); + } + + unsigned firstIndex = 0; + if (inputs.size() != lattices.size()) { + if (!point->isBlockStart()) { + if (!inputs.empty()) { + firstIndex = cast(inputs.front()).getResultNumber(); + } + visitNonControlFlowArgumentsImpl( + branch, RegionSuccessor(branch->getResults().slice(firstIndex, inputs.size())), + lattices, firstIndex + ); + } else { + if (!inputs.empty()) { + firstIndex = cast(inputs.front()).getArgNumber(); + } + Region *region = point->getBlock()->getParent(); + visitNonControlFlowArgumentsImpl( + branch, + RegionSuccessor(region, region->getArguments().slice(firstIndex, inputs.size())), + lattices, firstIndex + ); + } + } + + for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) { + join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it))); + } + } +} + +const AbstractSparseLattice * +AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint *point, Value value) { + AbstractSparseLattice *state = getLatticeElement(value); + addDependency(state, point); + return state; +} + +void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates( + ArrayRef lattices +) { + for (AbstractSparseLattice *lattice : lattices) { + setToEntryState(lattice); + } +} + +void AbstractSparseForwardDataFlowAnalysis::join( + AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs +) { + propagateIfChanged(lhs, lhs->join(rhs)); +} + +} // namespace llzk::dataflow diff --git a/test/Analysis/interval_analysis/interval_analysis_pass.llzk b/test/Analysis/interval_analysis/interval_analysis_pass.llzk index 401bac3e7..ff2c46062 100644 --- a/test/Analysis/interval_analysis/interval_analysis_pass.llzk +++ b/test/Analysis/interval_analysis/interval_analysis_pass.llzk @@ -1,4 +1,4 @@ -// RUN: llzk-opt -I %S -split-input-file -llzk-print-interval-analysis %s 2>&1 | FileCheck %s +// RUN: llzk-opt -split-input-file -llzk-print-interval-analysis %s 2>&1 | FileCheck %s module attributes {veridise.lang = "llzk"} { struct.def @ConstantConstraint { From db8e151e409adf02a80a4660cf5df5506e9c7e03 Mon Sep 17 00:00:00 2001 From: Ian Glen Neal Date: Wed, 12 Nov 2025 11:48:39 -0500 Subject: [PATCH 3/6] - Disable interprocedural parts of SparseAnalysis - Add tests - Refine boolean ops - Refine back propagation of add/sub --- lib/Analysis/IntervalAnalysis.cpp | 117 ++++----- lib/Analysis/SparseAnalysis.cpp | 6 + .../interval_analysis_pass.llzk | 236 ++++++++++++++++++ 3 files changed, 303 insertions(+), 56 deletions(-) diff --git a/lib/Analysis/IntervalAnalysis.cpp b/lib/Analysis/IntervalAnalysis.cpp index 748a58311..e3e9524a7 100644 --- a/lib/Analysis/IntervalAnalysis.cpp +++ b/lib/Analysis/IntervalAnalysis.cpp @@ -156,15 +156,39 @@ cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const Expre break; case FeltCmpPredicate::LT: res.expr = solver->mkBVUlt(lhs.expr, rhs.expr); + if (lhs.i.toUnreduced().computeGEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) { + res.i = Interval::True(f); + } + if (lhs.i.toUnreduced().computeLTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) { + res.i = Interval::False(f); + } break; case FeltCmpPredicate::LE: res.expr = solver->mkBVUle(lhs.expr, rhs.expr); + if (lhs.i.toUnreduced().computeGTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) { + res.i = Interval::True(f); + } + if (lhs.i.toUnreduced().computeLEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) { + res.i = Interval::False(f); + } break; case FeltCmpPredicate::GT: res.expr = solver->mkBVUgt(lhs.expr, rhs.expr); + if (lhs.i.toUnreduced().computeLEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) { + res.i = Interval::True(f); + } + if (lhs.i.toUnreduced().computeGTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) { + res.i = Interval::False(f); + } break; case FeltCmpPredicate::GE: res.expr = solver->mkBVUge(lhs.expr, rhs.expr); + if (lhs.i.toUnreduced().computeLTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) { + res.i = Interval::True(f); + } + if (lhs.i.toUnreduced().computeGEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) { + res.i = Interval::False(f); + } break; } return res; @@ -300,7 +324,7 @@ ChangeResult IntervalAnalysisLattice::meet(const AbstractSparseLattice &other) { } // Intersect the intervals ExpressionValue lhsExpr = val.getScalarValue(); - ExpressionValue rhsExpr = val.getScalarValue(); + ExpressionValue rhsExpr = rhs->getValue().getScalarValue(); Interval newInterval = lhsExpr.getInterval().intersect(rhsExpr.getInterval()); ChangeResult res = setValue(lhsExpr.withInterval(newInterval)); for (auto &v : rhs->constraints) { @@ -449,6 +473,11 @@ mlir::LogicalResult IntervalDataFlowAnalysis::visitOperation( } else { result = performUnaryArithmetic(op, operandVals[0]); } + // Also intersect with prior interval, if it's initialized + const ExpressionValue &prior = results[0]->getValue().getScalarValue(); + if (prior.getExpr()) { + result = result.withInterval(result.getInterval().intersect(prior.getInterval())); + } propagateIfChanged(results[0], results[0]->setValue(result)); } else if (EmitEqualityOp emitEq = llvm::dyn_cast(op)) { Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs(); @@ -681,14 +710,8 @@ void IntervalDataFlowAnalysis::applyInterval(Operation *valUser, Value val, Inte ExpressionValue newLatticeVal = oldLatticeVal.withInterval(intersection); ChangeResult changed = valLattice->setValue(newLatticeVal); - if (changed == ChangeResult::NoChange) { - // We don't need to keep recursing since `val`s interval hasn't changed - return; - } - if (auto blockArg = llvm::dyn_cast(val)) { auto fnOp = dyn_cast(blockArg.getOwner()->getParentOp()); - Operation *blockEntry = &blockArg.getOwner()->front(); // Apply the interval from the constrain function inputs to the compute function inputs if (propagateInputConstraints && fnOp && fnOp.isStructConstrain() && @@ -721,10 +744,12 @@ void IntervalDataFlowAnalysis::applyInterval(Operation *valUser, Value val, Inte // cmp. restricts each side of the comparison if the result is known. auto cmpCase = [&](CmpOp cmpOp) { // Cmp output range is [0, 1], so in order to do something, we must have newInterval - // either "true" (1) or "false" (0) + // either "true" (1) or "false" (0). + // -- In the case of a contradictory circuit, however, the cmp result is allowed + // to be empty. ensure( - newInterval.isBoolean(), - "new interval for CmpOp outside of allowed boolean range or is empty" + newInterval.isBoolean() || newInterval.isEmpty(), + "new interval for CmpOp is not boolean or empty" ); if (!newInterval.isDegenerate()) { // The comparison result is unknown, so we can't update the operand ranges @@ -861,49 +886,40 @@ void IntervalDataFlowAnalysis::applyInterval(Operation *valUser, Value val, Inte applyInterval(mulOp, rhs, newRhsInterval); }; - // Addition case: - // - If one operand is a constant, we can propagate the new interval when subtracting - // the constant auto addCase = [&](AddFeltOp addOp) { - // We check for the constant case first. - auto constCase = [&](FeltConstantOp constOperand, Value operand) { - auto latVal = getLatticeElement(operand)->getValue().getScalarValue(); - auto constVal = toDynamicAPInt(constOperand.getValue()); - Interval updatedInterval = newInterval - Interval::Degenerate(f, constVal); - applyInterval(addOp, operand, updatedInterval); - }; - Value lhs = addOp.getLhs(), rhs = addOp.getRhs(); + Lattice *lhsLat = getLatticeElement(lhs), *rhsLat = getLatticeElement(rhs); + ExpressionValue lhsVal = lhsLat->getValue().getScalarValue(); + ExpressionValue rhsVal = rhsLat->getValue().getScalarValue(); - auto lhsConstOp = dyn_cast_if_present(lhs.getDefiningOp()); - auto rhsConstOp = dyn_cast_if_present(rhs.getDefiningOp()); - // If both are consts, we don't need to do anything - if (lhsConstOp && !rhsConstOp) { - constCase(lhsConstOp, rhs); - } else if (rhsConstOp && !lhsConstOp) { - constCase(rhsConstOp, lhs); - } + const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval(); + + Interval derivedLhsInt = newInterval - currRhsInt; + Interval derivedRhsInt = newInterval - currLhsInt; + + Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt); + Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt); + + applyInterval(addOp, lhs, finalLhsInt); + applyInterval(addOp, rhs, finalRhsInt); }; - // Subtraction case: - // - If one operand is a constant, we can propagate the new interval when adding - // the constant. auto subCase = [&](SubFeltOp subOp) { - // We check for the constant case first. Value lhs = subOp.getLhs(), rhs = subOp.getRhs(); + Lattice *lhsLat = getLatticeElement(lhs), *rhsLat = getLatticeElement(rhs); + ExpressionValue lhsVal = lhsLat->getValue().getScalarValue(); + ExpressionValue rhsVal = rhsLat->getValue().getScalarValue(); - auto lhsConstOp = dyn_cast_if_present(lhs.getDefiningOp()); - auto rhsConstOp = dyn_cast_if_present(rhs.getDefiningOp()); - // If both are consts, we don't need to do anything - if (lhsConstOp && !rhsConstOp) { - auto constVal = toDynamicAPInt(lhsConstOp.getValue()); - Interval updatedInterval = Interval::Degenerate(f, constVal) - newInterval; - applyInterval(subOp, rhs, updatedInterval); - } else if (rhsConstOp && !lhsConstOp) { - auto constVal = toDynamicAPInt(rhsConstOp.getValue()); - Interval updatedInterval = newInterval + Interval::Degenerate(f, constVal); - applyInterval(subOp, lhs, updatedInterval); - } + const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval(); + + Interval derivedLhsInt = newInterval + currRhsInt; + Interval derivedRhsInt = currLhsInt - newInterval; + + Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt); + Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt); + + applyInterval(subOp, lhs, finalLhsInt); + applyInterval(subOp, rhs, finalRhsInt); }; auto readfCase = [&](FieldReadOp readfOp) { @@ -1056,17 +1072,6 @@ IntervalDataFlowAnalysis::getGeneralizedDecompInterval(Operation *baseOp, Value /* StructIntervals */ -static void getReversedOps(Region *r, llvm::SmallVector &opList) { - for (Block &b : llvm::reverse(*r)) { - for (Operation &op : llvm::reverse(b)) { - for (Region &nested : llvm::reverse(op.getRegions())) { - getReversedOps(&nested, opList); - } - opList.push_back(&op); - } - } -} - LogicalResult StructIntervals::computeIntervals( mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx ) { diff --git a/lib/Analysis/SparseAnalysis.cpp b/lib/Analysis/SparseAnalysis.cpp index 417c4a329..80a936e5e 100644 --- a/lib/Analysis/SparseAnalysis.cpp +++ b/lib/Analysis/SparseAnalysis.cpp @@ -136,6 +136,8 @@ LogicalResult AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *o operandLattices.push_back(operandLattice); } + // LLZK TODO: Enable for interprocedural analysis. + /* if (auto call = dyn_cast(op)) { /// LLZK: Use LLZK resolveCallable interface. // If the call operation is to an external function, attempt to infer the @@ -167,6 +169,7 @@ LogicalResult AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *o } return success(); } + */ // Invoke the operation transfer function. return visitOperationImpl(op, operandLattices, resultLattices); @@ -197,6 +200,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { // callgraph. if (block->isEntryBlock()) { // Check if this block is the entry block of a callable region. + // LLZK TODO: Enable for interprocedural analysis. + /* auto callable = dyn_cast(block->getParentOp()); if (callable && callable.getCallableRegion() == block->getParent()) { /// LLZK: Get callsites of the callable as the predecessors. @@ -225,6 +230,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { } return; } + */ // Check if the lattices can be determined from region control flow. if (auto branch = dyn_cast(block->getParentOp())) { diff --git a/test/Analysis/interval_analysis/interval_analysis_pass.llzk b/test/Analysis/interval_analysis/interval_analysis_pass.llzk index ff2c46062..40ff269f7 100644 --- a/test/Analysis/interval_analysis/interval_analysis_pass.llzk +++ b/test/Analysis/interval_analysis/interval_analysis_pass.llzk @@ -478,3 +478,239 @@ module attributes {veridise.lang = "llzk"} { // CHECK-LABEL: @CheckBitInv StructIntervals { // CHECK: %arg1 in TypeA:[ 0, 1 ] // CHECK: } + +// ----- + +module attributes {veridise.lang = "llzk"} { + struct.def @IntTest<[]> { + struct.field @x : !felt.type {llzk.pub} + struct.field @y : !felt.type {llzk.pub} + function.def @compute() -> !struct.type<@IntTest<[]>> attributes {function.allow_witness} { + %self = struct.new : <@IntTest<[]>> + function.return %self : !struct.type<@IntTest<[]>> + } + function.def @constrain(%arg0: !struct.type<@IntTest<[]>>) attributes {function.allow_constraint} { + %0 = struct.readf %arg0[@x] : <@IntTest<[]>>, !felt.type + %1 = struct.readf %arg0[@y] : <@IntTest<[]>>, !felt.type + %felt_const_65536 = felt.const 65536 + %felt_const_1 = felt.const 1 + %felt_const_256 = felt.const 256 + %felt_const_512 = felt.const 512 + %2 = felt.add %0, %1 : !felt.type, !felt.type + %3 = bool.cmp lt(%2, %felt_const_65536) + %4 = bool.cmp lt(%0, %felt_const_512) + %5 = cast.tofelt %3 : i1 + %6 = cast.tofelt %4 : i1 + constrain.eq %5, %felt_const_1 : !felt.type, !felt.type + constrain.eq %6, %felt_const_1 : !felt.type, !felt.type + function.return + } + } +} + +// CHECK-LABEL: @IntTest StructIntervals { +// CHECK-NEXT: %arg0[@x] in TypeA:[ 0, 511 ] +// CHECK-NEXT: %arg0[@y] in TypeF:[ 21888242871839275222246405745257275088696311157297823662689037894645226208072, 65535 ] +// CHECK-NEXT: } + +// ----- + +module attributes {veridise.lang = "llzk"} { + struct.def @IntTest<[]> { + struct.field @x : !felt.type {llzk.pub} + struct.field @y : !felt.type {llzk.pub} + struct.field @z : !felt.type {llzk.pub} + function.def @compute() -> !struct.type<@IntTest<[]>> attributes {function.allow_witness} { + %self = struct.new : <@IntTest<[]>> + function.return %self : !struct.type<@IntTest<[]>> + } + function.def @constrain(%arg0: !struct.type<@IntTest<[]>>) attributes {function.allow_constraint} { + %0 = struct.readf %arg0[@x] : <@IntTest<[]>>, !felt.type + %1 = struct.readf %arg0[@y] : <@IntTest<[]>>, !felt.type + %2 = struct.readf %arg0[@z] : <@IntTest<[]>>, !felt.type + %felt_const_65536 = felt.const 65536 + %felt_const_1 = felt.const 1 + %felt_const_256 = felt.const 256 + %felt_const_512 = felt.const 512 + %3 = felt.mul %0, %1 : !felt.type, !felt.type + %4 = bool.cmp lt(%3, %felt_const_65536) + %5 = bool.cmp ge(%1, %felt_const_256) + %6 = bool.cmp lt(%1, %felt_const_512) + %7 = bool.cmp lt(%0, %felt_const_512) + %8 = cast.tofelt %4 : i1 + %9 = cast.tofelt %5 : i1 + %10 = cast.tofelt %7 : i1 + %11 = cast.tofelt %6 : i1 + constrain.eq %8, %felt_const_1 : !felt.type, !felt.type + constrain.eq %9, %felt_const_1 : !felt.type, !felt.type + constrain.eq %10, %felt_const_1 : !felt.type, !felt.type + constrain.eq %11, %felt_const_1 : !felt.type, !felt.type + %12 = felt.mul %1, %felt_const_256 : !felt.type, !felt.type + %13 = felt.add %12, %0 : !felt.type, !felt.type + constrain.eq %2, %13 : !felt.type, !felt.type + function.return + } + } +} + +// TODO: Refine multiplication range +// CHECK-LABEL: @IntTest StructIntervals { +// CHECK-NEXT: %arg0[@x] in TypeA:[ 0, 511 ] +// CHECK-NEXT: %arg0[@y] in TypeA:[ 256, 511 ] +// CHECK-NEXT: %arg0[@z] in TypeA:[ 65536, 131327 ] +// CHECK-NEXT: } + +// ----- + +module attributes {veridise.lang = "llzk"} { + struct.def @IntTest<[]> { + struct.field @x: !felt.type {llzk.pub} + struct.field @y : !felt.type {llzk.pub} + + function.def @compute() -> !struct.type<@IntTest<[]>> attributes {function.allow_witness} { + %self = struct.new : <@IntTest<[]>> + function.return %self : !struct.type<@IntTest<[]>> + } + + function.def @constrain(%self: !struct.type<@IntTest<[]>>) attributes {function.allow_constraint} { + %x = struct.readf %self[@x] : <@IntTest<[]>>, !felt.type + %y = struct.readf %self[@y] : <@IntTest<[]>>, !felt.type + %65536 = felt.const 65536 + + %one = felt.const 1 + %256 = felt.const 256 + %512 = felt.const 512 + %m = felt.add %x, %y : !felt.type, !felt.type + // These two conditions are contradictory + %0 = bool.cmp lt(%m, %65536) + %1 = bool.cmp ge(%m, %65536) + %2 = bool.cmp lt(%x, %512) + %bool0 = cast.tofelt %0 : i1 + %bool1 = cast.tofelt %1 : i1 + %bool2 = cast.tofelt %2 : i1 + constrain.eq %bool0, %one : !felt.type, !felt.type + constrain.eq %bool1, %one : !felt.type, !felt.type + constrain.eq %bool2, %one : !felt.type, !felt.type + function.return + } + } +} + +// CHECK-LABEL: @IntTest StructIntervals { +// CHECK-NEXT: %arg0[@x] in TypeA:[ 0, 511 ] +// CHECK-NEXT: %arg0[@y] in Entire +// CHECK-NEXT: } + +// ----- + +module attributes {veridise.lang = "llzk"} { + struct.def @IntTest<[]> { + struct.field @x: !felt.type {llzk.pub} + struct.field @y : !felt.type {llzk.pub} + + function.def @compute() -> !struct.type<@IntTest<[]>> attributes {function.allow_witness} { + %self = struct.new : <@IntTest<[]>> + function.return %self : !struct.type<@IntTest<[]>> + } + + function.def @constrain(%self: !struct.type<@IntTest<[]>>) attributes {function.allow_constraint} { + %x = struct.readf %self[@x] : <@IntTest<[]>>, !felt.type + %y = struct.readf %self[@y] : <@IntTest<[]>>, !felt.type + %65536 = felt.const 65536 + %32768 = felt.const 32768 + + %one = felt.const 1 + %256 = felt.const 256 + %512 = felt.const 512 + %m = felt.add %x, %y : !felt.type, !felt.type + %0 = bool.cmp lt(%m, %65536) + %1 = bool.cmp ge(%m, %32768) + %2 = bool.cmp lt(%x, %512) + %bool0 = cast.tofelt %0 : i1 + %bool1 = cast.tofelt %1 : i1 + %bool2 = cast.tofelt %2 : i1 + constrain.eq %bool0, %one : !felt.type, !felt.type + constrain.eq %bool1, %one : !felt.type, !felt.type + constrain.eq %bool2, %one : !felt.type, !felt.type + function.return + } + } +} + +// CHECK-LABEL: @IntTest StructIntervals { +// CHECK-NEXT: %arg0[@x] in TypeA:[ 0, 511 ] +// CHECK-NEXT: %arg0[@y] in TypeA:[ 32257, 65535 ] +// CHECK-NEXT: } + +// ----- + +module attributes {veridise.lang = "llzk"} { + struct.def @SubTest<[]> { + struct.field @x : !felt.type {llzk.pub} + struct.field @y : !felt.type {llzk.pub} + struct.field @z : !felt.type {llzk.pub} + function.def @compute() -> !struct.type<@SubTest<[]>> attributes {function.allow_witness} { + %self = struct.new : <@SubTest<[]>> + function.return %self : !struct.type<@SubTest<[]>> + } + function.def @constrain(%arg0: !struct.type<@SubTest<[]>>) attributes {function.allow_constraint} { + %0 = struct.readf %arg0[@x] : <@SubTest<[]>>, !felt.type + %1 = struct.readf %arg0[@y] : <@SubTest<[]>>, !felt.type + // Note: y is read twice + %2 = struct.readf %arg0[@y] : <@SubTest<[]>>, !felt.type + %felt_const_1 = felt.const 1 + %felt_const_256 = felt.const 256 + %3 = felt.sub %0, %1 : !felt.type, !felt.type + %4 = bool.cmp lt(%0, %felt_const_256) + %5 = bool.cmp lt(%1, %felt_const_256) + %6 = cast.tofelt %4 : i1 + %7 = cast.tofelt %5 : i1 + constrain.eq %6, %felt_const_1 : !felt.type, !felt.type + constrain.eq %7, %felt_const_1 : !felt.type, !felt.type + constrain.eq %2, %3 : !felt.type, !felt.type + function.return + } + } +} + +// CHECK-LABEL: @SubTest StructIntervals { +// CHECK-NEXT: %arg0[@x] in TypeA:[ 0, 255 ] +// CHECK-NEXT: %arg0[@y] in TypeA:[ 0, 255 ] +// CHECK-NEXT: %arg0[@z] in Entire +// CHECK-NEXT: } + +// ----- + +module attributes {veridise.lang = "llzk"} { + struct.def @SubTest<[]> { + struct.field @x : !felt.type {llzk.pub} + struct.field @y : !felt.type {llzk.pub} + struct.field @z : !felt.type {llzk.pub} + function.def @compute() -> !struct.type<@SubTest<[]>> attributes {function.allow_witness} { + %self = struct.new : <@SubTest<[]>> + function.return %self : !struct.type<@SubTest<[]>> + } + function.def @constrain(%arg0: !struct.type<@SubTest<[]>>) attributes {function.allow_constraint} { + %0 = struct.readf %arg0[@x] : <@SubTest<[]>>, !felt.type + %1 = struct.readf %arg0[@y] : <@SubTest<[]>>, !felt.type + %2 = struct.readf %arg0[@z] : <@SubTest<[]>>, !felt.type + %felt_const_1 = felt.const 1 + %felt_const_256 = felt.const 256 + %3 = felt.sub %0, %1 : !felt.type, !felt.type + %4 = bool.cmp lt(%0, %felt_const_256) + %5 = bool.cmp lt(%1, %felt_const_256) + %6 = cast.tofelt %4 : i1 + %7 = cast.tofelt %5 : i1 + constrain.eq %6, %felt_const_1 : !felt.type, !felt.type + constrain.eq %7, %felt_const_1 : !felt.type, !felt.type + constrain.eq %2, %3 : !felt.type, !felt.type + function.return + } + } +} + +// CHECK-LABEL: @SubTest StructIntervals { +// CHECK-NEXT: %arg0[@x] in TypeA:[ 0, 255 ] +// CHECK-NEXT: %arg0[@y] in TypeA:[ 0, 255 ] +// CHECK-NEXT: %arg0[@z] in TypeF:[ 21888242871839275222246405745257275088696311157297823662689037894645226208328, 255 ] +// CHECK-NEXT: } From 1accbfb4b083098765d7cd3558b109ab1b027111 Mon Sep 17 00:00:00 2001 From: Ian Glen Neal Date: Wed, 12 Nov 2025 12:28:55 -0500 Subject: [PATCH 4/6] Add changelog, apply formatting fixes --- changelogs/unreleased/iangneal__back-prop-fixes.yaml | 8 ++++++++ include/llzk/Analysis/Field.h | 4 ++-- include/llzk/Util/DynamicAPIntHelper.h | 9 ++++----- lib/Analysis/Field.cpp | 4 +++- lib/Util/DynamicAPIntHelper.cpp | 7 ++----- 5 files changed, 19 insertions(+), 13 deletions(-) create mode 100644 changelogs/unreleased/iangneal__back-prop-fixes.yaml diff --git a/changelogs/unreleased/iangneal__back-prop-fixes.yaml b/changelogs/unreleased/iangneal__back-prop-fixes.yaml new file mode 100644 index 000000000..f2e41ff6b --- /dev/null +++ b/changelogs/unreleased/iangneal__back-prop-fixes.yaml @@ -0,0 +1,8 @@ +added: + - Intra-procedural port of MLIR's sparse data-flow analysis + +fixed: + - Fixed backwards propagation of values in Interval Analysis + +changed: + - Interval analysis now uses sparse analysis diff --git a/include/llzk/Analysis/Field.h b/include/llzk/Analysis/Field.h index fadc9139e..e78f2a904 100644 --- a/include/llzk/Analysis/Field.h +++ b/include/llzk/Analysis/Field.h @@ -9,14 +9,14 @@ #pragma once +#include "llzk/Util/DynamicAPIntHelper.h" + #include #include #include #include -#include "llzk/Util/DynamicAPIntHelper.h" - namespace llzk { /// @brief Information about the prime finite field used for the interval analysis. diff --git a/include/llzk/Util/DynamicAPIntHelper.h b/include/llzk/Util/DynamicAPIntHelper.h index bae476c1d..15a678bb3 100644 --- a/include/llzk/Util/DynamicAPIntHelper.h +++ b/include/llzk/Util/DynamicAPIntHelper.h @@ -42,11 +42,10 @@ inline llvm::DynamicAPInt toDynamicAPInt(const llvm::APInt &i) { llvm::APSInt toAPSInt(const llvm::DynamicAPInt &i); -llvm::DynamicAPInt modExp(const llvm::DynamicAPInt &base, - const llvm::DynamicAPInt &exp, - const llvm::DynamicAPInt &mod); +llvm::DynamicAPInt modExp( + const llvm::DynamicAPInt &base, const llvm::DynamicAPInt &exp, const llvm::DynamicAPInt &mod +); -llvm::DynamicAPInt modInversePrime(const llvm::DynamicAPInt &f, - const llvm::DynamicAPInt &p); +llvm::DynamicAPInt modInversePrime(const llvm::DynamicAPInt &f, const llvm::DynamicAPInt &p); } // namespace llzk diff --git a/lib/Analysis/Field.cpp b/lib/Analysis/Field.cpp index 61ddb5cf1..8d23a70fd 100644 --- a/lib/Analysis/Field.cpp +++ b/lib/Analysis/Field.cpp @@ -66,6 +66,8 @@ DynamicAPInt Field::reduce(const APInt &i) const { return reduce(toDynamicAPInt( DynamicAPInt Field::inv(const DynamicAPInt &i) const { return modInversePrime(i, prime()); } -DynamicAPInt Field::inv(const llvm::APInt &i) const { return modInversePrime(toDynamicAPInt(i), prime()); } +DynamicAPInt Field::inv(const llvm::APInt &i) const { + return modInversePrime(toDynamicAPInt(i), prime()); +} } // namespace llzk diff --git a/lib/Util/DynamicAPIntHelper.cpp b/lib/Util/DynamicAPIntHelper.cpp index b7ab82f3e..a7cf104bb 100644 --- a/lib/Util/DynamicAPIntHelper.cpp +++ b/lib/Util/DynamicAPIntHelper.cpp @@ -134,9 +134,7 @@ APSInt toAPSInt(const DynamicAPInt &i) { return res; } -DynamicAPInt modExp(const DynamicAPInt &base, - const DynamicAPInt &exp, - const DynamicAPInt &mod) { +DynamicAPInt modExp(const DynamicAPInt &base, const DynamicAPInt &exp, const DynamicAPInt &mod) { DynamicAPInt result(1); DynamicAPInt b = base; DynamicAPInt e = exp; @@ -154,8 +152,7 @@ DynamicAPInt modExp(const DynamicAPInt &base, return result; } -llvm::DynamicAPInt modInversePrime(const DynamicAPInt &f, - const DynamicAPInt &p) { +llvm::DynamicAPInt modInversePrime(const DynamicAPInt &f, const DynamicAPInt &p) { assert(f != 0 && "0 has no inverse"); // Fermat: f^(p-2) mod p DynamicAPInt exp = p - 2; From 05899c653e02773aa468976708a9d64eb930bf08 Mon Sep 17 00:00:00 2001 From: Ian Glen Neal Date: Wed, 12 Nov 2025 13:39:57 -0500 Subject: [PATCH 5/6] Add koalabear field --- lib/Analysis/Field.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/Analysis/Field.cpp b/lib/Analysis/Field.cpp index 8d23a70fd..718884881 100644 --- a/lib/Analysis/Field.cpp +++ b/lib/Analysis/Field.cpp @@ -52,6 +52,8 @@ void Field::initKnownFields(DenseMap &knownFields) { knownFields.try_emplace("goldilocks", Field("18446744069414584321")); // 2^31 - 1, used for Plonky3 knownFields.try_emplace("mersenne31", Field("2147483647")); + // 2^31 - 2^24 + 1, also for Plonky3 + knownFields.try_emplace("koalabear", Field("2130706433")); } DynamicAPInt Field::reduce(const DynamicAPInt &i) const { From 1703701cb5bdf0a3de2ad9b0205c1551d18da26f Mon Sep 17 00:00:00 2001 From: Ian Glen Neal Date: Wed, 12 Nov 2025 15:24:26 -0500 Subject: [PATCH 6/6] Add field checks --- include/llzk/Analysis/IntervalAnalysis.h | 7 +++++-- lib/Analysis/IntervalAnalysis.cpp | 10 +++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/include/llzk/Analysis/IntervalAnalysis.h b/include/llzk/Analysis/IntervalAnalysis.h index 952a9fea3..d61901e37 100644 --- a/include/llzk/Analysis/IntervalAnalysis.h +++ b/include/llzk/Analysis/IntervalAnalysis.h @@ -48,7 +48,9 @@ class ExpressionValue { /* Must be default initializable to be a ScalarLatticeValue. */ ExpressionValue() : i(), expr(nullptr) {} - explicit ExpressionValue(const Field &f, llvm::SMTExprRef exprRef) + explicit ExpressionValue(const Field &f) : i(Interval::Entire(f)), expr(nullptr) {} + + ExpressionValue(const Field &f, llvm::SMTExprRef exprRef) : i(Interval::Entire(f)), expr(exprRef) {} ExpressionValue(const Field &f, llvm::SMTExprRef exprRef, const llvm::DynamicAPInt &singleVal) @@ -295,7 +297,8 @@ class IntervalDataFlowAnalysis llvm::DenseMap fieldWriteResults; void setToEntryState(Lattice *lattice) override { - // initial state should be empty, so do nothing here + // Initialize the value with an interval in our specified field. + (void)lattice->setValue(ExpressionValue(field.get())); } llvm::SMTExprRef createFeltSymbol(const SourceRef &r) const; diff --git a/lib/Analysis/IntervalAnalysis.cpp b/lib/Analysis/IntervalAnalysis.cpp index e3e9524a7..b7caf3e99 100644 --- a/lib/Analysis/IntervalAnalysis.cpp +++ b/lib/Analysis/IntervalAnalysis.cpp @@ -1112,7 +1112,13 @@ LogicalResult StructIntervals::computeIntervals( SourceRef ref {arg}; if (searchSet.erase(ref)) { const IntervalAnalysisLattice *lattice = solver.lookupState(arg); - fieldRanges[ref] = lattice->getValue().getScalarValue().getInterval(); + // If we never referenced this argument, use a default value + ExpressionValue expr = lattice->getValue().getScalarValue(); + if (!expr.getExpr()) { + expr = expr.withInterval(Interval::Entire(ctx.getField())); + } + fieldRanges[ref] = expr.getInterval(); + assert(fieldRanges[ref].getField() == ctx.getField() && "bad interval defaults"); } } @@ -1122,12 +1128,14 @@ LogicalResult StructIntervals::computeIntervals( if (!lattices.empty() && searchSet.erase(ref)) { const IntervalAnalysisLattice *lattice = *lattices.begin(); fieldRanges[ref] = lattice->getValue().getScalarValue().getInterval(); + assert(fieldRanges[ref].getField() == ctx.getField() && "bad interval defaults"); } } for (const auto &[ref, val] : ctx.intervalDFA->getFieldWriteResults()) { if (searchSet.erase(ref)) { fieldRanges[ref] = val.getInterval(); + assert(fieldRanges[ref].getField() == ctx.getField() && "bad interval defaults"); } }