@@ -275,53 +275,25 @@ def all_story_names(cls) -> Sequence[str]:
275275 "PressBenchmarkStoryT" , bound = PressBenchmarkStory )
276276
277277
278- class PressBenchmarkStoryFilter (StoryFilter [PressBenchmarkStoryT ],
279- Generic [PressBenchmarkStoryT ]):
280- """
281- Filter stories by name or regexp.
278+ class RegexFilter ():
282279
283- Syntax:
284- "all" Include all stories (defaults to story_names).
285- "name" Include story with the given name.
286- "-name" Exclude story with the given name'
287- "foo.*" Include stories whose name matches the regexp.
288- "-foo.*" Exclude stories whose name matches the regexp.
289-
290- These patterns can be combined:
291- [".*", "-foo", "-bar"] Includes all except the "foo" and "bar" story
292- """
293-
294- @classmethod
295- @override
296- def kwargs_from_cli (cls , args : argparse .Namespace ) -> dict [str , Any ]:
297- kwargs = super ().kwargs_from_cli (args )
298- kwargs ["separate" ] = args .separate
299- kwargs ["url" ] = args .custom_benchmark_url
300- return kwargs
301-
302- def __init__ (self ,
303- story_cls : Type [PressBenchmarkStoryT ],
304- patterns : Sequence [str ],
305- args : Optional [argparse .Namespace ] = None ,
306- separate : bool = False ,
307- url : Optional [str ] = None ) -> None :
308- self .url : str | None = url
280+ def __init__ (self , all_names : Sequence [str ], default_names : Sequence [str ]):
281+ self ._all_names : dict [str , None ] = dict .fromkeys (all_names )
282+ self ._default_names : dict [str , None ] = dict .fromkeys (default_names )
309283 self ._selected_names : OrderedSet [str ] = OrderedSet ()
310- super ().__init__ (story_cls , patterns , args , separate )
311- assert issubclass (self .story_cls , PressBenchmarkStory )
312- for name in self ._known_names :
284+ for name in self ._all_names :
313285 assert name , "Invalid empty story name"
314286 assert not name .startswith ("-" ), (
315287 f"Known story names cannot start with '-', but got '{ name } '." )
316288 assert not name == "all" , "Known story name cannot match 'all'."
317289
318- @override
319- def process_all (self , patterns : Sequence [str ]) -> None :
290+ def process_all (self , patterns : Sequence [str ]) -> OrderedSet [str ]:
320291 if not isinstance (patterns , (list , tuple )):
321292 raise ValueError ("Expected Sequence of story name or patterns "
322293 f"but got '{ type (patterns )} '." )
323294 for pattern in patterns :
324295 self .process_pattern (pattern )
296+ return self ._selected_names
325297
326298 def process_pattern (self , pattern : str ) -> None :
327299 if pattern .startswith ("-" ):
@@ -343,12 +315,11 @@ def _pattern_to_regexp(self, pattern: str) -> re.Pattern:
343315 if pattern == "all" :
344316 return re .compile (".*" )
345317 if pattern == "default" :
346- default_story_names = self .story_cls .default_story_names ()
347- if default_story_names == self .story_cls .all_story_names ():
318+ if self ._default_names == self ._all_names :
348319 return re .compile (".*" )
349- joined_names = "|" .join (re .escape (name ) for name in default_story_names )
320+ joined_names = "|" .join (re .escape (name ) for name in self . _default_names )
350321 return re .compile (f"^({ joined_names } )$" )
351- if pattern in self ._known_names :
322+ if pattern in self ._all_names :
352323 return re .compile (re .escape (pattern ))
353324 return re .compile (pattern )
354325
@@ -378,13 +349,13 @@ def _remove_matching(self, regexp: re.Pattern, original_pattern: str) -> None:
378349 def _regexp_match (self , regexp : re .Pattern ,
379350 original_pattern : str ) -> list [str ]:
380351 substories = [
381- substory for substory in self ._known_names if regexp .fullmatch (substory )
352+ substory for substory in self ._all_names if regexp .fullmatch (substory )
382353 ]
383354 if not substories :
384355 substories = self ._regexp_match_ignorecase (regexp )
385356 if not substories :
386357 return self ._handle_no_match (original_pattern )
387- if len (substories ) == len (self ._known_names ) and self ._selected_names :
358+ if len (substories ) == len (self ._all_names ) and self ._selected_names :
388359 raise ValueError (f"'{ original_pattern } ' matched all and overrode all"
389360 "previously filtered story names." )
390361 return substories
@@ -394,20 +365,62 @@ def _regexp_match_ignorecase(self, regexp: re.Pattern) -> list[str]:
394365 "No matching stories, using case-insensitive fallback regexp." )
395366 iregexp : re .Pattern = re .compile (regexp .pattern , flags = re .IGNORECASE )
396367 return [
397- substory for substory in self ._known_names
398- if iregexp .fullmatch (substory )
368+ substory for substory in self ._all_names if iregexp .fullmatch (substory )
399369 ]
400370
401371 def _handle_no_match (self , original_pattern : str ) -> list [str ]:
402372 choices_ms , alternative = close_matches_message (original_pattern ,
403- self ._known_names )
373+ self ._all_names )
404374 error_message : str = f"'{ original_pattern } ' didn't match any stories."
405375 error_message += choices_ms
406376 if alternative :
407377 logging .error (error_message )
408378 return [alternative ]
409379 raise ValueError (error_message )
410380
381+
382+ class PressBenchmarkStoryFilter (StoryFilter [PressBenchmarkStoryT ],
383+ Generic [PressBenchmarkStoryT ]):
384+ """
385+ Filter stories by name or regexp.
386+
387+ Syntax:
388+ "all" Include all stories (defaults to story_names).
389+ "name" Include story with the given name.
390+ "-name" Exclude story with the given name'
391+ "foo.*" Include stories whose name matches the regexp.
392+ "-foo.*" Exclude stories whose name matches the regexp.
393+
394+ These patterns can be combined:
395+ [".*", "-foo", "-bar"] Includes all except the "foo" and "bar" story
396+ """
397+
398+ @classmethod
399+ @override
400+ def kwargs_from_cli (cls , args : argparse .Namespace ) -> dict [str , Any ]:
401+ kwargs = super ().kwargs_from_cli (args )
402+ kwargs ["separate" ] = args .separate
403+ kwargs ["url" ] = args .custom_benchmark_url
404+ return kwargs
405+
406+ def __init__ (self ,
407+ story_cls : Type [PressBenchmarkStoryT ],
408+ patterns : Sequence [str ],
409+ args : Optional [argparse .Namespace ] = None ,
410+ separate : bool = False ,
411+ url : Optional [str ] = None ) -> None :
412+ self .url : str | None = url
413+ self ._selected_names : OrderedSet [str ] = OrderedSet ()
414+ super ().__init__ (story_cls , patterns , args , separate )
415+ assert issubclass (self .story_cls , PressBenchmarkStory )
416+
417+ @override
418+ def process_all (self , patterns : Sequence [str ]) -> None :
419+ regex_filter = RegexFilter (
420+ all_names = self .story_cls .all_story_names (),
421+ default_names = self .story_cls .default_story_names ())
422+ self ._selected_names = regex_filter .process_all (patterns )
423+
411424 @override
412425 def create_stories (self , separate : bool ) -> Sequence [PressBenchmarkStoryT ]:
413426 names = list (self ._selected_names )
0 commit comments