|
| 1 | +//===- QUIRCircuitsAnalsysis.cpp - Cache values for circuits ----*- C++ -*-===// |
| 2 | +// |
| 3 | +// (C) Copyright IBM 2023. |
| 4 | +// |
| 5 | +// Any modifications or derivative works of this code must retain this |
| 6 | +// copyright notice, and modified files need to carry a notice indicating |
| 7 | +// that they have been altered from the originals. |
| 8 | +// |
| 9 | +//===----------------------------------------------------------------------===// |
| 10 | +/// |
| 11 | +/// This file implements a analysis for caching argument values with default |
| 12 | +/// values for angle and duration arguments. |
| 13 | +/// |
| 14 | +//===----------------------------------------------------------------------===// |
| 15 | + |
| 16 | +#include "Dialect/QUIR/Transforms/QUIRCircuitAnalysis.h" |
| 17 | + |
| 18 | +#include "Dialect/OQ3/IR/OQ3Ops.h" |
| 19 | +#include "Dialect/QCS/IR/QCSOps.h" |
| 20 | +#include "Dialect/QUIR/IR/QUIROps.h" |
| 21 | +#include "Dialect/QUIR/Utils/Utils.h" |
| 22 | +#include "mlir/Pass/AnalysisManager.h" |
| 23 | + |
| 24 | +#include "mlir/IR/Builders.h" |
| 25 | + |
| 26 | +using namespace mlir; |
| 27 | + |
| 28 | +namespace mlir::quir { |
| 29 | + |
| 30 | +double |
| 31 | +parameterValToDouble(mlir::qcs::ParameterLoadOp defOp, |
| 32 | + mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis) { |
| 33 | + assert(nameAnalysis && |
| 34 | + "A valid ParameterInitialValueAnalysis pointer is required"); |
| 35 | + return std::get<double>(defOp.getInitialValue(nameAnalysis->getNames())); |
| 36 | +} |
| 37 | + |
| 38 | +llvm::Expected<double> |
| 39 | +angleValToDouble(mlir::Value inVal, |
| 40 | + mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis, |
| 41 | + mlir::quir::QUIRCircuitAnalysis *circuitAnalysis) { |
| 42 | + |
| 43 | + llvm::StringRef errorStr; |
| 44 | + |
| 45 | + if (auto defOp = inVal.getDefiningOp<mlir::quir::ConstantOp>()) |
| 46 | + return defOp.getAngleValueFromConstant().convertToDouble(); |
| 47 | + |
| 48 | + if (auto defOp = inVal.getDefiningOp<mlir::qcs::ParameterLoadOp>()) |
| 49 | + return parameterValToDouble(defOp, nameAnalysis); |
| 50 | + |
| 51 | + if (auto blockArg = inVal.dyn_cast<mlir::BlockArgument>()) { |
| 52 | + auto circuitOp = mlir::dyn_cast<mlir::quir::CircuitOp>( |
| 53 | + inVal.getParentBlock()->getParentOp()); |
| 54 | + assert(circuitOp && "can only handle circuit arguments"); |
| 55 | + |
| 56 | + auto argNum = blockArg.getArgNumber(); |
| 57 | + if (circuitAnalysis == nullptr) { |
| 58 | + |
| 59 | + auto argAttr = circuitOp.getArgAttrOfType<mlir::quir::AngleAttr>( |
| 60 | + argNum, mlir::quir::getAngleAttrName()); |
| 61 | + return argAttr.getValue().convertToDouble(); |
| 62 | + } |
| 63 | + auto parentModuleOp = circuitOp->getParentOfType<mlir::ModuleOp>(); |
| 64 | + return std::get<QUIRCircuitAnalysisEntry::ANGLE>( |
| 65 | + circuitAnalysis->getAnalysisMap()[parentModuleOp][circuitOp][argNum]); |
| 66 | + } |
| 67 | + |
| 68 | + if (auto castOp = inVal.getDefiningOp<mlir::oq3::CastOp>()) { |
| 69 | + auto defOp = castOp.arg().getDefiningOp<mlir::qcs::ParameterLoadOp>(); |
| 70 | + if (defOp) |
| 71 | + return parameterValToDouble(defOp, nameAnalysis); |
| 72 | + if (auto constOp = castOp.arg().getDefiningOp<mlir::arith::ConstantOp>()) { |
| 73 | + if (auto angleAttr = constOp.getValue().dyn_cast<mlir::quir::AngleAttr>()) |
| 74 | + return angleAttr.getValue().convertToDouble(); |
| 75 | + if (auto floatAttr = constOp.getValue().dyn_cast<mlir::FloatAttr>()) |
| 76 | + return floatAttr.getValue().convertToDouble(); |
| 77 | + errorStr = "unable to cast Angle from constant op"; |
| 78 | + } else { |
| 79 | + errorStr = "unable to cast Angle from defining op"; |
| 80 | + } |
| 81 | + } else { |
| 82 | + errorStr = "Non-constant angles are not supported!"; |
| 83 | + } |
| 84 | + return llvm::createStringError(llvm::inconvertibleErrorCode(), errorStr); |
| 85 | +} // angleValToDouble |
| 86 | + |
| 87 | +double QUIRCircuitAnalysis::getAngleValue( |
| 88 | + mlir::Value operand, |
| 89 | + mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis) { |
| 90 | + assert(nameAnalysis && "valid nameAnalysis pointer required"); |
| 91 | + auto valueOrError = angleValToDouble(operand, nameAnalysis); |
| 92 | + if (auto err = valueOrError.takeError()) { |
| 93 | + operand.getDefiningOp()->emitOpError() << toString(std::move(err)) + "\n"; |
| 94 | + assert(false && "unhandled value in angleValToDouble"); |
| 95 | + } |
| 96 | + return *valueOrError; |
| 97 | +} |
| 98 | + |
| 99 | +llvm::StringRef QUIRCircuitAnalysis::getParameterName(mlir::Value operand) { |
| 100 | + llvm::StringRef parameterName = {}; |
| 101 | + qcs::ParameterLoadOp parameterLoad; |
| 102 | + parameterLoad = dyn_cast<qcs::ParameterLoadOp>(operand.getDefiningOp()); |
| 103 | + |
| 104 | + if (!parameterLoad) { |
| 105 | + auto castOp = dyn_cast<mlir::oq3::CastOp>(operand.getDefiningOp()); |
| 106 | + if (castOp) |
| 107 | + parameterLoad = |
| 108 | + dyn_cast<qcs::ParameterLoadOp>(castOp.arg().getDefiningOp()); |
| 109 | + } |
| 110 | + |
| 111 | + if (parameterLoad && |
| 112 | + parameterLoad->hasAttr(mlir::quir::getInputParameterAttrName())) { |
| 113 | + parameterName = parameterLoad->getAttrOfType<StringAttr>( |
| 114 | + mlir::quir::getInputParameterAttrName()); |
| 115 | + } |
| 116 | + return parameterName; |
| 117 | +} |
| 118 | + |
| 119 | +quir::DurationAttr QUIRCircuitAnalysis::getDuration(mlir::Value operand) { |
| 120 | + quir::DurationAttr duration; |
| 121 | + auto constantOp = dyn_cast<quir::ConstantOp>(operand.getDefiningOp()); |
| 122 | + |
| 123 | + if (constantOp) |
| 124 | + return constantOp.value().dyn_cast<DurationAttr>(); |
| 125 | + return duration; |
| 126 | +} |
| 127 | + |
| 128 | +QUIRCircuitAnalysis::QUIRCircuitAnalysis(mlir::Operation *moduleOp, |
| 129 | + AnalysisManager &am) { |
| 130 | + |
| 131 | + if (not invalid_) |
| 132 | + return; |
| 133 | + |
| 134 | + bool runGetAnalysis = true; |
| 135 | + |
| 136 | + mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis; |
| 137 | + auto topLevelModuleOp = moduleOp->getParentOfType<ModuleOp>(); |
| 138 | + if (topLevelModuleOp) { |
| 139 | + auto nameAnalysisOptional = |
| 140 | + am.getCachedParentAnalysis<mlir::qcs::ParameterInitialValueAnalysis>( |
| 141 | + moduleOp->getParentOfType<ModuleOp>()); |
| 142 | + if (nameAnalysisOptional.hasValue()) { |
| 143 | + nameAnalysis = &nameAnalysisOptional.getValue().get(); |
| 144 | + runGetAnalysis = false; |
| 145 | + } |
| 146 | + } |
| 147 | + |
| 148 | + if (runGetAnalysis) |
| 149 | + nameAnalysis = &am.getAnalysis<mlir::qcs::ParameterInitialValueAnalysis>(); |
| 150 | + |
| 151 | + std::unordered_map<mlir::Operation *, std::map<llvm::StringRef, Operation *>> |
| 152 | + circuitOps; |
| 153 | + |
| 154 | + moduleOp->walk([&](CircuitOp circuitOp) { |
| 155 | + circuitOps[circuitOp->getParentOfType<ModuleOp>()][circuitOp.sym_name()] = |
| 156 | + circuitOp.getOperation(); |
| 157 | + }); |
| 158 | + |
| 159 | + moduleOp->walk([&](CallCircuitOp callCircuitOp) { |
| 160 | + auto search = circuitOps[callCircuitOp->getParentOfType<ModuleOp>()].find( |
| 161 | + callCircuitOp.calleeAttr().getValue()); |
| 162 | + |
| 163 | + if (search == |
| 164 | + circuitOps[callCircuitOp->getParentOfType<ModuleOp>()].end()) { |
| 165 | + callCircuitOp->emitOpError("Could not find circuit."); |
| 166 | + return; |
| 167 | + } |
| 168 | + |
| 169 | + auto circuitOp = dyn_cast<CircuitOp>(search->second); |
| 170 | + auto parentModuleOp = circuitOp->getParentOfType<ModuleOp>(); |
| 171 | + |
| 172 | + for (uint ii = 0; ii < callCircuitOp.operands().size(); ++ii) { |
| 173 | + |
| 174 | + double value = 0; |
| 175 | + llvm::StringRef parameterName = {}; |
| 176 | + quir::DurationAttr duration; |
| 177 | + |
| 178 | + auto operand = callCircuitOp.operands()[ii]; |
| 179 | + |
| 180 | + // cache angle values and parameter names |
| 181 | + if (auto angType = operand.getType().dyn_cast<quir::AngleType>()) { |
| 182 | + |
| 183 | + value = getAngleValue(operand, nameAnalysis); |
| 184 | + parameterName = getParameterName(operand); |
| 185 | + circuitOperands[parentModuleOp][circuitOp][ii] = {value, parameterName, |
| 186 | + duration}; |
| 187 | + } |
| 188 | + |
| 189 | + // cache durations |
| 190 | + if (auto durType = operand.getType().dyn_cast<quir::DurationType>()) { |
| 191 | + |
| 192 | + duration = getDuration(operand); |
| 193 | + circuitOperands[parentModuleOp][circuitOp][ii] = {value, parameterName, |
| 194 | + duration}; |
| 195 | + } |
| 196 | + } |
| 197 | + }); |
| 198 | + invalid_ = false; |
| 199 | +} |
| 200 | + |
| 201 | +void QUIRCircuitAnalysisPass::runOnOperation() { |
| 202 | + mlir::Pass::getAnalysis<QUIRCircuitAnalysis>(); |
| 203 | +} // ParameterInitialValueAnalysisPass::runOnOperation() |
| 204 | + |
| 205 | +llvm::StringRef QUIRCircuitAnalysisPass::getArgument() const { |
| 206 | + return "quir-circuit-analysis"; |
| 207 | +} |
| 208 | + |
| 209 | +llvm::StringRef QUIRCircuitAnalysisPass::getDescription() const { |
| 210 | + return "Analyze Circuit Inputs"; |
| 211 | +} |
| 212 | + |
| 213 | +} // namespace mlir::quir |
0 commit comments