Skip to content

Commit b378f8f

Browse files
* Clean-up
* Handle the size for `measure_result` since it behaves differently on host and on device * Log result output Signed-off-by: Pradnya Khalate <pkhalate@nvidia.com>
1 parent 2ca8ce2 commit b378f8f

File tree

8 files changed

+43
-22
lines changed

8 files changed

+43
-22
lines changed

lib/Frontend/nvqpp/ConvertStmt.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,8 @@ bool QuakeBridgeVisitor::VisitReturnStmt(clang::ReturnStmt *x) {
359359
};
360360
IRBuilder irb(builder);
361361
Value tySize;
362-
if (!cudaq::cc::isDynamicType(eleTy)) {
362+
if (!cudaq::cc::isDynamicType(eleTy))
363363
tySize = irb.getByteSizeOfType(loc, eleTy);
364-
}
365364
if (!tySize) {
366365
// TODO: we need to recursively create copies of all
367366
// dynamic memory used within the type. See the

lib/Optimizer/CodeGen/ConvertCCToLLVM.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,19 @@ void cudaq::opt::populateCCTypeConversions(LLVMTypeConverter *converter) {
8181
return LLVM::LLVMStructType::getLiteral(type.getContext(), members,
8282
type.getPacked());
8383
});
84-
converter->addConversion([](quake::MeasureType type) -> Type {
85-
auto ctx = type.getContext();
86-
auto i32Ty = IntegerType::get(ctx, 32);
87-
auto i64Ty = IntegerType::get(ctx, 64);
88-
return LLVM::LLVMStructType::getLiteral(ctx, {i32Ty, i64Ty});
89-
});
9084
}
9185

9286
std::size_t cudaq::opt::getDataSize(llvm::DataLayout &dataLayout, Type ty) {
9387
LLVMTypeConverter converter(ty.getContext());
9488
cudaq::opt::populateCCTypeConversions(&converter);
9589
auto llvmDialectTy = converter.convertType(ty);
90+
// `measure_result` -> struct conversion for size calculation
91+
converter.addConversion([](quake::MeasureType type) -> Type {
92+
auto ctx = type.getContext();
93+
auto i32Ty = IntegerType::get(ctx, 32);
94+
auto i64Ty = IntegerType::get(ctx, 64);
95+
return LLVM::LLVMStructType::getLiteral(ctx, {i32Ty, i64Ty});
96+
});
9697
llvm::LLVMContext context;
9798
LLVM::TypeToLLVMIRTranslator translator(context);
9899
auto llvmTy = translator.translateType(llvmDialectTy);
@@ -103,6 +104,13 @@ std::size_t cudaq::opt::getDataOffset(llvm::DataLayout &dataLayout, Type ty,
103104
std::size_t off) {
104105
LLVMTypeConverter converter(ty.getContext());
105106
cudaq::opt::populateCCTypeConversions(&converter);
107+
// `measure_result` -> struct conversion for size calculation
108+
converter.addConversion([](quake::MeasureType type) -> Type {
109+
auto ctx = type.getContext();
110+
auto i32Ty = IntegerType::get(ctx, 32);
111+
auto i64Ty = IntegerType::get(ctx, 64);
112+
return LLVM::LLVMStructType::getLiteral(ctx, {i32Ty, i64Ty});
113+
});
106114
auto llvmDialectTy = converter.convertType(ty);
107115
llvm::LLVMContext context;
108116
LLVM::TypeToLLVMIRTranslator translator(context);

lib/Optimizer/CodeGen/ConvertToQIR.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,8 @@ void cudaq::opt::initializeTypeConversions(LLVMTypeConverter &typeConverter) {
206206
return LLVM::LLVMStructType::getLiteral(type.getContext(), mems,
207207
/*packed=*/false);
208208
});
209-
typeConverter.addConversion([](quake::MeasureType type) {
210-
return IntegerType::get(type.getContext(), 1);
211-
});
209+
typeConverter.addConversion(
210+
[](quake::MeasureType type) { return getResultType(type.getContext()); });
212211
cudaq::opt::populateCCTypeConversions(&typeConverter);
213212
}
214213

lib/Optimizer/CodeGen/QuakeToLLVM.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,12 @@ class DiscriminateOpPattern
284284
LogicalResult
285285
matchAndRewrite(quake::DiscriminateOp discr, OpAdaptor adaptor,
286286
ConversionPatternRewriter &rewriter) const override {
287-
auto m = discr.getMeasurement();
288-
rewriter.replaceOp(discr, m);
287+
auto loc = discr.getLoc();
288+
Value resultPtr = adaptor.getMeasurement();
289+
auto i1Ty = rewriter.getI1Type();
290+
auto i1PtrTy = LLVM::LLVMPointerType::get(i1Ty);
291+
auto cast = rewriter.create<LLVM::BitcastOp>(loc, i1PtrTy, resultPtr);
292+
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(discr, i1Ty, cast);
289293
return success();
290294
}
291295
};
@@ -1144,11 +1148,7 @@ class MeasureRewrite : public ConvertOpToLLVMPattern<OP> {
11441148
loc, cudaq::opt::getResultType(context), symbolRef, ValueRange{args});
11451149
if (regName)
11461150
callOp->setAttr("registerName", regName);
1147-
auto i1Ty = rewriter.getI1Type();
1148-
auto i1PtrTy = LLVM::LLVMPointerType::get(i1Ty);
1149-
auto cast =
1150-
rewriter.create<LLVM::BitcastOp>(loc, i1PtrTy, callOp.getResult());
1151-
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(measure, i1Ty, cast);
1151+
rewriter.replaceOp(measure, callOp.getResult());
11521152

11531153
return success();
11541154
}

lib/Optimizer/CodeGen/ReturnToOutputLog.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "cudaq/Optimizer/CodeGen/QIRFunctionNames.h"
1515
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
1616
#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
17+
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
1718
#include "llvm/ADT/TypeSwitch.h"
1819
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1920
#include "mlir/Transforms/Passes.h"
@@ -160,6 +161,15 @@ class ReturnRewrite : public OpRewritePattern<cudaq::cc::LogOutputOp> {
160161
}
161162
}
162163
})
164+
.Case([&](quake::MeasureType) {
165+
std::string labelStr = "result";
166+
if (prefix)
167+
labelStr = prefix->str();
168+
Value label = makeLabel(loc, rewriter, labelStr);
169+
rewriter.create<func::CallOp>(loc, TypeRange{},
170+
cudaq::opt::QIRRecordOutput,
171+
ArrayRef<Value>{val, label});
172+
})
163173
.Default([&](Type) {
164174
// If we reach here, we don't know how to handle this type.
165175
Value one = rewriter.create<arith::ConstantIntOp>(loc, 1, 64);
@@ -195,6 +205,8 @@ class ReturnRewrite : public OpRewritePattern<cudaq::cc::LogOutputOp> {
195205
if (auto arrTy = dyn_cast<cudaq::cc::StdvecType>(ty))
196206
return {std::string("array<") + translateType(arrTy.getElementType()) +
197207
std::string(" x ") + std::to_string(*vecSz) + std::string(">")};
208+
if (isa<quake::MeasureType>(ty))
209+
return {"Result"};
198210
return {"error"};
199211
}
200212

runtime/common/RecordLogParser.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ cudaq::RecordLogParser::getDataHandler(const std::string &dataType) {
240240
static details::DataHandler<double> f64Handler(
241241
std::make_unique<details::FloatConverter<double>>());
242242
// Map data type to the corresponding handler
243-
if (dataType == "measure_result")
243+
if (dataType == "result")
244244
return measureResultHandler;
245245
if (dataType == "i1")
246246
return boolHandler;
@@ -283,7 +283,7 @@ void cudaq::RecordLogParser::processSingleRecord(const std::string &recValue,
283283
// For result type, we don't use the record label (register name) as the type
284284
// annotation.
285285
if (currentOutput == OutputType::RESULT)
286-
label = "measure_result";
286+
label = "result";
287287
if (label.empty()) {
288288
if (currentOutput == OutputType::BOOL)
289289
label = "i1";

runtime/cudaq/builder/kernel_builder.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,7 @@ QuakeValue applyMeasure(ImplicitLocOpBuilder &builder, Value value,
794794
measureResult =
795795
builder.template create<QuakeMeasureOp>(measTy, value).getMeasOut();
796796

797-
Value bits = builder.create<quake::DiscriminateOp>(resTy, measureResult);
798-
return QuakeValue(builder, bits);
797+
return QuakeValue(builder, measureResult);
799798
}
800799

801800
QuakeValue mx(ImplicitLocOpBuilder &builder, QuakeValue &qubitOrQvec,

runtime/cudaq/qis/measure_result.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ class measure_result {
4040
std::size_t getUniqueId() const { return uniqueId; }
4141

4242
// Operator overloads for conversions and comparisons
43+
#ifdef CUDAQ_LIBRARY_MODE
4344
operator bool() const { return __nvqpp__MeasureResultBoolConversion(result); }
45+
#else
46+
operator bool() const { return result == 1; }
47+
#endif
4448
explicit operator int() const { return result; }
4549
explicit operator double() const { return static_cast<double>(result); }
4650

0 commit comments

Comments
 (0)