Skip to content

Commit 3911cec

Browse files
authored
Merge pull request #341 from eclipse-qrisp/catalyst_device
Support for different Catalyst backends
2 parents 0b9af6e + ad0f00a commit 3911cec

File tree

3 files changed

+65
-8
lines changed

3 files changed

+65
-8
lines changed

src/qrisp/jasp/evaluation_tools/catalyst_interface.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,16 @@ def jaspr_to_catalyst_jaxpr(jaspr):
9999
return make_jaxpr(eval_jaxpr(jaspr, eqn_evaluator=catalyst_eqn_evaluator))(*args)
100100

101101

102-
def jaspr_to_catalyst_function(jaspr):
102+
def jaspr_to_catalyst_function(jaspr, device=None):
103103

104104
# This function takes a jaspr and returns a function that performs a sequence
105105
# of .bind calls of Catalyst primitives, such that the function (when compiled)
106106
# by Catalyst reproduces the semantics of jaspr
107107

108108
# Initiate Catalyst backend info
109-
device = qml.device("lightning.qubit", wires=0)
109+
if device==None:
110+
device = qml.device("lightning.qubit", wires=0)
111+
110112
backend_info = catalyst.device.extract_backend_info(device)
111113

112114
def catalyst_function(*args):
@@ -138,10 +140,10 @@ def catalyst_function(*args):
138140

139141

140142
@lru_cache(int(1e5))
141-
def jaspr_to_catalyst_qjit(jaspr, function_name="jaspr_function"):
143+
def jaspr_to_catalyst_qjit(jaspr, function_name="jaspr_function", device=None):
142144
# This function takes a jaspr and turns it into a Catalyst QJIT object.
143145
# Perform the code specified by the Catalyst developers
144-
catalyst_function = jaspr_to_catalyst_function(jaspr)
146+
catalyst_function = jaspr_to_catalyst_function(jaspr, device=device)
145147
catalyst_function.__name__ = function_name
146148
jit_object = catalyst.QJIT(catalyst_function, catalyst.CompileOptions())
147149
jit_object.jaxpr = make_jaxpr(catalyst_function)(

src/qrisp/jasp/evaluation_tools/catalyst_qjit.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from qrisp.jasp.jasp_expression import make_jaspr
2121

2222

23-
def qjit(function):
23+
def qjit(function=None, device=None):
2424
"""
2525
Decorator to leverage the jasp + Catalyst infrastructure to compile the given
2626
function to QIR and run it on the Catalyst QIR runtime.
@@ -29,12 +29,40 @@ def qjit(function):
2929
----------
3030
function : callable
3131
A function performing Qrisp code.
32+
device : object
33+
The `PennyLane device <https://docs.pennylane.ai/projects/catalyst/en/stable/dev/devices.html>`_ to execute the function.
34+
The default device is `"lightning.qubit" <https://docs.pennylane.ai/projects/lightning/en/stable/lightning_qubit/device.html>`_,
35+
a fast state-vector qubit simulator.
3236
3337
Returns
3438
-------
3539
callable
3640
A function executing the compiled code.
3741
42+
Notes
43+
-----
44+
45+
Lightning-GPU is compatible with systems featuring NVIDIA Volta (SM 7.0) GPUs or newer.
46+
It is specifically optimized for Linux environments on X86-64 or ARM64 architectures running CUDA-12.
47+
48+
To install Lightning-GPU with NVIDIA CUDA support, the following packages need to be installed
49+
50+
::
51+
52+
pip install custatevec_cu12
53+
pip install pennylane-lightning-gpu
54+
55+
56+
Pre-built wheels for Lightning-AMDGPU are available for AMD MI300 series GPUs and systems running ROCm 7.0 or newer.
57+
58+
::
59+
60+
pip install pennylane-lightning-amdgpu
61+
62+
If the setup uses an older version of ROCm or a different AMD GPU series, Lightning-AMDGPU must be built manually from source.
63+
64+
Installation instructions for different platforms are available at `pennylane.ai/install <https://pennylane.ai/install#high-performance-computing-and-gpus>`_.
65+
3866
Examples
3967
--------
4068
@@ -65,8 +93,31 @@ def test_fun(i):
6593
>>> test_fun(5)
6694
[array(7.25, dtype=float64)]
6795
96+
97+
For executing on "lightning.gpu" we specify the device:
98+
99+
::
100+
101+
import pennylane as qml
102+
from qrisp import *
103+
from qrisp.jasp import qjit
104+
105+
dev = qml.device("lightning.gpu", wires=0)
106+
107+
@qjit(device=dev)
108+
def test_fun(i):
109+
qv = QuantumFloat(i, -2)
110+
with invert():
111+
cx(qv[0], qv[qv.size-1])
112+
h(qv[0])
113+
meas_res = measure(qv)
114+
return meas_res + 3
115+
68116
"""
69117

118+
if function is None:
119+
return lambda x: qjit(x, device=device)
120+
70121
def jitted_function(*args):
71122

72123
if not hasattr(function, "jaspr_dict"):
@@ -79,7 +130,7 @@ def jitted_function(*args):
79130
function.jaspr_dict[signature] = make_jaspr(function)(*args)
80131

81132
return function.jaspr_dict[signature].qjit(
82-
*args, function_name=function.__name__
133+
*args, function_name=function.__name__, device=device
83134
)
84135

85136
return jitted_function

src/qrisp/jasp/jasp_expression/centerclass.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def embedd(self, *args, name=None, inline=False):
670670
qs.abs_qc = new_abs_qc
671671
return res
672672

673-
def qjit(self, *args, function_name="jaspr_function"):
673+
def qjit(self, *args, function_name="jaspr_function", device=None):
674674
"""
675675
Leverages the Catalyst pipeline to compile a QIR representation of
676676
this function and executes that function using the Catalyst QIR runtime.
@@ -679,6 +679,10 @@ def qjit(self, *args, function_name="jaspr_function"):
679679
----------
680680
*args : iterable
681681
The arguments to call the function with.
682+
device : object
683+
The `PennyLane device <https://docs.pennylane.ai/projects/catalyst/en/stable/dev/devices.html>`_ to execute the function.
684+
The default device is `"lightning.qubit" <https://docs.pennylane.ai/projects/lightning/en/stable/lightning_qubit/device.html>`_,
685+
a fast state-vector qubit simulator.
682686
683687
Returns
684688
-------
@@ -691,7 +695,7 @@ def qjit(self, *args, function_name="jaspr_function"):
691695
jaspr_to_catalyst_qjit,
692696
)
693697

694-
qjit_obj = jaspr_to_catalyst_qjit(flattened_jaspr, function_name=function_name)
698+
qjit_obj = jaspr_to_catalyst_qjit(flattened_jaspr, function_name=function_name, device=device)
695699
res = qjit_obj.compiled_function(*args)
696700
if not isinstance(res, (tuple, list)):
697701
return res

0 commit comments

Comments
 (0)