Skip to content

Commit e66ccf5

Browse files
authored
Add filtering by index (#848)
1 parent 15f08d0 commit e66ccf5

File tree

9 files changed

+777
-147
lines changed

9 files changed

+777
-147
lines changed

src/valor_lite/classification/manager.py

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -144,18 +144,18 @@ def missing_prediction_labels(self) -> list[str]:
144144

145145
def create_filter(
146146
self,
147-
datum_ids: list[str] | None = None,
148-
labels: list[str] | None = None,
147+
datums: list[str] | NDArray[np.int32] | None = None,
148+
labels: list[str] | NDArray[np.int32] | None = None,
149149
) -> Filter:
150150
"""
151151
Creates a filter object.
152152
153153
Parameters
154154
----------
155-
datum_uids : list[str], optional
156-
An optional list of string uids representing datums.
157-
labels : list[str], optional
158-
An optional list of labels.
155+
datums : list[str] | NDArray[int32], optional
156+
An optional list of string uids or integer indices representing datums.
157+
labels : list[str] | NDArray[int32], optional
158+
An optional list of strings or integer indices representing labels.
159159
160160
Returns
161161
-------
@@ -165,50 +165,72 @@ def create_filter(
165165
# create datum mask
166166
n_pairs = self._detailed_pairs.shape[0]
167167
datum_mask = np.ones(n_pairs, dtype=np.bool_)
168-
if datum_ids is not None:
169-
if not datum_ids:
170-
return Filter(
171-
datum_mask=np.zeros_like(datum_mask),
172-
valid_label_indices=None,
173-
metadata=Metadata(),
168+
if datums is not None:
169+
# convert to array of valid datum indices
170+
if isinstance(datums, list):
171+
datums = np.array(
172+
[self.datum_id_to_index[uid] for uid in datums],
173+
dtype=np.int32,
174174
)
175-
valid_datum_indices = np.array(
176-
[self.datum_id_to_index[uid] for uid in datum_ids],
177-
dtype=np.int32,
178-
)
179-
datum_mask = np.isin(
180-
self._detailed_pairs[:, 0], valid_datum_indices
181-
)
175+
176+
# return early if all data removed
177+
if datums.size == 0:
178+
raise EmptyFilterError("filter removes all datums")
179+
180+
# validate indices
181+
if datums.max() >= len(self.index_to_datum_id):
182+
raise ValueError(
183+
f"datum index '{datums.max()}' exceeds total number of datums"
184+
)
185+
elif datums.min() < 0:
186+
raise ValueError(
187+
f"datum index '{datums.min()}' is a negative value"
188+
)
189+
190+
# create datum mask
191+
datum_mask = np.isin(self._detailed_pairs[:, 0], datums)
182192

183193
# collect valid label indices
184-
valid_label_indices = None
185194
if labels is not None:
186-
if not labels:
187-
return Filter(
188-
datum_mask=datum_mask,
189-
valid_label_indices=np.array([], dtype=np.int32),
190-
metadata=Metadata(),
195+
# convert to array of valid label indices
196+
if isinstance(labels, list):
197+
labels = np.array(
198+
[self.label_to_index[label] for label in labels]
191199
)
192-
valid_label_indices = np.array(
193-
[self.label_to_index[label] for label in labels] + [-1]
194-
)
200+
201+
# return early if all data removed
202+
if labels.size == 0:
203+
raise EmptyFilterError("filter removes all labels")
204+
205+
# validate indices
206+
if labels.max() >= len(self.index_to_label):
207+
raise ValueError(
208+
f"label index '{labels.max()}' exceeds total number of labels"
209+
)
210+
elif labels.min() < 0:
211+
raise ValueError(
212+
f"label index '{labels.min()}' is a negative value"
213+
)
214+
215+
# add -1 to represent null labels which should not be filtered
216+
labels = np.concatenate([labels, np.array([-1])])
195217

196218
filtered_detailed_pairs, _ = filter_cache(
197219
detailed_pairs=self._detailed_pairs,
198220
datum_mask=datum_mask,
199-
valid_label_indices=valid_label_indices,
221+
valid_label_indices=labels,
200222
n_labels=self.metadata.number_of_labels,
201223
)
202224

203225
number_of_datums = (
204-
len(datum_ids)
205-
if datum_ids is not None
226+
datums.size
227+
if datums is not None
206228
else self.metadata.number_of_datums
207229
)
208230

209231
return Filter(
210232
datum_mask=datum_mask,
211-
valid_label_indices=valid_label_indices,
233+
valid_label_indices=labels,
212234
metadata=Metadata.create(
213235
detailed_pairs=filtered_detailed_pairs,
214236
number_of_datums=number_of_datums,

src/valor_lite/object_detection/manager.py

Lines changed: 106 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -173,82 +173,146 @@ def metadata(self) -> Metadata:
173173

174174
def create_filter(
175175
self,
176-
datum_ids: list[str] | None = None,
177-
groundtruth_ids: list[str] | None = None,
178-
prediction_ids: list[str] | None = None,
179-
labels: list[str] | None = None,
176+
datums: list[str] | NDArray[np.int32] | None = None,
177+
groundtruths: list[str] | NDArray[np.int32] | None = None,
178+
predictions: list[str] | NDArray[np.int32] | None = None,
179+
labels: list[str] | NDArray[np.int32] | None = None,
180180
) -> Filter:
181181
"""
182182
Creates a filter object.
183183
184184
Parameters
185185
----------
186-
datum_uids : list[str], optional
187-
An optional list of string uids representing datums to keep.
188-
groundtruth_ids : list[str], optional
189-
An optional list of string uids representing ground truth annotations to keep.
190-
prediction_ids : list[str], optional
191-
An optional list of string uids representing prediction annotations to keep.
192-
labels : list[str], optional
193-
An optional list of labels to keep.
186+
datum : list[str] | NDArray[int32], optional
187+
An optional list of string ids or indices representing datums to keep.
188+
groundtruth : list[str] | NDArray[int32], optional
189+
An optional list of string ids or indices representing ground truth annotations to keep.
190+
prediction : list[str] | NDArray[int32], optional
191+
An optional list of string ids or indices representing prediction annotations to keep.
192+
labels : list[str] | NDArray[int32], optional
193+
An optional list of labels or indices to keep.
194194
"""
195195
mask_datums = np.ones(self._detailed_pairs.shape[0], dtype=np.bool_)
196196

197197
# filter datums
198-
if datum_ids is not None:
199-
if not datum_ids:
200-
raise EmptyFilterError("filter removes all datums")
201-
valid_datum_indices = np.array(
202-
[self.datum_id_to_index[uid] for uid in datum_ids],
203-
dtype=np.int32,
204-
)
205-
mask_datums = np.isin(
206-
self._detailed_pairs[:, 0], valid_datum_indices
207-
)
198+
if datums is not None:
199+
# convert to indices
200+
if isinstance(datums, list):
201+
datums = np.array(
202+
[self.datum_id_to_index[uid] for uid in datums],
203+
dtype=np.int32,
204+
)
205+
206+
# validate indices
207+
if datums.size == 0:
208+
raise EmptyFilterError(
209+
"filter removes all datums"
210+
) # validate indices
211+
elif datums.min() < 0:
212+
raise ValueError(
213+
f"datum index cannot be negative '{datums.min()}'"
214+
)
215+
elif datums.max() >= len(self.index_to_datum_id):
216+
raise ValueError(
217+
f"datum index cannot exceed total number of datums '{datums.max()}'"
218+
)
219+
220+
# apply to mask
221+
mask_datums = np.isin(self._detailed_pairs[:, 0], datums)
208222

209223
filtered_detailed_pairs = self._detailed_pairs[mask_datums]
210224
n_pairs = self._detailed_pairs[mask_datums].shape[0]
211225
mask_groundtruths = np.zeros(n_pairs, dtype=np.bool_)
212226
mask_predictions = np.zeros_like(mask_groundtruths)
213227

214228
# filter by ground truth annotation ids
215-
if groundtruth_ids is not None:
216-
valid_groundtruth_indices = np.array(
217-
[self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
218-
dtype=np.int32,
219-
)
229+
if groundtruths is not None:
230+
# convert to indices
231+
if isinstance(groundtruths, list):
232+
groundtruths = np.array(
233+
[
234+
self.groundtruth_id_to_index[uid]
235+
for uid in groundtruths
236+
],
237+
dtype=np.int32,
238+
)
239+
240+
# validate indices
241+
if groundtruths.size == 0:
242+
warnings.warn("filter removes all ground truths")
243+
elif groundtruths.min() < 0:
244+
raise ValueError(
245+
f"groundtruth annotation index cannot be negative '{groundtruths.min()}'"
246+
)
247+
elif groundtruths.max() >= len(self.index_to_groundtruth_id):
248+
raise ValueError(
249+
f"groundtruth annotation index cannot exceed total number of groundtruths '{groundtruths.max()}'"
250+
)
251+
252+
# apply to mask
220253
mask_groundtruths[
221254
~np.isin(
222255
filtered_detailed_pairs[:, 1],
223-
valid_groundtruth_indices,
256+
groundtruths,
224257
)
225258
] = True
226259

227260
# filter by prediction annotation ids
228-
if prediction_ids is not None:
229-
valid_prediction_indices = np.array(
230-
[self.prediction_id_to_index[uid] for uid in prediction_ids],
231-
dtype=np.int32,
232-
)
261+
if predictions is not None:
262+
# convert to indices
263+
if isinstance(predictions, list):
264+
predictions = np.array(
265+
[self.prediction_id_to_index[uid] for uid in predictions],
266+
dtype=np.int32,
267+
)
268+
269+
# validate indices
270+
if predictions.size == 0:
271+
warnings.warn("filter removes all predictions")
272+
elif predictions.min() < 0:
273+
raise ValueError(
274+
f"prediction annotation index cannot be negative '{predictions.min()}'"
275+
)
276+
elif predictions.max() >= len(self.index_to_prediction_id):
277+
raise ValueError(
278+
f"prediction annotation index cannot exceed total number of predictions '{predictions.max()}'"
279+
)
280+
281+
# apply to mask
233282
mask_predictions[
234283
~np.isin(
235284
filtered_detailed_pairs[:, 2],
236-
valid_prediction_indices,
285+
predictions,
237286
)
238287
] = True
239288

240289
# filter by labels
241290
if labels is not None:
242-
if not labels:
291+
# convert to indices
292+
if isinstance(labels, list):
293+
labels = np.array(
294+
[self.label_to_index[label] for label in labels]
295+
)
296+
297+
# validate indices
298+
if labels.size == 0:
243299
raise EmptyFilterError("filter removes all labels")
244-
valid_label_indices = np.array(
245-
[self.label_to_index[label] for label in labels] + [-1]
246-
)
300+
elif labels.min() < 0:
301+
raise ValueError(
302+
f"label index cannot be negative '{labels.min()}'"
303+
)
304+
elif labels.max() >= len(self.index_to_label):
305+
raise ValueError(
306+
f"label index cannot exceed total number of labels '{labels.max()}'"
307+
)
308+
309+
# apply to mask
310+
labels = np.concatenate([labels, np.array([-1])]) # add null label
247311
mask_groundtruths[
248-
~np.isin(filtered_detailed_pairs[:, 3], valid_label_indices)
312+
~np.isin(filtered_detailed_pairs[:, 3], labels)
249313
] = True
250314
mask_predictions[
251-
~np.isin(filtered_detailed_pairs[:, 4], valid_label_indices)
315+
~np.isin(filtered_detailed_pairs[:, 4], labels)
252316
] = True
253317

254318
filtered_detailed_pairs, _, _ = filter_cache(
@@ -260,8 +324,8 @@ def create_filter(
260324
)
261325

262326
number_of_datums = (
263-
len(datum_ids)
264-
if datum_ids
327+
datums.size
328+
if datums is not None
265329
else np.unique(filtered_detailed_pairs[:, 0]).size
266330
)
267331

0 commit comments

Comments
 (0)