Skip to content

Commit 9e30fe0

Browse files
committed
api: fix pickling of custom sparse functions with more than one dim
1 parent 19d61fe commit 9e30fe0

File tree

3 files changed

+57
-14
lines changed

3 files changed

+57
-14
lines changed

devito/types/dense.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ class DiscreteFunction(AbstractFunction, ArgProvider, Differentiable):
6363
The type of the underlying data object.
6464
"""
6565

66-
__rkwargs__ = AbstractFunction.__rkwargs__ + ('staggered', 'coefficients')
66+
__rkwargs__ = AbstractFunction.__rkwargs__ + \
67+
('shape_global', 'staggered', 'coefficients')
6768

6869
def __init_finalize__(self, *args, function=None, **kwargs):
6970
# Now that *all* __X_setup__ hooks have been called, we can let the
@@ -1009,7 +1010,7 @@ class Function(DiscreteFunction):
10091010
is_autopaddable = True
10101011

10111012
__rkwargs__ = (DiscreteFunction.__rkwargs__ +
1012-
('space_order', 'shape_global', 'dimensions'))
1013+
('space_order', 'dimensions'))
10131014

10141015
def _cache_meta(self):
10151016
# Attach additional metadata to self's cache entry

devito/types/sparse.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,19 @@ def __shape_setup__(cls, **kwargs):
9595
# A Grid must have been provided
9696
if grid is None:
9797
raise TypeError('Need `grid` argument')
98-
shape = kwargs.get('shape')
98+
shape = kwargs.get('shape', kwargs.get('shape_global'))
9999
dimensions = kwargs.get('dimensions')
100100
npoint = kwargs.get('npoint', kwargs.get('npoint_global'))
101101
glb_npoint = SparseDistributor.decompose(npoint, grid.distributor)
102+
# Plain SparseFunction construction with npoint.
102103
if shape is None:
103104
loc_shape = (glb_npoint[grid.distributor.myrank],)
105+
# No dimensions is only possible through rebuild, the shape is from
106+
# the existing function
107+
elif dimensions is None:
108+
loc_shape = list(shape)
109+
# For safety, ensure the distributed sparse dimension is correct
110+
loc_shape[cls._sparse_position] = glb_npoint[grid.distributor.myrank]
104111
else:
105112
loc_shape = []
106113
assert len(dimensions) == len(shape)
@@ -111,6 +118,7 @@ def __shape_setup__(cls, **kwargs):
111118
loc_shape.append(grid.size_map[d].loc)
112119
else:
113120
loc_shape.append(s)
121+
114122
return tuple(loc_shape)
115123

116124
def __fd_setup__(self):
@@ -733,16 +741,19 @@ def time_dim(self):
733741

734742
@classmethod
735743
def __shape_setup__(cls, **kwargs):
736-
shape = kwargs.get('shape')
737-
if shape is None:
738-
nt = kwargs.get('nt')
739-
if not isinstance(nt, int):
740-
raise TypeError('Need `nt` int argument')
741-
if nt <= 0:
742-
raise ValueError('`nt` must be > 0')
743-
744-
shape = list(AbstractSparseFunction.__shape_setup__(**kwargs))
745-
shape.insert(cls._time_position, nt)
744+
shape = list(AbstractSparseFunction.__shape_setup__(**kwargs))
745+
dimensions = as_tuple(kwargs.get('dimensions'))
746+
if dimensions is None or len(shape) == len(dimensions):
747+
# Shape has already been setup, for example via rebuild
748+
return tuple(shape)
749+
750+
nt = kwargs.get('nt')
751+
if not isinstance(nt, int):
752+
raise TypeError('Need `nt` int argument')
753+
if nt <= 0:
754+
raise ValueError('`nt` must be > 0')
755+
756+
shape.insert(cls._time_position, nt)
746757

747758
return tuple(shape)
748759

tests/test_pickle.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar,
2020
PointerArray, Lock, PThreadArray, SharedData, Timer,
2121
DeviceID, NPThreads, ThreadID, TempFunction, Indirection,
22-
FIndexed, ComponentAccess)
22+
FIndexed, ComponentAccess, DefaultDimension)
2323
from devito.types.basic import BoundSymbol, AbstractSymbol
2424
from devito.tools import EnrichedTuple
2525
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
@@ -29,6 +29,11 @@
2929
TimeAxis, RickerSource, Receiver)
3030

3131

32+
class SparseFirst(SparseFunction):
33+
34+
_sparse_position = 0
35+
36+
3237
class SD(SubDomain):
3338
name = 'sd'
3439

@@ -181,6 +186,32 @@ def test_precomputed_sparse_function(self, mode, pickle):
181186
assert sf.dtype == new_sf.dtype
182187
assert sf.npoint == new_sf.npoint == 3
183188

189+
def test_sparse_first(self, pickle):
190+
191+
dr = Dimension("cd")
192+
ds = DefaultDimension("ps", default_value=3)
193+
grid = Grid((11, 11))
194+
sf = SparseFirst(name="s", grid=grid, npoint=2,
195+
dimensions=(dr, ds), shape=(2, 3),
196+
coordinates=[[.5, .5], [.2, .2]])
197+
sf.data[0] = 1.
198+
199+
pkl_sf = pickle.dumps(sf)
200+
new_sf = pickle.loads(pkl_sf)
201+
202+
# .data is initialized, so it should have been pickled too
203+
assert np.all(sf.data[0] == 1.)
204+
assert np.all(new_sf.data[0] == 1.)
205+
assert new_sf.interpolation == sf.interpolation
206+
207+
# coordinates should also have been pickled
208+
assert np.all(sf.coordinates.data == new_sf.coordinates.data)
209+
210+
assert sf.space_order == new_sf.space_order
211+
assert sf.dtype == new_sf.dtype
212+
assert sf.npoint == new_sf.npoint
213+
assert sf.shape == new_sf.shape
214+
184215
def test_alias_sparse_function(self, pickle):
185216
grid = Grid(shape=(3,))
186217
sf = SparseFunction(name='sf', grid=grid, npoint=3, space_order=2,

0 commit comments

Comments
 (0)