Skip to content

Commit 55e7da2

Browse files
committed
Use nanobind::ndarray for axpby interface
1 parent ffd01d3 commit 55e7da2

File tree

4 files changed

+63
-89
lines changed

4 files changed

+63
-89
lines changed

Wrappers/Python/cil/framework/data_container.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,6 @@ def _axpby(self, a, b, y, out, dtype=numpy.float32, num_threads=NUM_THREADS):
553553
:type num_threads: int, optional, default 1/2 CPU of the system
554554
'''
555555

556-
c_float_p = ctypes.POINTER(ctypes.c_float)
557-
c_double_p = ctypes.POINTER(ctypes.c_double)
558556

559557
#convert a and b to numpy arrays and get the reference to the data (length = 1 or ndx.size)
560558
try:
@@ -593,19 +591,19 @@ def _axpby(self, a, b, y, out, dtype=numpy.float32, num_threads=NUM_THREADS):
593591
ndb = ndb.astype(dtype, casting='same_kind')
594592

595593
if dtype == numpy.float32:
596-
x_p = ndx.ctypes.data_as(c_float_p)
597-
y_p = ndy.ctypes.data_as(c_float_p)
598-
out_p = ndout.ctypes.data_as(c_float_p)
599-
a_p = nda.ctypes.data_as(c_float_p)
600-
b_p = ndb.ctypes.data_as(c_float_p)
594+
x_p = ndx
595+
y_p = ndy
596+
out_p = ndout
597+
a_p = nda
598+
b_p = ndb
601599
f = cilacc.saxpby
602600

603601
elif dtype == numpy.float64:
604-
x_p = ndx.ctypes.data_as(c_double_p)
605-
y_p = ndy.ctypes.data_as(c_double_p)
606-
out_p = ndout.ctypes.data_as(c_double_p)
607-
a_p = nda.ctypes.data_as(c_double_p)
608-
b_p = ndb.ctypes.data_as(c_double_p)
602+
x_p = ndx
603+
y_p = ndy
604+
out_p = ndout
605+
a_p = nda
606+
b_p = ndb
609607
f = cilacc.daxpby
610608

611609
else:

Wrappers/Python/cil/optimisation/operators/GradientOperator.py

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -269,60 +269,8 @@ def adjoint(self, x, out=None):
269269

270270
import ctypes, platform
271271
from ctypes import util
272-
# check for the extension
273-
if platform.system() == 'Linux':
274-
dll = 'libcilacc.so'
275-
elif platform.system() == 'Windows':
276-
dll_file = 'cilacc.dll'
277-
dll = util.find_library(dll_file)
278-
elif platform.system() == 'Darwin':
279-
dll = 'libcilacc.dylib'
280-
else:
281-
raise ValueError('Not supported platform, ', platform.system())
282-
283-
cilacc = ctypes.cdll.LoadLibrary(dll)
284-
285-
c_float_p = ctypes.POINTER(ctypes.c_float)
286-
287-
cilacc.openMPtest.restypes = ctypes.c_int32
288-
cilacc.openMPtest.argtypes = [ctypes.c_int32]
289-
290-
cilacc.fdiff4D.restype = ctypes.c_int32
291-
cilacc.fdiff4D.argtypes = [ctypes.POINTER(ctypes.c_float),
292-
ctypes.POINTER(ctypes.c_float),
293-
ctypes.POINTER(ctypes.c_float),
294-
ctypes.POINTER(ctypes.c_float),
295-
ctypes.POINTER(ctypes.c_float),
296-
ctypes.c_size_t,
297-
ctypes.c_size_t,
298-
ctypes.c_size_t,
299-
ctypes.c_size_t,
300-
ctypes.c_int32,
301-
ctypes.c_int32,
302-
ctypes.c_int32]
303-
304-
cilacc.fdiff3D.restype = ctypes.c_int32
305-
cilacc.fdiff3D.argtypes = [ctypes.POINTER(ctypes.c_float),
306-
ctypes.POINTER(ctypes.c_float),
307-
ctypes.POINTER(ctypes.c_float),
308-
ctypes.POINTER(ctypes.c_float),
309-
ctypes.c_size_t,
310-
ctypes.c_size_t,
311-
ctypes.c_size_t,
312-
ctypes.c_int32,
313-
ctypes.c_int32,
314-
ctypes.c_int32]
315-
316-
cilacc.fdiff2D.restype = ctypes.c_int32
317-
cilacc.fdiff2D.argtypes = [ctypes.POINTER(ctypes.c_float),
318-
ctypes.POINTER(ctypes.c_float),
319-
ctypes.POINTER(ctypes.c_float),
320-
ctypes.c_size_t,
321-
ctypes.c_size_t,
322-
ctypes.c_int32,
323-
ctypes.c_int32,
324-
ctypes.c_int32]
325272

273+
import cil.cilacc as cilacc
326274

327275
class Gradient_C(LinearOperator):
328276

@@ -391,7 +339,7 @@ def ndarray_as_c_pointer(ndx):
391339
def direct(self, x, out=None):
392340

393341
ndx = np.asarray(x.as_array(), dtype=np.float32, order='C')
394-
x_p = Gradient_C.ndarray_as_c_pointer(ndx)
342+
x_p = ndx
395343

396344
if out is None:
397345
out = self.range_geometry().allocate(None)
@@ -404,7 +352,7 @@ def direct(self, x, out=None):
404352
ndout.insert(ind, out.get_item(0).as_array()) #insert channels dc at correct point for channel data
405353

406354
#pass list of all arguments
407-
arg1 = [Gradient_C.ndarray_as_c_pointer(ndout[i]) for i in range(len(ndout))]
355+
arg1 = [ndout[i] for i in range(len(ndout))]
408356
arg2 = [el for el in self.domain_shape]
409357
args = arg1 + arg2 + [self.bnd_cond, 1, self.num_threads]
410358
status = self.fd(x_p, *args)
@@ -437,7 +385,7 @@ def adjoint(self, x, out=None):
437385
out = self.domain_geometry().allocate(None)
438386

439387
ndout = np.asarray(out.as_array(), dtype=np.float32, order='C')
440-
out_p = Gradient_C.ndarray_as_c_pointer(ndout)
388+
out_p = ndout
441389

442390
if self.split is False:
443391
ndx = [el.as_array() for el in x.containers]
@@ -450,7 +398,7 @@ def adjoint(self, x, out=None):
450398
if el != 1:
451399
ndx[i]/=el
452400

453-
arg1 = [Gradient_C.ndarray_as_c_pointer(ndx[i]) for i in range(self.ndim)]
401+
arg1 = [ndx[i] for i in range(self.ndim)]
454402
arg2 = [el for el in self.domain_shape]
455403
args = arg1 + arg2 + [self.bnd_cond, 0, self.num_threads]
456404

src/cilacc/axpby.cpp

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,13 @@ int daxpby_asbv(const double * x, const double * y, double * out, double a, cons
104104
}
105105
return 0;
106106
}
107-
int saxpby(const float * x, const float * y, float * out, const float *a, int a_type, const float* b, int b_type, int64 size, int nThreads)
107+
108+
int saxpby(DataFloatInput x, DataFloatInput y,
109+
DataFloatOutput out,
110+
DataFloatInput a, int type_a,
111+
DataFloatInput b, int type_b,
112+
int64 size, int nThreads
113+
)
108114
{
109115
//type = 0 float
110116
//type = 1 array of floats
@@ -114,20 +120,24 @@ int saxpby(const float * x, const float * y, float * out, const float *a, int a_
114120
int nThreads_initial;
115121
threads_setup(nThreads, &nThreads_initial);
116122

117-
if (a_type == 0 && b_type == 0)
118-
saxpby_asbs(x, y, out, *a, *b, size, nThreads);
119-
else if (a_type == 1 && b_type == 1)
120-
saxpby_avbv(x, y, out, a, b, size, nThreads);
121-
else if (a_type == 0 && b_type == 1)
122-
saxpby_asbv(x, y, out, *a, b, size, nThreads);
123-
else if (a_type == 1 && b_type == 0)
124-
saxpby_asbv(y, x, out, *b, a, size, nThreads);
123+
if (type_a == 0 && type_b == 0)
124+
saxpby_asbs(x.data(), y.data(), out.data(), *a.data(), *b.data(), size, nThreads);
125+
else if (type_a == 1 && type_b == 1)
126+
saxpby_avbv(x.data(), y.data(), out.data(), a.data(), b.data(), size, nThreads);
127+
else if (type_a == 0 && type_b == 1)
128+
saxpby_asbv(x.data(), y.data(), out.data(), *a.data(), b.data(), size, nThreads);
129+
else if (type_a == 1 && type_b == 0)
130+
saxpby_asbv(y.data(), x.data(), out.data(), *b.data(), a.data(), size, nThreads);
125131

126132
omp_set_num_threads(nThreads_initial);
127133

128134
return 0;
129135
}
130-
int daxpby(const double * x, const double * y, double * out, const double *a, int a_type, const double* b, int b_type, int64 size, int nThreads)
136+
int daxpby(DataDoubleInput x, DataDoubleInput y,
137+
DataDoubleOutput out,
138+
DataDoubleInput a, int type_a,
139+
DataDoubleInput b, int type_b,
140+
int64 size, int nThreads)
131141
{
132142
//type = 0 double
133143
//type = 1 array of double
@@ -137,14 +147,14 @@ int daxpby(const double * x, const double * y, double * out, const double *a, in
137147
int nThreads_initial;
138148
threads_setup(nThreads, &nThreads_initial);
139149

140-
if (a_type == 0 && b_type == 0)
141-
daxpby_asbs(x, y, out, *a, *b, size, nThreads);
142-
else if (a_type == 1 && b_type == 1)
143-
daxpby_avbv(x, y, out, a, b, size, nThreads);
144-
else if (a_type == 0 && b_type == 1)
145-
daxpby_asbv(x, y, out, *a, b, size, nThreads);
146-
else if (a_type == 1 && b_type == 0)
147-
daxpby_asbv(y, x, out, *b, a, size, nThreads);
150+
if (type_a == 0 && type_b == 0)
151+
daxpby_asbs(x.data(), y.data(), out.data(), *a.data(), *b.data(), size, nThreads);
152+
else if (type_a == 1 && type_b == 1)
153+
daxpby_avbv(x.data(), y.data(), out.data(), a.data(), b.data(), size, nThreads);
154+
else if (type_a == 0 && type_b == 1)
155+
daxpby_asbv(x.data(), y.data(), out.data(), *a.data(), b.data(), size, nThreads);
156+
else if (type_a == 1 && type_b == 0)
157+
daxpby_asbv(y.data(), x.data(), out.data(), *b.data(), a.data(), size, nThreads);
148158

149159
omp_set_num_threads(nThreads_initial);
150160

src/cilacc/include/axpby.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,34 @@
2020
#include <stdio.h>
2121
#include <omp.h>
2222
#include "utilities.h"
23+
#include <nanobind/ndarray.h>
24+
25+
namespace nb = nanobind;
2326

2427
using int64 = long long;
2528

29+
using DataFloatInput = nb::ndarray<const float>;
30+
using DataDoubleInput = nb::ndarray<const double>;
31+
32+
using DataFloatOutput = nb::ndarray<float>;
33+
using DataDoubleOutput = nb::ndarray<double>;
34+
35+
2636
int saxpby_asbs(const float * x, const float * y, float * out, float a, float b, int64 size, int nThreads);
2737
int saxpby_avbv(const float * x, const float * y, float * out, const float * a, const float * b, int64 size, int nThreads);
2838
int saxpby_asbv(const float * x, const float * y, float * out, float a, const float * b, int64 size, int nThreads);
2939
int daxpby_asbs(const double * x, const double * y, double * out, double a, double b, int64 size, int nThreads);
3040
int daxpby_avbv(const double * x, const double * y, double * out, const double * a, const double * b, int64 size, int nThreads);
3141
int daxpby_asbv(const double * x, const double * y, double * out, double a, const double * b, int64 size, int nThreads);
3242

33-
int saxpby(const float * x, const float * y, float * out, const float * a, int type_a, const float * b, int type_b, int64 size, int nThreads);
34-
int daxpby(const double * x, const double * y, double * out, const double * a, int type_a, const double * b, int type_b, int64 size, int nThreads);
43+
int saxpby(DataFloatInput x, DataFloatInput y,
44+
DataFloatOutput out,
45+
DataFloatInput a, int type_a,
46+
DataFloatInput b, int type_b,
47+
int64 size, int nThreads);
48+
int daxpby(DataDoubleInput x, DataDoubleInput y,
49+
DataDoubleOutput out,
50+
DataDoubleInput a, int type_a,
51+
DataDoubleInput b, int type_b,
52+
int64 size, int nThreads);
3553

0 commit comments

Comments
 (0)