Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelogs/unreleased/iangneal__plonky3-analysis-bugs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fixed:
- Interval analysis: fix handling of cast operations
- Analysis passes: improved performance of dataflow analyses
2 changes: 1 addition & 1 deletion include/llzk/Analysis/AnalysisWrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class ModuleAnalysis {
return results.at(getContext());
}

const mlir::DataFlowSolver &getSolver() const { return solver; }
mlir::DataFlowSolver &getSolver() { return solver; }

protected:
mlir::DataFlowSolver solver;
Expand Down
25 changes: 20 additions & 5 deletions include/llzk/Analysis/IntervalAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ class ExpressionValue {
return ExpressionValue(expr, newInterval);
}

/// @brief Return the current expression with a new SMT expression.
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const {
return ExpressionValue(newExpr, i);
}

/* Required to be a ScalarLatticeValue. */
/// @brief Fold two expressions together when overapproximating array elements.
ExpressionValue &join(const ExpressionValue & /*rhs*/) {
Expand All @@ -78,6 +83,10 @@ class ExpressionValue {

bool operator==(const ExpressionValue &rhs) const;

bool isBoolSort(llvm::SMTSolverRef solver) const {
return solver->getBoolSort() == solver->getSort(expr);
}

/// @brief Compute the intersection of the lhs and rhs intervals, and create a solver
/// expression that constrains both sides to be equal.
/// @param solver
Expand Down Expand Up @@ -219,6 +228,7 @@ class IntervalAnalysisLattice : public dataflow::AbstractDenseLattice {
mlir::FailureOr<LatticeValue> getValue(mlir::Value v) const;
mlir::FailureOr<LatticeValue> getValue(mlir::Value v, mlir::StringAttr f) const;

mlir::ChangeResult setValue(mlir::Value v, const LatticeValue &val);
mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e);
mlir::ChangeResult setValue(mlir::Value v, mlir::StringAttr f, ExpressionValue e);

Expand Down Expand Up @@ -339,13 +349,14 @@ class IntervalDataFlowAnalysis
/// @param after The current lattice state. Assumes that this has already been joined with the
/// `before` lattice in `visitOperation`, so lookups and updates can be performed on the `after`
/// lattice alone.
mlir::ChangeResult
applyInterval(mlir::Operation *originalOp, Lattice *after, mlir::Value val, Interval newInterval);
mlir::ChangeResult applyInterval(
mlir::Operation *originalOp, Lattice *originalLattice, Lattice *after, mlir::Value val,
Interval newInterval
);

/// @brief Special handling for generalized (s - c0) * (s - c1) * ... * (s - cN) = 0 patterns.
mlir::FailureOr<std::pair<llvm::DenseSet<mlir::Value>, Interval>> getGeneralizedDecompInterval(
const SourceRefLattice *SourceRefLattice, mlir::Value lhs, mlir::Value rhs
);
mlir::FailureOr<std::pair<llvm::DenseSet<mlir::Value>, Interval>>
getGeneralizedDecompInterval(mlir::Operation *baseOp, mlir::Value lhs, mlir::Value rhs);

bool isBoolOp(mlir::Operation *op) const {
return llvm::isa<boolean::AndBoolOp, boolean::OrBoolOp, boolean::XorBoolOp, boolean::NotBoolOp>(
Expand Down Expand Up @@ -390,6 +401,10 @@ class IntervalDataFlowAnalysis
bool isCallOp(mlir::Operation *op) const { return llvm::isa<function::CallOp>(op); }

bool isReturnOp(mlir::Operation *op) const { return llvm::isa<function::ReturnOp>(op); }

/// @brief Get the SourceRefLattice that defines `val`, or the SourceRefLattice after `baseOp`
/// if `val` has no associated SourceRefLattice.
const SourceRefLattice *getSourceRefLattice(mlir::Operation *baseOp, mlir::Value val);
};

/* StructIntervals */
Expand Down
63 changes: 40 additions & 23 deletions lib/Analysis/ConstraintDependencyGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,28 +87,28 @@ void SourceRefAnalysis::visitCallControlFlowTransfer(
auto funcOp = funcOpRes->get();

auto callOp = llvm::dyn_cast<CallOp>(call.getOperation());
ensure(callOp, "call is not a llzk::CallOp");
ensure(callOp, "call is not a CallOp");

for (unsigned i = 0; i < funcOp.getNumArguments(); i++) {
auto key = SourceRef(funcOp.getArgument(i));
auto val = beforeCall->getOrDefault(callOp.getOperand(i));
translation[key] = val;
SourceRef key(funcOp.getArgument(i));
// Look up the lattice that defines the operand value first, but default
// to the beforeCall if the operand is not defined by an operand.
const SourceRefLattice *operandLattice = beforeCall;
Value operand = callOp.getOperand(i);
if (Operation *defOp = operand.getDefiningOp()) {
operandLattice = getLattice(getProgramPointAfter(defOp));
}

translation[key] = operandLattice->getOrDefault(operand);
}

// The lattice at the return is:
// - the lattice before the call, plus
// - the translated, return values, plus
// - any translated internal values (so we can see where values are used)
mlir::ChangeResult updated = after->join(*beforeCall);
// The lattice at the return is the translated return values
mlir::ChangeResult updated = mlir::ChangeResult::NoChange;
for (unsigned i = 0; i < callOp.getNumResults(); i++) {
auto retVal = before.getReturnValue(i);
auto [translatedVal, _] = retVal.translate(translation);
updated |= after->setValue(callOp->getResult(i), translatedVal);
}
for (const auto &[val, refVal] : before.getMap()) {
auto [translatedVal, _] = refVal.translate(translation);
updated |= after->setValue(val, translatedVal);
}
propagateIfChanged(after, updated);
}
/// `action == CallControlFlowAction::External` indicates that:
Expand All @@ -131,13 +131,17 @@ mlir::LogicalResult SourceRefAnalysis::visitOperation(
LLVM_DEBUG(llvm::dbgs() << "SourceRefAnalysis::visitOperation: " << *op << '\n');
// Collect the references that are made by the operands to `op`.
SourceRefLattice::ValueMap operandVals;
for (mlir::OpOperand &operand : op->getOpOperands()) {
operandVals[operand.get()] = before.getOrDefault(operand.get());
for (OpOperand &operand : op->getOpOperands()) {
const SourceRefLattice *prior = &before;
// Lookup the lattice for the operand, if it is op defined.
Value operandVal = operand.get();
if (Operation *defOp = operandVal.getDefiningOp()) {
prior = getLattice(getProgramPointAfter(defOp));
}
// Get the value (if there was a defining operation), or the default value.
operandVals[operandVal] = prior->getOrDefault(operandVal);
}

// Propagate existing state.
join(after, before);

// Add operand values, if not already added. Ensures that the default value
// of a SourceRef (the source of the ref) is visible in the lattice.
ChangeResult res = after->setValues(operandVals);
Expand Down Expand Up @@ -165,8 +169,7 @@ mlir::LogicalResult SourceRefAnalysis::visitOperation(
} else if (auto createArray = llvm::dyn_cast<CreateArrayOp>(op)) {
// Create an array using the operand values, if they exist.
// Currently, the new array must either be fully initialized or uninitialized.

auto newArrayVal = SourceRefLatticeValue(createArray.getType().getShape());
SourceRefLatticeValue newArrayVal(createArray.getType().getShape());
// If the array is statically initialized, iterate through all operands and initialize the array
// value.
const auto &elements = createArray.getElements();
Expand All @@ -193,6 +196,7 @@ mlir::LogicalResult SourceRefAnalysis::visitOperation(
}

propagateIfChanged(after, res);
LLVM_DEBUG(llvm::dbgs().indent(4) << "lattice is of size " << after->size() << '\n');
return success();
}

Expand Down Expand Up @@ -395,13 +399,23 @@ mlir::LogicalResult ConstraintDependencyGraph::computeConstraints(
SourceRefRemappings translations;

ProgramPoint *pp = solver.getProgramPointAfter(fnCall.getOperation());
auto lattice = solver.lookupState<SourceRefLattice>(pp);
ensure(lattice, "could not find lattice for call operation");
auto *afterCallLattice = solver.lookupState<SourceRefLattice>(pp);
ensure(afterCallLattice, "could not find lattice for call operation");

// Map fn parameters to args in the call op
for (unsigned i = 0; i < fn.getNumArguments(); i++) {
SourceRef prefix(fn.getArgument(i));
SourceRefLatticeValue val = lattice->getOrDefault(fnCall.getOperand(i));
// Look up the lattice that defines the operand value first, but default
// to the afterCallLattice if the operand is not defined by an operand.
const SourceRefLattice *operandLattice = afterCallLattice;
Value operand = fnCall.getOperand(i);
if (Operation *defOp = operand.getDefiningOp()) {
ProgramPoint *defPoint = solver.getProgramPointAfter(defOp);
operandLattice = solver.lookupState<SourceRefLattice>(defPoint);
}
ensure(operandLattice, "could not find lattice for call operand");

SourceRefLatticeValue val = operandLattice->getOrDefault(operand);
translations.push_back({prefix, val});
}
auto &childAnalysis =
Expand All @@ -413,6 +427,9 @@ mlir::LogicalResult ConstraintDependencyGraph::computeConstraints(
);
}
auto translatedCDG = childAnalysis.getResult(ctx).translate(translations);
// Update the refMap with the translation
const auto &translatedRef2Val = translatedCDG.getRef2Val();
ref2Val.insert(translatedRef2Val.begin(), translatedRef2Val.end());

// Now, union sets based on the translation
// We should be able to just merge what is in the translatedCDG to the current CDG
Expand Down
Loading
Loading