Skip to content

Commit a3224b6

Browse files
authored
Add QUIRCircuitsAnalysis (#199)
Adds a QUIRCircuitsAnalysis to improve performance of quir.circuits.
1 parent d2744aa commit a3224b6

File tree

5 files changed

+306
-1
lines changed

5 files changed

+306
-1
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===- QUIRCircuitAnalysis.h - Cache circuit argument values ---*- C++ -*-===//
2+
//
3+
// (C) Copyright IBM 2023.
4+
//
5+
// Any modifications or derivative works of this code must retain this
6+
// copyright notice, and modified files need to carry a notice indicating
7+
// that they have been altered from the originals.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
///
11+
/// This file implements an analysis for caching argument attributes with
12+
/// default values for angle and duration arguments.
13+
///
14+
//===----------------------------------------------------------------------===//
15+
16+
#ifndef QUIR_CIRCUITS_ANALYSIS_H
17+
#define QUIR_CIRCUITS_ANALYSIS_H
18+
19+
#include "Dialect/Pulse/IR/PulseOps.h"
20+
#include "Dialect/QCS/Utils/ParameterInitialValueAnalysis.h"
21+
#include "Dialect/QUIR/IR/QUIROps.h"
22+
23+
#include "mlir/IR/BuiltinOps.h"
24+
#include "mlir/Pass/AnalysisManager.h"
25+
#include "mlir/Pass/Pass.h"
26+
27+
#include <tuple>
28+
29+
namespace mlir::quir {
30+
31+
enum QUIRCircuitAnalysisEntry { ANGLE = 0, PARAMETER_NAME, DURATION };
32+
33+
using OperandAttributes =
34+
std::tuple<double, llvm::StringRef, mlir::quir::DurationAttr>;
35+
36+
using CircuitAnalysisMap = std::unordered_map<
37+
mlir::Operation *,
38+
std::unordered_map<mlir::Operation *,
39+
std::unordered_map<int, OperandAttributes>>>;
40+
41+
class QUIRCircuitAnalysis {
42+
private:
43+
CircuitAnalysisMap circuitOperands;
44+
bool invalid_{true};
45+
46+
public:
47+
QUIRCircuitAnalysis(mlir::Operation *op, AnalysisManager &am);
48+
CircuitAnalysisMap &getAnalysisMap() { return circuitOperands; }
49+
50+
void invalidate() { invalid_ = true; }
51+
bool isInvalidated(const mlir::AnalysisManager::PreservedAnalyses &pa) {
52+
return invalid_;
53+
}
54+
55+
private:
56+
double getAngleValue(mlir::Value operand,
57+
mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis);
58+
llvm::StringRef getParameterName(mlir::Value operand);
59+
quir::DurationAttr getDuration(mlir::Value operand);
60+
};
61+
62+
struct QUIRCircuitAnalysisPass
63+
: public mlir::PassWrapper<QUIRCircuitAnalysisPass,
64+
OperationPass<ModuleOp>> {
65+
66+
QUIRCircuitAnalysisPass() = default;
67+
QUIRCircuitAnalysisPass(const QUIRCircuitAnalysisPass &pass) = default;
68+
69+
void runOnOperation() override;
70+
71+
llvm::StringRef getArgument() const override;
72+
llvm::StringRef getDescription() const override;
73+
}; // QUIRCircuitAnalysisPass
74+
75+
llvm::Expected<double>
76+
angleValToDouble(mlir::Value inVal,
77+
mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis,
78+
mlir::quir::QUIRCircuitAnalysis *circuitAnalysis = nullptr);
79+
80+
} // namespace mlir::quir
81+
82+
#endif // QUIR_CIRCUITS_ANALYSIS_H

lib/Dialect/QCS/IR/QCSOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ ParameterType ParameterLoadOp::getInitialValue(
153153
auto paramOpEntry = declareParametersMap.find(paramRefAttr.getValue());
154154

155155
if (paramOpEntry == declareParametersMap.end()) {
156-
op->emitError("Could not find declare parameter op" +
156+
op->emitError("Could not find declare parameter op " +
157157
paramRefAttr.getValue().str());
158158
return 0.0;
159159
}

lib/Dialect/QCS/Utils/ParameterInitialValueAnalysis.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ using namespace mlir::qcs;
2929
ParameterInitialValueAnalysis::ParameterInitialValueAnalysis(
3030
mlir::Operation *moduleOp) {
3131

32+
// ParameterInitialValueAnalysis should only process the top level
33+
// module where parameters are defined
34+
// find the top level module
35+
auto parentOp = moduleOp->getParentOfType<mlir::ModuleOp>();
36+
while (parentOp) {
37+
moduleOp = parentOp;
38+
parentOp = moduleOp->getParentOfType<mlir::ModuleOp>();
39+
}
40+
3241
if (not invalid_)
3342
return;
3443

lib/Dialect/QUIR/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRQUIRTransforms
2323
MergeParallelResets.cpp
2424
Passes.cpp
2525
QuantumDecoration.cpp
26+
QUIRCircuitAnalysis.cpp
2627
RemoveQubitOperands.cpp
2728
ReorderMeasurements.cpp
2829
ReorderCircuits.cpp
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
//===- QUIRCircuitsAnalsysis.cpp - Cache values for circuits ----*- C++ -*-===//
2+
//
3+
// (C) Copyright IBM 2023.
4+
//
5+
// Any modifications or derivative works of this code must retain this
6+
// copyright notice, and modified files need to carry a notice indicating
7+
// that they have been altered from the originals.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
///
11+
/// This file implements a analysis for caching argument values with default
12+
/// values for angle and duration arguments.
13+
///
14+
//===----------------------------------------------------------------------===//
15+
16+
#include "Dialect/QUIR/Transforms/QUIRCircuitAnalysis.h"
17+
18+
#include "Dialect/OQ3/IR/OQ3Ops.h"
19+
#include "Dialect/QCS/IR/QCSOps.h"
20+
#include "Dialect/QUIR/IR/QUIROps.h"
21+
#include "Dialect/QUIR/Utils/Utils.h"
22+
#include "mlir/Pass/AnalysisManager.h"
23+
24+
#include "mlir/IR/Builders.h"
25+
26+
using namespace mlir;
27+
28+
namespace mlir::quir {
29+
30+
double
31+
parameterValToDouble(mlir::qcs::ParameterLoadOp defOp,
32+
mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis) {
33+
assert(nameAnalysis &&
34+
"A valid ParameterInitialValueAnalysis pointer is required");
35+
return std::get<double>(defOp.getInitialValue(nameAnalysis->getNames()));
36+
}
37+
38+
llvm::Expected<double>
39+
angleValToDouble(mlir::Value inVal,
40+
mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis,
41+
mlir::quir::QUIRCircuitAnalysis *circuitAnalysis) {
42+
43+
llvm::StringRef errorStr;
44+
45+
if (auto defOp = inVal.getDefiningOp<mlir::quir::ConstantOp>())
46+
return defOp.getAngleValueFromConstant().convertToDouble();
47+
48+
if (auto defOp = inVal.getDefiningOp<mlir::qcs::ParameterLoadOp>())
49+
return parameterValToDouble(defOp, nameAnalysis);
50+
51+
if (auto blockArg = inVal.dyn_cast<mlir::BlockArgument>()) {
52+
auto circuitOp = mlir::dyn_cast<mlir::quir::CircuitOp>(
53+
inVal.getParentBlock()->getParentOp());
54+
assert(circuitOp && "can only handle circuit arguments");
55+
56+
auto argNum = blockArg.getArgNumber();
57+
if (circuitAnalysis == nullptr) {
58+
59+
auto argAttr = circuitOp.getArgAttrOfType<mlir::quir::AngleAttr>(
60+
argNum, mlir::quir::getAngleAttrName());
61+
return argAttr.getValue().convertToDouble();
62+
}
63+
auto parentModuleOp = circuitOp->getParentOfType<mlir::ModuleOp>();
64+
return std::get<QUIRCircuitAnalysisEntry::ANGLE>(
65+
circuitAnalysis->getAnalysisMap()[parentModuleOp][circuitOp][argNum]);
66+
}
67+
68+
if (auto castOp = inVal.getDefiningOp<mlir::oq3::CastOp>()) {
69+
auto defOp = castOp.arg().getDefiningOp<mlir::qcs::ParameterLoadOp>();
70+
if (defOp)
71+
return parameterValToDouble(defOp, nameAnalysis);
72+
if (auto constOp = castOp.arg().getDefiningOp<mlir::arith::ConstantOp>()) {
73+
if (auto angleAttr = constOp.getValue().dyn_cast<mlir::quir::AngleAttr>())
74+
return angleAttr.getValue().convertToDouble();
75+
if (auto floatAttr = constOp.getValue().dyn_cast<mlir::FloatAttr>())
76+
return floatAttr.getValue().convertToDouble();
77+
errorStr = "unable to cast Angle from constant op";
78+
} else {
79+
errorStr = "unable to cast Angle from defining op";
80+
}
81+
} else {
82+
errorStr = "Non-constant angles are not supported!";
83+
}
84+
return llvm::createStringError(llvm::inconvertibleErrorCode(), errorStr);
85+
} // angleValToDouble
86+
87+
double QUIRCircuitAnalysis::getAngleValue(
88+
mlir::Value operand,
89+
mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis) {
90+
assert(nameAnalysis && "valid nameAnalysis pointer required");
91+
auto valueOrError = angleValToDouble(operand, nameAnalysis);
92+
if (auto err = valueOrError.takeError()) {
93+
operand.getDefiningOp()->emitOpError() << toString(std::move(err)) + "\n";
94+
assert(false && "unhandled value in angleValToDouble");
95+
}
96+
return *valueOrError;
97+
}
98+
99+
llvm::StringRef QUIRCircuitAnalysis::getParameterName(mlir::Value operand) {
100+
llvm::StringRef parameterName = {};
101+
qcs::ParameterLoadOp parameterLoad;
102+
parameterLoad = dyn_cast<qcs::ParameterLoadOp>(operand.getDefiningOp());
103+
104+
if (!parameterLoad) {
105+
auto castOp = dyn_cast<mlir::oq3::CastOp>(operand.getDefiningOp());
106+
if (castOp)
107+
parameterLoad =
108+
dyn_cast<qcs::ParameterLoadOp>(castOp.arg().getDefiningOp());
109+
}
110+
111+
if (parameterLoad &&
112+
parameterLoad->hasAttr(mlir::quir::getInputParameterAttrName())) {
113+
parameterName = parameterLoad->getAttrOfType<StringAttr>(
114+
mlir::quir::getInputParameterAttrName());
115+
}
116+
return parameterName;
117+
}
118+
119+
quir::DurationAttr QUIRCircuitAnalysis::getDuration(mlir::Value operand) {
120+
quir::DurationAttr duration;
121+
auto constantOp = dyn_cast<quir::ConstantOp>(operand.getDefiningOp());
122+
123+
if (constantOp)
124+
return constantOp.value().dyn_cast<DurationAttr>();
125+
return duration;
126+
}
127+
128+
QUIRCircuitAnalysis::QUIRCircuitAnalysis(mlir::Operation *moduleOp,
129+
AnalysisManager &am) {
130+
131+
if (not invalid_)
132+
return;
133+
134+
bool runGetAnalysis = true;
135+
136+
mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis;
137+
auto topLevelModuleOp = moduleOp->getParentOfType<ModuleOp>();
138+
if (topLevelModuleOp) {
139+
auto nameAnalysisOptional =
140+
am.getCachedParentAnalysis<mlir::qcs::ParameterInitialValueAnalysis>(
141+
moduleOp->getParentOfType<ModuleOp>());
142+
if (nameAnalysisOptional.hasValue()) {
143+
nameAnalysis = &nameAnalysisOptional.getValue().get();
144+
runGetAnalysis = false;
145+
}
146+
}
147+
148+
if (runGetAnalysis)
149+
nameAnalysis = &am.getAnalysis<mlir::qcs::ParameterInitialValueAnalysis>();
150+
151+
std::unordered_map<mlir::Operation *, std::map<llvm::StringRef, Operation *>>
152+
circuitOps;
153+
154+
moduleOp->walk([&](CircuitOp circuitOp) {
155+
circuitOps[circuitOp->getParentOfType<ModuleOp>()][circuitOp.sym_name()] =
156+
circuitOp.getOperation();
157+
});
158+
159+
moduleOp->walk([&](CallCircuitOp callCircuitOp) {
160+
auto search = circuitOps[callCircuitOp->getParentOfType<ModuleOp>()].find(
161+
callCircuitOp.calleeAttr().getValue());
162+
163+
if (search ==
164+
circuitOps[callCircuitOp->getParentOfType<ModuleOp>()].end()) {
165+
callCircuitOp->emitOpError("Could not find circuit.");
166+
return;
167+
}
168+
169+
auto circuitOp = dyn_cast<CircuitOp>(search->second);
170+
auto parentModuleOp = circuitOp->getParentOfType<ModuleOp>();
171+
172+
for (uint ii = 0; ii < callCircuitOp.operands().size(); ++ii) {
173+
174+
double value = 0;
175+
llvm::StringRef parameterName = {};
176+
quir::DurationAttr duration;
177+
178+
auto operand = callCircuitOp.operands()[ii];
179+
180+
// cache angle values and parameter names
181+
if (auto angType = operand.getType().dyn_cast<quir::AngleType>()) {
182+
183+
value = getAngleValue(operand, nameAnalysis);
184+
parameterName = getParameterName(operand);
185+
circuitOperands[parentModuleOp][circuitOp][ii] = {value, parameterName,
186+
duration};
187+
}
188+
189+
// cache durations
190+
if (auto durType = operand.getType().dyn_cast<quir::DurationType>()) {
191+
192+
duration = getDuration(operand);
193+
circuitOperands[parentModuleOp][circuitOp][ii] = {value, parameterName,
194+
duration};
195+
}
196+
}
197+
});
198+
invalid_ = false;
199+
}
200+
201+
void QUIRCircuitAnalysisPass::runOnOperation() {
202+
mlir::Pass::getAnalysis<QUIRCircuitAnalysis>();
203+
} // ParameterInitialValueAnalysisPass::runOnOperation()
204+
205+
llvm::StringRef QUIRCircuitAnalysisPass::getArgument() const {
206+
return "quir-circuit-analysis";
207+
}
208+
209+
llvm::StringRef QUIRCircuitAnalysisPass::getDescription() const {
210+
return "Analyze Circuit Inputs";
211+
}
212+
213+
} // namespace mlir::quir

0 commit comments

Comments
 (0)