Skip to content

Commit c18f55a

Browse files
committed
isolate logic to cache fields in FieldCache
1 parent 56f42b2 commit c18f55a

File tree

4 files changed

+95
-56
lines changed

4 files changed

+95
-56
lines changed

src/stagpy/_caching.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from __future__ import annotations
2+
3+
import typing
4+
from collections import deque
5+
from dataclasses import dataclass
6+
from functools import cached_property
7+
8+
if typing.TYPE_CHECKING:
9+
from .datatypes import Field
10+
11+
12+
@dataclass(frozen=True)
13+
class FieldCache:
14+
"""FIFO cache of [Field][]s.
15+
16+
If `maxsize` is None, entries are never evicted from the cache.
17+
"""
18+
19+
maxsize: int | None
20+
21+
@cached_property
22+
def _stack(self) -> deque[tuple[int, str]]:
23+
return deque()
24+
25+
@cached_property
26+
def _data(self) -> dict[tuple[int, str], Field]:
27+
return {}
28+
29+
def _prune(self) -> None:
30+
if self.maxsize is None:
31+
return
32+
while len(self._stack) > self.maxsize:
33+
elt = self._stack.popleft()
34+
del self._data[elt]
35+
36+
def resize(self, new_size: int | None) -> None:
37+
object.__setattr__(self, "maxsize", new_size)
38+
self._prune()
39+
40+
def insert(self, istep: int, name: str, field: Field) -> None:
41+
key = (istep, name)
42+
if key not in self._data:
43+
self._stack.append(key)
44+
self._data[key] = field
45+
self._prune()
46+
47+
def get(self, istep: int, name: str) -> Field | None:
48+
return self._data.get((istep, name))
49+
50+
def evict_istep(self, istep: int) -> None:
51+
to_keep = []
52+
for key in self._stack:
53+
if key[0] == istep:
54+
del self._data[key]
55+
else:
56+
to_keep.append(key)
57+
self._stack.clear()
58+
self._stack.extend(to_keep)
59+
assert len(self._stack) == len(self._data)

src/stagpy/stagyydata.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from . import _helpers, error, phyvars, stagyyparsers, step
2222
from . import datatypes as dt
23+
from ._caching import FieldCache
2324
from .parfile import StagyyPar
2425
from .stagyyparsers import FieldXmf, TracersXmf
2526
from .step import Step
@@ -332,9 +333,7 @@ def __getitem__(self, istep: int | slice | Sequence[StepIndex]) -> Step | StepsV
332333

333334
def __delitem__(self, istep: int | None) -> None:
334335
if istep is not None and istep in self._data:
335-
self.sdat._collected_fields = [
336-
(i, f) for i, f in self.sdat._collected_fields if i != istep
337-
]
336+
self.sdat._field_cache.evict_istep(istep)
338337
del self._data[istep]
339338

340339
def __len__(self) -> int:
@@ -668,9 +667,6 @@ def __init__(self, path: PathLike, read_parameters_dat: bool = True):
668667
self.steps = Steps(self)
669668
self.snaps = Snaps(self)
670669
self._read_parameters_dat = read_parameters_dat
671-
self._nfields_max: int | None = 50
672-
# list of (istep, field_name) in memory
673-
self._collected_fields: list[tuple[int, str]] = []
674670

675671
def __repr__(self) -> str:
676672
return f"StagyyData({self.path!r})"
@@ -752,23 +748,17 @@ def _files(self) -> set[Path]:
752748
return set(out_dir.iterdir())
753749
return set()
754750

755-
@property
756-
def nfields_max(self) -> int | None:
757-
"""Maximum number of scalar fields kept in memory.
751+
def set_nfields_max(self, nfields: int | None) -> None:
752+
"""Adjust maximum number of scalar fields kept in memory.
758753
759754
Setting this to a value lower or equal to 5 raises a
760755
[stagpy.error.InvalidNfieldsError][]. Set this to `None` if
761756
you do not want any limit on the number of scalar fields kept in
762757
memory. Defaults to 50.
763758
"""
764-
return self._nfields_max
765-
766-
@nfields_max.setter
767-
def nfields_max(self, nfields: int | None) -> None:
768-
"""Check nfields > 5 or None."""
769759
if nfields is not None and nfields <= 5:
770760
raise error.InvalidNfieldsError(nfields)
771-
self._nfields_max = nfields
761+
self._field_cache.resize(nfields)
772762

773763
def filename(
774764
self,
@@ -811,3 +801,7 @@ def _binfiles_set(self, isnap: int) -> set[Path]:
811801
for fstem in phyvars.FIELD_FILES
812802
)
813803
return possible_files & self._files
804+
805+
@cached_property
806+
def _field_cache(self) -> FieldCache:
807+
return FieldCache(maxsize=50)

src/stagpy/step.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from numpy.typing import NDArray
2626
from pandas import DataFrame, Series
2727

28+
from ._caching import FieldCache
2829
from .datatypes import Varf
2930
from .stagyydata import StagyyData
3031

@@ -292,28 +293,41 @@ def __init__(
292293
self._extra = extravars
293294
self._files = files
294295
self._filesh5 = filesh5
295-
self._data: dict[str, Field] = {}
296296
super().__init__()
297297

298+
@cached_property
299+
def _all_vars(self) -> set[str]:
300+
return set(self._vars.keys()).union(self._extra.keys())
301+
302+
@cached_property
303+
def _cache(self) -> FieldCache:
304+
return self.step.sdat._field_cache
305+
298306
def __getitem__(self, name: str) -> Field:
299-
if name in self._data:
300-
return self._data[name]
301-
if name in self._vars:
302-
fld_names, parsed_data = self._get_raw_data(name)
303-
elif name in self._extra:
304-
self._data[name] = self._extra[name](self.step)
305-
return self._data[name]
306-
else:
307+
if name not in self._all_vars:
307308
raise error.UnknownFieldVarError(name)
309+
310+
maybe_fld = self._cache.get(self.step.istep, name)
311+
if maybe_fld is not None:
312+
return maybe_fld
313+
314+
if name in self._extra:
315+
fld = self._extra[name](self.step)
316+
self._cache.insert(self.step.istep, name, fld)
317+
return fld
318+
319+
# requested field is one of self._vars
320+
fld_names, parsed_data = self._get_raw_data(name)
308321
if parsed_data is None:
309322
raise error.MissingDataError(
310323
f"Missing field {name} in step {self.step.istep}"
311324
)
312325
header, fields = parsed_data
313326
self._cropped__header = header
314-
for fld_name, fld in zip(fld_names, fields):
315-
self._set(fld_name, fld)
316-
return self._data[name]
327+
for fld_name, fld_vals in zip(fld_names, fields):
328+
fld = Field(fld_vals, self._vars[fld_name])
329+
self._cache.insert(self.step.istep, fld_name, fld)
330+
return self[name]
317331

318332
@cached_property
319333
def _present_fields(self) -> list[str]:
@@ -371,20 +385,6 @@ def _get_raw_data(self, name: str) -> tuple[list[str], Any]:
371385
break
372386
return list_fvar, parsed_data
373387

374-
def _set(self, name: str, fld: NDArray) -> None:
375-
sdat = self.step.sdat
376-
col_fld = sdat._collected_fields
377-
col_fld.append((self.step.istep, name))
378-
if sdat.nfields_max is not None:
379-
while len(col_fld) > sdat.nfields_max:
380-
istep, fld_name = col_fld.pop(0)
381-
del sdat.steps[istep].fields[fld_name]
382-
self._data[name] = Field(fld, self._vars[name])
383-
384-
def __delitem__(self, name: str) -> None:
385-
if name in self._data:
386-
del self._data[name]
387-
388388
@cached_property
389389
def _header(self) -> dict[str, Any] | None:
390390
if self.step.isnap is None:

tests/test_stagyydata.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,9 @@ def test_sdat_path(example_dir: Path, sdat: StagyyData) -> None:
1111
assert sdat.path == example_dir
1212

1313

14-
def test_sdat_deflt_nfields_max(sdat: StagyyData) -> None:
15-
assert sdat.nfields_max == 50
16-
17-
18-
def test_sdat_set_nfields_max(sdat: StagyyData) -> None:
19-
sdat.nfields_max = 6
20-
assert sdat.nfields_max == 6
21-
22-
23-
def test_sdat_set_nfields_max_none(sdat: StagyyData) -> None:
24-
sdat.nfields_max = None
25-
assert sdat.nfields_max is None
26-
27-
2814
def test_sdat_set_nfields_max_invalid(sdat: StagyyData) -> None:
2915
with pytest.raises(stagpy.error.InvalidNfieldsError) as err:
30-
sdat.nfields_max = 5
16+
sdat.set_nfields_max(5)
3117
assert err.value.nfields == 5
3218

3319

0 commit comments

Comments
 (0)