Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelogs/unreleased/iangneal__back-prop-fixes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
added:
- Intra-procedural port of MLIR's sparse data-flow analysis

fixed:
- Fixed backwards propagation of values in Interval Analysis

changed:
- Interval analysis now uses sparse analysis
25 changes: 25 additions & 0 deletions include/llzk/Analysis/AnalysisUtil.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===-- AnalysisUtil.h - Data-flow analysis utils ---------------*- C++ -*-===//
//
// Part of the LLZK Project, under the Apache License v2.0.
// See LICENSE.txt for license information.
// Copyright 2025 Veridise Inc.
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

#pragma once

#include <mlir/Analysis/DataFlowFramework.h>

namespace llzk::dataflow {

/// LLZK: Added this utility to ensure analysis is performed for all structs
/// in a given module.
///
/// @brief Mark all operations from the top and included in the top operation
/// as live so the solver will perform dataflow analyses.
/// @param solver The solver.
/// @param top The top-level operation.
void markAllOpsAsLive(mlir::DataFlowSolver &solver, mlir::Operation *top);

} // namespace llzk::dataflow
2 changes: 1 addition & 1 deletion include/llzk/Analysis/AnalysisWrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

#pragma once

#include "llzk/Analysis/DenseAnalysis.h"
#include "llzk/Analysis/AnalysisUtil.h"
#include "llzk/Dialect/Struct/IR/Ops.h"
#include "llzk/Util/Compare.h"
#include "llzk/Util/ErrorHelper.h"
Expand Down
13 changes: 0 additions & 13 deletions include/llzk/Analysis/DenseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,6 @@

namespace llzk::dataflow {

//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//

/// LLZK: Added this utility to ensure analysis is performed for all structs
/// in a given module.
///
/// @brief Mark all operations from the top and included in the top operation
/// as live so the solver will perform dataflow analyses.
/// @param solver The solver.
/// @param top The top-level operation.
void markAllOpsAsLive(mlir::DataFlowSolver &solver, mlir::Operation *top);

//===----------------------------------------------------------------------===//
// AbstractDenseForwardDataFlowAnalysis
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 7 additions & 0 deletions include/llzk/Analysis/Field.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

#pragma once

#include "llzk/Util/DynamicAPIntHelper.h"

#include <llvm/ADT/DenseMap.h>
#include <llvm/ADT/DynamicAPInt.h>
#include <llvm/Support/SMTAPI.h>
Expand Down Expand Up @@ -51,6 +53,11 @@ class Field {
/// @brief Returns p - 1, which is the max value possible in a prime field described by p.
inline llvm::DynamicAPInt maxVal() const { return prime() - one(); }

/// @brief Returns the multiplicative inverse of `i` in prime field `p`.
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const;

llvm::DynamicAPInt inv(const llvm::APInt &i) const;

/// @brief Returns i mod p and reduces the result into the appropriate bitwidth.
/// Field elements are returned as signed integers so that negation functions
/// as expected (i.e., reducing -1 will yield p-1).
Expand Down
107 changes: 34 additions & 73 deletions include/llzk/Analysis/IntervalAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llzk/Analysis/DenseAnalysis.h"
#include "llzk/Analysis/Field.h"
#include "llzk/Analysis/Intervals.h"
#include "llzk/Analysis/SparseAnalysis.h"
#include "llzk/Dialect/Array/IR/Ops.h"
#include "llzk/Dialect/Bool/IR/Ops.h"
#include "llzk/Dialect/Cast/IR/Ops.h"
Expand Down Expand Up @@ -47,7 +48,9 @@ class ExpressionValue {
/* Must be default initializable to be a ScalarLatticeValue. */
ExpressionValue() : i(), expr(nullptr) {}

explicit ExpressionValue(const Field &f, llvm::SMTExprRef exprRef)
explicit ExpressionValue(const Field &f) : i(Interval::Entire(f)), expr(nullptr) {}

ExpressionValue(const Field &f, llvm::SMTExprRef exprRef)
: i(Interval::Entire(f)), expr(exprRef) {}

ExpressionValue(const Field &f, llvm::SMTExprRef exprRef, const llvm::DynamicAPInt &singleVal)
Expand Down Expand Up @@ -199,9 +202,7 @@ class IntervalAnalysisLatticeValue

class IntervalDataFlowAnalysis;

/// @brief Maps mlir::Values to LatticeValues.
///
class IntervalAnalysisLattice : public dataflow::AbstractDenseLattice {
class IntervalAnalysisLattice : public dataflow::AbstractSparseLattice {
public:
using LatticeValue = IntervalAnalysisLatticeValue;
// Map mlir::Values to LatticeValues
Expand All @@ -214,23 +215,18 @@ class IntervalAnalysisLattice : public dataflow::AbstractDenseLattice {
// Tracks all constraints and assignments in insertion order
using ConstraintSet = llvm::SetVector<ExpressionValue>;

using AbstractDenseLattice::AbstractDenseLattice;
using AbstractSparseLattice::AbstractSparseLattice;

mlir::ChangeResult join(const AbstractDenseLattice &other) override;
mlir::ChangeResult join(const AbstractSparseLattice &other) override;

mlir::ChangeResult meet(const AbstractDenseLattice & /*rhs*/) override {
llvm::report_fatal_error("IntervalDataFlowAnalysis::meet : unsupported");
return mlir::ChangeResult::NoChange;
}
mlir::ChangeResult meet(const AbstractSparseLattice &other) override;

void print(mlir::raw_ostream &os) const override;

mlir::FailureOr<LatticeValue> getValue(mlir::Value v) const;
mlir::FailureOr<LatticeValue> getValue(mlir::Value v, mlir::StringAttr f) const;
const LatticeValue &getValue() const { return val; }

mlir::ChangeResult setValue(mlir::Value v, const LatticeValue &val);
mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e);
mlir::ChangeResult setValue(mlir::Value v, mlir::StringAttr f, ExpressionValue e);
mlir::ChangeResult setValue(const LatticeValue &val);
mlir::ChangeResult setValue(ExpressionValue e);

mlir::ChangeResult addSolverConstraint(ExpressionValue e);

Expand All @@ -244,27 +240,16 @@ class IntervalAnalysisLattice : public dataflow::AbstractDenseLattice {
mlir::FailureOr<Interval> findInterval(llvm::SMTExprRef expr) const;
mlir::ChangeResult setInterval(llvm::SMTExprRef expr, const Interval &i);

size_t size() const { return valMap.size(); }

const ValueMap &getMap() const { return valMap; }

ValueMap::iterator begin() { return valMap.begin(); }
ValueMap::iterator end() { return valMap.end(); }
ValueMap::const_iterator begin() const { return valMap.begin(); }
ValueMap::const_iterator end() const { return valMap.end(); }

private:
ValueMap valMap;
FieldMap fieldMap;
LatticeValue val;
ConstraintSet constraints;
ExpressionIntervals intervals;
};

/* IntervalDataFlowAnalysis */

class IntervalDataFlowAnalysis
: public dataflow::DenseForwardDataFlowAnalysis<IntervalAnalysisLattice> {
using Base = dataflow::DenseForwardDataFlowAnalysis<IntervalAnalysisLattice>;
: public dataflow::SparseForwardDataFlowAnalysis<IntervalAnalysisLattice> {
using Base = dataflow::SparseForwardDataFlowAnalysis<IntervalAnalysisLattice>;
using Lattice = IntervalAnalysisLattice;
using LatticeValue = IntervalAnalysisLattice::LatticeValue;

Expand All @@ -276,23 +261,28 @@ class IntervalDataFlowAnalysis
mlir::DataFlowSolver &dataflowSolver, llvm::SMTSolverRef smt, const Field &f,
bool propInputConstraints
)
: Base::DenseForwardDataFlowAnalysis(dataflowSolver), _dataflowSolver(dataflowSolver),
: Base::SparseForwardDataFlowAnalysis(dataflowSolver), _dataflowSolver(dataflowSolver),
smtSolver(smt), field(f), propagateInputConstraints(propInputConstraints) {}

void visitCallControlFlowTransfer(
mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before,
Lattice *after
mlir::LogicalResult visitOperation(
mlir::Operation *op, mlir::ArrayRef<const Lattice *> operands,
mlir::ArrayRef<Lattice *> results
) override;

mlir::LogicalResult
visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override;

/// @brief Either return the existing SMT expression that corresponds to the SourceRef,
/// or create one.
/// @param r
/// @return
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r);

const llvm::DenseMap<SourceRef, llvm::DenseSet<Lattice *>> &getFieldReadResults() const {
return fieldReadResults;
}

const llvm::DenseMap<SourceRef, ExpressionValue> &getFieldWriteResults() const {
return fieldWriteResults;
}

private:
mlir::DataFlowSolver &_dataflowSolver;
llvm::SMTSolverRef smtSolver;
Expand All @@ -301,8 +291,14 @@ class IntervalDataFlowAnalysis
bool propagateInputConstraints;
mlir::SymbolTableCollection tables;

// Track field reads so that propagations to fields can be all updated efficiently.
llvm::DenseMap<SourceRef, llvm::DenseSet<Lattice *>> fieldReadResults;
// Track field writes values. For now, we'll overapproximate this.
llvm::DenseMap<SourceRef, ExpressionValue> fieldWriteResults;

void setToEntryState(Lattice *lattice) override {
// initial state should be empty, so do nothing here
// Initialize the value with an interval in our specified field.
(void)lattice->setValue(ExpressionValue(field.get()));
}

llvm::SMTExprRef createFeltSymbol(const SourceRef &r) const;
Expand Down Expand Up @@ -349,57 +345,22 @@ class IntervalDataFlowAnalysis
/// @param after The current lattice state. Assumes that this has already been joined with the
/// `before` lattice in `visitOperation`, so lookups and updates can be performed on the `after`
/// lattice alone.
mlir::ChangeResult applyInterval(
mlir::Operation *originalOp, Lattice *originalLattice, Lattice *after, mlir::Value val,
Interval newInterval
);
void applyInterval(mlir::Operation *originalOp, mlir::Value val, Interval newInterval);

/// @brief Special handling for generalized (s - c0) * (s - c1) * ... * (s - cN) = 0 patterns.
mlir::FailureOr<std::pair<llvm::DenseSet<mlir::Value>, Interval>>
getGeneralizedDecompInterval(mlir::Operation *baseOp, mlir::Value lhs, mlir::Value rhs);

bool isBoolOp(mlir::Operation *op) const {
return llvm::isa<boolean::AndBoolOp, boolean::OrBoolOp, boolean::XorBoolOp, boolean::NotBoolOp>(
op
);
}

bool isConversionOp(mlir::Operation *op) const {
return llvm::isa<cast::IntToFeltOp, cast::FeltToIndexOp>(op);
}

bool isApplyMapOp(mlir::Operation *op) const { return llvm::isa<polymorphic::ApplyMapOp>(op); }

bool isAssertOp(mlir::Operation *op) const { return llvm::isa<boolean::AssertOp>(op); }

bool isReadOp(mlir::Operation *op) const {
return llvm::isa<component::FieldReadOp, polymorphic::ConstReadOp, array::ReadArrayOp>(op);
}

bool isWriteOp(mlir::Operation *op) const {
return llvm::isa<component::FieldWriteOp, array::WriteArrayOp, array::InsertArrayOp>(op);
}

bool isArrayLengthOp(mlir::Operation *op) const { return llvm::isa<array::ArrayLengthOp>(op); }

bool isEmitOp(mlir::Operation *op) const {
return llvm::isa<constrain::EmitEqualityOp, constrain::EmitContainmentOp>(op);
}

bool isCreateOp(mlir::Operation *op) const {
return llvm::isa<component::CreateStructOp, array::CreateArrayOp>(op);
}

bool isExtractArrayOp(mlir::Operation *op) const { return llvm::isa<array::ExtractArrayOp>(op); }

bool isDefinitionOp(mlir::Operation *op) const {
return llvm::isa<
component::StructDefOp, function::FuncDefOp, component::FieldDefOp, global::GlobalDefOp,
mlir::ModuleOp>(op);
}

bool isCallOp(mlir::Operation *op) const { return llvm::isa<function::CallOp>(op); }

bool isReturnOp(mlir::Operation *op) const { return llvm::isa<function::ReturnOp>(op); }

/// @brief Get the SourceRefLattice that defines `val`, or the SourceRefLattice after `baseOp`
Expand Down
Loading
Loading