Skip to content

Commit 1726f43

Browse files
authored
Enable deepcopy for ExpData(View) (#2196)
Fixes a bug in `SwigPtrView.__deepcopy__` which did not produce a deep copy. Add `SwigPtrView.__eq__` to allow for comparison. The view objects are considered equal if the underlying viewed objects are equal. Fixes #2189.
1 parent 187362a commit 1726f43

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

python/sdist/amici/numpy.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ def __deepcopy__(self, memo):
137137
138138
:returns: SwigPtrView deep copy
139139
"""
140-
other = SwigPtrView(self._swigptr)
140+
# We assume we have a copy-ctor for the swigptr object
141+
other = self.__class__(copy.deepcopy(self._swigptr))
141142
other._field_names = copy.deepcopy(self._field_names)
142143
other._field_dimensions = copy.deepcopy(self._field_dimensions)
143144
other._cache = copy.deepcopy(self._cache)
@@ -151,6 +152,18 @@ def __repr__(self):
151152
"""
152153
return f"<{self.__class__.__name__}({self._swigptr})>"
153154

155+
def __eq__(self, other):
156+
"""
157+
Equality check
158+
159+
:param other: other object
160+
161+
:returns: whether other object is equal to this object
162+
"""
163+
if not isinstance(other, self.__class__):
164+
return False
165+
return self._swigptr == other._swigptr
166+
154167

155168
class ReturnDataView(SwigPtrView):
156169
"""
@@ -344,6 +357,9 @@ class ExpDataView(SwigPtrView):
344357
"""
345358
Interface class for C++ Exp Data objects that avoids possibly costly
346359
copies of member data.
360+
361+
NOTE: This currently assumes that the underlying :class:`ExpData`
362+
does not change after instantiating an :class:`ExpDataView`.
347363
"""
348364

349365
_field_names = [

python/tests/test_swig_interface.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numbers
88

99
import amici
10+
import numpy as np
1011

1112

1213
def test_version_number(pysb_example_presimulation_module):
@@ -451,3 +452,21 @@ def test_edata_equality_operator():
451452
# check that comparison with other types works
452453
# this is not implemented by swig by default
453454
assert e1 != 1
455+
456+
457+
def test_expdata_and_expdataview_are_deepcopyable():
458+
edata1 = amici.ExpData(3, 2, 3, range(4))
459+
edata1.setObservedData(np.zeros((3, 4)).flatten())
460+
461+
# ExpData
462+
edata2 = copy.deepcopy(edata1)
463+
assert edata1 == edata2
464+
assert edata1.this != edata2.this
465+
edata2.setTimepoints([0])
466+
assert edata1 != edata2
467+
468+
# ExpDataView
469+
ev1 = amici.ExpDataView(edata1)
470+
ev2 = copy.deepcopy(ev1)
471+
assert ev2._swigptr.this != ev1._swigptr.this
472+
assert ev1 == ev2

swig/edata.i

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def __repr__(self):
7777

7878
def __eq__(self, other):
7979
return other.__class__ == self.__class__ and __eq__(self, other)
80+
81+
def __deepcopy__(self, memo):
82+
# invoke copy constructor
83+
return type(self)(self)
84+
8085
%}
8186
};
8287
%extend std::unique_ptr<amici::ExpData> {

0 commit comments

Comments
 (0)