Skip to content

Commit 05d9d05

Browse files
authored
Parameters Performance Improvements (#177)
This PR adds performance improvements for working with parameters. Changes to MergeCircuitPass: * Remove some unused vectors * Change circuit merge so that only unique operands / arguments are added when merging circuits rather than the union of the two circuits arguments.
1 parent 3f4a0a5 commit 05d9d05

File tree

7 files changed

+87
-51
lines changed

7 files changed

+87
-51
lines changed

include/Dialect/QCS/IR/QCSOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "Dialect/QCS/IR/QCSTypes.h"
3131

3232
#include "mlir/IR/SymbolTable.h"
33+
#include "llvm/ADT/StringMap.h"
3334

3435
#define GET_OP_CLASSES
3536
#include "Dialect/QCS/IR/QCSOps.h.inc"

include/Dialect/QCS/IR/QCSOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def QCS_ParameterLoadOp : QCS_Op<"parameter_load",
289289

290290
let extraClassDeclaration = [{
291291
// Return the initial value - using ParameterInitialValueAnalysis
292-
ParameterType getInitialValue(std::unordered_map<std::string, ParameterType> &parameterNames);
292+
ParameterType getInitialValue(llvm::StringMap<ParameterType> &parameterNames);
293293

294294
// Return the initial value - slower SymbolTable version
295295
ParameterType getInitialValue();

include/Dialect/QCS/Utils/ParameterInitialValueAnalysis.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
#define QCS_PARAMETER_INITIAL_VALUE_ANALYSIS_H
2121

2222
#include "Dialect/QCS/IR/QCSOps.h"
23+
#include "HAL/SystemConfiguration.h"
2324

2425
#include "mlir/Pass/AnalysisManager.h"
2526
#include "mlir/Pass/Pass.h"
2627

28+
#include "llvm/ADT/StringMap.h"
29+
#include "llvm/ADT/StringRef.h"
2730
#include "llvm/Support/Error.h"
2831

2932
#include <string>
@@ -33,16 +36,16 @@ namespace mlir::qcs {
3336

3437
using namespace mlir;
3538

39+
using InitialValueType = llvm::StringMap<ParameterType>;
40+
3641
class ParameterInitialValueAnalysis {
3742
private:
38-
std::unordered_map<std::string, ParameterType> initial_values_;
39-
bool invalid_;
43+
InitialValueType initial_values_;
44+
bool invalid_{true};
4045

4146
public:
4247
ParameterInitialValueAnalysis(mlir::Operation *op);
43-
std::unordered_map<std::string, ParameterType> &getNames() {
44-
return initial_values_;
45-
}
48+
InitialValueType &getNames() { return initial_values_; }
4649
void invalidate() { invalid_ = true; }
4750
bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
4851
return invalid_;

lib/Dialect/QCS/IR/QCSOps.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "mlir/IR/BuiltinAttributes.h"
3030
#include "mlir/IR/BuiltinOps.h"
3131
#include "mlir/IR/SymbolTable.h"
32+
#include "llvm/ADT/StringMap.h"
33+
#include "llvm/ADT/StringRef.h"
3234

3335
using namespace mlir;
3436
using namespace mlir::qcs;
@@ -140,15 +142,15 @@ ParameterType ParameterLoadOp::getInitialValue() {
140142
}
141143

142144
// Returns the float value from the initial value of this parameter
143-
// this version uses a precomputed map of parrameter_name to the intial_value
145+
// this version uses a precomputed map of parameter_name to the initial_value
144146
// in order to avoid slow SymbolTable lookups
145147
ParameterType ParameterLoadOp::getInitialValue(
146-
std::unordered_map<std::string, ParameterType> &declareParametersMap) {
148+
llvm::StringMap<ParameterType> &declareParametersMap) {
147149
auto *op = getOperation();
148150
auto paramRefAttr =
149151
op->getAttrOfType<mlir::FlatSymbolRefAttr>("parameter_name");
150152

151-
auto paramOpEntry = declareParametersMap.find(paramRefAttr.getValue().str());
153+
auto paramOpEntry = declareParametersMap.find(paramRefAttr.getValue());
152154

153155
if (paramOpEntry == declareParametersMap.end()) {
154156
op->emitError("Could not find declare parameter op" +

lib/Dialect/QCS/Utils/ParameterInitialValueAnalysis.cpp

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,43 @@
2727
using namespace mlir::qcs;
2828

2929
ParameterInitialValueAnalysis::ParameterInitialValueAnalysis(
30-
mlir::Operation *op) {
31-
op->walk([&](DeclareParameterOp declareParameterOp) {
32-
double initial_value = 0.0;
33-
if (declareParameterOp.initial_value().hasValue()) {
34-
auto angleAttr = declareParameterOp.initial_value()
35-
.getValue()
36-
.dyn_cast<mlir::quir::AngleAttr>();
37-
auto floatAttr =
38-
declareParameterOp.initial_value().getValue().dyn_cast<FloatAttr>();
39-
if (!(angleAttr || floatAttr))
40-
op->emitError(
41-
"Parameters are currently limited to angles or float[64] only.");
30+
mlir::Operation *moduleOp) {
4231

43-
if (angleAttr)
44-
initial_value = angleAttr.getValue().convertToDouble();
32+
if (not invalid_)
33+
return;
4534

46-
if (floatAttr)
47-
initial_value = floatAttr.getValue().convertToDouble();
48-
}
49-
initial_values_[declareParameterOp.sym_name().str()] = initial_value;
50-
});
35+
// process the module top level to cache declareParameterOp initial_values
36+
// this does not use a walk method so that submodule (if present) are not
37+
// processed in order to limit processing time
38+
39+
for (auto &region : moduleOp->getRegions())
40+
for (auto &block : region.getBlocks())
41+
for (auto &op : block.getOperations()) {
42+
auto declareParameterOp = dyn_cast<DeclareParameterOp>(op);
43+
if (!declareParameterOp)
44+
continue;
45+
46+
// moduleOp->walk([&](DeclareParameterOp declareParameterOp) {
47+
double initial_value = 0.0;
48+
if (declareParameterOp.initial_value().hasValue()) {
49+
auto angleAttr = declareParameterOp.initial_value()
50+
.getValue()
51+
.dyn_cast<mlir::quir::AngleAttr>();
52+
auto floatAttr = declareParameterOp.initial_value()
53+
.getValue()
54+
.dyn_cast<FloatAttr>();
55+
if (!(angleAttr || floatAttr))
56+
declareParameterOp.emitError("Parameters are currently limited to "
57+
"angles or float[64] only.");
58+
59+
if (angleAttr)
60+
initial_value = angleAttr.getValue().convertToDouble();
61+
62+
if (floatAttr)
63+
initial_value = floatAttr.getValue().convertToDouble();
64+
}
65+
initial_values_[declareParameterOp.sym_name()] = initial_value;
66+
}
5167
invalid_ = false;
5268
}
5369

@@ -60,7 +76,7 @@ llvm::StringRef ParameterInitialValueAnalysisPass::getArgument() const {
6076
}
6177

6278
llvm::StringRef ParameterInitialValueAnalysisPass::getDescription() const {
63-
return "Run ParameterIntialValueAnalysis";
79+
return "Run ParameterInitialValueAnalysis";
6480
}
6581

6682
// TODO: move registerQCSPasses to separate source file if additional passes

lib/Dialect/QUIR/Transforms/MergeCircuits.cpp

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/ADT/StringRef.h"
3434

3535
#include <algorithm>
36+
#include <unordered_map>
3637
#include <vector>
3738

3839
using namespace mlir;
@@ -238,16 +239,29 @@ MergeCircuitsPass::mergeCallCircuits(PatternRewriter &rewriter,
238239

239240
rewriter.setInsertionPointAfter(nextCircuitOp);
240241

241-
llvm::SmallVector<Type> inputTypes;
242-
llvm::SmallVector<Value> inputValues;
243242
llvm::SmallVector<Type> outputTypes;
244243
llvm::SmallVector<Value> outputValues;
245244

246-
// merge input type into single SmallVector
247-
inputTypes.append(circuitOp->getOperandTypes().begin(),
248-
circuitOp->getOperandTypes().end());
249-
inputTypes.append(nextCircuitOp->getOperandTypes().begin(),
250-
nextCircuitOp->getOperandTypes().end());
245+
// merge the call_circuits
246+
// collect their input values
247+
llvm::SmallVector<Value> callInputValues;
248+
callInputValues.append(callCircuitOp->getOperands().begin(),
249+
callCircuitOp.getOperands().end());
250+
251+
llvm::SmallVector<int> insertedArguments;
252+
std::unordered_map<int, int> reusedArguments;
253+
int index = 0;
254+
for (auto inputValue : nextCallCircuitOp->getOperands()) {
255+
auto *search = find(callInputValues, inputValue);
256+
if (search == callInputValues.end()) {
257+
callInputValues.push_back(inputValue);
258+
insertedArguments.push_back(index);
259+
} else {
260+
int originalIndex = search - callInputValues.begin();
261+
reusedArguments[index] = originalIndex;
262+
}
263+
index++;
264+
}
251265

252266
// merge circuit names
253267
std::string newName =
@@ -269,12 +283,20 @@ MergeCircuitsPass::mergeCallCircuits(PatternRewriter &rewriter,
269283
// argument numbers
270284
BlockAndValueMapping mapper;
271285
auto baseArgNum = newCircuitOp.getNumArguments();
286+
int insertedCount = 0;
272287
for (uint cnt = 0; cnt < nextCircuitOp.getNumArguments(); cnt++) {
273288
auto arg = nextCircuitOp.getArgument(cnt);
274-
auto dictArg = nextCircuitOp.getArgAttrDict(cnt);
275-
newCircuitOp.insertArgument(baseArgNum + cnt, arg.getType(), dictArg,
276-
arg.getLoc());
277-
mapper.map(arg, newCircuitOp.getArgument(baseArgNum + cnt));
289+
int argumentIndex = 0;
290+
if (find(insertedArguments, cnt) != insertedArguments.end()) {
291+
auto dictArg = nextCircuitOp.getArgAttrDict(cnt);
292+
newCircuitOp.insertArgument(baseArgNum + insertedCount, arg.getType(),
293+
dictArg, arg.getLoc());
294+
argumentIndex = baseArgNum + insertedCount;
295+
insertedCount++;
296+
} else {
297+
argumentIndex = reusedArguments[cnt];
298+
}
299+
mapper.map(arg, newCircuitOp.getArgument(argumentIndex));
278300
}
279301

280302
// find return op in new circuit and copy second circuit into the
@@ -331,14 +353,6 @@ MergeCircuitsPass::mergeCallCircuits(PatternRewriter &rewriter,
331353
newCircuitOp->setAttr(mlir::quir::getPhysicalIdsAttrName(),
332354
rewriter.getI32ArrayAttr(ArrayRef<int>(allIds)));
333355

334-
// merge the call_circuits
335-
// collect their input values
336-
llvm::SmallVector<Value> callInputValues;
337-
callInputValues.append(callCircuitOp->getOperands().begin(),
338-
callCircuitOp.getOperands().end());
339-
callInputValues.append(nextCallCircuitOp->getOperands().begin(),
340-
nextCallCircuitOp.getOperands().end());
341-
342356
rewriter.setInsertionPointAfter(nextCallCircuitOp);
343357
auto newCallOp = rewriter.create<mlir::quir::CallCircuitOp>(
344358
callCircuitOp->getLoc(), newName, TypeRange(outputTypes),

test/Dialect/QUIR/Transforms/merge-circuits.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,21 +110,21 @@ module {
110110
%12:2 = quir.call_circuit @circuit_6(%0) : (!quir.qubit<1>) -> (i1, i1)
111111
quir.barrier %200 : (!quir.qubit<1>) -> ()
112112
%13:2 = quir.call_circuit @circuit_6(%0) : (!quir.qubit<1>) -> (i1, i1)
113-
// CHECK: %{{.*}}:4 = quir.call_circuit @circuit_6_q0_circuit_6_q0(%0, %0) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1, i1, i1)
113+
// CHECK: %{{.*}}:4 = quir.call_circuit @circuit_6_q0_circuit_6_q0(%0) : (!quir.qubit<1>) -> (i1, i1, i1, i1)
114114

115115

116116
quir.barrier %0, %1, %200, %201, %202 : (!quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>) -> ()
117117
%14:2 = quir.call_circuit @circuit_7(%0) : (!quir.qubit<1>) -> (i1, i1)
118118
quir.barrier %200, %201 : (!quir.qubit<1>, !quir.qubit<1>) -> ()
119119
quir.barrier %200, %202 : (!quir.qubit<1>, !quir.qubit<1>) -> ()
120120
%15:2 = quir.call_circuit @circuit_7(%0) : (!quir.qubit<1>) -> (i1, i1)
121-
// CHECK: %{{.*}}:4 = quir.call_circuit @circuit_7_q0_circuit_7_q0(%0, %0) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1, i1, i1)
121+
// CHECK: %{{.*}}:4 = quir.call_circuit @circuit_7_q0_circuit_7_q0(%0) : (!quir.qubit<1>) -> (i1, i1, i1, i1)
122122

123123
quir.barrier %0, %1, %200, %201, %202 : (!quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>) -> ()
124124
%16:2 = quir.call_circuit @circuit_8(%0) : (!quir.qubit<1>) -> (i1, i1)
125125
quir.barrier %0 : (!quir.qubit<1>) -> ()
126126
%17:2 = quir.call_circuit @circuit_8(%0) : (!quir.qubit<1>) -> (i1, i1)
127-
// CHECK-NOT: %{{.*}}:4 = quir.call_circuit @circuit_8_q0_circuit_8_q0(%0, %0) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1, i1, i1)
127+
// CHECK-NOT: %{{.*}}:4 = quir.call_circuit @circuit_8_q0_circuit_8_q0(%0) : (!quir.qubit<1>) -> (i1, i1, i1, i1)
128128

129129
%c0_i32 = arith.constant 0 : i32
130130
return %c0_i32 : i32

0 commit comments

Comments
 (0)