Skip to content

Commit b482585

Browse files
iangnealtim-hoffman
authored andcommitted
Interval Analysis Pass Rewrite (#210)
* Add inverse logic * - Port SparseAnalysis from MLIR - Fix TypeA -> Degenerate interval conversion * - Disable interprocedural parts of SparseAnalysis - Add tests - Refine boolean ops - Refine back propagation of add/sub * Add changelog, apply formatting fixes * Add koalabear field * Add field checks
1 parent f3c7191 commit b482585

File tree

16 files changed

+1323
-460
lines changed

16 files changed

+1323
-460
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
added:
2+
- Intra-procedural port of MLIR's sparse data-flow analysis
3+
4+
fixed:
5+
- Fixed backwards propagation of values in Interval Analysis
6+
7+
changed:
8+
- Interval analysis now uses sparse analysis
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===-- AnalysisUtil.h - Data-flow analysis utils ---------------*- C++ -*-===//
2+
//
3+
// Part of the LLZK Project, under the Apache License v2.0.
4+
// See LICENSE.txt for license information.
5+
// Copyright 2025 Veridise Inc.
6+
// SPDX-License-Identifier: Apache-2.0
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#pragma once
11+
12+
#include <mlir/Analysis/DataFlowFramework.h>
13+
14+
namespace llzk::dataflow {
15+
16+
/// LLZK: Added this utility to ensure analysis is performed for all structs
17+
/// in a given module.
18+
///
19+
/// @brief Mark all operations from the top and included in the top operation
20+
/// as live so the solver will perform dataflow analyses.
21+
/// @param solver The solver.
22+
/// @param top The top-level operation.
23+
void markAllOpsAsLive(mlir::DataFlowSolver &solver, mlir::Operation *top);
24+
25+
} // namespace llzk::dataflow

include/llzk/Analysis/AnalysisWrappers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
#pragma once
2424

25-
#include "llzk/Analysis/DenseAnalysis.h"
25+
#include "llzk/Analysis/AnalysisUtil.h"
2626
#include "llzk/Dialect/Struct/IR/Ops.h"
2727
#include "llzk/Util/Compare.h"
2828
#include "llzk/Util/ErrorHelper.h"

include/llzk/Analysis/DenseAnalysis.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,6 @@
3636

3737
namespace llzk::dataflow {
3838

39-
//===----------------------------------------------------------------------===//
40-
// Utilities
41-
//===----------------------------------------------------------------------===//
42-
43-
/// LLZK: Added this utility to ensure analysis is performed for all structs
44-
/// in a given module.
45-
///
46-
/// @brief Mark all operations from the top and included in the top operation
47-
/// as live so the solver will perform dataflow analyses.
48-
/// @param solver The solver.
49-
/// @param top The top-level operation.
50-
void markAllOpsAsLive(mlir::DataFlowSolver &solver, mlir::Operation *top);
51-
5239
//===----------------------------------------------------------------------===//
5340
// AbstractDenseForwardDataFlowAnalysis
5441
//===----------------------------------------------------------------------===//

include/llzk/Analysis/Field.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
#pragma once
1111

12+
#include "llzk/Util/DynamicAPIntHelper.h"
13+
1214
#include <llvm/ADT/DenseMap.h>
1315
#include <llvm/ADT/DynamicAPInt.h>
1416
#include <llvm/Support/SMTAPI.h>
@@ -51,6 +53,11 @@ class Field {
5153
/// @brief Returns p - 1, which is the max value possible in a prime field described by p.
5254
inline llvm::DynamicAPInt maxVal() const { return prime() - one(); }
5355

56+
/// @brief Returns the multiplicative inverse of `i` in prime field `p`.
57+
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const;
58+
59+
llvm::DynamicAPInt inv(const llvm::APInt &i) const;
60+
5461
/// @brief Returns i mod p and reduces the result into the appropriate bitwidth.
5562
/// Field elements are returned as signed integers so that negation functions
5663
/// as expected (i.e., reducing -1 will yield p-1).

include/llzk/Analysis/IntervalAnalysis.h

Lines changed: 34 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "llzk/Analysis/DenseAnalysis.h"
1616
#include "llzk/Analysis/Field.h"
1717
#include "llzk/Analysis/Intervals.h"
18+
#include "llzk/Analysis/SparseAnalysis.h"
1819
#include "llzk/Dialect/Array/IR/Ops.h"
1920
#include "llzk/Dialect/Bool/IR/Ops.h"
2021
#include "llzk/Dialect/Cast/IR/Ops.h"
@@ -47,7 +48,9 @@ class ExpressionValue {
4748
/* Must be default initializable to be a ScalarLatticeValue. */
4849
ExpressionValue() : i(), expr(nullptr) {}
4950

50-
explicit ExpressionValue(const Field &f, llvm::SMTExprRef exprRef)
51+
explicit ExpressionValue(const Field &f) : i(Interval::Entire(f)), expr(nullptr) {}
52+
53+
ExpressionValue(const Field &f, llvm::SMTExprRef exprRef)
5154
: i(Interval::Entire(f)), expr(exprRef) {}
5255

5356
ExpressionValue(const Field &f, llvm::SMTExprRef exprRef, const llvm::DynamicAPInt &singleVal)
@@ -199,9 +202,7 @@ class IntervalAnalysisLatticeValue
199202

200203
class IntervalDataFlowAnalysis;
201204

202-
/// @brief Maps mlir::Values to LatticeValues.
203-
///
204-
class IntervalAnalysisLattice : public dataflow::AbstractDenseLattice {
205+
class IntervalAnalysisLattice : public dataflow::AbstractSparseLattice {
205206
public:
206207
using LatticeValue = IntervalAnalysisLatticeValue;
207208
// Map mlir::Values to LatticeValues
@@ -214,23 +215,18 @@ class IntervalAnalysisLattice : public dataflow::AbstractDenseLattice {
214215
// Tracks all constraints and assignments in insertion order
215216
using ConstraintSet = llvm::SetVector<ExpressionValue>;
216217

217-
using AbstractDenseLattice::AbstractDenseLattice;
218+
using AbstractSparseLattice::AbstractSparseLattice;
218219

219-
mlir::ChangeResult join(const AbstractDenseLattice &other) override;
220+
mlir::ChangeResult join(const AbstractSparseLattice &other) override;
220221

221-
mlir::ChangeResult meet(const AbstractDenseLattice & /*rhs*/) override {
222-
llvm::report_fatal_error("IntervalDataFlowAnalysis::meet : unsupported");
223-
return mlir::ChangeResult::NoChange;
224-
}
222+
mlir::ChangeResult meet(const AbstractSparseLattice &other) override;
225223

226224
void print(mlir::raw_ostream &os) const override;
227225

228-
mlir::FailureOr<LatticeValue> getValue(mlir::Value v) const;
229-
mlir::FailureOr<LatticeValue> getValue(mlir::Value v, mlir::StringAttr f) const;
226+
const LatticeValue &getValue() const { return val; }
230227

231-
mlir::ChangeResult setValue(mlir::Value v, const LatticeValue &val);
232-
mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e);
233-
mlir::ChangeResult setValue(mlir::Value v, mlir::StringAttr f, ExpressionValue e);
228+
mlir::ChangeResult setValue(const LatticeValue &val);
229+
mlir::ChangeResult setValue(ExpressionValue e);
234230

235231
mlir::ChangeResult addSolverConstraint(ExpressionValue e);
236232

@@ -244,27 +240,16 @@ class IntervalAnalysisLattice : public dataflow::AbstractDenseLattice {
244240
mlir::FailureOr<Interval> findInterval(llvm::SMTExprRef expr) const;
245241
mlir::ChangeResult setInterval(llvm::SMTExprRef expr, const Interval &i);
246242

247-
size_t size() const { return valMap.size(); }
248-
249-
const ValueMap &getMap() const { return valMap; }
250-
251-
ValueMap::iterator begin() { return valMap.begin(); }
252-
ValueMap::iterator end() { return valMap.end(); }
253-
ValueMap::const_iterator begin() const { return valMap.begin(); }
254-
ValueMap::const_iterator end() const { return valMap.end(); }
255-
256243
private:
257-
ValueMap valMap;
258-
FieldMap fieldMap;
244+
LatticeValue val;
259245
ConstraintSet constraints;
260-
ExpressionIntervals intervals;
261246
};
262247

263248
/* IntervalDataFlowAnalysis */
264249

265250
class IntervalDataFlowAnalysis
266-
: public dataflow::DenseForwardDataFlowAnalysis<IntervalAnalysisLattice> {
267-
using Base = dataflow::DenseForwardDataFlowAnalysis<IntervalAnalysisLattice>;
251+
: public dataflow::SparseForwardDataFlowAnalysis<IntervalAnalysisLattice> {
252+
using Base = dataflow::SparseForwardDataFlowAnalysis<IntervalAnalysisLattice>;
268253
using Lattice = IntervalAnalysisLattice;
269254
using LatticeValue = IntervalAnalysisLattice::LatticeValue;
270255

@@ -276,23 +261,28 @@ class IntervalDataFlowAnalysis
276261
mlir::DataFlowSolver &dataflowSolver, llvm::SMTSolverRef smt, const Field &f,
277262
bool propInputConstraints
278263
)
279-
: Base::DenseForwardDataFlowAnalysis(dataflowSolver), _dataflowSolver(dataflowSolver),
264+
: Base::SparseForwardDataFlowAnalysis(dataflowSolver), _dataflowSolver(dataflowSolver),
280265
smtSolver(smt), field(f), propagateInputConstraints(propInputConstraints) {}
281266

282-
void visitCallControlFlowTransfer(
283-
mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before,
284-
Lattice *after
267+
mlir::LogicalResult visitOperation(
268+
mlir::Operation *op, mlir::ArrayRef<const Lattice *> operands,
269+
mlir::ArrayRef<Lattice *> results
285270
) override;
286271

287-
mlir::LogicalResult
288-
visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override;
289-
290272
/// @brief Either return the existing SMT expression that corresponds to the SourceRef,
291273
/// or create one.
292274
/// @param r
293275
/// @return
294276
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r);
295277

278+
const llvm::DenseMap<SourceRef, llvm::DenseSet<Lattice *>> &getFieldReadResults() const {
279+
return fieldReadResults;
280+
}
281+
282+
const llvm::DenseMap<SourceRef, ExpressionValue> &getFieldWriteResults() const {
283+
return fieldWriteResults;
284+
}
285+
296286
private:
297287
mlir::DataFlowSolver &_dataflowSolver;
298288
llvm::SMTSolverRef smtSolver;
@@ -301,8 +291,14 @@ class IntervalDataFlowAnalysis
301291
bool propagateInputConstraints;
302292
mlir::SymbolTableCollection tables;
303293

294+
// Track field reads so that propagations to fields can be all updated efficiently.
295+
llvm::DenseMap<SourceRef, llvm::DenseSet<Lattice *>> fieldReadResults;
296+
// Track field writes values. For now, we'll overapproximate this.
297+
llvm::DenseMap<SourceRef, ExpressionValue> fieldWriteResults;
298+
304299
void setToEntryState(Lattice *lattice) override {
305-
// initial state should be empty, so do nothing here
300+
// Initialize the value with an interval in our specified field.
301+
(void)lattice->setValue(ExpressionValue(field.get()));
306302
}
307303

308304
llvm::SMTExprRef createFeltSymbol(const SourceRef &r) const;
@@ -349,57 +345,22 @@ class IntervalDataFlowAnalysis
349345
/// @param after The current lattice state. Assumes that this has already been joined with the
350346
/// `before` lattice in `visitOperation`, so lookups and updates can be performed on the `after`
351347
/// lattice alone.
352-
mlir::ChangeResult applyInterval(
353-
mlir::Operation *originalOp, Lattice *originalLattice, Lattice *after, mlir::Value val,
354-
Interval newInterval
355-
);
348+
void applyInterval(mlir::Operation *originalOp, mlir::Value val, Interval newInterval);
356349

357350
/// @brief Special handling for generalized (s - c0) * (s - c1) * ... * (s - cN) = 0 patterns.
358351
mlir::FailureOr<std::pair<llvm::DenseSet<mlir::Value>, Interval>>
359352
getGeneralizedDecompInterval(mlir::Operation *baseOp, mlir::Value lhs, mlir::Value rhs);
360353

361-
bool isBoolOp(mlir::Operation *op) const {
362-
return llvm::isa<boolean::AndBoolOp, boolean::OrBoolOp, boolean::XorBoolOp, boolean::NotBoolOp>(
363-
op
364-
);
365-
}
366-
367-
bool isConversionOp(mlir::Operation *op) const {
368-
return llvm::isa<cast::IntToFeltOp, cast::FeltToIndexOp>(op);
369-
}
370-
371-
bool isApplyMapOp(mlir::Operation *op) const { return llvm::isa<polymorphic::ApplyMapOp>(op); }
372-
373-
bool isAssertOp(mlir::Operation *op) const { return llvm::isa<boolean::AssertOp>(op); }
374-
375354
bool isReadOp(mlir::Operation *op) const {
376355
return llvm::isa<component::FieldReadOp, polymorphic::ConstReadOp, array::ReadArrayOp>(op);
377356
}
378357

379-
bool isWriteOp(mlir::Operation *op) const {
380-
return llvm::isa<component::FieldWriteOp, array::WriteArrayOp, array::InsertArrayOp>(op);
381-
}
382-
383-
bool isArrayLengthOp(mlir::Operation *op) const { return llvm::isa<array::ArrayLengthOp>(op); }
384-
385-
bool isEmitOp(mlir::Operation *op) const {
386-
return llvm::isa<constrain::EmitEqualityOp, constrain::EmitContainmentOp>(op);
387-
}
388-
389-
bool isCreateOp(mlir::Operation *op) const {
390-
return llvm::isa<component::CreateStructOp, array::CreateArrayOp>(op);
391-
}
392-
393-
bool isExtractArrayOp(mlir::Operation *op) const { return llvm::isa<array::ExtractArrayOp>(op); }
394-
395358
bool isDefinitionOp(mlir::Operation *op) const {
396359
return llvm::isa<
397360
component::StructDefOp, function::FuncDefOp, component::FieldDefOp, global::GlobalDefOp,
398361
mlir::ModuleOp>(op);
399362
}
400363

401-
bool isCallOp(mlir::Operation *op) const { return llvm::isa<function::CallOp>(op); }
402-
403364
bool isReturnOp(mlir::Operation *op) const { return llvm::isa<function::ReturnOp>(op); }
404365

405366
/// @brief Get the SourceRefLattice that defines `val`, or the SourceRefLattice after `baseOp`

0 commit comments

Comments
 (0)