Skip to content

Commit 1f91a4d

Browse files
authored
MultiFab: Static to Member Math Methods (#301)
Replace `static` member functions that with regular member functions that take store the result/destination in `self`. (Assume `dst` is `self`.) - [x] `FabArray<FArrayBox>` - [x] `MultiFab` Close #296
1 parent 8e46613 commit 1f91a4d

File tree

2 files changed

+184
-105
lines changed

2 files changed

+184
-105
lines changed

src/Base/MultiFab.cpp

Lines changed: 177 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <AMReX_FabArrayBase.H>
1313
#include <AMReX_FabFactory.H>
1414
#include <AMReX_MultiFab.H>
15+
#include <AMReX_iMultiFab.H>
1516

1617
#include <memory>
1718
#include <string>
@@ -215,23 +216,69 @@ void init_MultiFab(py::module &m)
215216
py::arg("comp"), py::arg("ncomp"), py::arg("nghost")
216217
)
217218

218-
.def_static("saxpy",
219-
py::overload_cast< FabArray<FArrayBox> &, Real, FabArray<FArrayBox> const &, int, int, int, IntVect const & >(&FabArray<FArrayBox>::template Saxpy<FArrayBox>),
220-
py::arg("y"), py::arg("a"), py::arg("x"), py::arg("xcomp"), py::arg("ycomp"), py::arg("ncomp"), py::arg("nghost"),
221-
"y += a*x"
222-
)
223-
.def_static("xpay",
224-
py::overload_cast< FabArray<FArrayBox> &, Real, FabArray<FArrayBox> const &, int, int, int, IntVect const & >(&FabArray<FArrayBox>::template Xpay<FArrayBox>),
225-
py::arg("y"), py::arg("a"), py::arg("x"), py::arg("xcomp"), py::arg("ycomp"), py::arg("ncomp"), py::arg("nghost"),
226-
"y = x + a*y"
227-
)
228-
.def_static("lin_comb",
229-
py::overload_cast< FabArray<FArrayBox> &, Real, FabArray<FArrayBox> const &, int, Real, FabArray<FArrayBox> const &, int, int, int, IntVect const & >(&FabArray<FArrayBox>::template LinComb<FArrayBox>),
230-
py::arg("dst"),
219+
.def("saxpy",
220+
[](FabArray<FArrayBox> & dst, Real a, FabArray<FArrayBox> const & x, int x_comp, int comp, int ncomp, IntVect const & nghost)
221+
{
222+
FabArray<FArrayBox>::Saxpy(dst, a, x, x_comp, comp, ncomp, nghost);
223+
},
224+
py::arg("a"), py::arg("x"), py::arg("x_comp"), py::arg("comp"), py::arg("ncomp"), py::arg("nghost"),
225+
"self += a * x\n\n"
226+
"Parameters\n"
227+
"----------\n"
228+
"a : scalar a\n"
229+
"x : FabArray x\n"
230+
"x_comp : starting component of x\n"
231+
"comp : starting component of self\n"
232+
"ncomp : number of components\n"
233+
"nghost : number of ghost cells"
234+
)
235+
.def("xpay",
236+
[](FabArray<FArrayBox> & self, Real a, FabArray<FArrayBox> const & x, int x_comp, int comp, int ncomp, IntVect const & nghost)
237+
{
238+
FabArray<FArrayBox>::Xpay(self, a, x, x_comp, comp, ncomp, nghost);
239+
},
240+
py::arg("a"), py::arg("x"), py::arg("xcomp"), py::arg("comp"), py::arg("ncomp"), py::arg("nghost"),
241+
"self = x + a * self\n\n"
242+
"Parameters\n"
243+
"----------\n"
244+
"a : scalar a\n"
245+
"x : FabArray x\n"
246+
"x_comp : starting component of x\n"
247+
"comp : starting component of self\n"
248+
"ncomp : number of components\n"
249+
"nghost : number of ghost cells"
250+
)
251+
.def("lin_comb",
252+
[](
253+
FabArray<FArrayBox> & dst,
254+
Real a, FabArray<FArrayBox> const & x, int x_comp,
255+
Real b, FabArray<FArrayBox> const & y, int y_comp,
256+
int comp, int ncomp, IntVect const & nghost)
257+
{
258+
FabArray<FArrayBox>::LinComb(dst, a, x, x_comp, b, y, y_comp, comp, ncomp, nghost);
259+
},
231260
py::arg("a"), py::arg("x"), py::arg("xcomp"),
232261
py::arg("b"), py::arg("y"), py::arg("ycomp"),
233-
py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
234-
"dst = a*x + b*y"
262+
py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
263+
"self = a * x + b * y\n\n"
264+
"Parameters\n"
265+
"----------\n"
266+
"a : float\n"
267+
" scalar a\n"
268+
"x : FabArray\n"
269+
"xcomp : int\n"
270+
" starting component of x\n"
271+
"b : float\n"
272+
" scalar b\n"
273+
"y : FabArray\n"
274+
"ycomp : int\n"
275+
" starting component of y\n"
276+
"comp : int\n"
277+
" starting component of self\n"
278+
"numcomp : int\n"
279+
" number of components\n"
280+
"nghost : int\n"
281+
" number of ghost cells"
235282
)
236283

237284
.def("sum",
@@ -730,123 +777,164 @@ void init_MultiFab(py::module &m)
730777
)
731778

732779
/* static (standalone) simple math functions */
733-
.def_static("dot",
734-
py::overload_cast< MultiFab const &, int, MultiFab const &, int, int, int, bool >(&MultiFab::Dot),
735-
py::arg("x"), py::arg("xcomp"),
736-
py::arg("y"), py::arg("ycomp"),
780+
.def("dot",
781+
[](MultiFab const & self, int comp, MultiFab const & y, int y_comp, int numcomp, int nghost, bool local) {
782+
return MultiFab::Dot(self, comp, y, y_comp, numcomp, nghost, local);
783+
},
784+
py::arg("comp"),
785+
py::arg("y"), py::arg("y_comp"),
737786
py::arg("numcomp"), py::arg("nghost"), py::arg("local")=false,
738-
"Returns the dot product of two MultiFabs."
787+
"Returns the dot product of self with another MultiFab."
739788
)
740-
.def_static("dot",
741-
py::overload_cast< MultiFab const &, int, int, int, bool >(&MultiFab::Dot),
742-
py::arg("x"), py::arg("xcomp"),
789+
.def("dot",
790+
[](MultiFab const & self, int comp, int numcomp, int nghost, bool local) {
791+
return MultiFab::Dot(self, comp, numcomp, nghost, local);
792+
},
793+
py::arg("comp"),
794+
py::arg("numcomp"), py::arg("nghost"), py::arg("local")=false,
795+
"Returns the dot product with itself."
796+
)
797+
/** TODO: Bind iMultiFab
798+
.def("dot",
799+
[](MultiFab const& self, const iMultiFab& mask, int comp, MultiFab const& y, int y_comp, int numcomp, int nghost, bool local) {
800+
return MultiFab::Dot(mask, self, comp, y, y_comp, numcomp, nghost, local);
801+
},
802+
py::arg("mask"), py::arg("comp"), py::arg("y"), py::arg("y_comp"),
743803
py::arg("numcomp"), py::arg("nghost"), py::arg("local")=false,
744-
"Returns the dot product of a MultiFab with itself."
804+
"Returns the dot product of self with another MultiFab where the mask is valid."
745805
)
746-
//.def_static("dot", py::overload_cast< iMultiFab const&, const MultiFab&, int, MultiFab const&, int, int, int, bool >(&MultiFab::Dot))
806+
*/
747807

748-
.def_static("add",
749-
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, int >(&MultiFab::Add),
750-
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
751-
"Add src to dst including nghost ghost cells.\n"
808+
.def("add",
809+
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
810+
MultiFab::Add(self, src, srccomp, comp, numcomp, nghost);
811+
},
812+
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
813+
"Add src to self including nghost ghost cells.\n"
752814
"The two MultiFabs MUST have the same underlying BoxArray."
753815
)
754-
.def_static("add",
755-
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, IntVect const & >(&MultiFab::Add),
756-
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
757-
"Add src to dst including nghost ghost cells.\n"
816+
.def("add",
817+
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, IntVect const & nghost) {
818+
MultiFab::Add(self, src, srccomp, comp, numcomp, nghost);
819+
},
820+
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
821+
"Add src to self including nghost ghost cells.\n"
758822
"The two MultiFabs MUST have the same underlying BoxArray."
759823
)
760824

761-
.def_static("subtract",
762-
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, int >(&MultiFab::Subtract),
763-
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
764-
"Subtract src from dst including nghost ghost cells.\n"
825+
.def("subtract",
826+
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
827+
MultiFab::Subtract(self, src, srccomp, comp, numcomp, nghost);
828+
},
829+
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
830+
"Subtract src from self including nghost ghost cells.\n"
765831
"The two MultiFabs MUST have the same underlying BoxArray."
766832
)
767-
.def_static("subtract",
768-
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, IntVect const & >(&MultiFab::Subtract),
769-
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
770-
"Subtract src from dst including nghost ghost cells.\n"
833+
.def("subtract",
834+
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, IntVect const & nghost) {
835+
MultiFab::Subtract(self, src, srccomp, comp, numcomp, nghost);
836+
},
837+
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
838+
"Subtract src from self including nghost ghost cells.\n"
771839
"The two MultiFabs MUST have the same underlying BoxArray."
772840
)
773841

774-
.def_static("multiply",
775-
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, int >(&MultiFab::Multiply),
776-
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
777-
"Multiply dst by src including nghost ghost cells.\n"
842+
.def("multiply",
843+
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
844+
MultiFab::Multiply(self, src, srccomp, comp, numcomp, nghost);
845+
},
846+
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
847+
"Multiply self by src including nghost ghost cells.\n"
778848
"The two MultiFabs MUST have the same underlying BoxArray."
779849
)
780-
.def_static("multiply",
781-
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, IntVect const & >(&MultiFab::Multiply),
782-
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
783-
"Multiply dst by src including nghost ghost cells.\n"
850+
.def("multiply",
851+
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, IntVect const & nghost) {
852+
MultiFab::Multiply(self, src, srccomp, comp, numcomp, nghost);
853+
},
854+
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
855+
"Multiply self by src including nghost ghost cells.\n"
784856
"The two MultiFabs MUST have the same underlying BoxArray."
785857
)
786858

787-
.def_static("divide",
788-
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, int >(&MultiFab::Divide),
789-
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
790-
"Divide dst by src including nghost ghost cells.\n"
859+
.def("divide",
860+
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
861+
MultiFab::Divide(self, src, srccomp, comp, numcomp, nghost);
862+
},
863+
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
864+
"Divide self by src including nghost ghost cells.\n"
791865
"The two MultiFabs MUST have the same underlying BoxArray."
792866
)
793-
.def_static("divide",
794-
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, IntVect const & >(&MultiFab::Divide),
795-
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
796-
"Divide dst by src including nghost ghost cells.\n"
867+
.def("divide",
868+
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, IntVect const & nghost) {
869+
MultiFab::Divide(self, src, srccomp, comp, numcomp, nghost);
870+
},
871+
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
872+
"Divide self by src including nghost ghost cells.\n"
797873
"The two MultiFabs MUST have the same underlying BoxArray."
798874
)
799875

800-
.def_static("swap",
801-
py::overload_cast< MultiFab &, MultiFab &, int, int, int, int >(&MultiFab::Swap),
802-
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
803-
"Swap from src to dst including nghost ghost cells.\n"
876+
.def("swap",
877+
[](MultiFab & self, MultiFab & src, int srccomp, int comp, int numcomp, int nghost) {
878+
MultiFab::Swap(self, src, srccomp, comp, numcomp, nghost);
879+
},
880+
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
881+
"Swap from src to self including nghost ghost cells.\n"
804882
"The two MultiFabs MUST have the same underlying BoxArray.\n"
805883
"The swap is local."
806884
)
807-
.def_static("swap",
808-
py::overload_cast< MultiFab &, MultiFab &, int, int, int, IntVect const & >(&MultiFab::Swap),
809-
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
810-
"Swap from src to dst including nghost ghost cells.\n"
885+
.def("swap",
886+
[](MultiFab & self, MultiFab & src, int srccomp, int comp, int numcomp, IntVect const & nghost) {
887+
MultiFab::Swap(self, src, srccomp, comp, numcomp, nghost);
888+
},
889+
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
890+
"Swap from src to self including nghost ghost cells.\n"
811891
"The two MultiFabs MUST have the same underlying BoxArray.\n"
812892
"The swap is local."
813893
)
814894

815-
.def_static("saxpy",
816-
// py::overload_cast< MultiFab &, Real, MultiFab const &, int, int, int, int >(&MultiFab::Saxpy)
817-
static_cast<void (*)(MultiFab &, Real, MultiFab const &, int, int, int, int)>(&MultiFab::Saxpy),
818-
py::arg("dst"), py::arg("a"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
819-
"dst += a*src"
895+
.def("saxpy",
896+
[](MultiFab & self, Real a, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
897+
MultiFab::Saxpy(self, a, src, srccomp, comp, numcomp, nghost);
898+
},
899+
py::arg("a"), py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
900+
"self += a * src"
820901
)
821902

822-
.def_static("xpay",
823-
// py::overload_cast< MultiFab &, Real, MultiFab const &, int, int, int, int >(&MultiFab::Xpay)
824-
static_cast<void (*)(MultiFab &, Real, MultiFab const &, int, int, int, int)>(&MultiFab::Xpay),
825-
py::arg("dst"), py::arg("a"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
826-
"dst = src + a*dst"
903+
.def("xpay",
904+
[](MultiFab & self, Real a, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
905+
MultiFab::Xpay(self, a, src, srccomp, comp, numcomp, nghost);
906+
},
907+
py::arg("a"), py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
908+
"self = src + a * self"
827909
)
828910

829-
.def_static("lin_comb",
830-
// py::overload_cast< MultiFab &, Real, MultiFab const &, int, Real, MultiFab const &, int, int, int, int >(&MultiFab::LinComb)
831-
static_cast<void (*)(MultiFab &, Real, MultiFab const &, int, Real, MultiFab const &, int, int, int, int)>(&MultiFab::LinComb),
832-
py::arg("dst"),
911+
.def("lin_comb",
912+
[](MultiFab & self, Real a, MultiFab const & x, int x_comp, Real b, MultiFab const & y, int y_comp, int comp, int numcomp, int nghost) {
913+
MultiFab::LinComb(self, a, x, x_comp, b, y, y_comp, comp, numcomp, nghost);
914+
},
833915
py::arg("a"), py::arg("x"), py::arg("x_comp"),
834916
py::arg("b"), py::arg("y"), py::arg("y_comp"),
835-
py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
836-
"dst = a*x + b*y"
917+
py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
918+
"self = a * x + b * y"
837919
)
838920

839-
.def_static("add_product",
840-
py::overload_cast< MultiFab &, MultiFab const &, int, MultiFab const &, int, int, int, int >(&MultiFab::AddProduct),
841-
py::arg("dst"),
921+
.def("add_product",
922+
[](MultiFab & self, MultiFab const & src1, int comp1, MultiFab const & src2, int comp2, int comp, int numcomp, int nghost) {
923+
MultiFab::AddProduct(self, src1, comp1, src2, comp2, comp, numcomp, nghost);
924+
},
842925
py::arg("src1"), py::arg("comp1"),
843926
py::arg("src2"), py::arg("comp2"),
844-
py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
845-
"dst += src1*src2"
927+
py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
928+
"self += src1 * src2"
846929
)
847-
.def_static("add_product",
848-
py::overload_cast< MultiFab &, MultiFab const &, int, MultiFab const &, int, int, int, IntVect const & >(&MultiFab::AddProduct),
849-
"dst += src1*src2"
930+
.def("add_product",
931+
[](MultiFab & self, MultiFab const & src1, int comp1, MultiFab const & src2, int comp2, int comp, int numcomp, IntVect const & nghost) {
932+
MultiFab::AddProduct(self, src1, comp1, src2, comp2, comp, numcomp, nghost);
933+
},
934+
py::arg("src1"), py::arg("comp1"),
935+
py::arg("src2"), py::arg("comp2"),
936+
py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
937+
"self += src1 * src2"
850938
)
851939

852940
/* simple data validity checks */

tests/test_multifab.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -240,30 +240,21 @@ def test_mfab_ops(boxarr, distmap, nghost):
240240
src.set_val(30.0, 2, 1)
241241
dst.set_val(0.0, 0, 1)
242242

243-
# dst.add(src, 2, 0, 1, nghost)
244-
# dst.subtract(src, 1, 0, 1, nghost)
245-
# dst.multiply(src, 0, 0, 1, nghost)
246-
# dst.divide(src, 1, 0, 1, nghost)
247-
248-
dst.add(dst, src, 2, 0, 1, nghost)
249-
dst.subtract(dst, src, 1, 0, 1, nghost)
250-
dst.multiply(dst, src, 0, 0, 1, nghost)
251-
dst.divide(dst, src, 1, 0, 1, nghost)
243+
dst.add(src, 2, 0, 1, nghost)
244+
dst.subtract(src, 1, 0, 1, nghost)
245+
dst.multiply(src, 0, 0, 1, nghost)
246+
dst.divide(src, 1, 0, 1, nghost)
252247

253248
print(dst.min(0))
254249
np.testing.assert_allclose(dst.min(0), 5.0)
255250
np.testing.assert_allclose(dst.max(0), 5.0)
256251

257-
# dst.xpay(2.0, src, 0, 0, 1, nghost)
258-
# dst.saxpy(2.0, src, 1, 0, 1, nghost)
259-
dst.xpay(dst, 2.0, src, 0, 0, 1, nghost)
260-
dst.saxpy(dst, 2.0, src, 1, 0, 1, nghost)
252+
dst.xpay(2.0, src, 0, 0, 1, nghost)
253+
dst.saxpy(2.0, src, 1, 0, 1, nghost)
261254
np.testing.assert_allclose(dst.min(0), 60.0)
262255
np.testing.assert_allclose(dst.max(0), 60.0)
263256

264-
# dst.lin_comb(6.0, src, 1,
265-
# 1.0, src, 2, 0, 1, nghost)
266-
dst.lin_comb(dst, 6.0, src, 1, 1.0, src, 2, 0, 1, nghost)
257+
dst.lin_comb(6.0, src, 1, 1.0, src, 2, 0, 1, nghost)
267258
np.testing.assert_allclose(dst.min(0), 150.0)
268259
np.testing.assert_allclose(dst.max(0), 150.0)
269260

0 commit comments

Comments
 (0)