Skip to content

Commit e12883f

Browse files
authored
Improvement to MergeCircuit and ReorderMeasurements for Parameters (#157)
Adds additional improvements for handling input parameters in the MergeCircuit pass and the ReorderMeasurement pass. 1. The passes will now move more parameters related operations out of the way when attempting to reorder measurements and merge circuits. 2. The `MergeCircuitsPass::mergeCallCircuits` has been simplified. 3. This PR also adds the `--include-source` command line options as a debug flag. If this flag is passed when the direct input method is also used then the input source will be stored in the output payload at `manifest\input.{qasm|mlir}` depending on the file format.
1 parent 4b0ded7 commit e12883f

File tree

14 files changed

+388
-72
lines changed

14 files changed

+388
-72
lines changed

include/Dialect/QUIR/Utils/Utils.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,12 @@ llvm::Optional<Operation *> nextQuantumOpOrNull(Operation *op);
9696
/// it is of type OpType, otherwise return null
9797
// TODO: Should be replaced by an analysis compatable struct.
9898
template <class OpType>
99-
llvm::Optional<OpType> nextQuantumOpOrNullOfType(Operation *op);
99+
llvm::Optional<OpType> nextQuantumOpOrNullOfType(Operation *op) {
100+
auto nextOperation = nextQuantumOpOrNull(op);
101+
if (nextOperation && isa<OpType>(*nextOperation))
102+
return dyn_cast<OpType>(*nextOperation);
103+
return llvm::None;
104+
}
100105

101106
/// Get the previous Op that has the CPTPOp or UnitaryOp trait, or return null
102107
/// if none found
@@ -107,7 +112,12 @@ llvm::Optional<Operation *> prevQuantumOpOrNull(Operation *op);
107112
/// it if it is of type OpType, otherwise return null
108113
// TODO: Should be replaced by an analysis compatable struct.
109114
template <class OpType>
110-
llvm::Optional<OpType> prevQuantumOpOrNullOfType(Operation *op);
115+
llvm::Optional<OpType> prevQuantumOpOrNullOfType(Operation *op) {
116+
auto prevOperation = prevQuantumOpOrNull(op);
117+
if (prevOperation && isa<OpType>(*prevOperation))
118+
return dyn_cast<OpType>(*prevOperation);
119+
return llvm::None;
120+
}
111121

112122
/// Get the next Op that has the CPTPOp or UnitaryOp trait, or is control flow
113123
/// (has the RegionBranchOpInterface::Trait), or return null if none found

include/Payload/Payload.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class Payload {
6262
// write all files in plaintext to the stream
6363
virtual void writePlain(std::ostream &stream) = 0;
6464
virtual void writePlain(llvm::raw_ostream &stream) = 0;
65+
virtual void addFile(llvm::StringRef filename, llvm::StringRef str) = 0;
6566

6667
const std::string &getName() const { return name; }
6768
const std::string &getPrefix() const { return prefix; }

lib/API/api.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ static llvm::cl::opt<bool> plaintextPayload(
106106
"plaintext-payload", llvm::cl::desc("Write the payload in plaintext"),
107107
llvm::cl::init(false), llvm::cl::cat(qssc::config::getQSSCCategory()));
108108

109+
static llvm::cl::opt<bool> includeSourceInPayload(
110+
"include-source", llvm::cl::desc("Write the input source into the payload"),
111+
llvm::cl::init(false), llvm::cl::cat(qssc::config::getQSSCCategory()));
112+
109113
namespace {
110114
enum InputType { NONE, QASM, MLIR, QOBJ };
111115
} // anonymous namespace
@@ -672,6 +676,14 @@ compile_(int argc, char const **argv, std::string *outputString,
672676
}
673677

674678
if (emitAction == Action::GenQEM) {
679+
680+
if (includeSourceInPayload && directInput) {
681+
if (inputType == InputType::QASM)
682+
payload->addFile("manifest/input.qasm", inputSource + "\n");
683+
else if (inputType == InputType::MLIR)
684+
payload->addFile("manifest/input.mlir", inputSource + "\n");
685+
}
686+
675687
if (auto err = generateQEM_(target, std::move(payload), moduleOp, ostream))
676688
return err;
677689
}

lib/Conversion/OQ3ToStandard/OQ3ToStandard.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "Conversion/OQ3ToStandard/OQ3ToStandard.h"
16+
#include "Dialect/QUIR/IR/QUIROps.h"
1617
#include "Dialect/QUIR/Transforms/Passes.h"
1718

1819
#include "Dialect/OQ3/IR/OQ3Ops.h"
@@ -345,6 +346,27 @@ struct RemoveConvertedNilCastsPattern : public OQ3ToStandardConversion<CastOp> {
345346

346347
}; // struct RemoveConvertedNilCastsPattern
347348

349+
struct CastFromFloatConstPattern : public OQ3ToStandardConversion<CastOp> {
350+
using OQ3ToStandardConversion<CastOp>::OQ3ToStandardConversion;
351+
352+
LogicalResult
353+
matchAndRewrite(CastOp op, OpAdaptor adaptor,
354+
ConversionPatternRewriter &rewriter) const override {
355+
356+
auto constOp = op.arg().getDefiningOp<mlir::arith::ConstantOp>();
357+
if (!constOp)
358+
return failure();
359+
360+
auto floatAttr = constOp.getValue().dyn_cast<mlir::FloatAttr>();
361+
if (!floatAttr)
362+
return failure();
363+
364+
rewriter.replaceOp(op, {adaptor.arg()});
365+
return success();
366+
} // CastFromFloatConstPattern
367+
368+
}; // struct RemoveConvertedNilCastsPattern
369+
348370
struct RemoveI1ToCBitCastsPattern : public OQ3ToStandardConversion<CastOp> {
349371
using OQ3ToStandardConversion<CastOp>::OQ3ToStandardConversion;
350372

@@ -419,6 +441,7 @@ void oq3::populateOQ3ToStandardConversionPatterns(
419441
CastIndexToIntPattern,
420442
RemoveConvertedNilCastsPattern,
421443
RemoveI1ToCBitCastsPattern,
444+
CastFromFloatConstPattern,
422445
WideningIntCastsPattern>(patterns.getContext(), typeConverter);
423446
// clang-format on
424447
}

lib/Conversion/QUIRToStandard/TypeConversion.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ Optional<Type> QuirTypeConverter::convertAngleType(Type t) {
6363
// for function types in func defs and calls
6464
return intType;
6565
}
66+
if (auto floatType = t.dyn_cast<FloatType>()) {
67+
// MUST return the converted type as itself to mark legal
68+
// for function types in func defs and calls
69+
return floatType;
70+
}
6671
return llvm::None;
6772
} // convertAngleType
6873

0 commit comments

Comments
 (0)