Skip to content

Commit d1591bb

Browse files
A-Artemisleoll2
andauthored
Added dataset filtering by labels (#4989)
Co-authored-by: Leonardo Lai <[email protected]>
1 parent cc32240 commit d1591bb

File tree

6 files changed

+229
-7
lines changed

6 files changed

+229
-7
lines changed

application/backend/app/api/routers/datasets.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,23 @@ def add_dataset_item(
9797
status.HTTP_200_OK: {"description": "List of available dataset items", "model": DatasetItemsWithPagination},
9898
},
9999
)
100-
def list_dataset_items(
100+
def list_dataset_items( # noqa: PLR0913
101101
project: Annotated[ProjectView, Depends(get_project)],
102102
dataset_service: Annotated[DatasetService, Depends(get_dataset_service)],
103103
limit: Annotated[int, Query(ge=1, le=MAX_DATASET_ITEMS_NUMBER_RETURNED)] = DEFAULT_DATASET_ITEMS_NUMBER_RETURNED,
104104
offset: Annotated[int, Query(ge=0)] = 0,
105105
start_date: Annotated[datetime | None, Query()] = None,
106106
end_date: Annotated[datetime | None, Query()] = None,
107107
annotation_status: Annotated[DatasetItemAnnotationStatus | None, Query()] = None,
108+
labels: Annotated[list[UUID] | None, Query()] = None,
108109
) -> DatasetItemsWithPagination:
109110
"""List the available dataset items and their metadata. This endpoint supports pagination."""
110111
if start_date is not None and end_date is not None and start_date > end_date:
111112
raise HTTPException(
112113
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Start date must be before end date."
113114
)
114115
total = dataset_service.count_dataset_items(
115-
project=project, start_date=start_date, end_date=end_date, annotation_status=annotation_status
116+
project=project, start_date=start_date, end_date=end_date, annotation_status=annotation_status, label_ids=labels
116117
)
117118
dataset_items = dataset_service.list_dataset_items(
118119
project=project,
@@ -121,6 +122,7 @@ def list_dataset_items(
121122
start_date=start_date,
122123
end_date=end_date,
123124
annotation_status=annotation_status,
125+
label_ids=labels,
124126
)
125127
return DatasetItemsWithPagination(
126128
items=[DatasetItemView.model_validate(dataset_item, from_attributes=True) for dataset_item in dataset_items],

application/backend/app/repositories/dataset_item_repo.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,18 @@ def count(
6161
start_date: datetime | None = None,
6262
end_date: datetime | None = None,
6363
annotation_status: str | None = None,
64+
label_ids: list[str] | None = None,
6465
) -> int:
65-
stmt = select(func.count()).select_from(DatasetItemDB).where(DatasetItemDB.project_id == self.project_id)
66+
# When the query involves a JOIN (e.g. when filtering by labels), count distinct items to avoid duplicates
67+
if label_ids:
68+
select_fn = func.count(func.distinct(DatasetItemDB.id))
69+
else:
70+
select_fn = func.count()
71+
stmt = select(select_fn).select_from(DatasetItemDB).where(DatasetItemDB.project_id == self.project_id)
6672
stmt = self._apply_date_filters(stmt, start_date, end_date)
6773
stmt = self._apply_annotation_status_filter(stmt, annotation_status)
74+
if label_ids:
75+
stmt = stmt.join(DatasetItemLabelDB).where(DatasetItemLabelDB.label_id.in_(label_ids))
6876
return self.db.scalar(stmt) or 0
6977

7078
def list_items(
@@ -74,10 +82,13 @@ def list_items(
7482
start_date: datetime | None = None,
7583
end_date: datetime | None = None,
7684
annotation_status: str | None = None,
85+
label_ids: list[str] | None = None,
7786
) -> list[DatasetItemDB]:
7887
stmt = self._base_select()
7988
stmt = self._apply_date_filters(stmt, start_date, end_date)
8089
stmt = self._apply_annotation_status_filter(stmt, annotation_status)
90+
if label_ids:
91+
stmt = stmt.join(DatasetItemLabelDB).where(DatasetItemLabelDB.label_id.in_(label_ids)).distinct()
8192
stmt = stmt.order_by(DatasetItemDB.created_at.desc()).offset(offset).limit(limit)
8293
return list(self.db.scalars(stmt).all())
8394

application/backend/app/services/dataset_service.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,14 @@ def count_dataset_items(
159159
start_date: datetime | None = None,
160160
end_date: datetime | None = None,
161161
annotation_status: str | None = None,
162+
label_ids: list[UUID] | None = None,
162163
) -> int:
163164
"""Get number of available dataset items (within date range if specified)"""
164165
repo = DatasetItemRepository(project_id=str(project.id), db=self._db_session)
165-
return repo.count(start_date=start_date, end_date=end_date, annotation_status=annotation_status)
166+
label_ids_str = [str(label_id) for label_id in label_ids] if label_ids else None
167+
return repo.count(
168+
start_date=start_date, end_date=end_date, annotation_status=annotation_status, label_ids=label_ids_str
169+
)
166170

167171
def list_dataset_items(
168172
self,
@@ -172,9 +176,11 @@ def list_dataset_items(
172176
start_date: datetime | None = None,
173177
end_date: datetime | None = None,
174178
annotation_status: str | None = None,
179+
label_ids: list[UUID] | None = None,
175180
) -> list[DatasetItem]:
176181
"""Get information about available dataset items"""
177182
repo = DatasetItemRepository(project_id=str(project.id), db=self._db_session)
183+
label_ids_str = [str(label_id) for label_id in label_ids] if label_ids else None
178184
return [
179185
DatasetItem.model_validate(db)
180186
for db in repo.list_items(
@@ -183,6 +189,7 @@ def list_dataset_items(
183189
start_date=start_date,
184190
end_date=end_date,
185191
annotation_status=annotation_status,
192+
label_ids=label_ids_str,
186193
)
187194
]
188195

application/backend/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ target-version = "py313"
8787
line-length = 120
8888
exclude = [".venv*"]
8989

90+
[tool.ruff.format]
91+
line-ending = "lf"
92+
9093
[tool.ruff.lint]
9194
select = ["ARG", "E", "F", "I", "N", "UP", "YTT", "ASYNC", "S", "COM", "C4", "FA", "PIE", "PYI", "Q", "RSE", "RET", "SIM",
9295
"TID", "PL", "RUF", "C90", "D103", "ANN001", "ANN201", "ANN205"]

application/backend/tests/integration/services/test_dataset_service.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,82 @@ def _create_annotations(label_id: UUID) -> list[DatasetItemAnnotation]:
258258
return _create_annotations
259259

260260

261+
@pytest.fixture
262+
def fxt_project_with_labeled_dataset_items(
263+
fxt_project_with_pipeline, db_session
264+
) -> tuple[ProjectView, list[DatasetItemDB]]:
265+
"""Fixture to create a project with multiple labeled dataset items for testing label filtering."""
266+
project, _ = fxt_project_with_pipeline
267+
268+
# Ensure we have at least 2 labels
269+
assert len(project.task.labels) >= 2, "Project must have at least 2 labels for this fixture"
270+
271+
label_0_id = str(project.task.labels[0].id)
272+
label_1_id = str(project.task.labels[1].id)
273+
274+
configs = [
275+
# Item 0: No annotations
276+
{"name": "item_no_labels", "format": "jpg", "size": 1024, "width": 1024, "height": 768, "subset": "unassigned"},
277+
# Item 1: Has label_0
278+
{
279+
"name": "item_label_0",
280+
"format": "jpg",
281+
"size": 1024,
282+
"width": 1024,
283+
"height": 768,
284+
"subset": "unassigned",
285+
"annotation_data": [{"labels": [{"id": label_0_id}], "shape": {"type": "full_image"}}],
286+
},
287+
# Item 2: Has label_1
288+
{
289+
"name": "item_label_1",
290+
"format": "jpg",
291+
"size": 1024,
292+
"width": 1024,
293+
"height": 768,
294+
"subset": "unassigned",
295+
"annotation_data": [{"labels": [{"id": label_1_id}], "shape": {"type": "full_image"}}],
296+
},
297+
# Item 3: Has both label_0 and label_1
298+
{
299+
"name": "item_both_labels",
300+
"format": "jpg",
301+
"size": 1024,
302+
"width": 1024,
303+
"height": 768,
304+
"subset": "unassigned",
305+
"annotation_data": [
306+
{
307+
"labels": [{"id": label_0_id}],
308+
"shape": {"type": "rectangle", "x": 0, "y": 0, "width": 10, "height": 10},
309+
},
310+
{
311+
"labels": [{"id": label_1_id}],
312+
"shape": {"type": "rectangle", "x": 20, "y": 20, "width": 10, "height": 10},
313+
},
314+
],
315+
},
316+
]
317+
318+
db_dataset_items = []
319+
for config in configs:
320+
dataset_item = DatasetItemDB(**config)
321+
dataset_item.project_id = str(project.id)
322+
dataset_item.created_at = datetime.fromisoformat("2025-02-01T00:00:00Z")
323+
db_dataset_items.append(dataset_item)
324+
db_session.add_all(db_dataset_items)
325+
db_session.flush()
326+
327+
# Link labels to dataset items
328+
db_session.add(DatasetItemLabelDB(dataset_item_id=db_dataset_items[1].id, label_id=label_0_id))
329+
db_session.add(DatasetItemLabelDB(dataset_item_id=db_dataset_items[2].id, label_id=label_1_id))
330+
db_session.add(DatasetItemLabelDB(dataset_item_id=db_dataset_items[3].id, label_id=label_0_id))
331+
db_session.add(DatasetItemLabelDB(dataset_item_id=db_dataset_items[3].id, label_id=label_1_id))
332+
db_session.flush()
333+
334+
return project, db_dataset_items
335+
336+
261337
class TestDatasetServiceIntegration:
262338
"""Integration tests for DatasetService."""
263339

@@ -891,3 +967,109 @@ def test_annotation_status_filter_verifies_data_correctness(
891967
for item in to_review_items:
892968
assert item.annotation_data is not None
893969
assert item.user_reviewed is False
970+
971+
def test_list_dataset_items_filter_by_single_label(
972+
self,
973+
fxt_dataset_service: DatasetService,
974+
fxt_project_with_labeled_dataset_items: tuple[ProjectView, list[DatasetItemDB]],
975+
):
976+
"""Test listing dataset items filtered by a single label."""
977+
project, db_dataset_items = fxt_project_with_labeled_dataset_items
978+
label_0_id = project.task.labels[0].id
979+
980+
# Filter by label_0 - should return items 1 and 3 (item_label_0 and item_both_labels)
981+
dataset_items = fxt_dataset_service.list_dataset_items(
982+
project=project,
983+
label_ids=[label_0_id],
984+
)
985+
986+
assert len(dataset_items) == 2
987+
item_names = {item.name for item in dataset_items}
988+
assert item_names == {"item_label_0", "item_both_labels"}
989+
990+
def test_list_dataset_items_filter_by_multiple_labels(
991+
self,
992+
fxt_dataset_service: DatasetService,
993+
fxt_project_with_labeled_dataset_items: tuple[ProjectView, list[DatasetItemDB]],
994+
):
995+
"""Test listing dataset items filtered by multiple labels (OR logic)."""
996+
project, db_dataset_items = fxt_project_with_labeled_dataset_items
997+
label_0_id = project.task.labels[0].id
998+
label_1_id = project.task.labels[1].id
999+
1000+
# Filter by label_0 OR label_1 - should return items 1, 2, and 3
1001+
dataset_items = fxt_dataset_service.list_dataset_items(
1002+
project=project,
1003+
label_ids=[label_0_id, label_1_id],
1004+
)
1005+
1006+
assert len(dataset_items) == 3
1007+
item_names = {item.name for item in dataset_items}
1008+
assert item_names == {"item_label_0", "item_label_1", "item_both_labels"}
1009+
1010+
def test_list_dataset_items_filter_by_nonexistent_label(
1011+
self,
1012+
fxt_dataset_service: DatasetService,
1013+
fxt_project_with_labeled_dataset_items: tuple[ProjectView, list[DatasetItemDB]],
1014+
):
1015+
"""Test listing dataset items filtered by a nonexistent label."""
1016+
project, db_dataset_items = fxt_project_with_labeled_dataset_items
1017+
nonexistent_label_id = uuid4()
1018+
1019+
# Filter by nonexistent label - should return empty list
1020+
dataset_items = fxt_dataset_service.list_dataset_items(
1021+
project=project,
1022+
label_ids=[nonexistent_label_id],
1023+
)
1024+
1025+
assert len(dataset_items) == 0
1026+
1027+
def test_count_dataset_items_filter_by_single_label(
1028+
self,
1029+
fxt_dataset_service: DatasetService,
1030+
fxt_project_with_labeled_dataset_items: tuple[ProjectView, list[DatasetItemDB]],
1031+
):
1032+
"""Test counting dataset items filtered by a single label."""
1033+
project, db_dataset_items = fxt_project_with_labeled_dataset_items
1034+
label_0_id = project.task.labels[0].id
1035+
1036+
# Count items with label_0 - should return 2
1037+
count = fxt_dataset_service.count_dataset_items(
1038+
project=project,
1039+
label_ids=[label_0_id],
1040+
)
1041+
1042+
assert count == 2
1043+
1044+
def test_count_dataset_items_filter_by_multiple_labels(
1045+
self,
1046+
fxt_dataset_service: DatasetService,
1047+
fxt_project_with_labeled_dataset_items: tuple[ProjectView, list[DatasetItemDB]],
1048+
):
1049+
"""Test counting dataset items filtered by multiple labels (OR logic)."""
1050+
project, db_dataset_items = fxt_project_with_labeled_dataset_items
1051+
label_0_id = project.task.labels[0].id
1052+
label_1_id = project.task.labels[1].id
1053+
1054+
# Count items with label_0 OR label_1 - should return 3
1055+
count = fxt_dataset_service.count_dataset_items(
1056+
project=project,
1057+
label_ids=[label_0_id, label_1_id],
1058+
)
1059+
1060+
assert count == 3
1061+
1062+
def test_list_dataset_items_no_label_filter(
1063+
self,
1064+
fxt_dataset_service: DatasetService,
1065+
fxt_project_with_labeled_dataset_items: tuple[ProjectView, list[DatasetItemDB]],
1066+
):
1067+
"""Test listing dataset items without label filter returns all items."""
1068+
project, db_dataset_items = fxt_project_with_labeled_dataset_items
1069+
1070+
# No filter - should return all 4 items
1071+
dataset_items = fxt_dataset_service.list_dataset_items(project=project)
1072+
1073+
assert len(dataset_items) == 4
1074+
item_names = {item.name for item in dataset_items}
1075+
assert item_names == {"item_no_labels", "item_label_0", "item_label_1", "item_both_labels"}

application/backend/tests/unit/routers/test_datasets.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,20 @@ def test_list_dataset_items(self, fxt_get_project, fxt_dataset_item, fxt_dataset
108108

109109
assert response.status_code == status.HTTP_200_OK
110110
fxt_dataset_service.count_dataset_items.assert_called_once_with(
111-
project=fxt_get_project, start_date=None, end_date=None, annotation_status=None
111+
project=fxt_get_project,
112+
start_date=None,
113+
end_date=None,
114+
annotation_status=None,
115+
label_ids=None,
112116
)
113117
fxt_dataset_service.list_dataset_items.assert_called_once_with(
114-
project=fxt_get_project, limit=10, offset=0, start_date=None, end_date=None, annotation_status=None
118+
project=fxt_get_project,
119+
limit=10,
120+
offset=0,
121+
start_date=None,
122+
end_date=None,
123+
annotation_status=None,
124+
label_ids=None,
115125
)
116126

117127
def test_list_dataset_items_filtering_and_pagination(
@@ -130,6 +140,7 @@ def test_list_dataset_items_filtering_and_pagination(
130140
start_date=datetime(2025, 1, 9, 0, 0, 0, tzinfo=ZoneInfo("UTC")),
131141
end_date=datetime(2025, 12, 31, 23, 59, 59, tzinfo=ZoneInfo("UTC")),
132142
annotation_status=None,
143+
label_ids=None,
133144
)
134145
fxt_dataset_service.list_dataset_items.assert_called_once_with(
135146
project=fxt_get_project,
@@ -138,6 +149,7 @@ def test_list_dataset_items_filtering_and_pagination(
138149
start_date=datetime(2025, 1, 9, 0, 0, 0, tzinfo=ZoneInfo("UTC")),
139150
end_date=datetime(2025, 12, 31, 23, 59, 59, tzinfo=ZoneInfo("UTC")),
140151
annotation_status=None,
152+
label_ids=None,
141153
)
142154

143155
@pytest.mark.parametrize("limit", [1000, 0, -20])
@@ -174,7 +186,11 @@ def test_list_dataset_items_with_annotation_status(
174186

175187
assert response.status_code == status.HTTP_200_OK
176188
fxt_dataset_service.count_dataset_items.assert_called_once_with(
177-
project=fxt_get_project, start_date=None, end_date=None, annotation_status=annotation_status
189+
project=fxt_get_project,
190+
start_date=None,
191+
end_date=None,
192+
annotation_status=annotation_status,
193+
label_ids=None,
178194
)
179195
fxt_dataset_service.list_dataset_items.assert_called_once_with(
180196
project=fxt_get_project,
@@ -183,6 +199,7 @@ def test_list_dataset_items_with_annotation_status(
183199
start_date=None,
184200
end_date=None,
185201
annotation_status=annotation_status,
202+
label_ids=None,
186203
)
187204

188205
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)