Skip to content

Commit 6beb733

Browse files
committed
Release v1.0.9 (#214)
1 parent 94aa57c commit 6beb733

File tree

16 files changed

+957
-470
lines changed

16 files changed

+957
-470
lines changed

docs/src/installation.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The Python package can be installed with pip by simply running
1111
1212
pip install sphericart
1313
14-
This basic package makes use of NumPy. Implementations supporting PyTorch and JAX can be installed with
14+
This basic package makes use of NumPy and CuPy. Implementations supporting PyTorch and JAX can be installed with
1515

1616
.. code-block:: bash
1717
@@ -37,9 +37,10 @@ Before installing the JAX version of ``sphericart``, make sure you already have
3737
library installed according to the `official JAX installation instructions
3838
<https://jax.readthedocs.io/en/latest/installation.html>`_.
3939

40-
In addition, if you want to use the CUDA functionalities of sphericart (either with torch
41-
or JAX), make sure you have installed the CUDA toolkit and set up the environment variables
42-
``CUDA_HOME``, ``LD_LIBRARY_FLAGS``, and ``PATH`` accordingly.
40+
In addition, if you want to use the CUDA functionalities of sphericart (with CuPy,
41+
torch or JAX), make sure you have installed the CUDA toolkit
42+
and set up the environment variables ``CUDA_HOME``, ``LD_LIBRARY_FLAGS``, and ``PATH``
43+
accordingly. In case you want to use CuPy, it should be installed separately.
4344

4445

4546
Julia package

docs/src/python-api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
Python API
22
==========
33

4+
The Python calculators accept either NumPy arrays on CPU or CuPy arrays on CUDA.
5+
46
.. autoclass:: sphericart.SphericalHarmonics
57
:members:
68

docs/src/python-examples.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,8 @@ difference that, in line with pythonic mores, the
88

99
.. literalinclude:: ../../examples/python/example.py
1010
:language: python
11+
12+
The same calculators also accept CuPy arrays:
13+
14+
.. literalinclude:: ../../examples/python/cupy.py
15+
:language: python

examples/python/cupy.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import argparse
2+
3+
import cupy as cp
4+
import numpy as np
5+
6+
import sphericart
7+
8+
docstring = """
9+
An example of the CuPy interface of the `sphericart` library.
10+
11+
Computes Cartesian spherical harmonics on the GPU for a random array of 3D
12+
points and compares the result against the NumPy CPU backend.
13+
"""
14+
15+
16+
def sphericart_cupy_example(l_max=10, n_samples=10000):
17+
xyz_cpu = np.random.rand(n_samples, 3)
18+
xyz_gpu = cp.asarray(xyz_cpu)
19+
20+
sh_calculator = sphericart.SphericalHarmonics(l_max)
21+
22+
sh_cpu = sh_calculator.compute(xyz_cpu)
23+
sh_gpu = sh_calculator.compute(xyz_gpu)
24+
25+
print(
26+
"CPU vs GPU relative error: %12.8e"
27+
% (np.linalg.norm(sh_cpu - cp.asnumpy(sh_gpu)) / np.linalg.norm(sh_cpu))
28+
)
29+
30+
31+
if __name__ == "__main__":
32+
parser = argparse.ArgumentParser(description=docstring)
33+
34+
parser.add_argument("-l", type=int, default=10, help="maximum angular momentum")
35+
parser.add_argument("-s", type=int, default=1000, help="number of samples")
36+
37+
args = parser.parse_args()
38+
39+
sphericart_cupy_example(args.l, args.s)

python/src/sphericart/_c_lib.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,22 @@ class sphericart_solid_harmonics_calculator_f_t(ctypes.c_void_p):
2222
pass
2323

2424

25+
class sphericart_cuda_spherical_harmonics_calculator_t(ctypes.c_void_p):
26+
pass
27+
28+
29+
class sphericart_cuda_spherical_harmonics_calculator_f_t(ctypes.c_void_p):
30+
pass
31+
32+
33+
class sphericart_cuda_solid_harmonics_calculator_t(ctypes.c_void_p):
34+
pass
35+
36+
37+
class sphericart_cuda_solid_harmonics_calculator_f_t(ctypes.c_void_p):
38+
pass
39+
40+
2541
def setup_functions(lib):
2642
lib.sphericart_spherical_harmonics_new.restype = (
2743
sphericart_spherical_harmonics_calculator_t
@@ -221,6 +237,168 @@ def setup_functions(lib):
221237
sphericart_solid_harmonics_calculator_f_t,
222238
]
223239

240+
lib.sphericart_cuda_spherical_harmonics_new.restype = (
241+
sphericart_cuda_spherical_harmonics_calculator_t
242+
)
243+
lib.sphericart_cuda_spherical_harmonics_new.argtypes = [ctypes.c_size_t]
244+
245+
lib.sphericart_cuda_spherical_harmonics_new_f.restype = (
246+
sphericart_cuda_spherical_harmonics_calculator_f_t
247+
)
248+
lib.sphericart_cuda_spherical_harmonics_new_f.argtypes = [ctypes.c_size_t]
249+
250+
lib.sphericart_cuda_spherical_harmonics_delete.restype = None
251+
lib.sphericart_cuda_spherical_harmonics_delete.argtypes = [
252+
sphericart_cuda_spherical_harmonics_calculator_t
253+
]
254+
255+
lib.sphericart_cuda_spherical_harmonics_delete_f.restype = None
256+
lib.sphericart_cuda_spherical_harmonics_delete_f.argtypes = [
257+
sphericart_cuda_spherical_harmonics_calculator_f_t
258+
]
259+
260+
lib.sphericart_cuda_spherical_harmonics_compute_array.restype = None
261+
lib.sphericart_cuda_spherical_harmonics_compute_array.argtypes = [
262+
sphericart_cuda_spherical_harmonics_calculator_t,
263+
ctypes.POINTER(ctypes.c_double),
264+
ctypes.c_size_t,
265+
ctypes.POINTER(ctypes.c_double),
266+
ctypes.c_void_p,
267+
]
268+
269+
lib.sphericart_cuda_spherical_harmonics_compute_array_f.restype = None
270+
lib.sphericart_cuda_spherical_harmonics_compute_array_f.argtypes = [
271+
sphericart_cuda_spherical_harmonics_calculator_f_t,
272+
ctypes.POINTER(ctypes.c_float),
273+
ctypes.c_size_t,
274+
ctypes.POINTER(ctypes.c_float),
275+
ctypes.c_void_p,
276+
]
277+
278+
lib.sphericart_cuda_spherical_harmonics_compute_array_with_gradients.restype = None
279+
lib.sphericart_cuda_spherical_harmonics_compute_array_with_gradients.argtypes = [
280+
sphericart_cuda_spherical_harmonics_calculator_t,
281+
ctypes.POINTER(ctypes.c_double),
282+
ctypes.c_size_t,
283+
ctypes.POINTER(ctypes.c_double),
284+
ctypes.POINTER(ctypes.c_double),
285+
ctypes.c_void_p,
286+
]
287+
288+
lib.sphericart_cuda_spherical_harmonics_compute_array_with_gradients_f.restype = (
289+
None
290+
)
291+
lib.sphericart_cuda_spherical_harmonics_compute_array_with_gradients_f.argtypes = [
292+
sphericart_cuda_spherical_harmonics_calculator_f_t,
293+
ctypes.POINTER(ctypes.c_float),
294+
ctypes.c_size_t,
295+
ctypes.POINTER(ctypes.c_float),
296+
ctypes.POINTER(ctypes.c_float),
297+
ctypes.c_void_p,
298+
]
299+
300+
lib.sphericart_cuda_spherical_harmonics_compute_array_with_hessians.restype = None
301+
lib.sphericart_cuda_spherical_harmonics_compute_array_with_hessians.argtypes = [
302+
sphericart_cuda_spherical_harmonics_calculator_t,
303+
ctypes.POINTER(ctypes.c_double),
304+
ctypes.c_size_t,
305+
ctypes.POINTER(ctypes.c_double),
306+
ctypes.POINTER(ctypes.c_double),
307+
ctypes.POINTER(ctypes.c_double),
308+
ctypes.c_void_p,
309+
]
310+
311+
lib.sphericart_cuda_spherical_harmonics_compute_array_with_hessians_f.restype = None
312+
lib.sphericart_cuda_spherical_harmonics_compute_array_with_hessians_f.argtypes = [
313+
sphericart_cuda_spherical_harmonics_calculator_f_t,
314+
ctypes.POINTER(ctypes.c_float),
315+
ctypes.c_size_t,
316+
ctypes.POINTER(ctypes.c_float),
317+
ctypes.POINTER(ctypes.c_float),
318+
ctypes.POINTER(ctypes.c_float),
319+
ctypes.c_void_p,
320+
]
321+
322+
lib.sphericart_cuda_solid_harmonics_new.restype = (
323+
sphericart_cuda_solid_harmonics_calculator_t
324+
)
325+
lib.sphericart_cuda_solid_harmonics_new.argtypes = [ctypes.c_size_t]
326+
327+
lib.sphericart_cuda_solid_harmonics_new_f.restype = (
328+
sphericart_cuda_solid_harmonics_calculator_f_t
329+
)
330+
lib.sphericart_cuda_solid_harmonics_new_f.argtypes = [ctypes.c_size_t]
331+
332+
lib.sphericart_cuda_solid_harmonics_delete.restype = None
333+
lib.sphericart_cuda_solid_harmonics_delete.argtypes = [
334+
sphericart_cuda_solid_harmonics_calculator_t
335+
]
336+
337+
lib.sphericart_cuda_solid_harmonics_delete_f.restype = None
338+
lib.sphericart_cuda_solid_harmonics_delete_f.argtypes = [
339+
sphericart_cuda_solid_harmonics_calculator_f_t
340+
]
341+
342+
lib.sphericart_cuda_solid_harmonics_compute_array.restype = None
343+
lib.sphericart_cuda_solid_harmonics_compute_array.argtypes = [
344+
sphericart_cuda_solid_harmonics_calculator_t,
345+
ctypes.POINTER(ctypes.c_double),
346+
ctypes.c_size_t,
347+
ctypes.POINTER(ctypes.c_double),
348+
ctypes.c_void_p,
349+
]
350+
351+
lib.sphericart_cuda_solid_harmonics_compute_array_f.restype = None
352+
lib.sphericart_cuda_solid_harmonics_compute_array_f.argtypes = [
353+
sphericart_cuda_solid_harmonics_calculator_f_t,
354+
ctypes.POINTER(ctypes.c_float),
355+
ctypes.c_size_t,
356+
ctypes.POINTER(ctypes.c_float),
357+
ctypes.c_void_p,
358+
]
359+
360+
lib.sphericart_cuda_solid_harmonics_compute_array_with_gradients.restype = None
361+
lib.sphericart_cuda_solid_harmonics_compute_array_with_gradients.argtypes = [
362+
sphericart_cuda_solid_harmonics_calculator_t,
363+
ctypes.POINTER(ctypes.c_double),
364+
ctypes.c_size_t,
365+
ctypes.POINTER(ctypes.c_double),
366+
ctypes.POINTER(ctypes.c_double),
367+
ctypes.c_void_p,
368+
]
369+
370+
lib.sphericart_cuda_solid_harmonics_compute_array_with_gradients_f.restype = None
371+
lib.sphericart_cuda_solid_harmonics_compute_array_with_gradients_f.argtypes = [
372+
sphericart_cuda_solid_harmonics_calculator_f_t,
373+
ctypes.POINTER(ctypes.c_float),
374+
ctypes.c_size_t,
375+
ctypes.POINTER(ctypes.c_float),
376+
ctypes.POINTER(ctypes.c_float),
377+
ctypes.c_void_p,
378+
]
379+
380+
lib.sphericart_cuda_solid_harmonics_compute_array_with_hessians.restype = None
381+
lib.sphericart_cuda_solid_harmonics_compute_array_with_hessians.argtypes = [
382+
sphericart_cuda_solid_harmonics_calculator_t,
383+
ctypes.POINTER(ctypes.c_double),
384+
ctypes.c_size_t,
385+
ctypes.POINTER(ctypes.c_double),
386+
ctypes.POINTER(ctypes.c_double),
387+
ctypes.POINTER(ctypes.c_double),
388+
ctypes.c_void_p,
389+
]
390+
391+
lib.sphericart_cuda_solid_harmonics_compute_array_with_hessians_f.restype = None
392+
lib.sphericart_cuda_solid_harmonics_compute_array_with_hessians_f.argtypes = [
393+
sphericart_cuda_solid_harmonics_calculator_f_t,
394+
ctypes.POINTER(ctypes.c_float),
395+
ctypes.c_size_t,
396+
ctypes.POINTER(ctypes.c_float),
397+
ctypes.POINTER(ctypes.c_float),
398+
ctypes.POINTER(ctypes.c_float),
399+
ctypes.c_void_p,
400+
]
401+
224402

225403
class LibraryFinder(object):
226404
def __init__(self):

python/src/sphericart/_dispatch.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import ctypes
2+
3+
import numpy as np
4+
5+
6+
try:
7+
import cupy as cp
8+
from cupy import ndarray as cupy_ndarray
9+
except ImportError:
10+
cp = None
11+
12+
class cupy_ndarray: # type: ignore[no-redef]
13+
pass
14+
15+
16+
def is_array(array):
17+
return isinstance(array, np.ndarray) or isinstance(array, cupy_ndarray)
18+
19+
20+
def is_cupy_array(array):
21+
return isinstance(array, cupy_ndarray)
22+
23+
24+
def make_contiguous(array):
25+
if isinstance(array, np.ndarray):
26+
return np.ascontiguousarray(array)
27+
if isinstance(array, cupy_ndarray):
28+
return cp.ascontiguousarray(array)
29+
raise TypeError(f"only numpy and cupy arrays are supported, found {type(array)}")
30+
31+
32+
def empty_like(shape, array):
33+
if isinstance(array, np.ndarray):
34+
return np.empty(shape, dtype=array.dtype)
35+
if isinstance(array, cupy_ndarray):
36+
return cp.empty(shape, dtype=array.dtype)
37+
raise TypeError(f"only numpy and cupy arrays are supported, found {type(array)}")
38+
39+
40+
def get_pointer(array):
41+
if array.dtype == np.float32:
42+
ptr_type = ctypes.POINTER(ctypes.c_float)
43+
elif array.dtype == np.float64:
44+
ptr_type = ctypes.POINTER(ctypes.c_double)
45+
else:
46+
raise TypeError(
47+
f"only float32 and float64 arrays are supported, found {array.dtype}"
48+
)
49+
50+
if isinstance(array, np.ndarray):
51+
return array.ctypes.data_as(ptr_type)
52+
if isinstance(array, cupy_ndarray):
53+
return ctypes.cast(array.data.ptr, ptr_type)
54+
raise TypeError(f"only numpy and cupy arrays are supported, found {type(array)}")
55+
56+
57+
def get_cuda_stream(array):
58+
if not isinstance(array, cupy_ndarray):
59+
return None
60+
return ctypes.c_void_p(cp.cuda.get_current_stream().ptr)

0 commit comments

Comments
 (0)