-
Notifications
You must be signed in to change notification settings - Fork 355
Expand file tree
/
Copy pathAddMeasurements.cpp
More file actions
146 lines (124 loc) · 5.42 KB
/
AddMeasurements.cpp
File metadata and controls
146 lines (124 loc) · 5.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/*******************************************************************************
* Copyright (c) 2025 - 2026 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/
#include "PassDetails.h"
#include "cudaq/Frontend/nvqpp/AttributeNames.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
namespace cudaq::opt {
#define GEN_PASS_DEF_ADDMEASUREMENTS
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
} // namespace cudaq::opt
#define DEBUG_TYPE "add-measurements"
using namespace mlir;
namespace {
/// Analysis class that examines a function to determine whether it contains
/// measurement operations and collects all qubit allocations. Also, gather all
/// the returns for redirection
struct Analysis {
Analysis() = default;
explicit Analysis(func::FuncOp func) {
func.walk([&](Operation *op) {
if (op->hasTrait<cudaq::QuantumMeasure>()) {
hasMeasurement = true;
return WalkResult::interrupt();
}
if (isa<quake::AllocaOp>(op))
allocations.emplace_back(op);
else if (isa<func::ReturnOp>(op))
returns.emplace_back(op);
return WalkResult::advance();
});
}
bool hasMeasurement = false;
SmallVector<quake::AllocaOp> allocations;
SmallVector<func::ReturnOp> returns;
bool hasQubitAlloca() const { return !allocations.empty(); }
};
/// Add measurement operations for all allocated qubits in a function.
/// This transformation creates a new block at the end of the function,
/// redirects all return operations to branch to this block, adds `quake.mz`
/// measurement operations for each qubit allocation, and adds a final return.
/// For vector allocations, the measurements are collected into a vector of
/// measurement results.
LogicalResult
addMeasurements(func::FuncOp funcOp, SmallVector<quake::AllocaOp> &allocations,
const SmallVector<func::ReturnOp> &returnsToReplace) {
auto loc = funcOp.getLoc();
auto ctx = funcOp.getContext();
OpBuilder builder(ctx);
// Create a new block at the end of the function.
Block *newBlock = funcOp.addBlock();
// Add block arguments for return values if the function returns anything
ArrayRef<Type> returnTypes = funcOp.getFunctionType().getResults();
if (!returnTypes.empty()) {
SmallVector<Location> argLocs(returnTypes.size(), loc);
newBlock->addArguments(returnTypes, argLocs);
}
// Replace every func.return in the function with a branch to the new block.
for (auto returnOp : returnsToReplace) {
OpBuilder builder(returnOp);
builder.create<cf::BranchOp>(returnOp.getLoc(), newBlock,
returnOp.getOperands());
returnOp.erase();
}
// Set insertion point to the new block and add measurements
builder.setInsertionPointToEnd(newBlock);
auto measTy = quake::MeasureType::get(builder.getContext());
for (auto &[index, alloca] : llvm::enumerate(allocations)) {
if (auto veqTy = dyn_cast<quake::VeqType>(alloca.getType())) {
Type measurementsTy;
if (veqTy.hasSpecifiedSize())
measurementsTy =
quake::MeasurementsType::get(builder.getContext(), veqTy.getSize());
else
measurementsTy =
quake::MeasurementsType::getUnsized(builder.getContext());
builder.create<quake::MzOp>(loc, measurementsTy,
ValueRange{alloca.getResult()});
} else {
builder.create<quake::MzOp>(loc, measTy, alloca.getResult());
}
}
// Add the final return using block arguments
builder.create<func::ReturnOp>(loc, newBlock->getArguments());
return success();
}
struct AddMeasurementsPass
: public cudaq::opt::impl::AddMeasurementsBase<AddMeasurementsPass> {
using AddMeasurementsBase::AddMeasurementsBase;
void runOnOperation() override {
func::FuncOp func = getOperation();
if (!func || func.empty())
return;
if (!func->hasAttr(cudaq::entryPointAttrName))
return;
/// NOTE: Having a conditional on a measurement indicates that a measurement
/// is present, however, it does not guarantee that all the allocated qubits
/// are measured.
if (auto boolAttr = func->getAttr("qubitMeasurementFeedback")
.dyn_cast_or_null<mlir::BoolAttr>()) {
if (boolAttr.getValue())
return;
}
// Check if the function has any measurement operations, if yes, we don't do
// anything. If not, then check if the function has any qubit allocations,
// if yes, then we want to add measurements to it.
/// NOTE: Having an explicit measurement does not guarantee that all the
/// allocated qubits are measured.
Analysis analysis(func);
if (analysis.hasMeasurement || !analysis.hasQubitAlloca())
return;
LLVM_DEBUG(llvm::dbgs() << "Before adding measurements:\n" << *func);
if (failed(addMeasurements(func, analysis.allocations, analysis.returns)))
signalPassFailure();
LLVM_DEBUG(llvm::dbgs() << "After adding measurements:\n" << *func);
}
};
} // namespace