88"""
99
1010from __future__ import annotations
11+ from dataclasses import dataclass , field
1112from itertools import zip_longest
1213from pathlib import Path
1314import re
2223
2324if typing .TYPE_CHECKING :
2425 from typing import (Tuple , List , Dict , Optional , Union , Sequence , Iterator ,
25- Set )
26+ Set , Callable , Iterable )
2627 from os import PathLike
2728 from f90nml .namelist import Namelist
2829 from numpy import ndarray
@@ -431,9 +432,11 @@ def at_time(self, time: float, after: bool = False) -> Step:
431432 after )
432433 return self [self .sdat .tseries .isteps [itime ]]
433434
434- def filter (self , ** filters ):
435+ def filter (self , snap : bool = False , rprofs : bool = False ,
436+ fields : Optional [Iterable [str ]] = None ,
437+ func : Optional [Callable [[Step ], bool ]] = None ) -> _StepsView :
435438 """Build a _StepsView with requested filters."""
436- return self [:].filter (** filters )
439+ return self [:].filter (snap , rprofs , fields , func )
437440
438441
439442class _Snaps (_Steps ):
@@ -563,6 +566,41 @@ def _bind(self, isnap: int, istep: int):
563566 self .sdat .steps [istep ]._isnap = isnap
564567
565568
569+ @dataclass
570+ class _Filters :
571+ """Filters on a step view."""
572+
573+ snap : bool = False
574+ rprofs : bool = False
575+ fields : Set [str ] = field (default_factory = set )
576+ funcs : List [Callable [[Step ], bool ]] = field (default_factory = list )
577+
578+ def passes (self , step : Step ) -> bool :
579+ """Whether a given Step passes the filters."""
580+ if self .snap and step .isnap is None :
581+ return False
582+ if self .rprofs :
583+ try :
584+ _ = step .rprofs .centers
585+ except error .MissingDataError :
586+ return False
587+ if any (fld not in step .fields for fld in self .fields ):
588+ return False
589+ return all (func (step ) for func in self .funcs )
590+
591+ def __repr__ (self ) -> str :
592+ flts = []
593+ if self .snap :
594+ flts .append ('snap=True' )
595+ if self .rprofs :
596+ flts .append ('rprofs=True' )
597+ if self .fields :
598+ flts .append (f"fields={ self .fields !r} " )
599+ if self .funcs :
600+ flts .append (f"func={ self .funcs !r} " )
601+ return ', ' .join (flts )
602+
603+
566604class _StepsView :
567605 """Filtered iterator over steps or snaps.
568606
@@ -579,13 +617,7 @@ def __init__(self, steps_col: Union[_Steps, _Snaps], items):
579617 self ._col = steps_col
580618 self ._items = items
581619 self ._rprofs_averaged : Optional [_RprofsAveraged ] = None
582- self ._flt = {
583- 'snap' : False ,
584- 'rprofs' : False ,
585- 'fields' : [],
586- 'func' : lambda _ : True ,
587- }
588- self ._dflt_func = self ._flt ['func' ]
620+ self ._flt = _Filters ()
589621
590622 @property
591623 def rprofs_averaged (self ) -> _RprofsAveraged :
@@ -613,62 +645,51 @@ def stepstr(self) -> str:
613645
614646 def __repr__ (self ) -> str :
615647 rep = f'{ self ._col .sdat !r} .{ self .stepstr } '
616- flts = []
617- for flt in ('snap' , 'rprofs' , 'fields' ):
618- if self ._flt [flt ]:
619- flts .append (f'{ flt } ={ self ._flt [flt ]!r} ' )
620- if self ._flt ['func' ] is not self ._dflt_func :
621- flts .append (f"func={ self ._flt ['func' ]!r} " )
648+ flts = repr (self ._flt )
622649 if flts :
623- rep += '.filter({})' . format ( ', ' . join ( flts ))
650+ rep += f '.filter({ flts } )'
624651 return rep
625652
626- def _pass (self , item ) -> bool :
653+ def _pass (self , item : int ) -> bool :
627654 """Check whether an item passes the filters."""
628655 try :
629656 step = self ._col [item ]
630657 except KeyError :
631658 return False
632- okf = True
633- okf = okf and (not self ._flt ['snap' ] or step .isnap is not None )
634- if self ._flt ['rprofs' ]:
635- try :
636- _ = step .rprofs .centers
637- except error .MissingDataError :
638- return False
639- okf = okf and all (f in step .fields for f in self ._flt ['fields' ])
640- okf = okf and bool (self ._flt ['func' ](step ))
641- return okf
659+ return self ._flt .passes (step )
642660
643- def filter (self , ** filters ):
644- """Update filters with provided arguments.
661+ def filter (self , snap : bool = False , rprofs : bool = False ,
662+ fields : Optional [Iterable [str ]] = None ,
663+ func : Optional [Callable [[Step ], bool ]] = None ) -> _StepsView :
664+ """Add filters to the view.
645665
646- Note that filters are only resolved when the view is iterated, and
647- hence they do not compose. Each call to filter merely updates the
648- relevant filters. For example, with this code::
666+ Note that filters are only resolved when the view is iterated.
667+ Successive calls to :meth:`filter` compose. For example, with this
668+ code::
649669
650670 view = sdat.steps[500:].filter(rprofs=True, fields=['T'])
651- view.filter(fields=[])
671+ view.filter(fields=['eta' ])
652672
653673 the produced ``view``, when iterated, will generate the steps after the
654- 500-th that have radial profiles. The ``fields`` filter set in the
655- first line is emptied in the second line .
674+ 500-th that have radial profiles, and both the temperature and
675+ viscosity fields .
656676
657677 Args:
658- snap (bool): the step must be a snapshot to pass.
659- rprofs (bool): the step must have rprofs data to pass.
660- fields (list): list of fields that must be present to pass.
661- func (function): arbitrary function taking a
662- :class:`~stagpy._step.Step` as argument and returning a True
663- value if the step should pass the filter.
678+ snap: if true, the step must be a snapshot to pass.
679+ rprofs: if true, the step must have rprofs data to pass.
680+ fields: list of fields that must be present to pass.
681+ func: arbitrary function returning whether a step should pass the
682+ filter.
664683
665684 Returns:
666685 self.
667686 """
668- for flt , val in self ._flt .items ():
669- self ._flt [flt ] = filters .pop (flt , val )
670- if filters :
671- raise error .UnknownFiltersError (filters .keys ())
687+ self ._flt .snap = self ._flt .snap or snap
688+ self ._flt .rprofs = self ._flt .rprofs or rprofs
689+ if fields is not None :
690+ self ._flt .fields = self ._flt .fields .union (fields )
691+ if func is not None :
692+ self ._flt .funcs .append (func )
672693 return self
673694
674695 def __iter__ (self ) -> Iterator [Step ]:
0 commit comments