Skip to content

Commit 6bac08f

Browse files
committed
✨ update QuartzToQIR to use int64_t for indexing and enhance register handling
Signed-off-by: burgholzer <[email protected]>
1 parent 939a66a commit 6bac08f

File tree

1 file changed

+18
-30
lines changed

1 file changed

+18
-30
lines changed

mlir/lib/Conversion/QuartzToQIR/QuartzToQIR.cpp

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,14 @@ namespace {
6666
*/
6767
struct LoweringState : QIRMetadata {
6868
/// Map from qubit index to pointer value for reuse
69-
DenseMap<size_t, Value> ptrMap;
69+
DenseMap<int64_t, Value> ptrMap;
7070

7171
/// Map from classical result index to pointer value for reuse
72-
DenseMap<size_t, Value> resultPtrMap;
72+
DenseMap<int64_t, Value> resultPtrMap;
7373

7474
/// Map from (register_name, register_index) to result pointer
7575
/// This allows caching result pointers for measurements with register info
76-
DenseMap<std::pair<StringRef, size_t>, Value> registerResultMap;
77-
78-
/// Sequence of measurements to record in output block
79-
/// Each entry: (result_ptr, register_name, register_index)
80-
SmallVector<std::tuple<Value, StringRef, size_t>> measurementSequence;
76+
DenseMap<std::pair<StringRef, int64_t>, Value> registerResultMap;
8177
};
8278

8379
/**
@@ -152,7 +148,7 @@ struct ConvertQuartzStaticQIR final : StatefulOpConversionPattern<StaticOp> {
152148
matchAndRewrite(StaticOp op, OpAdaptor /*adaptor*/,
153149
ConversionPatternRewriter& rewriter) const override {
154150
auto* ctx = getContext();
155-
const auto index = static_cast<size_t>(op.getIndex());
151+
const auto index = op.getIndex();
156152

157153
// Get or create a pointer to the qubit
158154
if (getState().ptrMap.contains(index)) {
@@ -161,7 +157,7 @@ struct ConvertQuartzStaticQIR final : StatefulOpConversionPattern<StaticOp> {
161157
} else {
162158
// Create constant and inttoptr operations
163159
const auto constantOp = rewriter.create<LLVM::ConstantOp>(
164-
op.getLoc(), rewriter.getI64IntegerAttr(static_cast<int64_t>(index)));
160+
op.getLoc(), rewriter.getI64IntegerAttr(index));
165161
const auto intToPtrOp = rewriter.replaceOpWithNewOp<LLVM::IntToPtrOp>(
166162
op, LLVM::LLVMPointerType::get(ctx), constantOp->getResult(0));
167163

@@ -337,17 +333,15 @@ struct ConvertQuartzMeasureQIR final : StatefulOpConversionPattern<MeasureOp> {
337333
auto* ctx = getContext();
338334
const auto ptrType = LLVM::LLVMPointerType::get(ctx);
339335
auto& state = getState();
340-
auto& numResults = state.numResults;
336+
const auto numResults = static_cast<int64_t>(state.numResults);
341337
auto& resultPtrMap = state.resultPtrMap;
342338
auto& registerResultMap = state.registerResultMap;
343-
auto& measurementSequence = state.measurementSequence;
344339

345340
// Get or create result pointer value
346341
Value resultValue;
347342
if (op.getRegisterName() && op.getRegisterSize() && op.getRegisterIndex()) {
348343
const auto registerName = op.getRegisterName().value();
349-
const auto registerIndex =
350-
static_cast<size_t>(op.getRegisterIndex().value());
344+
const auto registerIndex = op.getRegisterIndex().value();
351345
const auto key = std::make_pair(registerName, registerIndex);
352346

353347
if (const auto it = registerResultMap.find(key);
@@ -364,24 +358,20 @@ struct ConvertQuartzMeasureQIR final : StatefulOpConversionPattern<MeasureOp> {
364358
.getResult();
365359
resultPtrMap[numResults] = resultValue;
366360
registerResultMap.insert({key, resultValue});
367-
numResults++;
361+
state.numResults++;
368362
}
369-
370-
// Track this measurement for output recording
371-
measurementSequence.emplace_back(resultValue, registerName,
372-
registerIndex);
373363
} else {
374364
// No register info - assign sequential result pointer
375365
const auto constantOp = rewriter.create<LLVM::ConstantOp>(
376-
op.getLoc(), rewriter.getI64IntegerAttr(static_cast<int64_t>(
377-
numResults))); // Sequential result index
366+
op.getLoc(),
367+
rewriter.getI64IntegerAttr(numResults)); // Sequential result index
378368
resultValue = rewriter
379369
.create<LLVM::IntToPtrOp>(op.getLoc(), ptrType,
380370
constantOp->getResult(0))
381371
.getResult();
382372
resultPtrMap[numResults] = resultValue;
383-
measurementSequence.emplace_back(resultValue, "c", numResults);
384-
numResults++;
373+
registerResultMap.insert({{"c", numResults}, resultValue});
374+
state.numResults++;
385375
}
386376

387377
// Create mz (measure) call: mz(qubit, result)
@@ -578,7 +568,7 @@ struct QuartzToQIR final : impl::QuartzToQIRBase<QuartzToQIR> {
578568
*/
579569
static void addOutputRecording(LLVM::LLVMFuncOp& main, MLIRContext* ctx,
580570
LoweringState* state) {
581-
if (state->measurementSequence.empty()) {
571+
if (state->registerResultMap.empty()) {
582572
return; // No measurements to record
583573
}
584574

@@ -592,16 +582,14 @@ struct QuartzToQIR final : impl::QuartzToQIRBase<QuartzToQIR> {
592582
builder.setInsertionPoint(&outputBlock.back());
593583

594584
// Group measurements by register
595-
llvm::StringMap<SmallVector<std::pair<size_t, Value>>> registerGroups;
596-
for (const auto& [resultPtr, regName, regIdx] :
597-
state->measurementSequence) {
598-
if (!regName.empty()) {
599-
registerGroups[regName].emplace_back(regIdx, resultPtr);
600-
}
585+
llvm::StringMap<SmallVector<std::pair<int64_t, Value>>> registerGroups;
586+
for (const auto& [key, resultPtr] : state->registerResultMap) {
587+
const auto& [registerName, registerIndex] = key;
588+
registerGroups[registerName].emplace_back(registerIndex, resultPtr);
601589
}
602590

603591
// Sort registers by name for deterministic output
604-
SmallVector<std::pair<StringRef, SmallVector<std::pair<size_t, Value>>>>
592+
SmallVector<std::pair<StringRef, SmallVector<std::pair<int64_t, Value>>>>
605593
sortedRegisters;
606594
for (auto& [name, measurements] : registerGroups) {
607595
sortedRegisters.emplace_back(name, std::move(measurements));

0 commit comments

Comments
 (0)