-
Notifications
You must be signed in to change notification settings - Fork 57
Runtime router #2117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Runtime router #2117
Changes from all commits
a57e6c5
3d8695f
3d2ed3b
ab18f54
72036f4
83d5b33
70a11b4
7cffa7b
3eb866f
1ca60b2
4ae4a46
4227d89
f4c7dc9
c1d738f
db26be6
5142ce1
28016b3
fa429f2
5864e30
17d42c8
a5ac864
4c4dc80
f55c8c5
23b0f6a
2dfc770
31b014a
f078e39
428779c
10d263c
b52cfbb
d40f0de
cd9ca02
dbc635a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -142,7 +142,7 @@ class BackendInfo: | |||||
|
|
||||||
| # pylint: disable=too-many-branches | ||||||
| @debug_logger | ||||||
| def extract_backend_info(device: qml.devices.QubitDevice) -> BackendInfo: | ||||||
| def extract_backend_info(device: qml.devices.QubitDevice, device_capabilities=None) -> BackendInfo: | ||||||
| """Extract the backend info from a quantum device. The device is expected to carry a reference | ||||||
| to a valid TOML config file.""" | ||||||
|
|
||||||
|
|
@@ -192,6 +192,7 @@ def extract_backend_info(device: qml.devices.QubitDevice) -> BackendInfo: | |||||
| for k, v in getattr(device, "device_kwargs", {}).items(): | ||||||
| if k not in device_kwargs: # pragma: no branch | ||||||
| device_kwargs[k] = v | ||||||
| device_kwargs["coupling_map"] = getattr(device_capabilities, "coupling_map") | ||||||
|
|
||||||
| return BackendInfo(dname, device_name, device_lpath, device_kwargs) | ||||||
|
|
||||||
|
|
@@ -308,9 +309,9 @@ class QJITDevice(qml.devices.Device): | |||||
|
|
||||||
| @staticmethod | ||||||
| @debug_logger | ||||||
| def extract_backend_info(device) -> BackendInfo: | ||||||
| def extract_backend_info(device, device_capabilities=None) -> BackendInfo: | ||||||
| """Wrapper around extract_backend_info in the runtime module.""" | ||||||
| return extract_backend_info(device) | ||||||
| return extract_backend_info(device, device_capabilities) | ||||||
|
|
||||||
| @debug_logger_init | ||||||
| def __init__(self, original_device): | ||||||
|
|
@@ -319,14 +320,40 @@ def __init__(self, original_device): | |||||
| for key, value in original_device.__dict__.items(): | ||||||
| self.__setattr__(key, value) | ||||||
|
|
||||||
| check_device_wires(original_device.wires) | ||||||
|
|
||||||
| super().__init__(wires=original_device.wires) | ||||||
| if (original_device.wires is not None) and any( | ||||||
| isinstance(wire_label, tuple) and (len(wire_label) >= 2) | ||||||
| for wire_label in original_device.wires.labels | ||||||
| ): | ||||||
| wires_from_cmap = set() | ||||||
| for wire_label in original_device.wires.labels: | ||||||
| wires_from_cmap.add(wire_label[0]) | ||||||
| wires_from_cmap.add(wire_label[1]) | ||||||
| wires_from_cmap = qml.wires.Wires(list(wires_from_cmap)) | ||||||
| # check_device_wires(wires_from_cmap) not called | ||||||
| # since automatic qubit management | ||||||
| super().__init__(wires=wires_from_cmap, shots=original_device.shots) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. core PL has deprecated setting shots on devices
Suggested change
|
||||||
| else: | ||||||
| check_device_wires(original_device.wires) | ||||||
| super().__init__(wires=original_device.wires, shots=original_device.shots) | ||||||
|
|
||||||
| # Capability loading | ||||||
| device_capabilities = get_device_capabilities(original_device, self.original_device.shots) | ||||||
| device_capabilities = get_device_capabilities(original_device) | ||||||
|
|
||||||
| # TODO: This is a temporary measure to ensure consistency of behaviour. Remove this | ||||||
| # when customizable multi-pathway decomposition is implemented. (Epic 74474) | ||||||
| if hasattr(original_device, "_to_matrix_ops"): | ||||||
| _to_matrix_ops = getattr(original_device, "_to_matrix_ops") | ||||||
| setattr(device_capabilities, "to_matrix_ops", _to_matrix_ops) | ||||||
| if _to_matrix_ops and not device_capabilities.supports_operation("QubitUnitary"): | ||||||
| raise CompileError( | ||||||
| "The device that specifies to_matrix_ops must support QubitUnitary." | ||||||
| ) | ||||||
| if original_device.wires is not None: | ||||||
| setattr(device_capabilities, "coupling_map", original_device.wires.labels) | ||||||
|
Comment on lines
+351
to
+352
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A couple thoughts here:
If the entire purpose is to just have a UI to take in the wire map from the user, I would suggest taking in it through its own dedicated decorator. The decorator would be on the device, and its entire purpose is just set the coupling map: dev = qml.device("lightning.qubit", wires=3)
dev = catalyst.set_coupling_map(dev, [(0,1),(1,2),(2,3)])
@qjit
@qml.qnode(dev)
def circuit():
...This way, your functionality is properly modulated. The Implementation wise, this also prevents the need to piggy back the map on the capabilities (which I see (a) you need to pass into other functions now, and (b) has its own complicated song and dance already). Your function could just add the map to the qjit_device.py
def extract_backend_info(device: qml.devices.QubitDevice):
device_kwargs = {}
...
for k, v in getattr(device, "device_kwargs", {}).items():
if k not in device_kwargs: # pragma: no branch
device_kwargs[k] = v
return BackendInfo(...., device_kwargs)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Of course, a restriction of doing it through kwargs in the first place is that the coupling map must be static. However as a first draft we don't need to be concerned over this. The verification in |
||||||
| else: | ||||||
| setattr(device_capabilities, "coupling_map", None) | ||||||
|
|
||||||
| backend = QJITDevice.extract_backend_info(original_device) | ||||||
| backend = QJITDevice.extract_backend_info(original_device, device_capabilities) | ||||||
|
|
||||||
| self.backend_name = backend.c_interface_name | ||||||
| self.backend_lib = backend.lpath | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| # Copyright 2025 Xanadu Quantum Technologies Inc. | ||
|
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
|
|
||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Integration tests for routing at runtime""" | ||
|
|
||
| from functools import partial | ||
|
|
||
| import pennylane as qml | ||
| import pytest | ||
| from pennylane import numpy as np | ||
| from pennylane.transforms.transpile import transpile | ||
|
|
||
|
|
||
| def qfunc_ops(wires, x, y, z): | ||
| qml.Hadamard(wires=wires[0]) | ||
| qml.RZ(z, wires=wires[2]) | ||
| qml.CNOT(wires=[wires[2], wires[0]]) | ||
| qml.CNOT(wires=[wires[1], wires[0]]) | ||
| qml.RX(x, wires=wires[0]) | ||
| qml.CNOT(wires=[wires[0], wires[2]]) | ||
| qml.RZ(-z, wires=wires[2]) | ||
| qml.RX(y, wires=wires[0]) | ||
| qml.PauliY(wires=wires[2]) | ||
| qml.CY(wires=[wires[1], wires[2]]) | ||
|
|
||
|
|
||
| # pylint: disable=too-many-public-methods | ||
| class TestRouting: | ||
| """Unit tests for testing routing function at runtime""" | ||
|
|
||
| all_to__all_device = qml.device("lightning.qubit") | ||
| linear_device = qml.device("lightning.qubit", wires=[(0, 1), (1, 2)]) | ||
|
|
||
| input_devices = ((all_to__all_device, linear_device),) | ||
|
|
||
| @pytest.mark.parametrize("all_to__all_device, linear_device", input_devices) | ||
| def test_state_invariance_under_routing(self, all_to__all_device, linear_device): | ||
| """test that transpile does not alter output for state measurement""" | ||
| def circuit(wires, x, y, z): | ||
| qfunc_ops(wires, x, y, z) | ||
| return qml.state() | ||
|
|
||
| all_to_all_qnode = qml.qjit(qml.QNode(circuit, all_to__all_device)) | ||
| linear_qnode = qml.qjit(qml.QNode(circuit, linear_device)) | ||
|
|
||
| assert np.allclose( | ||
| all_to_all_qnode([0, 1, 2], 0.1, 0.2, 0.3), linear_qnode([0, 1, 2], 0.1, 0.2, 0.3) | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("all_to__all_device, linear_device", input_devices) | ||
| def test_probs_invariance_under_routing(self, all_to__all_device, linear_device): | ||
| """test that transpile does not alter output for probs measurement""" | ||
| def circuit(wires, x, y, z): | ||
| qfunc_ops(wires, x, y, z) | ||
| return qml.probs() | ||
|
|
||
| all_to_all_qnode = qml.qjit(qml.QNode(circuit, all_to__all_device)) | ||
| linear_qnode = qml.qjit(qml.QNode(circuit, linear_device)) | ||
|
|
||
| assert np.allclose( | ||
| all_to_all_qnode([0, 1, 2], 0.1, 0.2, 0.3), linear_qnode([0, 1, 2], 0.1, 0.2, 0.3) | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("all_to__all_device, linear_device", input_devices) | ||
| def test_sample_invariance_under_routing(self, all_to__all_device, linear_device): | ||
| """test that transpile does not alter output for sample measurement""" | ||
| def circuit(wires, x, y, z): | ||
| qfunc_ops(wires, x, y, z) | ||
| return qml.sample() | ||
|
|
||
| all_to_all_qnode = qml.qjit( | ||
| partial(qml.set_shots, shots=10)(qml.QNode(circuit, all_to__all_device)), seed=37 | ||
| ) | ||
| linear_qnode = qml.qjit( | ||
| partial(qml.set_shots, shots=10)(qml.QNode(circuit, linear_device)), seed=37 | ||
| ) | ||
| assert np.allclose( | ||
| all_to_all_qnode([0, 1, 2], 0.1, 0.2, 0.3), linear_qnode([0, 1, 2], 0.1, 0.2, 0.3) | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("all_to__all_device, linear_device", input_devices) | ||
| def test_counts_invariance_under_routing(self, all_to__all_device, linear_device): | ||
| """test that transpile does not alter output for counts measurement""" | ||
| def circuit(wires, x, y, z): | ||
| qfunc_ops(wires, x, y, z) | ||
| return qml.counts() | ||
|
|
||
| all_to_all_qnode = qml.qjit( | ||
| partial(qml.set_shots, shots=10)(qml.QNode(circuit, all_to__all_device)), seed=37 | ||
| ) | ||
| linear_qnode = qml.qjit( | ||
| partial(qml.set_shots, shots=10)(qml.QNode(circuit, linear_device)), seed=37 | ||
| ) | ||
| assert np.allclose( | ||
| all_to_all_qnode([0, 1, 2], 0.1, 0.2, 0.3), linear_qnode([0, 1, 2], 0.1, 0.2, 0.3) | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("all_to__all_device, linear_device", input_devices) | ||
| def test_expvals_invariance_under_routing(self, all_to__all_device, linear_device): | ||
| """test that transpile does not alter output for expectation value measurement""" | ||
| def circuit(wires, x, y, z): | ||
| qfunc_ops(wires, x, y, z) | ||
| return qml.expval(qml.X(0) @ qml.Y(1)), qml.var(qml.Z(2)) | ||
|
|
||
| all_to_all_qnode = qml.qjit(qml.QNode(circuit, all_to__all_device)) | ||
| linear_qnode = qml.qjit(qml.QNode(circuit, linear_device)) | ||
| assert np.allclose( | ||
| all_to_all_qnode([0, 1, 2], 0.1, 0.2, 0.3), linear_qnode([0, 1, 2], 0.1, 0.2, 0.3) | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |
|
|
||
| #include "Exception.hpp" | ||
| #include "QuantumDevice.hpp" | ||
| #include "Routing.hpp" | ||
| #include "Types.h" | ||
|
|
||
| namespace Catalyst::Runtime { | ||
|
|
@@ -170,6 +171,8 @@ class RTDevice { | |
|
|
||
| std::unique_ptr<SharedLibraryManager> rtd_dylib{nullptr}; | ||
| std::unique_ptr<QuantumDevice> rtd_qdevice{nullptr}; | ||
| // device specific routing pass pointer. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment isn't really necessary, there's already the variable name : ) |
||
| std::unique_ptr<RoutingPass> RUNTIME_ROUTER{nullptr}; | ||
|
|
||
| RTDeviceStatus status{RTDeviceStatus::Inactive}; | ||
|
|
||
|
|
@@ -224,6 +227,20 @@ class RTDevice { | |
| _pl2runtime_device_info(rtd_lib, rtd_name); | ||
| } | ||
|
|
||
| explicit RTDevice(std::string_view _rtd_lib, std::string_view _rtd_name, | ||
| std::string_view _rtd_kwargs, bool _auto_qubit_management, | ||
| std::string_view coupling_map_str) | ||
| : rtd_lib(_rtd_lib), rtd_name(_rtd_name), rtd_kwargs(_rtd_kwargs), | ||
| auto_qubit_management(_auto_qubit_management) | ||
| { | ||
| // Extract coupling map from the kwargs passed | ||
| // If coupling map is provided then it takes in the form {...,'couplingMap' ((a,b),(b,c))} | ||
| // else {...,'couplingMap' (a,b,c)} | ||
| if (coupling_map_str.find("((") != std::string::npos) | ||
| RUNTIME_ROUTER = std::make_unique<RoutingPass>(coupling_map_str); | ||
|
Comment on lines
+239
to
+240
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a reminder that we always use braces, even if it's one-line : ) // Don't do this
if (blah)
thing;
// Do this
if (blah) {
thing;
}@mlxd I'm spreading the good word |
||
| _pl2runtime_device_info(rtd_lib, rtd_name); | ||
| } | ||
|
|
||
| ~RTDevice() = default; | ||
| RTDevice(const RTDevice &other) = delete; | ||
| RTDevice &operator=(const RTDevice &other) = delete; | ||
|
|
@@ -264,6 +281,10 @@ class RTDevice { | |
| void setDeviceStatus(RTDeviceStatus new_status) noexcept { status = new_status; } | ||
|
|
||
| bool getQubitManagementMode() { return auto_qubit_management; } | ||
| [[nodiscard]] auto getRuntimeRouter() -> std::unique_ptr<RoutingPass> & | ||
| { | ||
| return RUNTIME_ROUTER; | ||
| } | ||
|
|
||
| [[nodiscard]] auto getDeviceStatus() const -> RTDeviceStatus { return status; } | ||
|
|
||
|
|
@@ -320,13 +341,14 @@ class ExecutionContext final { | |
| } | ||
|
|
||
| [[nodiscard]] auto getOrCreateDevice(std::string_view rtd_lib, std::string_view rtd_name, | ||
| std::string_view rtd_kwargs, bool auto_qubit_management) | ||
| std::string_view rtd_kwargs, bool auto_qubit_management, | ||
| std::string_view coupling_map_str = {}) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise std::optional |
||
| -> const std::shared_ptr<RTDevice> & | ||
| { | ||
| std::lock_guard<std::mutex> lock(pool_mu); | ||
|
|
||
| auto device = | ||
| std::make_shared<RTDevice>(rtd_lib, rtd_name, rtd_kwargs, auto_qubit_management); | ||
| auto device = std::make_shared<RTDevice>(rtd_lib, rtd_name, rtd_kwargs, | ||
| auto_qubit_management, coupling_map_str); | ||
|
|
||
| const size_t key = device_pool.size(); | ||
| for (size_t i = 0; i < key; i++) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think your functionality is restricted to automatic management. The number of wires used by the qnode is still known, it's just that there's some connectivity between them.
The action by your functionality is just inserting SWAP gates right? So it doesn't require allocation of new qubits. The IR would look like the IR with a known number of wires.