Skip to content

Commit 44f2dab

Browse files
* Retain variable names assigned to measurement results
* Disallow measure_result as argument to entry-point kernels Signed-off-by: Pradnya Khalate <[email protected]>
1 parent 873f99c commit 44f2dab

File tree

3 files changed

+86
-42
lines changed

3 files changed

+86
-42
lines changed

lib/Frontend/nvqpp/ASTBridge.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ static bool hasAnyQubitTypes(FunctionType funcTy) {
7373
return false;
7474
}
7575

76+
// Check the builtin type `FunctionType` to see if it has any `MeasureType`
77+
// arguments.
78+
static bool hasMeasureResultArgs(FunctionType funcTy) {
79+
for (auto ty : funcTy.getInputs()) {
80+
if (isa<quake::MeasureType>(ty))
81+
return true;
82+
if (auto vecTy = dyn_cast<cudaq::cc::StdvecType>(ty))
83+
if (isa<quake::MeasureType>(vecTy.getElementType()))
84+
return true;
85+
}
86+
return false;
87+
}
88+
7689
// Remove the Itanium mangling "_ZTS" prefix. This is to match the name returned
7790
// by `typeid(TYPE).name()`.
7891
static std::string
@@ -640,6 +653,7 @@ void ASTBridgeAction::ASTBridgeConsumer::HandleTranslationUnit(
640653
// Flag func as a quantum kernel.
641654
func->setAttr(kernelAttrName, unitAttr);
642655
if ((!hasAnyQubitTypes(func.getFunctionType())) &&
656+
(!hasMeasureResultArgs(func.getFunctionType())) &&
643657
(!cudaq::ASTBridgeAction::ASTBridgeConsumer::isCustomOpGenerator(
644658
fdPair.second))) {
645659
// Flag func as an entry point to a quantum kernel.

lib/Frontend/nvqpp/ConvertDecl.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,11 @@ bool QuakeBridgeVisitor::VisitVarDecl(clang::VarDecl *x) {
747747
// and if so, find the mz and tag it with the variable name
748748
auto elementType = vecType.getElementType();
749749

750+
if (auto meas = initVec.getDefiningOp<quake::MeasurementInterface>()) {
751+
meas.setRegisterName(builder.getStringAttr(x->getName()));
752+
return true;
753+
}
754+
750755
// Drop out if this is not an i1
751756
if (!elementType.isIntOrFloat() ||
752757
elementType.getIntOrFloatBitWidth() != 1)
@@ -820,8 +825,8 @@ bool QuakeBridgeVisitor::VisitVarDecl(clang::VarDecl *x) {
820825

821826
// If this was an auto var = mz(q), then we want to know the
822827
// var name, as it will serve as the classical bit register name
823-
if (auto mz = initValue.getDefiningOp<quake::MzOp>())
824-
mz.setRegisterName(builder.getStringAttr(x->getName()));
828+
if (auto meas = initValue.getDefiningOp<quake::MeasurementInterface>())
829+
meas.setRegisterName(builder.getStringAttr(x->getName()));
825830

826831
assert(initValue && "initializer value must be lowered");
827832
if (isa<IntegerType>(initValue.getType()) && isa<IntegerType>(type)) {

test/AST-Quake/mz.cpp

Lines changed: 65 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ struct S {
1818
};
1919

2020
// clang-format off
21-
// CHECK-LABEL: func.func @__nvqpp__mlirgen__S() attributes
22-
// CHECK: %[[VAL_2:.*]] = quake.alloca !quake.veq<20>
23-
// CHECK: quake.mz %[[VAL_2]] : (!quake.veq<20>) -> !cc.stdvec<!quake.measure>
21+
// CHECK-LABEL: func.func @__nvqpp__mlirgen__S() attributes {"cudaq-entrypoint", "cudaq-kernel"} {
22+
// CHECK: %[[VAL_0:.*]] = quake.alloca !quake.veq<20>
23+
// CHECK: %[[VAL_1:.*]] = quake.mz %[[VAL_0]] : (!quake.veq<20>) -> !cc.stdvec<!quake.measure>
2424
// CHECK: return
2525
// CHECK: }
2626
// clang-format on
2727

2828
struct VectorOfStaticVeq {
29-
std::vector<bool> operator()() __qpu__ {
29+
std::vector<cudaq::measure_result> operator()() __qpu__ {
3030
cudaq::qubit q1;
3131
cudaq::qvector reg1(4);
3232
cudaq::qvector reg2(2);
@@ -35,23 +35,49 @@ struct VectorOfStaticVeq {
3535
}
3636
};
3737

38-
// CHECK-LABEL: func.func @__nvqpp__mlirgen__VectorOfStaticVeq() -> !cc.stdvec<i1> attributes {
39-
// CHECK: %[[VAL_11:.*]] = arith.constant 1 : i64
40-
// CHECK: %[[VAL_0:.*]] = quake.alloca !quake.ref
41-
// CHECK: %[[VAL_3:.*]] = quake.alloca !quake.veq<4>
42-
// CHECK: %[[VAL_6:.*]] = quake.alloca !quake.veq<2>
43-
// CHECK: %[[VAL_7:.*]] = quake.alloca !quake.ref
44-
// CHECK: %[[VAL_81:.*]] = quake.mz %[[VAL_0]], %[[VAL_3]], %[[VAL_6]], %[[VAL_7]] : (!quake.ref, !quake.veq<4>, !quake.veq<2>, !quake.ref) -> !cc.stdvec<!quake.measure>
45-
// CHECK: %[[VAL_8:.*]] = quake.discriminate %[[VAL_81]] :
46-
// CHECK: %[[VAL_9:.*]] = cc.stdvec_data %[[VAL_8]] : (!cc.stdvec<i1>) -> !cc.ptr<i8>
47-
// CHECK: %[[VAL_10:.*]] = cc.stdvec_size %[[VAL_8]] : (!cc.stdvec<i1>) -> i64
48-
// CHECK: %[[VAL_12:.*]] = call @__nvqpp_vectorCopyCtor(%[[VAL_9]], %[[VAL_10]], %[[VAL_11]]) : (!cc.ptr<i8>, i64, i64) -> !cc.ptr<i8>
49-
// CHECK: %[[VAL_13:.*]] = cc.stdvec_init %[[VAL_12]], %[[VAL_10]] : (!cc.ptr<i8>, i64) -> !cc.stdvec<i1>
50-
// CHECK: return %[[VAL_13]] : !cc.stdvec<i1>
38+
// CHECK-LABEL: func.func @__nvqpp__mlirgen__VectorOfStaticVeq() -> !cc.stdvec<!quake.measure> attributes {"cudaq-kernel"} {
39+
// CHECK: %[[VAL_0:.*]] = arith.constant 4 : i64
40+
// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.ref
41+
// CHECK: %[[VAL_2:.*]] = quake.alloca !quake.veq<4>
42+
// CHECK: %[[VAL_3:.*]] = quake.alloca !quake.veq<2>
43+
// CHECK: %[[VAL_4:.*]] = quake.alloca !quake.ref
44+
// CHECK: %[[VAL_5:.*]] = quake.mz %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : (!quake.ref, !quake.veq<4>, !quake.veq<2>, !quake.ref) -> !cc.stdvec<!quake.measure>
45+
// CHECK: %[[VAL_6:.*]] = cc.stdvec_data %[[VAL_5]] : (!cc.stdvec<!quake.measure>) -> !cc.ptr<i8>
46+
// CHECK: %[[VAL_7:.*]] = cc.stdvec_size %[[VAL_5]] : (!cc.stdvec<!quake.measure>) -> i64
47+
// CHECK: %[[VAL_8:.*]] = call @__nvqpp_vectorCopyCtor(%[[VAL_6]], %[[VAL_7]], %[[VAL_0]]) : (!cc.ptr<i8>, i64, i64) -> !cc.ptr<i8>
48+
// CHECK: %[[VAL_9:.*]] = cc.stdvec_init %[[VAL_8]], %[[VAL_7]] : (!cc.ptr<i8>, i64) -> !cc.stdvec<!quake.measure>
49+
// CHECK: return %[[VAL_9]] : !cc.stdvec<!quake.measure>
50+
// CHECK: }
51+
52+
struct VectorOfStaticVeq_Bool {
53+
std::vector<bool> operator()() __qpu__ {
54+
cudaq::qubit q1;
55+
cudaq::qvector reg1(4);
56+
cudaq::qvector reg2(2);
57+
cudaq::qubit q2;
58+
auto res = mz(q1, reg1, reg2, q2);
59+
return cudaq::to_bool_vector(res);
60+
}
61+
};
62+
63+
// CHECK-LABEL: func.func @__nvqpp__mlirgen__VectorOfStaticVeq_Bool() -> !cc.stdvec<i1> attributes {"cudaq-entrypoint", "cudaq-kernel"} {
64+
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
65+
// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.ref
66+
// CHECK: %[[VAL_2:.*]] = quake.alloca !quake.veq<4>
67+
// CHECK: %[[VAL_3:.*]] = quake.alloca !quake.veq<2>
68+
// CHECK: %[[VAL_4:.*]] = quake.alloca !quake.ref
69+
// CHECK: %[[VAL_5:.*]] = quake.mz %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] name "res" : (!quake.ref, !quake.veq<4>, !quake.veq<2>, !quake.ref) -> !cc.stdvec<!quake.measure>
70+
// CHECK: %[[VAL_6:.*]] = quake.discriminate %[[VAL_5]] : (!cc.stdvec<!quake.measure>) -> !cc.stdvec<i1>
71+
// CHECK: %[[VAL_7:.*]] = call @__nvqpp_cudaq_to_bool_vector(%[[VAL_6]]) : (!cc.stdvec<i1>) -> !cc.stdvec<i1>
72+
// CHECK: %[[VAL_8:.*]] = cc.stdvec_data %[[VAL_7]] : (!cc.stdvec<i1>) -> !cc.ptr<i8>
73+
// CHECK: %[[VAL_9:.*]] = cc.stdvec_size %[[VAL_7]] : (!cc.stdvec<i1>) -> i64
74+
// CHECK: %[[VAL_10:.*]] = call @__nvqpp_vectorCopyCtor(%[[VAL_8]], %[[VAL_9]], %[[VAL_0]]) : (!cc.ptr<i8>, i64, i64) -> !cc.ptr<i8>
75+
// CHECK: %[[VAL_11:.*]] = cc.stdvec_init %[[VAL_10]], %[[VAL_9]] : (!cc.ptr<i8>, i64) -> !cc.stdvec<i1>
76+
// CHECK: return %[[VAL_11]] : !cc.stdvec<i1>
5177
// CHECK: }
5278

5379
struct VectorOfDynamicVeq {
54-
std::vector<bool> operator()(unsigned i, unsigned j) __qpu__ {
80+
std::vector<cudaq::measure_result> operator()(unsigned i, unsigned j) __qpu__ {
5581
cudaq::qubit q1;
5682
cudaq::qvector reg1(i);
5783
cudaq::qvector reg2(j);
@@ -61,26 +87,25 @@ struct VectorOfDynamicVeq {
6187
};
6288

6389
// CHECK-LABEL: func.func @__nvqpp__mlirgen__VectorOfDynamicVeq(
64-
// CHECK-SAME: %[[VAL_0:.*]]: i32{{.*}}, %[[VAL_1:.*]]: i32{{.*}}) -> !cc.stdvec<i1> attributes {
65-
// CHECK: %[[VAL_15:.*]] = arith.constant 1 : i64
66-
// CHECK: %[[VAL_2:.*]] = cc.alloca i32
67-
// CHECK: cc.store %[[VAL_0]], %[[VAL_2]] : !cc.ptr<i32>
68-
// CHECK: %[[VAL_3:.*]] = cc.alloca i32
69-
// CHECK: cc.store %[[VAL_1]], %[[VAL_3]] : !cc.ptr<i32>
70-
// CHECK: %[[VAL_4:.*]] = quake.alloca !quake.ref
71-
// CHECK: %[[VAL_5:.*]] = cc.load %[[VAL_2]] : !cc.ptr<i32>
72-
// CHECK: %[[VAL_6:.*]] = cc.cast unsigned %[[VAL_5]] : (i32) -> i64
73-
// CHECK: %[[VAL_7:.*]] = quake.alloca !quake.veq<?>[%[[VAL_6]] : i64]
74-
// CHECK: %[[VAL_8:.*]] = cc.load %[[VAL_3]] : !cc.ptr<i32>
75-
// CHECK: %[[VAL_9:.*]] = cc.cast unsigned %[[VAL_8]] : (i32) -> i64
76-
// CHECK: %[[VAL_10:.*]] = quake.alloca !quake.veq<?>[%[[VAL_9]] : i64]
77-
// CHECK: %[[VAL_11:.*]] = quake.alloca !quake.ref
78-
// CHECK: %[[VAL_112:.*]] = quake.mz %[[VAL_4]], %[[VAL_7]], %[[VAL_10]], %[[VAL_11]] : (!quake.ref, !quake.veq<?>, !quake.veq<?>, !quake.ref) -> !cc.stdvec<!quake.measure>
79-
// CHECK: %[[VAL_12:.*]] = quake.discriminate %[[VAL_112]] :
80-
// CHECK: %[[VAL_13:.*]] = cc.stdvec_data %[[VAL_12]] : (!cc.stdvec<i1>) -> !cc.ptr<i8>
81-
// CHECK: %[[VAL_14:.*]] = cc.stdvec_size %[[VAL_12]] : (!cc.stdvec<i1>) -> i64
82-
// CHECK: %[[VAL_16:.*]] = call @__nvqpp_vectorCopyCtor(%[[VAL_13]], %[[VAL_14]], %[[VAL_15]]) : (!cc.ptr<i8>, i64, i64) -> !cc.ptr<i8>
83-
// CHECK: %[[VAL_17:.*]] = cc.stdvec_init %[[VAL_16]], %[[VAL_14]] : (!cc.ptr<i8>, i64) -> !cc.stdvec<i1>
84-
// CHECK: return %[[VAL_17]] : !cc.stdvec<i1>
90+
// CHECK-SAME: %[[VAL_0:.*]]: i32 loc("mz.cpp":65:3),
91+
// CHECK-SAME: %[[VAL_1:.*]]: i32 loc("mz.cpp":65:3)) -> !cc.stdvec<!quake.measure> attributes {"cudaq-kernel"} {
92+
// CHECK: %[[VAL_2:.*]] = arith.constant 4 : i64 loc(#loc20)
93+
// CHECK: %[[VAL_3:.*]] = cc.alloca i32 loc(#loc21)
94+
// CHECK: cc.store %[[VAL_0]], %[[VAL_3]] : !cc.ptr<i32> loc(#loc21)
95+
// CHECK: %[[VAL_4:.*]] = cc.alloca i32 loc(#loc22)
96+
// CHECK: cc.store %[[VAL_1]], %[[VAL_4]] : !cc.ptr<i32> loc(#loc22)
97+
// CHECK: %[[VAL_5:.*]] = quake.alloca !quake.ref loc(#loc23)
98+
// CHECK: %[[VAL_6:.*]] = cc.load %[[VAL_3]] : !cc.ptr<i32> loc(#loc21)
99+
// CHECK: %[[VAL_7:.*]] = cc.cast unsigned %[[VAL_6]] : (i32) -> i64 loc(#loc24)
100+
// CHECK: %[[VAL_8:.*]] = quake.alloca !quake.veq<?>{{\[}}%[[VAL_7]] : i64] loc(#loc25)
101+
// CHECK: %[[VAL_9:.*]] = cc.load %[[VAL_4]] : !cc.ptr<i32> loc(#loc22)
102+
// CHECK: %[[VAL_10:.*]] = cc.cast unsigned %[[VAL_9]] : (i32) -> i64 loc(#loc26)
103+
// CHECK: %[[VAL_11:.*]] = quake.alloca !quake.veq<?>{{\[}}%[[VAL_10]] : i64] loc(#loc27)
104+
// CHECK: %[[VAL_12:.*]] = quake.alloca !quake.ref loc(#loc28)
105+
// CHECK: %[[VAL_13:.*]] = quake.mz %[[VAL_5]], %[[VAL_8]], %[[VAL_11]], %[[VAL_12]] : (!quake.ref, !quake.veq<?>, !quake.veq<?>, !quake.ref) -> !cc.stdvec<!quake.measure> loc(#loc29)
106+
// CHECK: %[[VAL_14:.*]] = cc.stdvec_data %[[VAL_13]] : (!cc.stdvec<!quake.measure>) -> !cc.ptr<i8> loc(#loc20)
107+
// CHECK: %[[VAL_15:.*]] = cc.stdvec_size %[[VAL_13]] : (!cc.stdvec<!quake.measure>) -> i64 loc(#loc20)
108+
// CHECK: %[[VAL_16:.*]] = call @__nvqpp_vectorCopyCtor(%[[VAL_14]], %[[VAL_15]], %[[VAL_2]]) : (!cc.ptr<i8>, i64, i64) -> !cc.ptr<i8> loc(#loc20)
109+
// CHECK: %[[VAL_17:.*]] = cc.stdvec_init %[[VAL_16]], %[[VAL_15]] : (!cc.ptr<i8>, i64) -> !cc.stdvec<!quake.measure> loc(#loc20)
110+
// CHECK: return %[[VAL_17]] : !cc.stdvec<!quake.measure>
85111
// CHECK: }
86-

0 commit comments

Comments
 (0)