Skip to content

Commit 019f714

Browse files
committed
Use nanobind::ndarray for FiniteDifferenceLibrary interface
1 parent 55e7da2 commit 019f714

File tree

2 files changed

+58
-18
lines changed

2 files changed

+58
-18
lines changed

src/cilacc/FiniteDifferenceLibrary.cpp

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,14 @@ int fdiff_adjoint_periodic(float *outimagefull, const float *inimageXfull, const
538538
return 0;
539539
}
540540

541-
int fdiff4D(float *imagefull, float *gradCfull, float *gradZfull, float *gradYfull, float *gradXfull, size_t nc, size_t nz, size_t ny, size_t nx, int boundary, int direction, int nThreads)
541+
int fdiff4D(nb::ndarray<float> imagefull,
542+
nb::ndarray<float> gradCfull,
543+
nb::ndarray<float> gradZfull,
544+
nb::ndarray<float> gradYfull,
545+
nb::ndarray<float> gradXfull,
546+
size_t nc, size_t nz, size_t ny, size_t nx,
547+
int boundary, int direction,
548+
int nThreads)
542549
{
543550
int nThreads_initial;
544551
threads_setup(nThreads, &nThreads_initial);
@@ -547,22 +554,28 @@ int fdiff_adjoint_periodic(float *outimagefull, const float *inimageXfull, const
547554
if (boundary)
548555
{
549556
if (direction)
550-
status = fdiff_direct_periodic(imagefull, gradXfull, gradYfull, gradZfull, gradCfull, nx, ny, nz, nc);
557+
status = fdiff_direct_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), gradCfull.data(), nx, ny, nz, nc);
551558
else
552-
status = fdiff_adjoint_periodic(imagefull, gradXfull, gradYfull, gradZfull, gradCfull, nx, ny, nz, nc);
559+
status = fdiff_adjoint_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), gradCfull.data(), nx, ny, nz, nc);
553560
}
554561
else
555562
{
556563
if (direction)
557-
status = fdiff_direct_neumann(imagefull, gradXfull, gradYfull, gradZfull, gradCfull, nx, ny, nz, nc);
564+
status = fdiff_direct_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), gradCfull.data(), nx, ny, nz, nc);
558565
else
559-
status = fdiff_adjoint_neumann(imagefull, gradXfull, gradYfull, gradZfull, gradCfull, nx, ny, nz, nc);
566+
status = fdiff_adjoint_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), gradCfull.data(), nx, ny, nz, nc);
560567
}
561568

562569
omp_set_num_threads(nThreads_initial);
563570
return status;
564571
}
565-
int fdiff3D(float *imagefull, float *gradZfull, float *gradYfull, float *gradXfull, size_t nz, size_t ny, size_t nx, int boundary, int direction, int nThreads)
572+
int fdiff3D(nb::ndarray<float> imagefull,
573+
nb::ndarray<float> gradZfull,
574+
nb::ndarray<float> gradYfull,
575+
nb::ndarray<float> gradXfull,
576+
size_t nz, size_t ny, size_t nx,
577+
int boundary, int direction,
578+
int nThreads)
566579
{
567580
int nThreads_initial;
568581
threads_setup(nThreads, &nThreads_initial);
@@ -571,22 +584,27 @@ int fdiff3D(float *imagefull, float *gradZfull, float *gradYfull, float *gradXfu
571584
if (boundary)
572585
{
573586
if (direction)
574-
status = fdiff_direct_periodic(imagefull, gradXfull, gradYfull, gradZfull, NULL, nx, ny, nz, 1);
587+
status = fdiff_direct_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), NULL, nx, ny, nz, 1);
575588
else
576-
status = fdiff_adjoint_periodic(imagefull, gradXfull, gradYfull, gradZfull, NULL, nx, ny, nz, 1);
589+
status = fdiff_adjoint_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), NULL, nx, ny, nz, 1);
577590
}
578591
else
579592
{
580593
if (direction)
581-
status = fdiff_direct_neumann(imagefull, gradXfull, gradYfull, gradZfull, NULL, nx, ny, nz, 1);
594+
status = fdiff_direct_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), NULL, nx, ny, nz, 1);
582595
else
583-
status = fdiff_adjoint_neumann(imagefull, gradXfull, gradYfull, gradZfull, NULL, nx, ny, nz, 1);
596+
status = fdiff_adjoint_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), NULL, nx, ny, nz, 1);
584597
}
585598

586599
omp_set_num_threads(nThreads_initial);
587600
return status;
588601
}
589-
int fdiff2D(float *imagefull, float *gradYfull, float *gradXfull, size_t ny, size_t nx, int boundary, int direction, int nThreads)
602+
int fdiff2D(nb::ndarray<float> imagefull,
603+
nb::ndarray<float> gradYfull,
604+
nb::ndarray<float> gradXfull,
605+
size_t ny, size_t nx,
606+
int boundary, int direction,
607+
int nThreads)
590608
{
591609
int nThreads_initial;
592610
threads_setup(nThreads, &nThreads_initial);
@@ -595,16 +613,16 @@ int fdiff2D(float *imagefull, float *gradYfull, float *gradXfull, size_t ny, siz
595613
if (boundary)
596614
{
597615
if (direction)
598-
status = fdiff_direct_periodic(imagefull, gradXfull, gradYfull, NULL, NULL, nx, ny, 1, 1);
616+
status = fdiff_direct_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), NULL, NULL, nx, ny, 1, 1);
599617
else
600-
status = fdiff_adjoint_periodic(imagefull, gradXfull, gradYfull, NULL, NULL, nx, ny, 1, 1);
618+
status = fdiff_adjoint_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), NULL, NULL, nx, ny, 1, 1);
601619
}
602620
else
603621
{
604622
if (direction)
605-
status = fdiff_direct_neumann(imagefull, gradXfull, gradYfull, NULL, NULL, nx, ny, 1, 1);
623+
status = fdiff_direct_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), NULL, NULL, nx, ny, 1, 1);
606624
else
607-
status = fdiff_adjoint_neumann(imagefull, gradXfull, gradYfull, NULL, NULL, nx, ny, 1, 1);
625+
status = fdiff_adjoint_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), NULL, NULL, nx, ny, 1, 1);
608626
}
609627

610628
omp_set_num_threads(nThreads_initial);

src/cilacc/include/FiniteDifferenceLibrary.h

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,36 @@
2222
#include <omp.h>
2323
#include "utilities.h"
2424

25+
#include <nanobind/ndarray.h>
26+
namespace nb = nanobind;
27+
28+
2529
int fdiff_direct_neumann(const float *inimagefull, float *outimageXfull, float *outimageYfull, float *outimageZfull, float *outimageCfull, size_t nx, size_t ny, size_t nz, size_t nc);
2630
int fdiff_direct_periodic(const float *inimagefull, float *outimageXfull, float *outimageYfull, float *outimageZfull, float *outimageCfull, size_t nx, size_t ny, size_t nz, size_t nc);
2731
int fdiff_adjoint_neumann(float *outimagefull, const float *inimageXfull, const float *inimageYfull, const float *inimageZfull, const float *inimageCfull, size_t nx, size_t ny, size_t nz, size_t nc);
2832
int fdiff_adjoint_periodic(float *outimagefull, const float *inimageXfull, const float *inimageYfull, const float *inimageZfull, const float *inimageCfull, size_t nx, size_t ny, size_t nz, size_t nc);
2933

3034

3135
int openMPtest(int nThreads);
32-
int fdiff4D(float *imagefull, float *gradCfull, float *gradZfull, float *gradYfull, float *gradXfull, size_t nc, size_t nz, size_t ny, size_t nx, int boundary, int direction, int nThreads);
33-
int fdiff3D(float *imagefull, float *gradZfull, float *gradYfull, float *gradXfull, size_t nz, size_t ny, size_t nx, int boundary, int direction, int nThreads);
34-
int fdiff2D(float *imagefull, float *gradYfull, float *gradXfull, size_t ny, size_t nx, int boundary, int direction, int nThreads);
36+
int fdiff4D(nb::ndarray<float> imagefull,
37+
nb::ndarray<float> gradCfull,
38+
nb::ndarray<float> gradZfull,
39+
nb::ndarray<float> gradYfull,
40+
nb::ndarray<float> gradXfull,
41+
size_t nc, size_t nz, size_t ny,
42+
size_t nx, int boundary, int direction,
43+
int nThreads);
44+
int fdiff3D(nb::ndarray<float> imagefull,
45+
nb::ndarray<float> gradZfull,
46+
nb::ndarray<float> gradYfull,
47+
nb::ndarray<float> gradXfull,
48+
size_t nz, size_t ny, size_t nx,
49+
int boundary, int direction,
50+
int nThreads);
51+
int fdiff2D(nb::ndarray<float> imagefull,
52+
nb::ndarray<float> gradYfull,
53+
nb::ndarray<float> gradXfull,
54+
size_t ny, size_t nx,
55+
int boundary, int direction,
56+
int nThreads);
3557

0 commit comments

Comments
 (0)