Skip to content

Commit d4132dc

Browse files
Fixed cython interface to c++ energy classes
1 parent 0a431b6 commit d4132dc

File tree

3 files changed

+55
-20
lines changed

3 files changed

+55
-20
lines changed

fidimag/common/c_clib.pyx

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,35 @@ cdef extern from "c_energy.h":
285285
# __cinit__ and __dealloc__ methods which are guaranteed to be called exactly
286286
# once upon creation and deletion of the Python instance.
287287

288-
cdef class PyExchangeEnergy:
289-
cdef ExchangeEnergy *thisptr
288+
cdef class PyEnergy:
289+
cdef Energy *thisptr
290290
# Try cinit:
291-
def __cinit__(self, double [:] A):
291+
def __cinit__(self):
292+
# Should we allocate?
292293
# self.thisptr = new ExchangeEnergy()
293-
self.thisptr.init(&A[0])
294+
print("In Python A")
295+
296+
# We need to inherit from a cdef-ined base class to make the inheritance to
297+
# work. This way the constructors are called properly
298+
cdef class PyExchangeEnergy(PyEnergy):
299+
cdef ExchangeEnergy *derivedptr
300+
# Try cinit:
301+
def __cinit__(self, double [:] A):
302+
print("In Python B")
303+
304+
self.derivedptr = new ExchangeEnergy()
305+
self.derivedptr.init(&A[0])
306+
307+
if self.thisptr:
308+
print("in B: deallocating old A")
309+
del self.thisptr
310+
311+
# DEBUG: check contents of the A array
312+
# def printA(self):
313+
# lst = []
314+
# for i in range(4):
315+
# lst.append(self.derivedptr.A[i])
316+
# print(lst)
294317

295318
# We could use another constructor if we use this method:
296319
# def __cinit__(self):
@@ -299,13 +322,14 @@ cdef class PyExchangeEnergy:
299322
# def __dealloc__(self):
300323
# if type(self) is PyExchangeEnergy:
301324
# del self.thisptr
302-
# Necessary?
325+
326+
# Necessary?:
303327
def compute_field(self, t):
304-
self.thisptr.compute_field(t)
328+
self.derivedptr.compute_field(t)
305329
def compute_energy(self, time):
306-
return self.thisptr.compute_energy()
330+
return self.derivedptr.compute_energy()
307331
def setup(self, nx, ny, nz, dx, dy, dz, unit_length,
308332
double [:] spin, double [:] Ms, double [:] Ms_inv):
309-
return self.thisptr.setup(nx, ny, nz, dx, dy, dz, unit_length,
310-
&spin[0], &Ms[0], &Ms_inv[0])
333+
return self.derivedptr.setup(nx, ny, nz, dx, dy, dz, unit_length,
334+
&spin[0], &Ms[0], &Ms_inv[0])
311335

native/include/c_energy.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
class Energy {
55
public:
6-
Energy();
6+
Energy() {std::cout << "In A; at " << this << "\n";};
7+
~Energy() {std::cout << "Killing A\n";};
78
bool set_up;
89
int nx, ny, nz, n;
910
double dx, dy, dz;
@@ -16,17 +17,20 @@ class Energy {
1617
double *coordinates;
1718
int *ngbs;
1819
double compute_energy();
19-
void setup(int nx, int ny, int nz, double dx, double dy, double dz, double unit_length, double *spin, double *Ms, double *Ms_inv);
20-
virtual void compute_field(double t) = 0;
20+
void setup(int nx, int ny, int nz, double dx, double dy, double dz,
21+
double unit_length, double *spin, double *Ms, double *Ms_inv,
22+
double *coordinates, double *ngbs,
23+
double *energy, double *field
24+
);
25+
virtual void compute_field(double t) {};
2126
};
2227

2328
class ExchangeEnergy : public Energy {
2429
public:
25-
ExchangeEnergy() {};
30+
ExchangeEnergy() {std::cout << "In B; at " << this << "\n";};
2631
void init(double *A) {
27-
std::cout << "Here A[0] = " << A[0] << "\n";
2832
this->set_up = false;
29-
// this->A = A;
33+
this->A = A;
3034
}
3135
double *A;
3236
void compute_field(double t);

native/src/c_energy.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,28 @@ double Energy::compute_energy() {
1010
}
1111

1212

13-
void Energy::setup(int nx, int ny, int nz,
14-
double dx, double dy, double dz,
15-
double unit_length,
16-
double *spin, double *Ms,
17-
double *Ms_inv) {
13+
void Energy::setup(int nx, int ny, int nz, double dx, double dy, double dz,
14+
double unit_length, double *spin, double *Ms, double *Ms_inv,
15+
double *coordinates, double *ngbs,
16+
double *energy, double *field
17+
); {
1818
this->nx = nx;
1919
this->ny = ny;
2020
this->nz = nz;
2121
this->dx = dx;
2222
this->dy = dy;
2323
this->dz = dz;
2424
this->unit_length = unit_length;
25+
26+
// Arrays
2527
this->spin = spin;
2628
this->Ms = Ms;
2729
this->Ms_inv = Ms_inv;
30+
this->coordinates = coordinates;
31+
this->ngbs = ngbs;
32+
this ->energy->energy;
33+
this->field = field;
34+
2835
set_up = true;
2936
}
3037

0 commit comments

Comments
 (0)