Skip to content
Closed
68 changes: 68 additions & 0 deletions include/circt/Analysis/OpDepthAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
//===- OpDepthAnalysis.h - operation depth analyses -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines AIG operation depth analysis.
//
//===----------------------------------------------------------------------===//

#ifndef CIRCT_ANALYSIS_OPDEPTH_ANALYSIS_H
#define CIRCT_ANALYSIS_OPDEPTH_ANALYSIS_H

#include "circt/Support/LLVM.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"

#include "circt/Dialect/AIG/AIGOps.h"
#include "circt/Dialect/HW/HWOps.h"

namespace mlir {
class AnalysisManager;
} // namespace mlir
namespace circt {
namespace aig {
namespace analysis {

class OpDepthAnalysis {
public:
OpDepthAnalysis(hw::HWModuleOp moduleOp, mlir::AnalysisManager &am);

/// Get the depth of operations of a specific name
size_t getOpDepth(AndInverterOp op) const {
assert(opDepths.count(op));
return opDepths.at(op);
}

bool isOnCriticalPath(AndInverterOp op) const {
return criticalPath.count(op);
}

const DenseMap<AndInverterOp, size_t> &getOpDepthMap() const {
return opDepths;
}

size_t updateLevel(AndInverterOp op, bool isRoot = false);
void updateAllLevel();

SmallVector<AndInverterOp> getPOs();

private:
void setCriticalPath(AndInverterOp op);

private:
DenseMap<AndInverterOp, size_t> opDepths;
SetVector<AndInverterOp> criticalPath;
size_t currDepth = 0;
hw::HWModuleOp module;
};

} // namespace analysis
} // namespace aig
} // namespace circt

#endif // CIRCT_ANALYSIS_OPDEPTH_ANALYSIS_H
2 changes: 1 addition & 1 deletion include/circt/Dialect/AIG/AIGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class AIGOp<string mnemonic, list<Trait> traits = []> :
Op<AIG_Dialect, mnemonic, traits>;

def AndInverterOp : AIGOp<"and_inv", [SameOperandsAndResultType, Pure]> {
def AndInverterOp : AIGOp<"and_inv", [SameOperandsAndResultType, Pure, Commutative]> {
let summary = "AIG dialect AND operation";
let description = [{
The `aig.and_inv` operation represents an And-Inverter in the AIG dialect.
Expand Down
8 changes: 8 additions & 0 deletions include/circt/Dialect/AIG/AIGPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@ def LowerVariadic : Pass<"aig-lower-variadic", "hw::HWModuleOp"> {
let summary = "Lower variadic AndInverter operations to binary AndInverter";
}

def BalanceVariadic : Pass<"aig-balance-variadic", "hw::HWModuleOp"> {
let summary = "Lower variadic AndInverter operations to binary AndInverter";
}

def LowerWordToBits : Pass<"aig-lower-word-to-bits", "hw::HWModuleOp"> {
let summary = "Lower multi-bit AndInverter to single-bit ones";
let dependentDialects = ["comb::CombDialect"];
}

def MaximumAndCover : Pass<"maximum-and-cover", "hw::HWModuleOp"> {
let summary = "Maximum And Cover";
}

#endif // CIRCT_DIALECT_AIG_AIGPASSES_TD
20 changes: 20 additions & 0 deletions integration_test/circt-synth/aig-balancing-lec.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// REQUIRES: libz3
// REQUIRES: circt-lec-jit

// RUN: circt-opt %s --cse --convert-aig-to-comb -o %t1.mlir
// RUN: circt-opt %s --maximum-and-cover --aig-balance-variadic --cse --convert-aig-to-comb -o %t2.mlir

// RUN: circt-lec %t.mlir %s -c1=aig -c2=aig --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_AIG
// COMB_AIG: c1 == c2
hw.module @aig(in %a: i1, in %b: i1, in %c: i1, in %d: i1, out o1: i1, out o2: i1, out o3: i1) {
%0 = aig.and_inv %a, %b : i1
%1 = aig.and_inv %0, %c : i1
%2 = aig.and_inv %b, %c : i1
%3 = aig.and_inv %2, %d : i1

%4 = aig.and_inv %c, %d : i1
%5 = aig.and_inv %b, %4 : i1
%6 = aig.and_inv %a, %5 : i1

hw.output %1, %3, %6 : i1, i1, i1
}
8 changes: 8 additions & 0 deletions lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set(LLVM_OPTIONAL_SOURCES
DependenceAnalysis.cpp
FIRRTLInstanceInfo.cpp
OpCountAnalysis.cpp
OpDepthAnalysis.cpp
SchedulingAnalysis.cpp
TestPasses.cpp
)
Expand Down Expand Up @@ -35,6 +36,13 @@ add_circt_library(CIRCTOpCountAnalysis
MLIRIR
)

add_circt_library(CIRCTOpDepthAnalysis
OpDepthAnalysis.cpp

LINK_LIBS PUBLIC
MLIRIR
)

add_circt_library(CIRCTSchedulingAnalysis
SchedulingAnalysis.cpp

Expand Down
85 changes: 85 additions & 0 deletions lib/Analysis/OpDepthAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//===- OpCountAnalysis.cpp - operation count analyses -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the op depth (level) analysis.
//
//===----------------------------------------------------------------------===//

#include "circt/Analysis/OpDepthAnalysis.h"
#include "mlir/IR/Operation.h"

using namespace circt;
using namespace aig;
using namespace analysis;

OpDepthAnalysis::OpDepthAnalysis(hw::HWModuleOp moduleOp,
mlir::AnalysisManager &am)
: module(moduleOp) {
// llvm::dbgs() << "OpDepthAnalysis Init\n";
// updateAllLevel();
}

SmallVector<AndInverterOp> OpDepthAnalysis::getPOs() {
SmallVector<AndInverterOp> po;
for (auto op : module.getOps<AndInverterOp>()) {
bool isPO = true;
for (auto *user : op->getUsers()) {
if (isa<AndInverterOp>(user)) {
isPO = false;
break;
}
}
if (isPO)
po.push_back(op);
}
return po;
}

void OpDepthAnalysis::updateAllLevel() {
auto po = getPOs();
for (auto &op : po) {
currDepth = std::max(currDepth, updateLevel(op));
}

for (auto &op : po) {
if (auto it = opDepths.find(op);
it != opDepths.end() && it->second == currDepth) {
setCriticalPath(op);
break;
}
}
}

size_t OpDepthAnalysis::updateLevel(AndInverterOp op, bool isRoot) {
if (auto it = opDepths.find(op); !isRoot && it != opDepths.end()) {
return it->second;
}

/// PI is level 0, so the minimum level of an AndInverterOp is 1
size_t maxDepth = 1;
for (auto fanin : op.getOperands()) {
auto faninOp = fanin.getDefiningOp<AndInverterOp>();
if (faninOp) {
size_t faninDepth = updateLevel(faninOp);
maxDepth = std::max(maxDepth, faninDepth + 1);
}
}
opDepths[op] = maxDepth;
return maxDepth;
}

void OpDepthAnalysis::setCriticalPath(AndInverterOp op) {
size_t clevel = opDepths[op];
for (auto fanin : op.getOperands()) {
auto faninOp = fanin.getDefiningOp<AndInverterOp>();
if (faninOp && opDepths[faninOp] + 1 == clevel) {
setCriticalPath(faninOp);
// break; // TODO: Should break when there are multiple critical paths?
}
}
}
166 changes: 166 additions & 0 deletions lib/Dialect/AIG/Transforms/BalanceVariadic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
//===- BalanceVariadic.cpp - Lowering Variadic to Binary Ops ------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass lowers variadic AndInverter operations to balanced binary
// AndInverter operations.
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/PriorityQueue.h"

#include "circt/Analysis/OpDepthAnalysis.h"
#include "circt/Dialect/AIG/AIGOps.h"
#include "circt/Dialect/AIG/AIGPasses.h"
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "mlir/IR/Iterators.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

#define DEBUG_TYPE "aig-balance-variadic"

namespace circt {
namespace aig {
#define GEN_PASS_DEF_BALANCEVARIADIC
#include "circt/Dialect/AIG/AIGPasses.h.inc"
} // namespace aig
} // namespace circt

using namespace circt;
using namespace aig;

namespace {
/// For wrapping Value and complement information into one object
struct Signal {
Value value;
bool complement;

Signal() = default;
Signal(Value v, bool complement) : value(v), complement(complement) {}

bool isComplement() const { return complement; }
Value getValue() const { return value; }
};

struct BalanceVariadicDriver {
BalanceVariadicDriver(mlir::IRRewriter &rewriter,
aig::analysis::OpDepthAnalysis *opDepthAnalysis)
: rewriter(rewriter), opDepthAnalysis(opDepthAnalysis) {}

struct PairSorter {
bool operator()(const std::pair<size_t, Signal> &lhs,
const std::pair<size_t, Signal> &rhs) const {
return lhs.first > rhs.first;
}
};

using NodeLevelHeap =
llvm::PriorityQueue<std::pair<size_t, Signal>,
std::vector<std::pair<size_t, Signal>>, PairSorter>;

void balanceVariadicAndInverterOp(AndInverterOp op) {
rewriter.setInsertionPoint(op);

NodeLevelHeap sortByLevel;
for (auto [fanin, inverted] :
llvm::zip(op.getOperands(), op.getInverted())) {
auto faninOp = fanin.getDefiningOp<AndInverterOp>();
size_t level = faninOp ? opDepthAnalysis->updateLevel(faninOp, true) : 0;
sortByLevel.push({level, Signal(fanin, inverted)});
}

// extract the top two elements with minimum level
// and replace them with a new AndInverterOp
while (sortByLevel.size() > 2) {
auto [llv, lhs] = sortByLevel.top();
sortByLevel.pop();
auto [rlv, rhs] = sortByLevel.top();
sortByLevel.pop();

auto balanced = rewriter.create<AndInverterOp>(
op.getLoc(), lhs.getValue(), rhs.getValue(), lhs.isComplement(),
rhs.isComplement());

size_t level = std::max(llv, rlv) + 1;
sortByLevel.push({level, Signal(balanced, false)});
}

switch (sortByLevel.size()) {
case 0:
break;
case 1: {
auto signal = sortByLevel.top().second;
sortByLevel.pop();
rewriter.replaceOp(op, signal.getValue());
break;
}
default:
auto lhs = sortByLevel.top().second;
sortByLevel.pop();
auto rhs = sortByLevel.top().second;

rewriter.replaceOp(op, rewriter.create<AndInverterOp>(
op.getLoc(), lhs.getValue(), rhs.getValue(),
lhs.isComplement(), rhs.isComplement()));
}
}

void balanceRecursive(AndInverterOp op) {
if (visited.count(op))
return;

visited.insert(op);
assert(!op->use_empty());

for (auto fanin : op.getOperands()) {
auto faninOp = fanin.getDefiningOp<AndInverterOp>();
if (faninOp) {
balanceRecursive(faninOp);
}
}

if (op.getOperands().size() <= 2)
return;

balanceVariadicAndInverterOp(op);
// opDepthAnalysis->updateLevel(op, true);
}

void balancing() {
// Balance each variadic AndInverterOp in reverse topological order
// Will ignore dangling internal AIG nodes
for (AndInverterOp po : opDepthAnalysis->getPOs()) {
balanceRecursive(po);
}
}

private:
DenseSet<Operation *> visited;
mlir::IRRewriter &rewriter;
aig::analysis::OpDepthAnalysis *opDepthAnalysis;
};

struct BalanceVariadicPass
: public impl::BalanceVariadicBase<BalanceVariadicPass> {
void runOnOperation() override;
};
} // namespace

//===----------------------------------------------------------------------===//
// Balance Variadic pass
//===----------------------------------------------------------------------===//
void BalanceVariadicPass::runOnOperation() {
auto *opDepthAnalysis = &getAnalysis<aig::analysis::OpDepthAnalysis>();

auto module = getOperation();
MLIRContext *ctx = module->getContext();
mlir::IRRewriter rewriter(ctx);

BalanceVariadicDriver driver(rewriter, opDepthAnalysis);
driver.balancing();
}
Loading
Loading