Skip to content

Commit 45b8c68

Browse files
* Properly(?) return measure_result type - may need clean up
- Factory and Marshal * Can use `auto` now * NVQIR function for logging result record seems to have issues Signed-off-by: Pradnya Khalate <pkhalate@nvidia.com>
1 parent b378f8f commit 45b8c68

File tree

13 files changed

+101
-145
lines changed

13 files changed

+101
-145
lines changed

lib/Frontend/nvqpp/ASTBridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ listReachableFunctions(clang::CallGraphNode *cgn) {
5353
// Does `ty` refer to a Quake quantum type? This also checks custom recursive
5454
// types. It does not check builtin recursive types; e.g., `!llvm.ptr<T>`.
5555
static bool isQubitType(Type ty) {
56-
if (quake::isQuakeType(ty))
56+
if (quake::isQuantumType(ty))
5757
return true;
5858
// FIXME: next if case is a bug.
5959
if (auto vecTy = dyn_cast<cudaq::cc::StdvecType>(ty))

lib/Frontend/nvqpp/ConvertExpr.cpp

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ maybeUnpackOperands(OpBuilder &builder, Location loc, ValueRange operands,
9999
return std::make_pair(targets, SmallVector<Value>{});
100100
}
101101

102+
static Value emitDiscriminate(OpBuilder &builder, Location loc, Value val) {
103+
if (isa<quake::MeasureType>(val.getType()))
104+
return builder.create<quake::DiscriminateOp>(loc, builder.getI1Type(), val);
105+
return val;
106+
}
107+
102108
namespace {
103109
// Type used to specialize the buildOp function. This extends the cases below by
104110
// prefixing a single parameter value to the list of arguments for cases 1
@@ -637,11 +643,7 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
637643
}
638644
case clang::CastKind::CK_IntegralToFloating: {
639645
auto value = popValue();
640-
// If source is `!quake.measure`, discriminate it first
641-
if (isa<quake::MeasureType>(value.getType())) {
642-
value = builder.create<quake::DiscriminateOp>(loc, builder.getI1Type(),
643-
value);
644-
}
646+
value = emitDiscriminate(builder, loc, value);
645647
auto mode =
646648
(x->getSubExpr()->getType()->isUnsignedIntegerOrEnumerationType())
647649
? cudaq::cc::CastOpMode::Unsigned
@@ -651,24 +653,14 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
651653
}
652654
case clang::CastKind::CK_IntegralToBoolean: {
653655
auto last = popValue();
654-
// If the value is `!quake.measure`, discriminate it first
655-
if (isa<quake::MeasureType>(last.getType())) {
656-
last =
657-
builder.create<quake::DiscriminateOp>(loc, builder.getI1Type(), last);
658-
return pushValue(last);
659-
}
656+
last = emitDiscriminate(builder, loc, last);
660657
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, last.getType());
661658
return pushValue(builder.create<arith::CmpIOp>(
662659
loc, arith::CmpIPredicate::ne, last, zero));
663660
}
664661
case clang::CastKind::CK_FloatingToBoolean: {
665662
auto last = popValue();
666-
// If the value is `!quake.measure`, discriminate it first
667-
if (isa<quake::MeasureType>(last.getType())) {
668-
last =
669-
builder.create<quake::DiscriminateOp>(loc, builder.getI1Type(), last);
670-
return pushValue(last);
671-
}
663+
last = emitDiscriminate(builder, loc, last);
672664
Value zero = opt::factory::createFloatConstant(
673665
loc, builder, 0.0, cast<FloatType>(last.getType()));
674666
return pushValue(builder.create<arith::CmpFOp>(
@@ -687,7 +679,7 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
687679
auto i1Type = builder.getI1Type();
688680
// Handle conversion of `measure_result`
689681
if (isa<quake::MeasureType>(sub.getType())) {
690-
auto i1Val = builder.create<quake::DiscriminateOp>(loc, i1Type, sub);
682+
auto i1Val = emitDiscriminate(builder, loc, sub);
691683
// Convert to `int`
692684
if (isa<IntegerType>(castToTy))
693685
return pushValue(
@@ -860,12 +852,8 @@ bool QuakeBridgeVisitor::VisitBinaryOperator(clang::BinaryOperator *x) {
860852
rhs = maybeLoadValue(rhs);
861853
lhs = maybeLoadValue(lhs);
862854
// Discriminate measure types before comparison
863-
if (isa<quake::MeasureType>(lhs.getType()))
864-
lhs =
865-
builder.create<quake::DiscriminateOp>(loc, builder.getI1Type(), lhs);
866-
if (isa<quake::MeasureType>(rhs.getType()))
867-
rhs =
868-
builder.create<quake::DiscriminateOp>(loc, builder.getI1Type(), rhs);
855+
lhs = emitDiscriminate(builder, loc, lhs);
856+
rhs = emitDiscriminate(builder, loc, rhs);
869857
// Floating point comparison?
870858
if (isa<FloatType>(lhs.getType())) {
871859
arith::CmpFPredicate pred;
@@ -945,10 +933,8 @@ bool QuakeBridgeVisitor::VisitBinaryOperator(clang::BinaryOperator *x) {
945933
rhs = maybeLoadValue(rhs);
946934
lhs = maybeLoadValue(lhs);
947935
// Discriminate measure types before arithmetic
948-
if (isa<quake::MeasureType>(lhs.getType()))
949-
lhs = builder.create<quake::DiscriminateOp>(loc, builder.getI1Type(), lhs);
950-
if (isa<quake::MeasureType>(rhs.getType()))
951-
rhs = builder.create<quake::DiscriminateOp>(loc, builder.getI1Type(), rhs);
936+
lhs = emitDiscriminate(builder, loc, lhs);
937+
rhs = emitDiscriminate(builder, loc, rhs);
952938
castToSameType(builder, loc, x->getLHS()->getType().getTypePtrOrNull(), lhs,
953939
x->getRHS()->getType().getTypePtrOrNull(), rhs);
954940
switch (x->getOpcode()) {
@@ -1036,10 +1022,7 @@ bool QuakeBridgeVisitor::TraverseConditionalOperator(
10361022
if (!TraverseStmt(x->getCond()))
10371023
return false;
10381024
auto condVal = popValue();
1039-
// Discriminate if condition is `!quake.measure`
1040-
if (isa<quake::MeasureType>(condVal.getType()))
1041-
condVal = builder.create<quake::DiscriminateOp>(loc, builder.getI1Type(),
1042-
condVal);
1025+
condVal = emitDiscriminate(builder, loc, condVal);
10431026
Type resultTy = builder.getI64Type();
10441027

10451028
// Create shared lambda for the x->getTrueExpr() and x->getFalseExpr()
@@ -1622,12 +1605,8 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
16221605
if (isa<cc::PointerType>(rhs.getType()))
16231606
rhs = builder.create<cc::LoadOp>(loc, rhs);
16241607
// Discriminate measure types
1625-
if (isa<quake::MeasureType>(lhs.getType()))
1626-
lhs = builder.create<quake::DiscriminateOp>(loc, builder.getI1Type(),
1627-
lhs);
1628-
if (isa<quake::MeasureType>(rhs.getType()))
1629-
rhs = builder.create<quake::DiscriminateOp>(loc, builder.getI1Type(),
1630-
rhs);
1608+
lhs = emitDiscriminate(builder, loc, lhs);
1609+
rhs = emitDiscriminate(builder, loc, rhs);
16311610
// Choose predicate based on operator
16321611
auto pred = (opKind == clang::OO_EqualEqual) ? arith::CmpIPredicate::eq
16331612
: arith::CmpIPredicate::ne;

lib/Optimizer/Builder/Factory.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ Type genBufferType(Type ty) {
7171
assert(!cudaq::cc::isDynamicType(ty) && "must be a type of static extent");
7272
return ty;
7373
}
74+
if (isa<quake::MeasureType>(ty)) {
75+
auto i32Ty = IntegerType::get(ctx, 32);
76+
auto i64Ty = IntegerType::get(ctx, 64);
77+
return cudaq::cc::StructType::get(ctx, {i32Ty, i64Ty});
78+
}
7479
return ty;
7580
}
7681

@@ -430,6 +435,13 @@ Type factory::convertToHostSideType(Type ty, ModuleOp mod) {
430435
return cc::PointerType::get(factory::stlVectorType(
431436
IntegerType::get(ctx, /*FIXME sizeof a pointer?*/ 64)));
432437
}
438+
if (isa<quake::MeasureType>(ty)) {
439+
auto *ctx = ty.getContext();
440+
auto i32Ty = IntegerType::get(ctx, 32);
441+
auto i64Ty = IntegerType::get(ctx, 64);
442+
// Return the `measure_result` struct {int result, size_t uniqueId}
443+
return cc::StructType::get(ctx, {i32Ty, i64Ty});
444+
}
433445
return ty;
434446
}
435447

@@ -644,7 +656,7 @@ FunctionType factory::toHostSideFuncType(FunctionType funcTy, bool addThisPtr,
644656
hasSRet = true;
645657
} else {
646658
assert(funcTy.getNumResults() == 1);
647-
resultTy = funcTy.getResult(0);
659+
resultTy = convertToHostSideType(funcTy.getResult(0), module);
648660
}
649661
}
650662
// If this kernel is a plain old function or a static member function, we

lib/Optimizer/Builder/Intrinsics.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,9 @@ static constexpr IntrinsicCode intrinsicTable[] = {
457457
)#"},
458458
{cudaq::opt::QIRIntegerRecordOutput, {}, R"#(
459459
func.func private @__quantum__rt__int_record_output(i64, !cc.ptr<i8>)
460+
)#"},
461+
{cudaq::opt::QIRRecordOutput, {}, R"#(
462+
func.func private @__quantum__rt__result_record_output(!cc.ptr<!llvm.struct<"Result", opaque>>, !cc.ptr<i8>)
460463
)#"},
461464
{cudaq::opt::QIRTupleRecordOutput, {}, R"#(
462465
func.func private @__quantum__rt__tuple_record_output(i64, !cc.ptr<i8>)

lib/Optimizer/Builder/Marshal.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,9 @@ std::pair<bool, func::FuncOp> cudaq::opt::marshal::lookupHostEntryPointFunc(
790790
// No host entry point needed.
791791
return {false, func::FuncOp{}};
792792
}
793+
if (!funcOp->hasAttr(cudaq::entryPointAttrName)) {
794+
return {false, func::FuncOp{}};
795+
}
793796
if (auto *decl = module.lookupSymbol(mangledEntryPointName))
794797
if (auto func = dyn_cast<func::FuncOp>(decl)) {
795798
func.eraseBody();

lib/Optimizer/CodeGen/ConvertCCToLLVM.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,19 @@ void cudaq::opt::populateCCTypeConversions(LLVMTypeConverter *converter) {
8181
return LLVM::LLVMStructType::getLiteral(type.getContext(), members,
8282
type.getPacked());
8383
});
84-
}
85-
86-
std::size_t cudaq::opt::getDataSize(llvm::DataLayout &dataLayout, Type ty) {
87-
LLVMTypeConverter converter(ty.getContext());
88-
cudaq::opt::populateCCTypeConversions(&converter);
89-
auto llvmDialectTy = converter.convertType(ty);
9084
// `measure_result` -> struct conversion for size calculation
91-
converter.addConversion([](quake::MeasureType type) -> Type {
85+
converter->addConversion([](quake::MeasureType type) -> Type {
9286
auto ctx = type.getContext();
9387
auto i32Ty = IntegerType::get(ctx, 32);
9488
auto i64Ty = IntegerType::get(ctx, 64);
9589
return LLVM::LLVMStructType::getLiteral(ctx, {i32Ty, i64Ty});
9690
});
91+
}
92+
93+
std::size_t cudaq::opt::getDataSize(llvm::DataLayout &dataLayout, Type ty) {
94+
LLVMTypeConverter converter(ty.getContext());
95+
cudaq::opt::populateCCTypeConversions(&converter);
96+
auto llvmDialectTy = converter.convertType(ty);
9797
llvm::LLVMContext context;
9898
LLVM::TypeToLLVMIRTranslator translator(context);
9999
auto llvmTy = translator.translateType(llvmDialectTy);
@@ -104,13 +104,6 @@ std::size_t cudaq::opt::getDataOffset(llvm::DataLayout &dataLayout, Type ty,
104104
std::size_t off) {
105105
LLVMTypeConverter converter(ty.getContext());
106106
cudaq::opt::populateCCTypeConversions(&converter);
107-
// `measure_result` -> struct conversion for size calculation
108-
converter.addConversion([](quake::MeasureType type) -> Type {
109-
auto ctx = type.getContext();
110-
auto i32Ty = IntegerType::get(ctx, 32);
111-
auto i64Ty = IntegerType::get(ctx, 64);
112-
return LLVM::LLVMStructType::getLiteral(ctx, {i32Ty, i64Ty});
113-
});
114107
auto llvmDialectTy = converter.convertType(ty);
115108
llvm::LLVMContext context;
116109
LLVM::TypeToLLVMIRTranslator translator(context);

lib/Optimizer/CodeGen/ReturnToOutputLog.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,22 @@ class ReturnRewrite : public OpRewritePattern<cudaq::cc::LogOutputOp> {
170170
cudaq::opt::QIRRecordOutput,
171171
ArrayRef<Value>{val, label});
172172
})
173+
.Case([&](cudaq::cc::PointerType ptrTy) {
174+
// Check if this is a pointer to %Result (converted MeasureType)
175+
if (auto structTy =
176+
dyn_cast<LLVM::LLVMStructType>(ptrTy.getElementType()))
177+
if (structTy.isIdentified() && structTy.getName() == "Result") {
178+
// Handle as measure result
179+
std::string labelStr = "result";
180+
if (prefix)
181+
labelStr = prefix->str();
182+
Value label = makeLabel(loc, rewriter, labelStr);
183+
rewriter.create<func::CallOp>(loc, TypeRange{},
184+
cudaq::opt::QIRRecordOutput,
185+
ArrayRef<Value>{val, label});
186+
return;
187+
}
188+
})
173189
.Default([&](Type) {
174190
// If we reach here, we don't know how to handle this type.
175191
Value one = rewriter.create<arith::ConstantIntOp>(loc, 1, 64);
@@ -236,6 +252,11 @@ struct ReturnToOutputLogPass
236252
signalPassFailure();
237253
return;
238254
}
255+
if (failed(irBuilder.loadIntrinsic(module, cudaq::opt::QIRRecordOutput))) {
256+
module.emitError("could not load QIR result record output function.");
257+
signalPassFailure();
258+
return;
259+
}
239260
if (failed(irBuilder.loadIntrinsic(module, cudaq::opt::QISTrap))) {
240261
module.emitError("could not load QIR trap function.");
241262
signalPassFailure();

runtime/common/RecordLogParser.cpp

Lines changed: 2 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -105,77 +105,6 @@ void cudaq::RecordLogParser::handleOutput(
105105
const std::string &recValue = entries[2];
106106
std::string recLabel = (entries.size() == 4) ? entries[3] : "";
107107
cudaq::trim(recLabel);
108-
if (recType == "RESULT") {
109-
// Sample-type QIR output, where we have an array of `RESULT` per shot. For
110-
// example,
111-
// START
112-
// OUTPUT RESULT 1 r00000
113-
// ....
114-
// OUTPUT RESULT 1 r00009
115-
// END 0
116-
117-
currentOutput = OutputType::RESULT;
118-
const bool isUninitializedContainer =
119-
(containerMeta.m_type == ContainerType::NONE) ||
120-
(containerMeta.m_type == ContainerType::ARRAY &&
121-
containerMeta.elementCount == 0);
122-
if (isUninitializedContainer) {
123-
// NOTE: This is a temporary workaround until all backends consistently
124-
// use the new transformation pass that wraps result records inside an
125-
// array record output. For now, we permit "naked" RESULT records, i.e.,
126-
// if the QIR produced by a sampled kernel emits a sequence of RESULT
127-
// records without enclosing them in an ARRAY, we interpret them
128-
// collectively as an array of results.
129-
// NOTE: This assumption prevents us from correctly supporting `run` with
130-
// `qir-base` profile.
131-
containerMeta.m_type = ContainerType::ARRAY;
132-
containerMeta.elementCount =
133-
std::stoul(metadata[ResultCountMetadataName]);
134-
containerMeta.arrayType = "i1";
135-
preallocateArray();
136-
}
137-
138-
// Note: For ordered schema, we expect the results are sequential in the
139-
// same order that mz operations are called. This may include results in
140-
// named registers (specified in kernel code) and other auto-generated
141-
// register names. If index cannot be extracted from the label, we fall back
142-
// to using this mechanism.
143-
auto idxLabel = std::to_string(containerMeta.processedElements);
144-
145-
// Get the index from the label, if feasible.
146-
/// TODO: The `sample` API should be updated to not allow explicit
147-
/// measurement operations in the kernel when targeting hardware backends.
148-
// Until then, we handle both cases here - auto-generated labels like
149-
// r00000, r00001, ... and named results like result%0, result%1, ...
150-
if (!recLabel.empty()) {
151-
std::size_t percentPos = recLabel.find('%');
152-
if (percentPos != std::string::npos) {
153-
idxLabel = recLabel.substr(percentPos + 1);
154-
}
155-
// This logic is fragile; for example user may have only one mz assigned
156-
// to variable like r00001 and it will be interpreted as index 1, and
157-
// cause `Array index out of bounds` error. The proper fix is to disallow
158-
// explicit mz operations in sampled kernels. Also, `run` is appropriate
159-
// for getting sub-register results.
160-
else if (recLabel.size() == 6 && recLabel[0] == 'r') {
161-
// check that the last 5 characters are all digits
162-
bool allDigits = true;
163-
for (std::size_t i = 1; i < 6; ++i) {
164-
if (recLabel[i] < '0' || recLabel[i] > '9') {
165-
allDigits = false;
166-
break;
167-
}
168-
}
169-
if (allDigits) {
170-
idxLabel = recLabel.substr(1);
171-
}
172-
}
173-
}
174-
175-
processArrayEntry(recValue, fmt::format("[{}]", idxLabel));
176-
containerMeta.processedElements++;
177-
return;
178-
}
179108
if (recType == "ARRAY") {
180109
containerMeta.m_type = ContainerType::ARRAY;
181110
containerMeta.elementCount = std::stoul(recValue);
@@ -196,6 +125,8 @@ void cudaq::RecordLogParser::handleOutput(
196125
}
197126
return;
198127
}
128+
if (recType == "RESULT")
129+
currentOutput = OutputType::RESULT;
199130
if (recType == "BOOL")
200131
currentOutput = OutputType::BOOL;
201132
else if (recType == "INT")

runtime/nvqir/NVQIR.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,23 @@ void __quantum__rt__result_record_output(Result *r, int8_t *name) {
675675
if (ctx && ctx->name == "run") {
676676

677677
std::string regName(reinterpret_cast<const char *>(name));
678+
679+
// Base profile - value stored in measRes2Val
680+
auto iter = measRes2Val.find(r);
681+
if (iter != measRes2Val.end()) {
682+
quantumRTGenericRecordOutput("RESULT", (iter->second ? 1 : 0),
683+
regName.c_str());
684+
return;
685+
}
686+
687+
// Full QIR - r is ResultOne or ResultZero (measurement already happened)
688+
if (r == ResultOne || r == ResultZero) {
689+
bool val = (r == ResultOne);
690+
quantumRTGenericRecordOutput("RESULT", (val ? 1 : 0), regName.c_str());
691+
return;
692+
}
693+
694+
// Fallback - use qubit mapping and re-measure (legacy behavior)
678695
auto qI = qubitToSizeT(measRes2QB[r]);
679696
auto b = nvqir::getCircuitSimulatorInternal()->mz(qI, regName);
680697
quantumRTGenericRecordOutput("RESULT", (b ? 1 : 0), regName.c_str());

runtime/nvqir/QIRTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ void __nvqpp_cleanup_arrays();
8989
}
9090

9191
// Results
92+
/// FIXME: What should this be?
9293
using Result = bool;
9394
static const Result ResultZeroVal = false;
9495
static const Result ResultOneVal = true;

0 commit comments

Comments
 (0)