Skip to content

Commit 82883b1

Browse files
authored
Improve SubroutineCloningPass and MergeCircuitsPass performance (#192)
Improves the performance of the SubroutineCloningPassm and MergeCircuitsPass by replacing the usage of SymbolTable::lookupSymbolIn with a llvm::StringMap storing the symbols of interest.
1 parent d3156c4 commit 82883b1

File tree

8 files changed

+78
-32
lines changed

8 files changed

+78
-32
lines changed

cmake/apple-clang.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ if(ENABLE_ADDRESS_SANITIZER OR ENABLE_UNDEFINED_SANITIZER OR ENABLE_THREAD_SANIT
3939
endif()
4040

4141
set (CMAKE_CXX_FLAGS_DEBUG "-g3 -O0")
42-
set (CMAKE_CXX_FLAGS_RELEASE "-g -O2 -DNDEBUG")
42+
set (CMAKE_CXX_FLAGS_RELEASE "-g -O2 -DNOVERIFY")
4343

4444
set (CMAKE_INSTALL_RPATH_USE_LINK_PATH ON)

cmake/linux-gcc.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ if(ENABLE_THREAD_SANITIZER)
4040
endif()
4141

4242
set (CMAKE_CXX_FLAGS_DEBUG "-g3 -O0")
43-
set (CMAKE_CXX_FLAGS_RELEASE "-g -O2 -DNDEBUG")
43+
set (CMAKE_CXX_FLAGS_RELEASE "-g -O2 -DNOVERIFY")
4444

4545
# set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fuse-ld=gold")
4646
set (CMAKE_CXX_STANDARD_LIBRARIES "-lstdc++fs -lpthread")

cmake/llvm-clang.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ if(ENABLE_ADDRESS_SANITIZER OR ENABLE_UNDEFINED_SANITIZER OR ENABLE_THREAD_SANIT
4242
endif()
4343

4444
set (CMAKE_CXX_FLAGS_DEBUG "-g3 -O0")
45-
set (CMAKE_CXX_FLAGS_RELEASE "-g -O2 -DNDEBUG")
45+
set (CMAKE_CXX_FLAGS_RELEASE "-g -O2 -DNOVERIFY")
4646

4747
set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fuse-ld=gold")
4848
set (CMAKE_CXX_STANDARD_LIBRARIES "-lstdc++fs -lpthread")

include/Dialect/QUIR/Transforms/MergeCircuits.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ struct MergeCircuitsPass
3434
: public PassWrapper<MergeCircuitsPass, OperationPass<>> {
3535
void runOnOperation() override;
3636

37-
static CircuitOp getCircuitOp(CallCircuitOp callCircuitOp);
38-
static LogicalResult mergeCallCircuits(PatternRewriter &rewriter,
39-
CallCircuitOp callCircuitOp,
40-
CallCircuitOp nextCallCircuitOp);
37+
static CircuitOp getCircuitOp(CallCircuitOp callCircuitOp,
38+
llvm::StringMap<Operation *> *symbolMap);
39+
static LogicalResult
40+
mergeCallCircuits(PatternRewriter &rewriter, CallCircuitOp callCircuitOp,
41+
CallCircuitOp nextCallCircuitOp,
42+
llvm::StringMap<Operation *> *symbolMap);
4143

4244
llvm::StringRef getArgument() const override;
4345
llvm::StringRef getDescription() const override;

include/Dialect/QUIR/Transforms/SubroutineCloning.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,17 @@ class Operation;
3535
} // namespace mlir
3636

3737
namespace mlir::quir {
38+
39+
using SymbolOpMap = llvm::StringMap<Operation *>;
40+
3841
struct SubroutineCloningPass
3942
: public PassWrapper<SubroutineCloningPass, OperationPass<>> {
4043
auto lookupQubitId(const Value val) -> int;
4144

4245
template <class CallLikeOp>
4346
auto getMangledName(Operation *op) -> std::string;
4447
template <class CallLikeOp, class FuncLikeOp>
45-
void processCallOp(Operation *op);
48+
void processCallOp(Operation *op, SymbolOpMap &symbolOpMap);
4649

4750
void runOnOperation() override;
4851

lib/API/api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ static llvm::cl::opt<bool> verifyDiagnostics(
8585
"expected-* lines on the corresponding line"),
8686
llvm::cl::init(false), llvm::cl::cat(qssc::config::getQSSCCategory()));
8787

88-
#ifndef NDEBUG
88+
#ifndef NOVERIFY
8989
#define VERIFY_PASSES_DEFAULT true
9090
#else
9191
#define VERIFY_PASSES_DEFAULT false

lib/Dialect/QUIR/Transforms/MergeCircuits.cpp

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,13 @@ bool moveUsers(Operation *curOp, MoveListVec &moveList) {
5252

5353
// This pattern matches on two CallCircuitOps separated by non-quantum ops
5454
struct CircuitAndCircuitPattern : public OpRewritePattern<CallCircuitOp> {
55-
explicit CircuitAndCircuitPattern(MLIRContext *ctx)
56-
: OpRewritePattern<CallCircuitOp>(ctx) {}
55+
explicit CircuitAndCircuitPattern(MLIRContext *ctx,
56+
llvm::StringMap<Operation *> &symbolMap)
57+
: OpRewritePattern<CallCircuitOp>(ctx) {
58+
_symbolMap = &symbolMap;
59+
}
60+
61+
llvm::StringMap<Operation *> *_symbolMap;
5762

5863
LogicalResult matchAndRewrite(CallCircuitOp callCircuitOp,
5964
PatternRewriter &rewriter) const override {
@@ -141,7 +146,7 @@ struct CircuitAndCircuitPattern : public OpRewritePattern<CallCircuitOp> {
141146
return failure();
142147

143148
return MergeCircuitsPass::mergeCallCircuits(rewriter, callCircuitOp,
144-
nextCallCircuitOp);
149+
nextCallCircuitOp, _symbolMap);
145150
} // matchAndRewrite
146151
}; // struct CircuitAndCircuitPattern
147152

@@ -220,22 +225,25 @@ struct CircuitAndBarrierPattern : public OpRewritePattern<CallCircuitOp> {
220225

221226
} // end anonymous namespace
222227

223-
CircuitOp MergeCircuitsPass::getCircuitOp(CallCircuitOp callCircuitOp) {
224-
auto circuitAttr = callCircuitOp->getAttrOfType<FlatSymbolRefAttr>("callee");
225-
assert(circuitAttr && "Requires a 'callee' symbol reference attribute");
228+
CircuitOp
229+
MergeCircuitsPass::getCircuitOp(CallCircuitOp callCircuitOp,
230+
llvm::StringMap<Operation *> *symbolMap) {
231+
// look for func def match
232+
assert(symbolMap && "a valid symbolMap pointer must be provided");
233+
auto search = symbolMap->find(callCircuitOp.callee());
234+
235+
assert(search != symbolMap->end() && "matching circuit not found");
226236

227-
auto circuitOp = SymbolTable::lookupNearestSymbolFrom<CircuitOp>(
228-
callCircuitOp, circuitAttr);
237+
auto circuitOp = dyn_cast<CircuitOp>(search->second);
229238
assert(circuitOp && "matching circuit not found");
230239
return circuitOp;
231240
}
232241

233-
LogicalResult
234-
MergeCircuitsPass::mergeCallCircuits(PatternRewriter &rewriter,
235-
CallCircuitOp callCircuitOp,
236-
CallCircuitOp nextCallCircuitOp) {
237-
auto circuitOp = getCircuitOp(callCircuitOp);
238-
auto nextCircuitOp = getCircuitOp(nextCallCircuitOp);
242+
LogicalResult MergeCircuitsPass::mergeCallCircuits(
243+
PatternRewriter &rewriter, CallCircuitOp callCircuitOp,
244+
CallCircuitOp nextCallCircuitOp, llvm::StringMap<Operation *> *symbolMap) {
245+
auto circuitOp = getCircuitOp(callCircuitOp, symbolMap);
246+
auto nextCircuitOp = getCircuitOp(nextCallCircuitOp, symbolMap);
239247

240248
rewriter.setInsertionPointAfter(nextCircuitOp);
241249

@@ -365,14 +373,24 @@ MergeCircuitsPass::mergeCallCircuits(PatternRewriter &rewriter,
365373
rewriter.replaceOp(nextCallCircuitOp,
366374
ResultRange(iterSep, newCallOp.result_end()));
367375

376+
// add new name to symbolMap
377+
// do not remove old in case the are multiple calls
378+
(*symbolMap)[newName] = newCircuitOp.getOperation();
379+
368380
return success();
369381
}
370382

371383
void MergeCircuitsPass::runOnOperation() {
372384
Operation *moduleOperation = getOperation();
373385

386+
llvm::StringMap<Operation *> circuitOpsMap;
387+
388+
moduleOperation->walk([&](CircuitOp circuitOp) {
389+
circuitOpsMap[circuitOp.sym_name()] = circuitOp.getOperation();
390+
});
391+
374392
RewritePatternSet patterns(&getContext());
375-
patterns.add<CircuitAndCircuitPattern>(&getContext());
393+
patterns.add<CircuitAndCircuitPattern>(&getContext(), circuitOpsMap);
376394
patterns.add<BarrierAndCircuitPattern>(&getContext());
377395
patterns.add<CircuitAndBarrierPattern>(&getContext());
378396

lib/Dialect/QUIR/Transforms/SubroutineCloning.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,21 @@ auto SubroutineCloningPass::getMangledName(Operation *op) -> std::string {
8282
} // getMangledName
8383

8484
template <class CallLikeOp, class FuncLikeOp>
85-
void SubroutineCloningPass::processCallOp(Operation *op) {
85+
void SubroutineCloningPass::processCallOp(Operation *op,
86+
SymbolOpMap &symbolOps) {
8687
auto callOp = dyn_cast<CallLikeOp>(op);
8788
OpBuilder build(moduleOperation->getRegion(0));
8889

8990
// look for func def match
90-
Operation *findOp =
91-
SymbolTable::lookupSymbolIn(moduleOperation, callOp.callee());
91+
auto search = symbolOps.find(callOp.callee());
92+
93+
if (search == symbolOps.end()) {
94+
callOp->emitOpError() << "No matching function def found for "
95+
<< callOp.callee() << "\n";
96+
return signalPassFailure();
97+
}
98+
99+
Operation *findOp = search->second;
92100
if (findOp) {
93101
std::vector<Value> qOperands;
94102
qubitCallOperands(callOp, qOperands);
@@ -99,9 +107,7 @@ void SubroutineCloningPass::processCallOp(Operation *op) {
99107
FlatSymbolRefAttr::get(&getContext(), mangledName));
100108

101109
// does the mangled function already exist?
102-
Operation *mangledOp =
103-
SymbolTable::lookupSymbolIn(moduleOperation, mangledName);
104-
if (mangledOp) // nothing to do
110+
if (symbolOps.find(mangledName) != symbolOps.end())
105111
return;
106112

107113
// clone the func def with the new name
@@ -127,6 +133,8 @@ void SubroutineCloningPass::processCallOp(Operation *op) {
127133
// add calls within the new func def to the callWorkList
128134
newFunc->walk([&](CallLikeOp op) { callWorkList.push_back(op); });
129135

136+
symbolOps[mangledName] = newFunc.getOperation();
137+
130138
} else { // matching function not found
131139
callOp->emitOpError() << "No matching function def found for "
132140
<< callOp.callee() << "\n";
@@ -148,18 +156,33 @@ void SubroutineCloningPass::runOnOperation() {
148156

149157
mainFunc->walk([&](CallSubroutineOp op) { callWorkList.push_back(op); });
150158

159+
SymbolOpMap symbolOps;
160+
161+
if (!callWorkList.empty()) {
162+
moduleOperation->walk([&](FuncOp functionOp) {
163+
symbolOps[functionOp.sym_name()] = functionOp.getOperation();
164+
});
165+
}
166+
151167
while (!callWorkList.empty()) {
152168
Operation *op = callWorkList.front();
153169
callWorkList.pop_front();
154-
processCallOp<CallSubroutineOp, FuncOp>(op);
170+
processCallOp<CallSubroutineOp, FuncOp>(op, symbolOps);
155171
}
156172

157173
mainFunc->walk([&](CallCircuitOp op) { callWorkList.push_back(op); });
158174

175+
if (!callWorkList.empty()) {
176+
symbolOps.clear();
177+
moduleOperation->walk([&](CircuitOp circuitOp) {
178+
symbolOps[circuitOp.sym_name()] = circuitOp.getOperation();
179+
});
180+
}
181+
159182
while (!callWorkList.empty()) {
160183
Operation *op = callWorkList.front();
161184
callWorkList.pop_front();
162-
processCallOp<CallCircuitOp, CircuitOp>(op);
185+
processCallOp<CallCircuitOp, CircuitOp>(op, symbolOps);
163186
}
164187

165188
// All subroutine defs that have been cloned are no longer needed

0 commit comments

Comments
 (0)