Skip to content

Commit 3c6821a

Browse files
* Retain variable names assigned to measurement results
* Disallow measure_result as argument to entry-point kernels * WIP Signed-off-by: Pradnya Khalate <pkhalate@nvidia.com>
1 parent 9135c07 commit 3c6821a

File tree

18 files changed

+255
-100
lines changed

18 files changed

+255
-100
lines changed

include/cudaq/Optimizer/Builder/Intrinsics.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ static constexpr const char llvmStackRestore[] = "llvm.stackrestore";
7272

7373
static constexpr const char cudaqConvertToInteger[] =
7474
"__nvqpp_cudaqConvertToInteger";
75+
76+
static constexpr const char cudaqConvertToBoolVector[] =
77+
"__nvqpp_cudaq_to_bool_vector";
78+
7579
/// Builder for lowering the clang AST to an IR for CUDA-Q. Lowering includes
7680
/// the transformation of both quantum and classical computation. Different
7781
/// features of the CUDA-Q programming model are lowered into different dialects

lib/Frontend/nvqpp/ASTBridge.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ static bool hasAnyQubitTypes(FunctionType funcTy) {
7373
return false;
7474
}
7575

76+
// Check the builtin type `FunctionType` to see if it has any `MeasureType`
77+
// arguments.
78+
static bool hasMeasureResultArgs(FunctionType funcTy) {
79+
for (auto ty : funcTy.getInputs()) {
80+
if (isa<quake::MeasureType>(ty))
81+
return true;
82+
if (auto vecTy = dyn_cast<cudaq::cc::StdvecType>(ty))
83+
if (isa<quake::MeasureType>(vecTy.getElementType()))
84+
return true;
85+
}
86+
return false;
87+
}
88+
7689
// Remove the Itanium mangling "_ZTS" prefix. This is to match the name returned
7790
// by `typeid(TYPE).name()`.
7891
static std::string
@@ -640,6 +653,7 @@ void ASTBridgeAction::ASTBridgeConsumer::HandleTranslationUnit(
640653
// Flag func as a quantum kernel.
641654
func->setAttr(kernelAttrName, unitAttr);
642655
if ((!hasAnyQubitTypes(func.getFunctionType())) &&
656+
(!hasMeasureResultArgs(func.getFunctionType())) &&
643657
(!cudaq::ASTBridgeAction::ASTBridgeConsumer::isCustomOpGenerator(
644658
fdPair.second))) {
645659
// Flag func as an entry point to a quantum kernel.

lib/Frontend/nvqpp/ConvertDecl.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ bool QuakeBridgeVisitor::interceptRecordDecl(clang::RecordDecl *x) {
169169
auto fnTy = cast<FunctionType>(popType());
170170
return pushType(cc::IndirectCallableType::get(fnTy));
171171
}
172+
// Measurement result type.
173+
if (name == "measure_result")
174+
return pushType(quake::MeasureType::get(ctx));
172175
if (!isInNamespace(x, "solvers") && !isInNamespace(x, "qec")) {
173176
auto loc = toLocation(x);
174177
TODO_loc(loc, "unhandled type, " + name + ", in cudaq namespace");
@@ -744,6 +747,11 @@ bool QuakeBridgeVisitor::VisitVarDecl(clang::VarDecl *x) {
744747
// and if so, find the mz and tag it with the variable name
745748
auto elementType = vecType.getElementType();
746749

750+
if (auto meas = initVec.getDefiningOp<quake::MeasurementInterface>()) {
751+
meas.setRegisterName(builder.getStringAttr(x->getName()));
752+
return true;
753+
}
754+
747755
// Drop out if this is not an i1
748756
if (!elementType.isIntOrFloat() ||
749757
elementType.getIntOrFloatBitWidth() != 1)
@@ -817,9 +825,8 @@ bool QuakeBridgeVisitor::VisitVarDecl(clang::VarDecl *x) {
817825

818826
// If this was an auto var = mz(q), then we want to know the
819827
// var name, as it will serve as the classical bit register name
820-
if (auto discr = initValue.getDefiningOp<quake::DiscriminateOp>())
821-
if (auto mz = discr.getMeasurement().getDefiningOp<quake::MzOp>())
822-
mz.setRegisterName(builder.getStringAttr(x->getName()));
828+
if (auto meas = initValue.getDefiningOp<quake::MeasurementInterface>())
829+
meas.setRegisterName(builder.getStringAttr(x->getName()));
823830

824831
assert(initValue && "initializer value must be lowered");
825832
if (isa<IntegerType>(initValue.getType()) && isa<IntegerType>(type)) {

lib/Frontend/nvqpp/ConvertExpr.cpp

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -658,14 +658,34 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
658658
}
659659
case clang::CastKind::CK_UserDefinedConversion: {
660660
auto sub = popValue();
661-
// castToTy is the converion function signature.
661+
// castToTy is the conversion function signature.
662662
castToTy = popType();
663663
if (isa<IntegerType>(castToTy) && isa<IntegerType>(sub.getType())) {
664664
auto locSub = toLocation(x->getSubExpr());
665665
bool result = intToIntCast(locSub, sub);
666666
assert(result && "integer conversion failed");
667667
return result;
668668
}
669+
auto i1Type = builder.getI1Type();
670+
671+
// Handle conversion of `measure_result` to `bool`.
672+
if (isa<quake::MeasureType>(sub.getType()))
673+
return pushValue(builder.create<quake::DiscriminateOp>(loc, i1Type, sub));
674+
675+
// Handle conversion of `std::vector<measure_result>` to `std::vector<bool>`
676+
if (auto vecTy = dyn_cast<cc::StdvecType>(sub.getType()))
677+
if (isa<quake::MeasureType>(vecTy.getElementType()))
678+
return pushValue(builder.create<quake::DiscriminateOp>(
679+
loc, cc::StdvecType::get(i1Type), sub));
680+
681+
// Handle pointer to `measure_result` (from vector element access)
682+
if (auto ptrTy = dyn_cast<cc::PointerType>(sub.getType()))
683+
if (isa<quake::MeasureType>(ptrTy.getElementType())) {
684+
auto loaded = builder.create<cc::LoadOp>(loc, sub);
685+
return pushValue(
686+
builder.create<quake::DiscriminateOp>(loc, i1Type, loaded));
687+
}
688+
669689
TODO_loc(loc, "unhandled user-defined implicit conversion");
670690
}
671691
case clang::CastKind::CK_ConstructorConversion: {
@@ -1015,7 +1035,7 @@ bool QuakeBridgeVisitor::VisitMaterializeTemporaryExpr(
10151035
// In those cases, there is nothing to materialize, so we can just pass the
10161036
// Value on the top of the stack.
10171037
if (isa<cc::CallableType, quake::VeqType, quake::RefType, cc::SpanLikeType,
1018-
quake::StateType>(ty))
1038+
quake::StateType, quake::MeasureType>(ty))
10191039
return true;
10201040

10211041
// If not one of the above special cases, then materialize the value to a
@@ -1520,7 +1540,13 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
15201540
auto funcArity = func->getNumParams();
15211541
SmallVector<Value> args = lastValues(funcArity);
15221542
if (isa<clang::CXXMethodDecl>(func)) {
1523-
[[maybe_unused]] auto thisPtrValue = popValue();
1543+
auto thisPtrValue = popValue();
1544+
1545+
// Handle `measure_result` conversion operators
1546+
if (isa<clang::CXXConversionDecl>(func) &&
1547+
isInClassInNamespace(func, "measure_result", "cudaq")) {
1548+
return pushValue(thisPtrValue);
1549+
}
15241550
}
15251551
auto calleeOp = popValue();
15261552

@@ -1646,7 +1672,6 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
16461672
}
16471673

16481674
if (funcName == "mx" || funcName == "my" || funcName == "mz") {
1649-
// Measurements always return a bool or a std::vector<bool>.
16501675
bool useStdvec =
16511676
(args.size() > 1) ||
16521677
(args.size() == 1 && isa<quake::VeqType>(args[0].getType()));
@@ -1660,11 +1685,8 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
16601685
return builder.create<quake::MyOp>(loc, measTy, args).getMeasOut();
16611686
return builder.create<quake::MzOp>(loc, measTy, args).getMeasOut();
16621687
}();
1663-
Type resTy = builder.getI1Type();
1664-
if (useStdvec)
1665-
resTy = cc::StdvecType::get(resTy);
1666-
return pushValue(
1667-
builder.create<quake::DiscriminateOp>(loc, resTy, measure));
1688+
// No more discrimination needed, just return the measurement.
1689+
return pushValue(measure);
16681690
}
16691691

16701692
// Handle the quantum gate set.
@@ -2136,6 +2158,28 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
21362158
.getResult(0));
21372159
}
21382160

2161+
if (funcName == "to_bool_vector") {
2162+
// args[0] is !cc.stdvec<!quake.measure> from mz()
2163+
auto arg = args[0];
2164+
// Insert discriminate if needed
2165+
if (auto vecTy = dyn_cast<cc::StdvecType>(arg.getType())) {
2166+
if (isa<quake::MeasureType>(vecTy.getElementType())) {
2167+
auto i1Ty = builder.getI1Type();
2168+
arg = builder.create<quake::DiscriminateOp>(
2169+
loc, cc::StdvecType::get(i1Ty), arg);
2170+
}
2171+
}
2172+
IRBuilder irBuilder(builder.getContext());
2173+
if (failed(irBuilder.loadIntrinsic(module, cudaqConvertToBoolVector))) {
2174+
reportClangError(x, mangler, "cannot load cudaqConvertToBoolVector");
2175+
return false;
2176+
}
2177+
return pushValue(builder
2178+
.create<func::CallOp>(loc, arg.getType(),
2179+
cudaqConvertToBoolVector, arg)
2180+
.getResult(0));
2181+
}
2182+
21392183
if (funcName == "slice_vector") {
21402184
auto svecTy = dyn_cast<cc::SpanLikeType>(args[0].getType());
21412185
auto eleTy = svecTy.getElementType();
@@ -3217,6 +3261,18 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
32173261
return pushValue(builder.create<cc::LoadOp>(loc, copyObj));
32183262
}
32193263

3264+
// Handle `measure_result` copy constructor
3265+
if (ctor->isCopyConstructor() &&
3266+
isInClassInNamespace(ctor, "measure_result", "cudaq")) {
3267+
// The source is a pointer to measure_result (!cc.ptr<!quake.measure>)
3268+
// Just load and return the value
3269+
assert(x->getNumArgs() == 1);
3270+
auto srcPtr = popValue();
3271+
if (auto ptrTy = dyn_cast<cc::PointerType>(srcPtr.getType()))
3272+
if (isa<quake::MeasureType>(ptrTy.getElementType()))
3273+
return pushValue(builder.create<cc::LoadOp>(loc, srcPtr));
3274+
}
3275+
32203276
// TODO: remove this when we can handle ctors more generally.
32213277
if (!ctor->isDefaultConstructor()) {
32223278
LLVM_DEBUG(llvm::dbgs() << ctorName << " - unhandled ctor:\n"; x->dump());

lib/Frontend/nvqpp/ConvertStmt.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,12 @@ bool QuakeBridgeVisitor::VisitReturnStmt(clang::ReturnStmt *x) {
361361
Value tySize;
362362
if (!cudaq::cc::isDynamicType(eleTy))
363363
tySize = irb.getByteSizeOfType(loc, eleTy);
364+
if (isa<quake::MeasureType>(eleTy)) {
365+
/// FIXME: Confirm that this is okay.
366+
tySize = irb.getByteSizeOfType(loc, builder.getI32Type());
367+
} else {
368+
tySize = irb.getByteSizeOfType(loc, eleTy);
369+
}
364370
if (!tySize) {
365371
// TODO: we need to recursively create copies of all
366372
// dynamic memory used within the type. See the

lib/Frontend/nvqpp/ConvertType.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,24 @@ static bool isFunctionCallable(Type t) {
124124
return false;
125125
}
126126

127+
static bool isMeasureResultType(Type t) { return isa<quake::MeasureType>(t); }
128+
129+
static bool isMeasureResultSequenceType(Type t) {
130+
if (auto vec = dyn_cast<cudaq::cc::SpanLikeType>(t)) {
131+
auto eleTy = vec.getElementType();
132+
return isMeasureResultType(eleTy) || isMeasureResultSequenceType(eleTy);
133+
}
134+
return isMeasureResultType(t);
135+
}
136+
127137
/// Return true if and only if \p t is a (simple) arithmetic type, an arithmetic
128138
/// sequence type (possibly dynamic in length), or a static product type of
129139
/// arithmetic types. Note that this means a product type with a dynamic
130140
/// sequence of arithmetic types is \em disallowed.
131141
static bool isKernelResultType(Type t) {
132142
return isArithmeticType(t) || isArithmeticSequenceType(t) ||
133-
isStaticArithmeticProductType(t);
143+
isStaticArithmeticProductType(t) || isMeasureResultType(t) ||
144+
isMeasureResultSequenceType(t);
134145
}
135146

136147
/// Return true if and only if \p t is a (simple) arithmetic type, an possibly

lib/Optimizer/Builder/Intrinsics.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,13 @@ static constexpr IntrinsicCode intrinsicTable[] = {
337337
func.func private @__nvqpp_cudaq_state_numberOfQubits(%p : !cc.ptr<!quake.state>) -> i64
338338
)#"},
339339

340+
{cudaq::cudaqConvertToBoolVector, {}, R"#(
341+
func.func private @__nvqpp_cudaq_to_bool_vector(%arg : !cc.stdvec<i1>) -> !cc.stdvec<i1> {
342+
// TODO: Implement
343+
return %arg : !cc.stdvec<i1>
344+
}
345+
)#"},
346+
340347
{cudaq::runtime::bindingDeconstructString,
341348
{},
342349
"func.func private @__nvqpp_deconstructString(!cc.ptr<i8>)"},

lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,8 +579,9 @@ struct DiscriminateOpRewrite
579579
ConversionPatternRewriter &rewriter) const override {
580580
auto loc = disc.getLoc();
581581
Value m = adaptor.getMeasurement();
582-
auto i1PtrTy = cudaq::cc::PointerType::get(rewriter.getI1Type());
583-
auto cast = rewriter.create<cudaq::cc::CastOp>(loc, i1PtrTy, m);
582+
auto resultTy = typeConverter->convertType(disc.getResult().getType());
583+
auto ptrTy = cudaq::cc::PointerType::get(resultTy);
584+
auto cast = rewriter.create<cudaq::cc::CastOp>(loc, ptrTy, m);
584585
rewriter.replaceOpWithNewOp<cudaq::cc::LoadOp>(disc, cast);
585586
return success();
586587
}

lib/Optimizer/Dialect/CC/CCOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include "mlir/IR/OpImplementation.h"
2121
#include "mlir/IR/PatternMatch.h"
2222
#include "mlir/IR/TypeUtilities.h"
23+
/// FIXME: This seems wrong to add in `cc` dialect!
24+
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
2325

2426
using namespace mlir;
2527

@@ -77,6 +79,10 @@ Value cudaq::cc::getByteSizeOfType(OpBuilder &builder, Location loc, Type ty,
7779
// we're assuming pointers are 64 bits.
7880
return {8};
7981
})
82+
.Case([](quake::MeasureType) -> std::optional<std::int32_t> {
83+
/// FIXME: What should this be?
84+
return {4};
85+
})
8086
.Default({});
8187

8288
if (rawSize)

runtime/cudaq/qis/measure_result.h

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include <optional>
1112
#include <vector>
1213

1314
namespace cudaq {
@@ -26,24 +27,14 @@ class measure_result {
2627
int result = 0;
2728

2829
/// Unique integer for measure result identification
29-
std::size_t uniqueId = 0;
30+
std::optional<std::size_t> uniqueId = std::nullopt;
3031

3132
public:
32-
measure_result(int res, std::size_t id) : result(res), uniqueId(id) {}
3333
measure_result(int res) : result(res) {}
34+
measure_result(int res, std::size_t id) : result(res), uniqueId(id) {}
3435

3536
operator int() const { return result; }
3637
operator bool() const { return __nvqpp__MeasureResultBoolConversion(result); }
37-
38-
/// TODO: This needs to be revisited to support MLIR mode properly.
39-
static std::vector<bool>
40-
to_bool_vector(const std::vector<measure_result> &results) {
41-
std::vector<bool> boolResults;
42-
boolResults.reserve(results.size());
43-
for (const auto &res : results)
44-
boolResults.push_back(static_cast<bool>(res));
45-
return boolResults;
46-
}
4738
};
4839

4940
} // namespace cudaq

0 commit comments

Comments
 (0)