Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/rc_sync.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 0

# Import Numpy/Pybind11 because it is imported in setup.py
- name: Install Numpy and Pybind11
run: |
Expand Down
8 changes: 5 additions & 3 deletions doc/releases/changelog-0.13.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
Similar to PennyLane's behaviour, this experimental feature will fall back to the old system
whenever the graph cannot find decomposition rules for all unsupported operators in the program,
and a ``UserWarning`` is raised.
[(#2099)](https://github.com/PennyLaneAI/catalyst/pull/2099)
[(#2091)](https://github.com/PennyLaneAI/catalyst/pull/2091)
[(#2029)](https://github.com/PennyLaneAI/catalyst/pull/2029)
[(#2001)](https://github.com/PennyLaneAI/catalyst/pull/2001)
[(#2029)](https://github.com/PennyLaneAI/catalyst/pull/2029)
[(#2068)](https://github.com/PennyLaneAI/catalyst/pull/2068)
[(#2091)](https://github.com/PennyLaneAI/catalyst/pull/2091)
[(#2099)](https://github.com/PennyLaneAI/catalyst/pull/2099)


* Catalyst now supports dynamic wire allocation with ``qml.allocate()`` and
``qml.deallocate()`` when program capture is enabled.
Expand Down
42 changes: 42 additions & 0 deletions frontend/test/pytest/from_plxpr/test_capture_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,48 @@ def circuit(x: float, y: float, z: float):

assert jnp.allclose(circuit(1.5, 2.5, 3.5), capture_result)

def test_transform_graph_decompose_workflow(self, backend):
"""Test the integration for a circuit with a 'decompose' graph transform."""

# Capture enabled

qml.capture.enable()
qml.decomposition.enable_graph()

@qjit(target="mlir")
@partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ])
@qml.qnode(qml.device(backend, wires=2))
def captured_circuit(x: float, y: float, z: float):
m = qml.measure(0)

@qml.cond(m)
def cond_fn():
qml.Rot(x, y, z, 0)

cond_fn()
return qml.expval(qml.PauliZ(0))

capture_result = captured_circuit(1.5, 2.5, 3.5)

qml.decomposition.disable_graph()
qml.capture.disable()

# Capture disabled
@qjit
@partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ])
@qml.qnode(qml.device(backend, wires=2))
def circuit(x: float, y: float, z: float):
m = catalyst.measure(0)

@catalyst.cond(m)
def cond_fn():
qml.Rot(x, y, z, 0)

cond_fn()
return qml.expval(qml.PauliZ(0))

assert jnp.allclose(circuit(1.5, 2.5, 3.5), capture_result)

def test_transform_map_wires_workflow(self, backend):
"""Test the integration for a circuit with a 'map_wires' transform."""

Expand Down
141 changes: 67 additions & 74 deletions mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/ValueRange.h"

#include "Quantum/IR/QuantumInterfaces.h"
#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Transforms/Patterns.h"

Expand All @@ -40,6 +41,7 @@ namespace quantum {
/// - A runtime Value (for dynamic indices computed at runtime)
/// - An IntegerAttr (for compile-time constant indices)
/// - Invalid/uninitialized (represented by std::monostate)
/// And a qreg value to represent the qreg that the index belongs to
///
/// The struct uses std::variant to ensure only one type is active at a time,
/// preventing invalid states.
Expand All @@ -54,17 +56,21 @@ namespace quantum {
/// Value idx = dynamicIdx.getValue(); // Get the Value
/// }
/// }
struct QubitIndex {
class QubitIndex {
private:
// use monostate to represent the invalid index
std::variant<std::monostate, Value, IntegerAttr> index;
Value qreg;

QubitIndex() : index(std::monostate()) {}
QubitIndex(Value val) : index(val) {}
QubitIndex(IntegerAttr attr) : index(attr) {}
public:
QubitIndex() : index(std::monostate()), qreg(nullptr) {}
QubitIndex(Value val, Value qreg) : index(val), qreg(qreg) {}
QubitIndex(IntegerAttr attr, Value qreg) : index(attr), qreg(qreg) {}

bool isValue() const { return std::holds_alternative<Value>(index); }
bool isAttr() const { return std::holds_alternative<IntegerAttr>(index); }
operator bool() const { return isValue() || isAttr(); }
Value getReg() const { return qreg; }
Value getValue() const { return isValue() ? std::get<Value>(index) : nullptr; }
IntegerAttr getAttr() const { return isAttr() ? std::get<IntegerAttr>(index) : nullptr; }
};
Expand All @@ -76,25 +82,16 @@ class OpSignatureAnalyzer {
public:
OpSignatureAnalyzer() = delete;
OpSignatureAnalyzer(CustomOp op, bool enableQregMode)
: signature(OpSignature{
.params = op.getParams(),
.inQubits = op.getInQubits(),
.inCtrlQubits = op.getInCtrlQubits(),
.inCtrlValues = op.getInCtrlValues(),
.outQubits = op.getOutQubits(),
.outCtrlQubits = op.getOutCtrlQubits(),
})
: signature(OpSignature{.params = op.getParams(),
.inQubits = op.getNonCtrlQubitOperands(),
.inCtrlQubits = op.getCtrlQubitOperands(),
.inCtrlValues = op.getCtrlValueOperands(),
.outQubits = op.getNonCtrlQubitResults(),
.outCtrlQubits = op.getCtrlQubitResults()})
{
if (!enableQregMode)
return;

signature.sourceQreg = getSourceQreg(signature.inQubits.front());
if (!signature.sourceQreg) {
op.emitError("Cannot get source qreg");
isValid = false;
return;
}

// input wire indices
for (Value qubit : signature.inQubits) {
const QubitIndex index = getExtractIndex(qubit);
Expand All @@ -117,13 +114,35 @@ class OpSignatureAnalyzer {
signature.inCtrlWireIndices.emplace_back(index);
}

assert((signature.inWireIndices.size() + signature.inCtrlWireIndices.size()) > 0 &&
"inWireIndices or inCtrlWireIndices should not be empty");

// Output qubit indices are the same as input qubit indices
signature.outQubitIndices = signature.inWireIndices;
signature.outCtrlQubitIndices = signature.inCtrlWireIndices;
}

operator bool() const { return isValid; }

Value getUpdatedQreg(PatternRewriter &rewriter, Location loc)
{
// FIXME: This will cause an issue when the decomposition function has cross-qreg
// inputs and outputs. Now, we just assume has only one qreg input, the global one exists.
// raise an error if the qreg is not the same
Value qreg = signature.inWireIndices[0].getReg();

bool sameQreg = true;
for (const auto &index : signature.inWireIndices) {
sameQreg &= index.getReg() == qreg;
}
for (const auto &index : signature.inCtrlWireIndices) {
sameQreg &= index.getReg() == qreg;
}

assert(sameQreg && "The qreg of the input wires should be the same");
return qreg;
}

// Prepare the operands for calling the decomposition function
// There are two cases:
// 1. The first input is a qreg, which means the decomposition function is a qreg mode function
Expand All @@ -144,15 +163,8 @@ class OpSignatureAnalyzer {

int operandIdx = 0;
if (isa<quantum::QuregType>(funcInputs[0])) {
Value updatedQreg = signature.sourceQreg;
for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) {
const QubitIndex &index = signature.inWireIndices[i];
updatedQreg =
rewriter.create<quantum::InsertOp>(loc, updatedQreg.getType(), updatedQreg,
index.getValue(), index.getAttr(), qubit);
}
operands[operandIdx++] = getUpdatedQreg(rewriter, loc);

operands[operandIdx++] = updatedQreg;
if (!signature.params.empty()) {
auto [startIdx, endIdx] =
findParamTypeRange(funcInputs, signature.params.size(), operandIdx);
Expand All @@ -163,16 +175,12 @@ class OpSignatureAnalyzer {
}
}

if (!signature.inWireIndices.empty()) {
operands[operandIdx] = fromTensorOrAsIs(signature.inWireIndices,
funcInputs[operandIdx], rewriter, loc);
operandIdx++;
}

if (!signature.inCtrlWireIndices.empty()) {
operands[operandIdx] = fromTensorOrAsIs(signature.inCtrlWireIndices,
funcInputs[operandIdx], rewriter, loc);
operandIdx++;
for (const auto &indices : {signature.inWireIndices, signature.inCtrlWireIndices}) {
if (!indices.empty()) {
operands[operandIdx] =
fromTensorOrAsIs(indices, funcInputs[operandIdx], rewriter, loc);
operandIdx++;
}
}
}
else {
Expand Down Expand Up @@ -218,18 +226,16 @@ class OpSignatureAnalyzer {

SmallVector<Value> newResults;
rewriter.setInsertionPointAfter(callOp);
for (const QubitIndex &index : signature.outQubitIndices) {
auto extractOp = rewriter.create<quantum::ExtractOp>(
callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
index.getAttr());
newResults.emplace_back(extractOp.getResult());
}
for (const QubitIndex &index : signature.outCtrlQubitIndices) {
auto extractOp = rewriter.create<quantum::ExtractOp>(
callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
index.getAttr());
newResults.emplace_back(extractOp.getResult());

for (const auto &indices : {signature.outQubitIndices, signature.outCtrlQubitIndices}) {
for (const auto &index : indices) {
auto extractOp = rewriter.create<quantum::ExtractOp>(
callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
index.getAttr());
newResults.emplace_back(extractOp.getResult());
}
}

return newResults;
}

Expand All @@ -245,7 +251,6 @@ class OpSignatureAnalyzer {
ValueRange outCtrlQubits;

// Qreg mode specific information
Value sourceQreg = nullptr;
SmallVector<QubitIndex> inWireIndices;
SmallVector<QubitIndex> inCtrlWireIndices;
SmallVector<QubitIndex> outQubitIndices;
Expand Down Expand Up @@ -333,39 +338,21 @@ class OpSignatureAnalyzer {
return values.front();
}

Value getSourceQreg(Value qubit)
{
while (qubit) {
if (auto extractOp = qubit.getDefiningOp<quantum::ExtractOp>()) {
return extractOp.getQreg();
}

if (auto customOp = dyn_cast_or_null<quantum::CustomOp>(qubit.getDefiningOp())) {
if (customOp.getQubitOperands().empty()) {
break;
}
qubit = customOp.getQubitOperands()[0];
}
}

return nullptr;
}

QubitIndex getExtractIndex(Value qubit)
{
while (qubit) {
if (auto extractOp = qubit.getDefiningOp<quantum::ExtractOp>()) {
if (Value idx = extractOp.getIdx()) {
return QubitIndex(idx);
return QubitIndex(idx, extractOp.getQreg());
}
if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) {
return QubitIndex(idxAttr);
return QubitIndex(idxAttr, extractOp.getQreg());
}
}

if (auto customOp = dyn_cast_or_null<quantum::CustomOp>(qubit.getDefiningOp())) {
auto qubitOperands = customOp.getQubitOperands();
auto qubitResults = customOp.getQubitResults();
if (auto gate = dyn_cast_or_null<quantum::QuantumGate>(qubit.getDefiningOp())) {
auto qubitOperands = gate.getQubitOperands();
auto qubitResults = gate.getQubitResults();
auto it =
llvm::find_if(qubitResults, [&](Value result) { return result == qubit; });

Expand All @@ -377,6 +364,10 @@ class OpSignatureAnalyzer {
}
}
}
else if (auto measureOp = dyn_cast_or_null<quantum::MeasureOp>(qubit.getDefiningOp())) {
qubit = measureOp.getInQubit();
continue;
}

break;
}
Expand All @@ -394,7 +385,8 @@ struct DecomposeLoweringRewritePattern : public OpRewritePattern<CustomOp> {
DecomposeLoweringRewritePattern(MLIRContext *context,
const llvm::StringMap<func::FuncOp> &registry,
const llvm::StringSet<llvm::MallocAllocator> &gateSet)
: OpRewritePattern(context), decompositionRegistry(registry), targetGateSet(gateSet)
: OpRewritePattern<CustomOp>(context), decompositionRegistry(registry),
targetGateSet(gateSet)
{
}

Expand All @@ -421,11 +413,12 @@ struct DecomposeLoweringRewritePattern : public OpRewritePattern<CustomOp> {
assert(decompFunc.getFunctionType().getNumResults() >= 1 &&
"Decomposition function must have at least one result");

rewriter.setInsertionPointAfter(op);

auto enableQreg = isa<quantum::QuregType>(decompFunc.getFunctionType().getInput(0));
auto analyzer = OpSignatureAnalyzer(op, enableQreg);
assert(analyzer && "Analyzer should be valid");

rewriter.setInsertionPointAfter(op);
auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc());
auto callOp =
rewriter.create<func::CallOp>(op.getLoc(), decompFunc.getFunctionType().getResults(),
Expand Down
46 changes: 46 additions & 0 deletions mlir/test/Quantum/DecomposeLoweringTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,49 @@ module @cnot_alternative_decomposition {
return %out_qubits_2#0, %out_qubits_4 : !quantum.bit, !quantum.bit
}
}

// -----

module @mcm_example {
func.func public @test_mcm_hadamard() -> tensor<2xf64> {
%0 = quantum.alloc( 1) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
%mres, %out_qubit = quantum.measure %1 : i1, !quantum.bit
%2 = quantum.insert %0[ 0], %out_qubit : !quantum.reg, !quantum.bit

// CHECK: [[RZ_QUBIT:%.+]] = quantum.custom "RZ"([[CST_0:%.+]])
// CHECK: [[RY_QUBIT:%.+]] = quantum.custom "RY"([[CST_1:%.+]]) [[RZ_QUBIT]] : !quantum.bit
// CHECK: [[REG_1:%.+]] = quantum.insert [[REG:%.+]][[[EXTRACTED:%.+]]], [[RY_QUBIT]] : !quantum.reg, !quantum.bit
// CHECK-NOT: quantum.custom "Hadamard"
%3 = quantum.extract %2[ 0] : !quantum.reg -> !quantum.bit
%out_qubits = quantum.custom "Hadamard"() %3 : !quantum.bit
%4 = quantum.insert %2[ 0], %out_qubits : !quantum.reg, !quantum.bit

%5 = quantum.compbasis qreg %4 : !quantum.obs
%6 = quantum.probs %5 : tensor<2xf64>
quantum.dealloc %4 : !quantum.reg
return %6 : tensor<2xf64>
}

// Decomposition function should be applied and removed from the module
// CHECK-NOT: func.func public @rz_ry
func.func public @rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage<internal>, num_wires = 1 : i64, target_gate = "Hadamard"} {
%cst = arith.constant 3.1415926535897931 : f64
%cst_0 = arith.constant 1.5707963267948966 : f64
%0 = stablehlo.slice %arg1 [0:1] : (tensor<1xi64>) -> tensor<1xi64>
%1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor<i64>
%extracted = tensor.extract %1[] : tensor<i64>
%2 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit
%out_qubits = quantum.custom "RZ"(%cst_0) %2 : !quantum.bit
%3 = stablehlo.slice %arg1 [0:1] : (tensor<1xi64>) -> tensor<1xi64>
%4 = stablehlo.reshape %3 : (tensor<1xi64>) -> tensor<i64>
%extracted_1 = tensor.extract %1[] : tensor<i64>
%5 = quantum.insert %arg0[%extracted_1], %out_qubits : !quantum.reg, !quantum.bit
%extracted_2 = tensor.extract %4[] : tensor<i64>
%6 = quantum.extract %5[%extracted_2] : !quantum.reg -> !quantum.bit
%out_qubits_3 = quantum.custom "RY"(%cst) %6 : !quantum.bit
%extracted_4 = tensor.extract %4[] : tensor<i64>
%7 = quantum.insert %5[%extracted_4], %out_qubits_3 : !quantum.reg, !quantum.bit
return %7 : !quantum.reg
}
}