|
19 | 19 | from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar, |
20 | 20 | PointerArray, Lock, PThreadArray, SharedData, Timer, |
21 | 21 | DeviceID, NPThreads, ThreadID, TempFunction, Indirection, |
22 | | - FIndexed, ComponentAccess) |
| 22 | + FIndexed, ComponentAccess, DefaultDimension) |
23 | 23 | from devito.types.basic import BoundSymbol, AbstractSymbol |
24 | 24 | from devito.tools import EnrichedTuple |
25 | 25 | from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer, |
|
29 | 29 | TimeAxis, RickerSource, Receiver) |
30 | 30 |
|
31 | 31 |
|
| 32 | +class SparseFirst(SparseFunction): |
| 33 | + |
| 34 | + _sparse_position = 0 |
| 35 | + |
| 36 | + |
32 | 37 | class SD(SubDomain): |
33 | 38 | name = 'sd' |
34 | 39 |
|
@@ -181,6 +186,32 @@ def test_precomputed_sparse_function(self, mode, pickle): |
181 | 186 | assert sf.dtype == new_sf.dtype |
182 | 187 | assert sf.npoint == new_sf.npoint == 3 |
183 | 188 |
|
| 189 | + def test_sparse_first(self, pickle): |
| 190 | + |
| 191 | + dr = Dimension("cd") |
| 192 | + ds = DefaultDimension("ps", default_value=3) |
| 193 | + grid = Grid((11, 11)) |
| 194 | + sf = SparseFirst(name="s", grid=grid, npoint=2, |
| 195 | + dimensions=(dr, ds), shape=(2, 3), |
| 196 | + coordinates=[[.5, .5], [.2, .2]]) |
| 197 | + sf.data[0] = 1. |
| 198 | + |
| 199 | + pkl_sf = pickle.dumps(sf) |
| 200 | + new_sf = pickle.loads(pkl_sf) |
| 201 | + |
| 202 | + # .data is initialized, so it should have been pickled too |
| 203 | + assert np.all(sf.data[0] == 1.) |
| 204 | + assert np.all(new_sf.data[0] == 1.) |
| 205 | + assert new_sf.interpolation == sf.interpolation |
| 206 | + |
| 207 | + # coordinates should also have been pickled |
| 208 | + assert np.all(sf.coordinates.data == new_sf.coordinates.data) |
| 209 | + |
| 210 | + assert sf.space_order == new_sf.space_order |
| 211 | + assert sf.dtype == new_sf.dtype |
| 212 | + assert sf.npoint == new_sf.npoint |
| 213 | + assert sf.shape == new_sf.shape |
| 214 | + |
184 | 215 | def test_alias_sparse_function(self, pickle): |
185 | 216 | grid = Grid(shape=(3,)) |
186 | 217 | sf = SparseFunction(name='sf', grid=grid, npoint=3, space_order=2, |
|
0 commit comments