Skip to content

Commit 802b28b

Browse files
committed
get lazy loading actually working for strds
1 parent bcd1f30 commit 802b28b

File tree

6 files changed

+340
-95
lines changed

6 files changed

+340
-95
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ Attributes:
153153
- [ ] Support `end_time`
154154
- [ ] Accept writing into a specific mapset (GRASS 8.5)
155155
- [ ] Accept non homogeneous 3D resolution in NS and EW dimensions (GRASS 8.5)
156-
- [ ] Lazy loading of all raster types
156+
- [x] Lazy loading of STDS on the time dimension
157157
- [ ] Properly test with lat-lon location
158158

159159
### Stretch goals

src/xarray_grass/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from xarray_grass.grass_interface import GrassConfig as GrassConfig
22
from xarray_grass.grass_interface import GrassInterface as GrassInterface
33
from xarray_grass.xarray_grass import GrassBackendEntrypoint as GrassBackendEntrypoint
4-
from xarray_grass.xarray_grass import GrassBackendArray as GrassBackendArray
4+
from xarray_grass.grass_backend_array import (
5+
GrassSTDSBackendArray as GrassSTDSBackendArray,
6+
)
57
from xarray_grass.to_grass import to_grass as to_grass
68
from xarray_grass.coord_utils import RegionData as RegionData
79

src/xarray_grass/grass_backend_array.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,24 @@
2525
from xarray_grass.grass_interface import GrassInterface
2626

2727

28-
class GrassBackendArray(BackendArray):
29-
"""Lazy loading of grass arrays"""
28+
class GrassSTDSBackendArray(BackendArray):
29+
"""Lazy loading of grass Space-Time DataSets (multiple maps in time series)"""
3030

3131
def __init__(
3232
self,
3333
shape,
3434
dtype,
35-
# lock,
36-
map_id: str,
35+
map_list: list, # List of map metadata objects
3736
map_type: str,
3837
grass_interface: GrassInterface,
3938
):
4039
self.shape = shape
4140
self.dtype = dtype
4241
self._lock = threading.Lock()
43-
self.map_id = map_id
42+
self.map_list = map_list # List with .id attribute
4443
self.map_type = map_type # "raster" or "raster3d"
4544
self.grass_interface = grass_interface
46-
self._array: np.ndarray = None
45+
self._cached_maps = {} # Cache loaded maps by index
4746

4847
def __getitem__(self, key: xr.core.indexing.ExplicitIndexer) -> np.typing.ArrayLike:
4948
"""takes in input an index and returns a NumPy array"""
@@ -55,13 +54,44 @@ def __getitem__(self, key: xr.core.indexing.ExplicitIndexer) -> np.typing.ArrayL
5554
)
5655

5756
def _raw_indexing_method(self, key: tuple):
57+
"""Load only the maps needed for the requested slice"""
5858
with self._lock:
59-
if self._array is None:
60-
self._array = self._load_map()
61-
return self._array[key]
59+
# key is a tuple of slices/indices for each dimension
60+
# First dimension is time
61+
time_key = key[0] if key else slice(None)
62+
spatial_key = key[1:] if len(key) > 1 else ()
6263

63-
def _load_map(self):
64-
if self.map_type == "raster":
65-
return self.grass_interface.read_raster_map(self.map_id)
66-
else: # 'raster3d'
67-
return self.grass_interface.read_raster3d_map(self.map_id)
64+
# Determine which time indices are needed
65+
if isinstance(time_key, slice):
66+
time_indices = range(*time_key.indices(self.shape[0]))
67+
elif isinstance(time_key, int):
68+
time_indices = [time_key]
69+
else:
70+
time_indices = list(time_key)
71+
72+
# Load only the needed maps
73+
result_list = []
74+
for t_idx in time_indices:
75+
if t_idx not in self._cached_maps:
76+
map_data = self.map_list[t_idx]
77+
if self.map_type == "raster":
78+
self._cached_maps[t_idx] = self.grass_interface.read_raster_map(
79+
map_data.id
80+
)
81+
else: # 'raster3d'
82+
self._cached_maps[t_idx] = (
83+
self.grass_interface.read_raster3d_map(map_data.id)
84+
)
85+
86+
# Apply spatial indexing
87+
if spatial_key:
88+
result_list.append(self._cached_maps[t_idx][spatial_key])
89+
else:
90+
result_list.append(self._cached_maps[t_idx])
91+
92+
# Stack results along time dimension
93+
if len(result_list) == 1 and isinstance(time_key, int):
94+
# Single time slice requested as integer index
95+
return result_list[0]
96+
else:
97+
return np.stack(result_list, axis=0)

src/xarray_grass/grass_interface.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from grass.script import array as garray
2929
import grass.pygrass.utils as gutils
3030
from grass.pygrass import raster as graster
31-
from grass.pygrass.raster.abstract import Info
31+
from grass.pygrass.raster.abstract import Info, RasterAbstractBase
3232
import grass.temporal as tgis
3333

3434
from xarray_grass.coord_utils import (
@@ -230,11 +230,11 @@ def name_is_str3ds(self, name: str) -> bool:
230230
return bool(tgis.SpaceTimeRaster3DDataset(str3ds_id).is_in_db())
231231

232232
def name_is_raster(self, raster_name: str) -> bool:
233-
"""return True if the given name is a map in the grass database
234-
False if not
235-
"""
233+
"""return True if the given name is a raster map in the grass database."""
234+
# Using pygrass instead of gscript is at least 40x faster
236235
map_id = self.get_id_from_name(raster_name)
237-
return bool(gs.find_file(name=map_id, element="raster").get("file"))
236+
map_object = RasterAbstractBase(map_id)
237+
return map_object.exist()
238238

239239
def name_is_raster_3d(self, raster3d_name: str) -> bool:
240240
"""return True if the given name is a 3D raster in the grass database."""
@@ -261,11 +261,11 @@ def grass_dtype(self, dtype: str) -> str:
261261
@staticmethod
262262
def numpy_dtype(mtype: str) -> np.dtype:
263263
if mtype == "CELL":
264-
dtype = np.int64
264+
dtype = np.dtype("int64")
265265
elif mtype == "FCELL":
266-
dtype = np.float32
266+
dtype = np.dtype("float32")
267267
elif mtype == "DCELL":
268-
dtype = np.float64
268+
dtype = np.dtype("float64")
269269
else:
270270
raise ValueError(f"Unknown GRASS data type: {mtype}")
271271
return dtype

src/xarray_grass/xarray_grass.py

Lines changed: 76 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import xarray_grass
2626
from xarray_grass.grass_interface import GrassInterface
27-
from xarray_grass.grass_backend_array import GrassBackendArray
27+
from xarray_grass.grass_backend_array import GrassSTDSBackendArray
2828

2929

3030
class GrassBackendEntrypoint(BackendEntrypoint):
@@ -265,6 +265,8 @@ def open_grass_maps(
265265
data_array_list.append(data_array)
266266
if raise_on_not_found and any(not_found.values()):
267267
raise ValueError(f"Objects not found: {not_found}")
268+
269+
crs_wkt = gi.get_crs_wkt_str()
268270
finally:
269271
if session is not None:
270272
session.__exit__(None, None, None)
@@ -277,7 +279,7 @@ def open_grass_maps(
277279
data_array_dict = {da.name: da for da in data_array_list}
278280

279281
attrs = {
280-
"crs_wkt": gi.get_crs_wkt_str(),
282+
"crs_wkt": crs_wkt,
281283
"Conventions": "CF-1.13-draft",
282284
# "title": "",
283285
"history": f"{datetime.now(timezone.utc)}: Created with xarray-grass version {xarray_grass.__version__}",
@@ -347,9 +349,7 @@ def open_grass_raster_3d(raster_3d_name: str, grass_i: GrassInterface) -> xr.Dat
347349

348350

349351
def open_grass_strds(strds_name: str, grass_i: GrassInterface) -> xr.DataArray:
350-
"""must be called from within a grass session
351-
TODO: lazy loading
352-
"""
352+
"""Open a STRDS with lazy loading - data is only loaded when accessed"""
353353
strds_id = grass_i.get_id_from_name(strds_name)
354354
strds_name = grass_i.get_name_from_id(strds_id)
355355
x_coords, y_coords, _ = get_coordinates(grass_i, raster_3d=False).values()
@@ -360,45 +360,46 @@ def open_grass_strds(strds_name: str, grass_i: GrassInterface) -> xr.DataArray:
360360
time_unit = strds_infos.time_unit
361361
start_time_dim = f"start_time_{strds_name}"
362362
end_time_dim = f"end_time_{strds_name}"
363-
dims = [start_time_dim, "y", "x"]
364-
coordinates = dict.fromkeys(dims)
365-
coordinates["x"] = x_coords
366-
coordinates["y"] = y_coords
363+
367364
map_list = grass_i.list_maps_in_strds(strds_id)
368365
region = grass_i.get_region()
369-
array_list = []
370-
for map_data in map_list:
371-
# Lazy load the array
372-
backend_array = GrassBackendArray(
373-
shape=(region.rows, region.cols),
374-
dtype=map_data.dtype,
375-
map_id=map_data.id,
376-
map_type="raster",
377-
grass_interface=grass_i,
378-
)
379-
lazy_array = xr.core.indexing.LazilyIndexedArray(backend_array)
380-
# add time dimension at the beginning
381-
lazy_array_with_time = np.expand_dims(lazy_array, axis=0)
382-
383-
# ndarray = grass_i.read_raster_map(map_data.id)
384-
# # add time dimension at the beginning
385-
# ndarray = np.expand_dims(ndarray, axis=0)
386-
387-
coordinates[start_time_dim] = [map_data.start_time]
388-
coordinates[end_time_dim] = (start_time_dim, [map_data.end_time])
389-
390-
data_array = xr.DataArray(
391-
lazy_array_with_time,
392-
coords=coordinates,
393-
dims=dims,
394-
name=strds_name,
395-
)
396-
array_list.append(data_array)
397-
da_concat = xr.concat(array_list, dim=start_time_dim)
366+
367+
# Create a single backend array for the entire STRDS
368+
backend_array = GrassSTDSBackendArray(
369+
shape=(len(map_list), region.rows, region.cols),
370+
dtype=map_list[0].dtype,
371+
map_list=map_list,
372+
map_type="raster",
373+
grass_interface=grass_i,
374+
)
375+
lazy_array = xr.core.indexing.LazilyIndexedArray(backend_array)
376+
377+
# Create Variable with lazy array
378+
var = xr.Variable(dims=[start_time_dim, "y", "x"], data=lazy_array)
379+
380+
# Extract time coordinates
381+
start_times = [map_data.start_time for map_data in map_list]
382+
end_times = [map_data.end_time for map_data in map_list]
383+
384+
# Create coordinates
385+
coordinates = {
386+
"x": x_coords,
387+
"y": y_coords,
388+
start_time_dim: start_times,
389+
end_time_dim: (start_time_dim, end_times),
390+
}
391+
392+
# Convert to DataArray
393+
data_array = xr.DataArray(
394+
var,
395+
coords=coordinates,
396+
name=strds_name,
397+
)
398+
398399
# Add CF attributes
399400
r_infos = grass_i.get_raster_info(map_list[0].id)
400401
da_with_attrs = set_cf_coordinates(
401-
da_concat,
402+
data_array,
402403
gi=grass_i,
403404
is_3d=False,
404405
time_dims=[start_time_dim, end_time_dim],
@@ -414,7 +415,7 @@ def open_grass_strds(strds_name: str, grass_i: GrassInterface) -> xr.DataArray:
414415

415416

416417
def open_grass_str3ds(str3ds_name: str, grass_i: GrassInterface) -> xr.DataArray:
417-
"""Open a series of 3D raster maps.
418+
"""Open a STR3DS with lazy loading - data is only loaded when accessed
418419
TODO: Figure out what to do when the z value of the maps is time."""
419420
str3ds_id = grass_i.get_id_from_name(str3ds_name)
420421
str3ds_name = grass_i.get_name_from_id(str3ds_id)
@@ -426,43 +427,47 @@ def open_grass_str3ds(str3ds_name: str, grass_i: GrassInterface) -> xr.DataArray
426427
time_unit = strds_infos.time_unit
427428
start_time_dim = f"start_time_{str3ds_name}"
428429
end_time_dim = f"end_time_{str3ds_name}"
429-
dims = [start_time_dim, "z", "y_3d", "x_3d"]
430-
coordinates = dict.fromkeys(dims)
431-
coordinates["x_3d"] = x_coords
432-
coordinates["y_3d"] = y_coords
433-
coordinates["z"] = z_coords
430+
434431
map_list = grass_i.list_maps_in_str3ds(str3ds_id)
435432
region = grass_i.get_region()
436-
array_list = []
437-
for map_data in map_list:
438-
# Lazy load the map
439-
backend_array = GrassBackendArray(
440-
shape=(region.depths, region.rows3, region.cols3),
441-
dtype=map_data.dtype,
442-
map_id=map_data.id,
443-
map_type="raster3d",
444-
grass_interface=grass_i,
445-
)
446-
lazy_array = xr.core.indexing.LazilyIndexedArray(backend_array)
447-
# add time dimension at the beginning
448-
lazy_array_with_time = np.expand_dims(lazy_array, axis=0)
449-
450-
coordinates[start_time_dim] = [map_data.start_time]
451-
coordinates[end_time_dim] = (start_time_dim, [map_data.end_time])
452-
453-
data_array = xr.DataArray(
454-
lazy_array_with_time,
455-
coords=coordinates,
456-
dims=dims,
457-
name=str3ds_name,
458-
)
459-
array_list.append(data_array)
460433

461-
da_concat = xr.concat(array_list, dim=start_time_dim)
434+
# Create a single backend array for the entire STR3DS
435+
backend_array = GrassSTDSBackendArray(
436+
shape=(len(map_list), region.depths, region.rows3, region.cols3),
437+
dtype=map_list[0].dtype,
438+
map_list=map_list,
439+
map_type="raster3d",
440+
grass_interface=grass_i,
441+
)
442+
lazy_array = xr.core.indexing.LazilyIndexedArray(backend_array)
443+
444+
# Create Variable with lazy array
445+
var = xr.Variable(dims=[start_time_dim, "z", "y_3d", "x_3d"], data=lazy_array)
446+
447+
# Extract time coordinates
448+
start_times = [map_data.start_time for map_data in map_list]
449+
end_times = [map_data.end_time for map_data in map_list]
450+
451+
# Create coordinates
452+
coordinates = {
453+
"x_3d": x_coords,
454+
"y_3d": y_coords,
455+
"z": z_coords,
456+
start_time_dim: start_times,
457+
end_time_dim: (start_time_dim, end_times),
458+
}
459+
460+
# Convert to DataArray
461+
data_array = xr.DataArray(
462+
var,
463+
coords=coordinates,
464+
name=str3ds_name,
465+
)
466+
462467
# Add CF attributes
463468
r3_infos = grass_i.get_raster3d_info(map_list[0].id)
464469
da_with_attrs = set_cf_coordinates(
465-
da_concat,
470+
data_array,
466471
gi=grass_i,
467472
is_3d=True,
468473
z_unit=r3_infos["vertical_units"],

0 commit comments

Comments
 (0)