@@ -173,82 +173,146 @@ def metadata(self) -> Metadata:
173173
174174 def create_filter (
175175 self ,
176- datum_ids : list [str ] | None = None ,
177- groundtruth_ids : list [str ] | None = None ,
178- prediction_ids : list [str ] | None = None ,
179- labels : list [str ] | None = None ,
176+ datums : list [str ] | NDArray [ np . int32 ] | None = None ,
177+ groundtruths : list [str ] | NDArray [ np . int32 ] | None = None ,
178+ predictions : list [str ] | NDArray [ np . int32 ] | None = None ,
179+ labels : list [str ] | NDArray [ np . int32 ] | None = None ,
180180 ) -> Filter :
181181 """
182182 Creates a filter object.
183183
184184 Parameters
185185 ----------
186- datum_uids : list[str], optional
187- An optional list of string uids representing datums to keep.
188- groundtruth_ids : list[str], optional
189- An optional list of string uids representing ground truth annotations to keep.
190- prediction_ids : list[str], optional
191- An optional list of string uids representing prediction annotations to keep.
192- labels : list[str], optional
193- An optional list of labels to keep.
186+ datum : list[str] | NDArray[int32 ], optional
187+ An optional list of string ids or indices representing datums to keep.
188+ groundtruth : list[str] | NDArray[int32 ], optional
189+ An optional list of string ids or indices representing ground truth annotations to keep.
190+ prediction : list[str] | NDArray[int32 ], optional
191+ An optional list of string ids or indices representing prediction annotations to keep.
192+ labels : list[str] | NDArray[int32] , optional
193+ An optional list of labels or indices to keep.
194194 """
195195 mask_datums = np .ones (self ._detailed_pairs .shape [0 ], dtype = np .bool_ )
196196
197197 # filter datums
198- if datum_ids is not None :
199- if not datum_ids :
200- raise EmptyFilterError ("filter removes all datums" )
201- valid_datum_indices = np .array (
202- [self .datum_id_to_index [uid ] for uid in datum_ids ],
203- dtype = np .int32 ,
204- )
205- mask_datums = np .isin (
206- self ._detailed_pairs [:, 0 ], valid_datum_indices
207- )
198+ if datums is not None :
199+ # convert to indices
200+ if isinstance (datums , list ):
201+ datums = np .array (
202+ [self .datum_id_to_index [uid ] for uid in datums ],
203+ dtype = np .int32 ,
204+ )
205+
206+ # validate indices
207+ if datums .size == 0 :
208+ raise EmptyFilterError (
209+ "filter removes all datums"
210+ ) # validate indices
211+ elif datums .min () < 0 :
212+ raise ValueError (
213+ f"datum index cannot be negative '{ datums .min ()} '"
214+ )
215+ elif datums .max () >= len (self .index_to_datum_id ):
216+ raise ValueError (
217+ f"datum index cannot exceed total number of datums '{ datums .max ()} '"
218+ )
219+
220+ # apply to mask
221+ mask_datums = np .isin (self ._detailed_pairs [:, 0 ], datums )
208222
209223 filtered_detailed_pairs = self ._detailed_pairs [mask_datums ]
210224 n_pairs = self ._detailed_pairs [mask_datums ].shape [0 ]
211225 mask_groundtruths = np .zeros (n_pairs , dtype = np .bool_ )
212226 mask_predictions = np .zeros_like (mask_groundtruths )
213227
214228 # filter by ground truth annotation ids
215- if groundtruth_ids is not None :
216- valid_groundtruth_indices = np .array (
217- [self .groundtruth_id_to_index [uid ] for uid in groundtruth_ids ],
218- dtype = np .int32 ,
219- )
229+ if groundtruths is not None :
230+ # convert to indices
231+ if isinstance (groundtruths , list ):
232+ groundtruths = np .array (
233+ [
234+ self .groundtruth_id_to_index [uid ]
235+ for uid in groundtruths
236+ ],
237+ dtype = np .int32 ,
238+ )
239+
240+ # validate indices
241+ if groundtruths .size == 0 :
242+ warnings .warn ("filter removes all ground truths" )
243+ elif groundtruths .min () < 0 :
244+ raise ValueError (
245+ f"groundtruth annotation index cannot be negative '{ groundtruths .min ()} '"
246+ )
247+ elif groundtruths .max () >= len (self .index_to_groundtruth_id ):
248+ raise ValueError (
249+ f"groundtruth annotation index cannot exceed total number of groundtruths '{ groundtruths .max ()} '"
250+ )
251+
252+ # apply to mask
220253 mask_groundtruths [
221254 ~ np .isin (
222255 filtered_detailed_pairs [:, 1 ],
223- valid_groundtruth_indices ,
256+ groundtruths ,
224257 )
225258 ] = True
226259
227260 # filter by prediction annotation ids
228- if prediction_ids is not None :
229- valid_prediction_indices = np .array (
230- [self .prediction_id_to_index [uid ] for uid in prediction_ids ],
231- dtype = np .int32 ,
232- )
261+ if predictions is not None :
262+ # convert to indices
263+ if isinstance (predictions , list ):
264+ predictions = np .array (
265+ [self .prediction_id_to_index [uid ] for uid in predictions ],
266+ dtype = np .int32 ,
267+ )
268+
269+ # validate indices
270+ if predictions .size == 0 :
271+ warnings .warn ("filter removes all predictions" )
272+ elif predictions .min () < 0 :
273+ raise ValueError (
274+ f"prediction annotation index cannot be negative '{ predictions .min ()} '"
275+ )
276+ elif predictions .max () >= len (self .index_to_prediction_id ):
277+ raise ValueError (
278+ f"prediction annotation index cannot exceed total number of predictions '{ predictions .max ()} '"
279+ )
280+
281+ # apply to mask
233282 mask_predictions [
234283 ~ np .isin (
235284 filtered_detailed_pairs [:, 2 ],
236- valid_prediction_indices ,
285+ predictions ,
237286 )
238287 ] = True
239288
240289 # filter by labels
241290 if labels is not None :
242- if not labels :
291+ # convert to indices
292+ if isinstance (labels , list ):
293+ labels = np .array (
294+ [self .label_to_index [label ] for label in labels ]
295+ )
296+
297+ # validate indices
298+ if labels .size == 0 :
243299 raise EmptyFilterError ("filter removes all labels" )
244- valid_label_indices = np .array (
245- [self .label_to_index [label ] for label in labels ] + [- 1 ]
246- )
300+ elif labels .min () < 0 :
301+ raise ValueError (
302+ f"label index cannot be negative '{ labels .min ()} '"
303+ )
304+ elif labels .max () >= len (self .index_to_label ):
305+ raise ValueError (
306+ f"label index cannot exceed total number of labels '{ labels .max ()} '"
307+ )
308+
309+ # apply to mask
310+ labels = np .concatenate ([labels , np .array ([- 1 ])]) # add null label
247311 mask_groundtruths [
248- ~ np .isin (filtered_detailed_pairs [:, 3 ], valid_label_indices )
312+ ~ np .isin (filtered_detailed_pairs [:, 3 ], labels )
249313 ] = True
250314 mask_predictions [
251- ~ np .isin (filtered_detailed_pairs [:, 4 ], valid_label_indices )
315+ ~ np .isin (filtered_detailed_pairs [:, 4 ], labels )
252316 ] = True
253317
254318 filtered_detailed_pairs , _ , _ = filter_cache (
@@ -260,8 +324,8 @@ def create_filter(
260324 )
261325
262326 number_of_datums = (
263- len ( datum_ids )
264- if datum_ids
327+ datums . size
328+ if datums is not None
265329 else np .unique (filtered_detailed_pairs [:, 0 ]).size
266330 )
267331
0 commit comments