11from __future__ import annotations
22
3+ import itertools
34from abc import abstractmethod
45from dataclasses import dataclass , field
56from pathlib import Path
@@ -139,7 +140,14 @@ def fuzzy_column_match(pos: CodeRange, location: Location) -> bool:
139140
140141
141142class ResultSet (dict [str , dict [Path , list [Result ]]]):
143+ results_for_rule : dict [str , list [Result ]]
144+
145+ def __init__ (self , * args , ** kwargs ):
146+ super ().__init__ (* args , ** kwargs )
147+ self .results_for_rule = {}
148+
142149 def add_result (self , result : Result ):
150+ self .results_for_rule .setdefault (result .rule_id , []).append (result )
143151 for loc in result .locations :
144152 self .setdefault (result .rule_id , {}).setdefault (loc .file , []).append (result )
145153
@@ -157,23 +165,39 @@ def results_for_rule_and_file(
157165 """
158166 return self .get (rule_id , {}).get (file .relative_to (context .directory ), [])
159167
168+ def results_for_rules (self , rule_ids : list [str ]) -> list [Result ]:
169+ """
170+ Returns flat list of all results that match any of the given rule IDs.
171+ """
172+ return list (
173+ itertools .chain .from_iterable (
174+ self .results_for_rule .get (rule_id , []) for rule_id in rule_ids
175+ )
176+ )
177+
160178 def files_for_rule (self , rule_id : str ) -> list [Path ]:
161179 return list (self .get (rule_id , {}).keys ())
162180
163181 def all_rule_ids (self ) -> list [str ]:
164182 return list (self .keys ())
165183
166184 def __or__ (self , other ):
167- result = ResultSet ( super (). __or__ ( other ) )
185+ result = self . __class__ ( )
168186 for k in self .keys () | other .keys ():
169- result [k ] = list_dict_or (self [k ], other [k ])
187+ result [k ] = list_dict_or (self .get (k , {}), other .get (k , {}))
188+ result .results_for_rule = list_dict_or (
189+ self .results_for_rule , other .results_for_rule
190+ )
170191 return result
171192
193+ def __ior__ (self , other ):
194+ return self | other
195+
172196
173197def list_dict_or (
174198 dictionary : dict [Any , list [Any ]], other : dict [Any , list [Any ]]
175- ) -> dict [Path , list [Any ]]:
176- result_dict = other | dictionary
199+ ) -> dict [Any , list [Any ]]:
200+ result_dict = {}
177201 for k in other .keys () | dictionary .keys ():
178- result_dict [k ] = dictionary [ k ] + other [ k ]
202+ result_dict [k ] = dictionary . get ( k , []) + other . get ( k , [])
179203 return result_dict
0 commit comments