Skip to content

Commit a6db3f5

Browse files
committed
Refactor filters mechanism
Filters now compose, and the API is type annotated properly.
1 parent 3a4c79d commit a6db3f5

File tree

2 files changed

+67
-59
lines changed

2 files changed

+67
-59
lines changed

stagpy/error.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import typing
55

66
if typing.TYPE_CHECKING:
7-
from typing import Sequence
87
from os import PathLike
98
from .stagyydata import StagyyData
109
from ._step import Step
@@ -194,18 +193,6 @@ def __init__(self, varname: str):
194193
super().__init__(varname)
195194

196195

197-
class UnknownFiltersError(StagpyError):
198-
"""Raised when invalid step filter is requested.
199-
200-
Attributes:
201-
filters: the invalid filter names.
202-
"""
203-
204-
def __init__(self, filters: Sequence[str]):
205-
self.filters = filters
206-
super().__init__(', '.join(repr(f) for f in filters))
207-
208-
209196
class UnknownFieldVarError(UnknownVarError):
210197
"""Raised when invalid field var is requested.
211198

stagpy/stagyydata.py

Lines changed: 67 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
from __future__ import annotations
11+
from dataclasses import dataclass, field
1112
from itertools import zip_longest
1213
from pathlib import Path
1314
import re
@@ -22,7 +23,7 @@
2223

2324
if typing.TYPE_CHECKING:
2425
from typing import (Tuple, List, Dict, Optional, Union, Sequence, Iterator,
25-
Set)
26+
Set, Callable, Iterable)
2627
from os import PathLike
2728
from f90nml.namelist import Namelist
2829
from numpy import ndarray
@@ -431,9 +432,11 @@ def at_time(self, time: float, after: bool = False) -> Step:
431432
after)
432433
return self[self.sdat.tseries.isteps[itime]]
433434

434-
def filter(self, **filters):
435+
def filter(self, snap: bool = False, rprofs: bool = False,
436+
fields: Optional[Iterable[str]] = None,
437+
func: Optional[Callable[[Step], bool]] = None) -> _StepsView:
435438
"""Build a _StepsView with requested filters."""
436-
return self[:].filter(**filters)
439+
return self[:].filter(snap, rprofs, fields, func)
437440

438441

439442
class _Snaps(_Steps):
@@ -563,6 +566,41 @@ def _bind(self, isnap: int, istep: int):
563566
self.sdat.steps[istep]._isnap = isnap
564567

565568

569+
@dataclass
570+
class _Filters:
571+
"""Filters on a step view."""
572+
573+
snap: bool = False
574+
rprofs: bool = False
575+
fields: Set[str] = field(default_factory=set)
576+
funcs: List[Callable[[Step], bool]] = field(default_factory=list)
577+
578+
def passes(self, step: Step) -> bool:
579+
"""Whether a given Step passes the filters."""
580+
if self.snap and step.isnap is None:
581+
return False
582+
if self.rprofs:
583+
try:
584+
_ = step.rprofs.centers
585+
except error.MissingDataError:
586+
return False
587+
if any(fld not in step.fields for fld in self.fields):
588+
return False
589+
return all(func(step) for func in self.funcs)
590+
591+
def __repr__(self) -> str:
592+
flts = []
593+
if self.snap:
594+
flts.append('snap=True')
595+
if self.rprofs:
596+
flts.append('rprofs=True')
597+
if self.fields:
598+
flts.append(f"fields={self.fields!r}")
599+
if self.funcs:
600+
flts.append(f"func={self.funcs!r}")
601+
return ', '.join(flts)
602+
603+
566604
class _StepsView:
567605
"""Filtered iterator over steps or snaps.
568606
@@ -579,13 +617,7 @@ def __init__(self, steps_col: Union[_Steps, _Snaps], items):
579617
self._col = steps_col
580618
self._items = items
581619
self._rprofs_averaged: Optional[_RprofsAveraged] = None
582-
self._flt = {
583-
'snap': False,
584-
'rprofs': False,
585-
'fields': [],
586-
'func': lambda _: True,
587-
}
588-
self._dflt_func = self._flt['func']
620+
self._flt = _Filters()
589621

590622
@property
591623
def rprofs_averaged(self) -> _RprofsAveraged:
@@ -613,62 +645,51 @@ def stepstr(self) -> str:
613645

614646
def __repr__(self) -> str:
615647
rep = f'{self._col.sdat!r}.{self.stepstr}'
616-
flts = []
617-
for flt in ('snap', 'rprofs', 'fields'):
618-
if self._flt[flt]:
619-
flts.append(f'{flt}={self._flt[flt]!r}')
620-
if self._flt['func'] is not self._dflt_func:
621-
flts.append(f"func={self._flt['func']!r}")
648+
flts = repr(self._flt)
622649
if flts:
623-
rep += '.filter({})'.format(', '.join(flts))
650+
rep += f'.filter({flts})'
624651
return rep
625652

626-
def _pass(self, item) -> bool:
653+
def _pass(self, item: int) -> bool:
627654
"""Check whether an item passes the filters."""
628655
try:
629656
step = self._col[item]
630657
except KeyError:
631658
return False
632-
okf = True
633-
okf = okf and (not self._flt['snap'] or step.isnap is not None)
634-
if self._flt['rprofs']:
635-
try:
636-
_ = step.rprofs.centers
637-
except error.MissingDataError:
638-
return False
639-
okf = okf and all(f in step.fields for f in self._flt['fields'])
640-
okf = okf and bool(self._flt['func'](step))
641-
return okf
659+
return self._flt.passes(step)
642660

643-
def filter(self, **filters):
644-
"""Update filters with provided arguments.
661+
def filter(self, snap: bool = False, rprofs: bool = False,
662+
fields: Optional[Iterable[str]] = None,
663+
func: Optional[Callable[[Step], bool]] = None) -> _StepsView:
664+
"""Add filters to the view.
645665
646-
Note that filters are only resolved when the view is iterated, and
647-
hence they do not compose. Each call to filter merely updates the
648-
relevant filters. For example, with this code::
666+
Note that filters are only resolved when the view is iterated.
667+
Successive calls to :meth:`filter` compose. For example, with this
668+
code::
649669
650670
view = sdat.steps[500:].filter(rprofs=True, fields=['T'])
651-
view.filter(fields=[])
671+
view.filter(fields=['eta'])
652672
653673
the produced ``view``, when iterated, will generate the steps after the
654-
500-th that have radial profiles. The ``fields`` filter set in the
655-
first line is emptied in the second line.
674+
500-th that have radial profiles, and both the temperature and
675+
viscosity fields.
656676
657677
Args:
658-
snap (bool): the step must be a snapshot to pass.
659-
rprofs (bool): the step must have rprofs data to pass.
660-
fields (list): list of fields that must be present to pass.
661-
func (function): arbitrary function taking a
662-
:class:`~stagpy._step.Step` as argument and returning a True
663-
value if the step should pass the filter.
678+
snap: if true, the step must be a snapshot to pass.
679+
rprofs: if true, the step must have rprofs data to pass.
680+
fields: list of fields that must be present to pass.
681+
func: arbitrary function returning whether a step should pass the
682+
filter.
664683
665684
Returns:
666685
self.
667686
"""
668-
for flt, val in self._flt.items():
669-
self._flt[flt] = filters.pop(flt, val)
670-
if filters:
671-
raise error.UnknownFiltersError(filters.keys())
687+
self._flt.snap = self._flt.snap or snap
688+
self._flt.rprofs = self._flt.rprofs or rprofs
689+
if fields is not None:
690+
self._flt.fields = self._flt.fields.union(fields)
691+
if func is not None:
692+
self._flt.funcs.append(func)
672693
return self
673694

674695
def __iter__(self) -> Iterator[Step]:

0 commit comments

Comments
 (0)