Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a57e6c5
floyd warshall
ritu-thombre99 Jul 22, 2025
3d8695f
numm qubit for printing
ritu-thombre99 Jul 22, 2025
3d2ed3b
track with main
ritu-thombre99 Aug 5, 2025
ab18f54
formatting
ritu-thombre99 Aug 5, 2025
72036f4
clang format
ritu-thombre99 Aug 8, 2025
83d5b33
fix ci errors
ritu-thombre99 Aug 8, 2025
70a11b4
fix ci errors
ritu-thombre99 Aug 8, 2025
7cffa7b
fix linting errors
ritu-thombre99 Aug 8, 2025
3eb866f
device capabilities as an optional arg
ritu-thombre99 Aug 8, 2025
1ca60b2
fix formatting again
ritu-thombre99 Aug 8, 2025
4ae4a46
modularize and options
ritu-thombre99 Aug 13, 2025
4227d89
change qubit types from int to QubitIdType
ritu-thombre99 Aug 13, 2025
f4c7dc9
TODO: fix rtdptr dependancy
ritu-thombre99 Aug 21, 2025
c1d738f
refractoring
ritu-thombre99 Aug 22, 2025
db26be6
fix the end permutation
ritu-thombre99 Aug 22, 2025
5142ce1
permute end mapping
ritu-thombre99 Aug 24, 2025
28016b3
remove prints
ritu-thombre99 Aug 26, 2025
fa429f2
add checks for multiqubit gates in runtime
ritu-thombre99 Aug 28, 2025
5864e30
track with main
ritu-thombre99 Aug 28, 2025
17d42c8
clang formatting
ritu-thombre99 Aug 28, 2025
a5ac864
clang format
ritu-thombre99 Aug 28, 2025
4c4dc80
remove redefinitions
ritu-thombre99 Aug 28, 2025
f55c8c5
[no ci] bump nightly version
Oct 11, 2025
23b0f6a
Merge branch 'PennyLaneAI:main' into main
ritu-thombre99 Oct 11, 2025
2dfc770
track with main
ritu-thombre99 Oct 11, 2025
31b014a
return to trivial mapping before getting state
ritu-thombre99 Oct 12, 2025
f078e39
return to trivial mapping before getting state
ritu-thombre99 Oct 12, 2025
428779c
remove NullQubit print operations
ritu-thombre99 Oct 13, 2025
10d263c
check if only 1 and 2 qubit gates are present when performing qubit r…
ritu-thombre99 Oct 13, 2025
b52cfbb
integration tests
ritu-thombre99 Oct 13, 2025
d40f0de
fix format
ritu-thombre99 Oct 13, 2025
cd9ca02
fix format
ritu-thombre99 Oct 13, 2025
dbc635a
docstring in tests
ritu-thombre99 Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions frontend/catalyst/device/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class BackendInfo:

# pylint: disable=too-many-branches
@debug_logger
def extract_backend_info(device: qml.devices.QubitDevice) -> BackendInfo:
def extract_backend_info(device: qml.devices.QubitDevice, device_capabilities=None) -> BackendInfo:
"""Extract the backend info from a quantum device. The device is expected to carry a reference
to a valid TOML config file."""

Expand Down Expand Up @@ -192,6 +192,7 @@ def extract_backend_info(device: qml.devices.QubitDevice) -> BackendInfo:
for k, v in getattr(device, "device_kwargs", {}).items():
if k not in device_kwargs: # pragma: no branch
device_kwargs[k] = v
device_kwargs["coupling_map"] = getattr(device_capabilities, "coupling_map")

return BackendInfo(dname, device_name, device_lpath, device_kwargs)

Expand Down Expand Up @@ -308,9 +309,9 @@ class QJITDevice(qml.devices.Device):

@staticmethod
@debug_logger
def extract_backend_info(device) -> BackendInfo:
def extract_backend_info(device, device_capabilities=None) -> BackendInfo:
"""Wrapper around extract_backend_info in the runtime module."""
return extract_backend_info(device)
return extract_backend_info(device, device_capabilities)

@debug_logger_init
def __init__(self, original_device):
Expand All @@ -319,14 +320,40 @@ def __init__(self, original_device):
for key, value in original_device.__dict__.items():
self.__setattr__(key, value)

check_device_wires(original_device.wires)

super().__init__(wires=original_device.wires)
if (original_device.wires is not None) and any(
isinstance(wire_label, tuple) and (len(wire_label) >= 2)
for wire_label in original_device.wires.labels
):
wires_from_cmap = set()
for wire_label in original_device.wires.labels:
wires_from_cmap.add(wire_label[0])
wires_from_cmap.add(wire_label[1])
wires_from_cmap = qml.wires.Wires(list(wires_from_cmap))
# check_device_wires(wires_from_cmap) not called
# since automatic qubit management
Comment on lines +332 to +333
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think your functionality is restricted to automatic management. The number of wires used by the qnode is still known, it's just that there's some connectivity between them.

The action by your functionality is just inserting SWAP gates right? So it doesn't require allocation of new qubits. The IR would look like the IR with a known number of wires.

super().__init__(wires=wires_from_cmap, shots=original_device.shots)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

core PL has deprecated setting shots on devices

Suggested change
super().__init__(wires=wires_from_cmap, shots=original_device.shots)
super().__init__(wires=wires_from_cmap)

else:
check_device_wires(original_device.wires)
super().__init__(wires=original_device.wires, shots=original_device.shots)

# Capability loading
device_capabilities = get_device_capabilities(original_device, self.original_device.shots)
device_capabilities = get_device_capabilities(original_device)

# TODO: This is a temporary measure to ensure consistency of behaviour. Remove this
# when customizable multi-pathway decomposition is implemented. (Epic 74474)
if hasattr(original_device, "_to_matrix_ops"):
_to_matrix_ops = getattr(original_device, "_to_matrix_ops")
setattr(device_capabilities, "to_matrix_ops", _to_matrix_ops)
if _to_matrix_ops and not device_capabilities.supports_operation("QubitUnitary"):
raise CompileError(
"The device that specifies to_matrix_ops must support QubitUnitary."
)
if original_device.wires is not None:
setattr(device_capabilities, "coupling_map", original_device.wires.labels)
Comment on lines +351 to +352
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple thoughts here:

  1. An immediate issue with this logic here is that it would interfere with the regular way of wires, i.e. qml.device("lightning.qubit", wires=3), because 3 is also not None. You could do the same check as above, i.e. check that each entry in the Wires object is a tuple, but...
  2. ... I'm not so sure about the UI. So I guess your UI is to take in the map through the wires ketword on the device? This only happens to work because in PennyLane, the device wires only need to be a list of hashable objects, which are then treated as labels. However in Catalyst we simplify this a bit, by letting the device wires kwarg to just mean "the number of wires this circuit uses". This UI breaks that assumption, and could cause difficulties in other places (for example, there's work going on about formalizing device capacity and qnode algorithm wires, which will likely change the device wire UI).

If the entire purpose is to just have a UI to take in the wire map from the user, I would suggest taking in it through its own dedicated decorator. The decorator would be on the device, and its entire purpose is just set the coupling map:

dev = qml.device("lightning.qubit", wires=3)
dev = catalyst.set_coupling_map(dev, [(0,1),(1,2),(2,3)])

@qjit
@qml.qnode(dev)
def circuit():
   ...

This way, your functionality is properly modulated. The set_coupling_map function could live in a separate file, e.g. routing.py, and that module would take care of, e.g. verifiying the map makes sense on the device, etc.

Implementation wise, this also prevents the need to piggy back the map on the capabilities (which I see (a) you need to pass into other functions now, and (b) has its own complicated song and dance already). Your function could just add the map to the device_kwargs of the PL device it takes in. This is because in extract_backend_info, at the end all PL device kwargs will be moved onto the qjit device:

qjit_device.py

def extract_backend_info(device: qml.devices.QubitDevice):
   device_kwargs = {}
   ...
    for k, v in getattr(device, "device_kwargs", {}).items():
        if k not in device_kwargs:  # pragma: no branch
            device_kwargs[k] = v
    return BackendInfo(...., device_kwargs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, a restriction of doing it through kwargs in the first place is that the coupling map must be static. However as a first draft we don't need to be concerned over this. The verification in routing.py can just check this as well.

else:
setattr(device_capabilities, "coupling_map", None)

backend = QJITDevice.extract_backend_info(original_device)
backend = QJITDevice.extract_backend_info(original_device, device_capabilities)

self.backend_name = backend.c_interface_name
self.backend_lib = backend.lpath
Expand Down
120 changes: 120 additions & 0 deletions frontend/test/pytest/test_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Integration tests for routing at runtime"""

from functools import partial

import pennylane as qml
import pytest
from pennylane import numpy as np
from pennylane.transforms.transpile import transpile

Check notice on line 22 in frontend/test/pytest/test_routing.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_routing.py#L22

Unused transpile imported from pennylane.transforms.transpile (unused-import)


def qfunc_ops(wires, x, y, z):

Check notice on line 25 in frontend/test/pytest/test_routing.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_routing.py#L25

Missing function or method docstring (missing-function-docstring)
qml.Hadamard(wires=wires[0])
qml.RZ(z, wires=wires[2])
qml.CNOT(wires=[wires[2], wires[0]])
qml.CNOT(wires=[wires[1], wires[0]])
qml.RX(x, wires=wires[0])
qml.CNOT(wires=[wires[0], wires[2]])
qml.RZ(-z, wires=wires[2])
qml.RX(y, wires=wires[0])
qml.PauliY(wires=wires[2])
qml.CY(wires=[wires[1], wires[2]])


# pylint: disable=too-many-public-methods
class TestRouting:
"""Unit tests for testing routing function at runtime"""

all_to__all_device = qml.device("lightning.qubit")
linear_device = qml.device("lightning.qubit", wires=[(0, 1), (1, 2)])

input_devices = ((all_to__all_device, linear_device),)

@pytest.mark.parametrize("all_to__all_device, linear_device", input_devices)
def test_state_invariance_under_routing(self, all_to__all_device, linear_device):
"""test that transpile does not alter output for state measurement"""
def circuit(wires, x, y, z):
qfunc_ops(wires, x, y, z)
return qml.state()

all_to_all_qnode = qml.qjit(qml.QNode(circuit, all_to__all_device))
linear_qnode = qml.qjit(qml.QNode(circuit, linear_device))

assert np.allclose(
all_to_all_qnode([0, 1, 2], 0.1, 0.2, 0.3), linear_qnode([0, 1, 2], 0.1, 0.2, 0.3)
)

@pytest.mark.parametrize("all_to__all_device, linear_device", input_devices)
def test_probs_invariance_under_routing(self, all_to__all_device, linear_device):
"""test that transpile does not alter output for probs measurement"""
def circuit(wires, x, y, z):
qfunc_ops(wires, x, y, z)
return qml.probs()

all_to_all_qnode = qml.qjit(qml.QNode(circuit, all_to__all_device))
linear_qnode = qml.qjit(qml.QNode(circuit, linear_device))

assert np.allclose(
all_to_all_qnode([0, 1, 2], 0.1, 0.2, 0.3), linear_qnode([0, 1, 2], 0.1, 0.2, 0.3)
)

@pytest.mark.parametrize("all_to__all_device, linear_device", input_devices)
def test_sample_invariance_under_routing(self, all_to__all_device, linear_device):
"""test that transpile does not alter output for sample measurement"""
def circuit(wires, x, y, z):
qfunc_ops(wires, x, y, z)
return qml.sample()

all_to_all_qnode = qml.qjit(
partial(qml.set_shots, shots=10)(qml.QNode(circuit, all_to__all_device)), seed=37
)
linear_qnode = qml.qjit(
partial(qml.set_shots, shots=10)(qml.QNode(circuit, linear_device)), seed=37
)
assert np.allclose(
all_to_all_qnode([0, 1, 2], 0.1, 0.2, 0.3), linear_qnode([0, 1, 2], 0.1, 0.2, 0.3)
)

@pytest.mark.parametrize("all_to__all_device, linear_device", input_devices)
def test_counts_invariance_under_routing(self, all_to__all_device, linear_device):
"""test that transpile does not alter output for counts measurement"""
def circuit(wires, x, y, z):
qfunc_ops(wires, x, y, z)
return qml.counts()

all_to_all_qnode = qml.qjit(
partial(qml.set_shots, shots=10)(qml.QNode(circuit, all_to__all_device)), seed=37
)
linear_qnode = qml.qjit(
partial(qml.set_shots, shots=10)(qml.QNode(circuit, linear_device)), seed=37
)
assert np.allclose(
all_to_all_qnode([0, 1, 2], 0.1, 0.2, 0.3), linear_qnode([0, 1, 2], 0.1, 0.2, 0.3)
)

@pytest.mark.parametrize("all_to__all_device, linear_device", input_devices)
def test_expvals_invariance_under_routing(self, all_to__all_device, linear_device):
"""test that transpile does not alter output for expectation value measurement"""
def circuit(wires, x, y, z):
qfunc_ops(wires, x, y, z)
return qml.expval(qml.X(0) @ qml.Y(1)), qml.var(qml.Z(2))

all_to_all_qnode = qml.qjit(qml.QNode(circuit, all_to__all_device))
linear_qnode = qml.qjit(qml.QNode(circuit, linear_device))
assert np.allclose(
all_to_all_qnode([0, 1, 2], 0.1, 0.2, 0.3), linear_qnode([0, 1, 2], 0.1, 0.2, 0.3)
)
28 changes: 25 additions & 3 deletions runtime/lib/capi/ExecutionContext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include "Exception.hpp"
#include "QuantumDevice.hpp"
#include "Routing.hpp"
#include "Types.h"

namespace Catalyst::Runtime {
Expand Down Expand Up @@ -170,6 +171,8 @@ class RTDevice {

std::unique_ptr<SharedLibraryManager> rtd_dylib{nullptr};
std::unique_ptr<QuantumDevice> rtd_qdevice{nullptr};
// device specific routing pass pointer.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment isn't really necessary, there's already the variable name : )

std::unique_ptr<RoutingPass> RUNTIME_ROUTER{nullptr};

RTDeviceStatus status{RTDeviceStatus::Inactive};

Expand Down Expand Up @@ -224,6 +227,20 @@ class RTDevice {
_pl2runtime_device_info(rtd_lib, rtd_name);
}

explicit RTDevice(std::string_view _rtd_lib, std::string_view _rtd_name,
std::string_view _rtd_kwargs, bool _auto_qubit_management,
std::string_view coupling_map_str)
: rtd_lib(_rtd_lib), rtd_name(_rtd_name), rtd_kwargs(_rtd_kwargs),
auto_qubit_management(_auto_qubit_management)
{
// Extract coupling map from the kwargs passed
// If coupling map is provided then it takes in the form {...,'couplingMap' ((a,b),(b,c))}
// else {...,'couplingMap' (a,b,c)}
if (coupling_map_str.find("((") != std::string::npos)
RUNTIME_ROUTER = std::make_unique<RoutingPass>(coupling_map_str);
Comment on lines +239 to +240
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a reminder that we always use braces, even if it's one-line : )

// Don't do this 
if (blah)
   thing;

// Do this
if (blah) {
   thing;
}

@mlxd I'm spreading the good word

_pl2runtime_device_info(rtd_lib, rtd_name);
}

~RTDevice() = default;
RTDevice(const RTDevice &other) = delete;
RTDevice &operator=(const RTDevice &other) = delete;
Expand Down Expand Up @@ -264,6 +281,10 @@ class RTDevice {
void setDeviceStatus(RTDeviceStatus new_status) noexcept { status = new_status; }

bool getQubitManagementMode() { return auto_qubit_management; }
[[nodiscard]] auto getRuntimeRouter() -> std::unique_ptr<RoutingPass> &
{
return RUNTIME_ROUTER;
}

[[nodiscard]] auto getDeviceStatus() const -> RTDeviceStatus { return status; }

Expand Down Expand Up @@ -320,13 +341,14 @@ class ExecutionContext final {
}

[[nodiscard]] auto getOrCreateDevice(std::string_view rtd_lib, std::string_view rtd_name,
std::string_view rtd_kwargs, bool auto_qubit_management)
std::string_view rtd_kwargs, bool auto_qubit_management,
std::string_view coupling_map_str = {})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise std::optional

-> const std::shared_ptr<RTDevice> &
{
std::lock_guard<std::mutex> lock(pool_mu);

auto device =
std::make_shared<RTDevice>(rtd_lib, rtd_name, rtd_kwargs, auto_qubit_management);
auto device = std::make_shared<RTDevice>(rtd_lib, rtd_name, rtd_kwargs,
auto_qubit_management, coupling_map_str);

const size_t key = device_pool.size();
for (size_t i = 0; i < key; i++) {
Expand Down
Loading
Loading