Skip to content

Commit 2c6bd6d

Browse files
authored
Add subset filtering for dataset items (#5004)
1 parent 3896f35 commit 2c6bd6d

File tree

5 files changed

+332
-4
lines changed

5 files changed

+332
-4
lines changed

application/backend/app/api/routers/datasets.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
SetDatasetItemAnnotations,
1818
)
1919
from app.core.models import Pagination
20-
from app.models import DatasetItemAnnotationStatus
20+
from app.models import DatasetItemAnnotationStatus, DatasetItemSubset
2121
from app.schemas import ProjectView
2222
from app.services import DatasetService, ResourceNotFoundError
2323
from app.services.dataset_service import AnnotationValidationError, InvalidImageError, SubsetAlreadyAssignedError
@@ -106,14 +106,20 @@ def list_dataset_items( # noqa: PLR0913
106106
end_date: Annotated[datetime | None, Query()] = None,
107107
annotation_status: Annotated[DatasetItemAnnotationStatus | None, Query()] = None,
108108
labels: Annotated[list[UUID] | None, Query()] = None,
109+
subset: Annotated[DatasetItemSubset | None, Query()] = None,
109110
) -> DatasetItemsWithPagination:
110111
"""List the available dataset items and their metadata. This endpoint supports pagination."""
111112
if start_date is not None and end_date is not None and start_date > end_date:
112113
raise HTTPException(
113114
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Start date must be before end date."
114115
)
115116
total = dataset_service.count_dataset_items(
116-
project=project, start_date=start_date, end_date=end_date, annotation_status=annotation_status, label_ids=labels
117+
project=project,
118+
start_date=start_date,
119+
end_date=end_date,
120+
annotation_status=annotation_status,
121+
label_ids=labels,
122+
subset=subset,
117123
)
118124
dataset_items = dataset_service.list_dataset_items(
119125
project=project,
@@ -123,6 +129,7 @@ def list_dataset_items( # noqa: PLR0913
123129
end_date=end_date,
124130
annotation_status=annotation_status,
125131
label_ids=labels,
132+
subset=subset,
126133
)
127134
return DatasetItemsWithPagination(
128135
items=[DatasetItemView.model_validate(dataset_item, from_attributes=True) for dataset_item in dataset_items],

application/backend/app/repositories/dataset_item_repo.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ def _apply_annotation_status_filter(stmt: Select, annotation_status: str | None
5050
stmt = stmt.where(DatasetItemDB.annotation_data.is_not(None), DatasetItemDB.user_reviewed.is_(False))
5151
return stmt
5252

53+
@staticmethod
54+
def _apply_subset_filter(stmt: Select, subset: str | None = None) -> Select:
55+
"""Apply subset filter to a select statement."""
56+
if subset is not None:
57+
stmt = stmt.where(DatasetItemDB.subset == subset)
58+
return stmt
59+
5360
def save(self, dataset_item_db: DatasetItemDB) -> DatasetItemDB:
5461
dataset_item_db.updated_at = datetime.now(UTC)
5562
self.db.add(dataset_item_db)
@@ -62,6 +69,7 @@ def count(
6269
end_date: datetime | None = None,
6370
annotation_status: str | None = None,
6471
label_ids: list[str] | None = None,
72+
subset: str | None = None,
6573
) -> int:
6674
# When the query involves a JOIN (e.g. when filtering by labels), count distinct items to avoid duplicates
6775
if label_ids:
@@ -71,6 +79,7 @@ def count(
7179
stmt = select(select_fn).select_from(DatasetItemDB).where(DatasetItemDB.project_id == self.project_id)
7280
stmt = self._apply_date_filters(stmt, start_date, end_date)
7381
stmt = self._apply_annotation_status_filter(stmt, annotation_status)
82+
stmt = self._apply_subset_filter(stmt, subset)
7483
if label_ids:
7584
stmt = stmt.join(DatasetItemLabelDB).where(DatasetItemLabelDB.label_id.in_(label_ids))
7685
return self.db.scalar(stmt) or 0
@@ -83,10 +92,12 @@ def list_items(
8392
end_date: datetime | None = None,
8493
annotation_status: str | None = None,
8594
label_ids: list[str] | None = None,
95+
subset: str | None = None,
8696
) -> list[DatasetItemDB]:
8797
stmt = self._base_select()
8898
stmt = self._apply_date_filters(stmt, start_date, end_date)
8999
stmt = self._apply_annotation_status_filter(stmt, annotation_status)
100+
stmt = self._apply_subset_filter(stmt, subset)
90101
if label_ids:
91102
stmt = stmt.join(DatasetItemLabelDB).where(DatasetItemLabelDB.label_id.in_(label_ids)).distinct()
92103
stmt = stmt.order_by(DatasetItemDB.created_at.desc()).offset(offset).limit(limit)

application/backend/app/services/dataset_service.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,20 @@ def count_dataset_items(
158158
end_date: datetime | None = None,
159159
annotation_status: str | None = None,
160160
label_ids: list[UUID] | None = None,
161+
subset: str | None = None,
161162
) -> int:
162163
"""Get number of available dataset items (within date range if specified)"""
163164
repo = DatasetItemRepository(project_id=str(project.id), db=self._db_session)
164165
label_ids_str = [str(label_id) for label_id in label_ids] if label_ids else None
165166
return repo.count(
166-
start_date=start_date, end_date=end_date, annotation_status=annotation_status, label_ids=label_ids_str
167+
start_date=start_date,
168+
end_date=end_date,
169+
annotation_status=annotation_status,
170+
label_ids=label_ids_str,
171+
subset=subset,
167172
)
168173

169-
def list_dataset_items(
174+
def list_dataset_items( # noqa: PLR0913
170175
self,
171176
project: ProjectView,
172177
limit: int = 20,
@@ -175,6 +180,7 @@ def list_dataset_items(
175180
end_date: datetime | None = None,
176181
annotation_status: str | None = None,
177182
label_ids: list[UUID] | None = None,
183+
subset: str | None = None,
178184
) -> list[DatasetItem]:
179185
"""Get information about available dataset items"""
180186
repo = DatasetItemRepository(project_id=str(project.id), db=self._db_session)
@@ -188,6 +194,7 @@ def list_dataset_items(
188194
end_date=end_date,
189195
annotation_status=annotation_status,
190196
label_ids=label_ids_str,
197+
subset=subset,
191198
)
192199
]
193200

application/backend/tests/integration/services/test_dataset_service.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,122 @@ def fxt_project_with_labeled_dataset_items(
331331
return project, db_dataset_items
332332

333333

334+
@pytest.fixture
335+
def fxt_project_with_subset_items(fxt_project_with_pipeline, db_session) -> tuple[ProjectView, list[DatasetItemDB]]:
336+
"""Fixture with dataset items covering all subset types."""
337+
project, _ = fxt_project_with_pipeline
338+
339+
# Unassigned items
340+
unassigned_items = [
341+
DatasetItemDB(
342+
name="unassigned1",
343+
format="jpg",
344+
size=1024,
345+
width=1024,
346+
height=768,
347+
subset=DatasetItemSubset.UNASSIGNED,
348+
user_reviewed=False,
349+
project_id=str(project.id),
350+
created_at=datetime.fromisoformat("2025-02-01T00:00:00Z"),
351+
),
352+
DatasetItemDB(
353+
name="unassigned2",
354+
format="jpg",
355+
size=1024,
356+
width=1024,
357+
height=768,
358+
subset=DatasetItemSubset.UNASSIGNED,
359+
user_reviewed=False,
360+
project_id=str(project.id),
361+
created_at=datetime.fromisoformat("2025-02-02T00:00:00Z"),
362+
),
363+
]
364+
365+
# Training items
366+
training_items = [
367+
DatasetItemDB(
368+
name="training1",
369+
format="jpg",
370+
size=1024,
371+
width=1024,
372+
height=768,
373+
subset=DatasetItemSubset.TRAINING,
374+
user_reviewed=False,
375+
project_id=str(project.id),
376+
created_at=datetime.fromisoformat("2025-02-03T00:00:00Z"),
377+
),
378+
DatasetItemDB(
379+
name="training2",
380+
format="jpg",
381+
size=1024,
382+
width=1024,
383+
height=768,
384+
subset=DatasetItemSubset.TRAINING,
385+
user_reviewed=False,
386+
project_id=str(project.id),
387+
created_at=datetime.fromisoformat("2025-02-04T00:00:00Z"),
388+
),
389+
DatasetItemDB(
390+
name="training3",
391+
format="jpg",
392+
size=1024,
393+
width=1024,
394+
height=768,
395+
subset=DatasetItemSubset.TRAINING,
396+
user_reviewed=False,
397+
project_id=str(project.id),
398+
created_at=datetime.fromisoformat("2025-02-05T00:00:00Z"),
399+
),
400+
]
401+
402+
# Validation items
403+
validation_items = [
404+
DatasetItemDB(
405+
name="validation1",
406+
format="jpg",
407+
size=1024,
408+
width=1024,
409+
height=768,
410+
subset=DatasetItemSubset.VALIDATION,
411+
user_reviewed=False,
412+
project_id=str(project.id),
413+
created_at=datetime.fromisoformat("2025-02-06T00:00:00Z"),
414+
),
415+
DatasetItemDB(
416+
name="validation2",
417+
format="jpg",
418+
size=1024,
419+
width=1024,
420+
height=768,
421+
subset=DatasetItemSubset.VALIDATION,
422+
user_reviewed=False,
423+
project_id=str(project.id),
424+
created_at=datetime.fromisoformat("2025-02-07T00:00:00Z"),
425+
),
426+
]
427+
428+
# Testing items
429+
testing_items = [
430+
DatasetItemDB(
431+
name="testing1",
432+
format="jpg",
433+
size=1024,
434+
width=1024,
435+
height=768,
436+
subset=DatasetItemSubset.TESTING,
437+
user_reviewed=False,
438+
project_id=str(project.id),
439+
created_at=datetime.fromisoformat("2025-02-08T00:00:00Z"),
440+
),
441+
]
442+
443+
db_dataset_items = [*unassigned_items, *training_items, *validation_items, *testing_items]
444+
db_session.add_all(db_dataset_items)
445+
db_session.flush()
446+
447+
return project, db_dataset_items
448+
449+
334450
class TestDatasetServiceIntegration:
335451
"""Integration tests for DatasetService."""
336452

@@ -1069,3 +1185,155 @@ def test_list_dataset_items_no_label_filter(
10691185
assert len(dataset_items) == 4
10701186
item_names = {item.name for item in dataset_items}
10711187
assert item_names == {"item_no_labels", "item_label_0", "item_label_1", "item_both_labels"}
1188+
1189+
@pytest.mark.parametrize(
1190+
"subset, expected_count",
1191+
[
1192+
(None, 8), # All items
1193+
("unassigned", 2), # 2 unassigned items
1194+
("training", 3), # 3 training items
1195+
("validation", 2), # 2 validation items
1196+
("testing", 1), # 1 testing item
1197+
],
1198+
)
1199+
def test_count_dataset_items_with_subset(
1200+
self,
1201+
fxt_dataset_service: DatasetService,
1202+
fxt_project_with_subset_items: tuple[ProjectView, list[DatasetItemDB]],
1203+
subset: str | None,
1204+
expected_count: int,
1205+
) -> None:
1206+
"""Test counting dataset items with subset filter."""
1207+
project, db_dataset_items = fxt_project_with_subset_items
1208+
1209+
count = fxt_dataset_service.count_dataset_items(project=project, subset=subset)
1210+
1211+
assert count == expected_count
1212+
1213+
@pytest.mark.parametrize(
1214+
"subset, expected_names",
1215+
[
1216+
(
1217+
None,
1218+
[
1219+
"unassigned1",
1220+
"unassigned2",
1221+
"training1",
1222+
"training2",
1223+
"training3",
1224+
"validation1",
1225+
"validation2",
1226+
"testing1",
1227+
],
1228+
),
1229+
("unassigned", ["unassigned1", "unassigned2"]),
1230+
("training", ["training1", "training2", "training3"]),
1231+
("validation", ["validation1", "validation2"]),
1232+
("testing", ["testing1"]),
1233+
],
1234+
)
1235+
def test_list_dataset_items_with_subset(
1236+
self,
1237+
fxt_dataset_service: DatasetService,
1238+
fxt_project_with_subset_items: tuple[ProjectView, list[DatasetItemDB]],
1239+
subset: str | None,
1240+
expected_names: list[str],
1241+
) -> None:
1242+
"""Test listing dataset items with subset filter."""
1243+
project, db_dataset_items = fxt_project_with_subset_items
1244+
1245+
dataset_items = fxt_dataset_service.list_dataset_items(
1246+
project=project,
1247+
limit=20,
1248+
offset=0,
1249+
subset=subset,
1250+
)
1251+
1252+
assert len(dataset_items) == len(expected_names)
1253+
actual_names = sorted([item.name for item in dataset_items])
1254+
assert actual_names == sorted(expected_names)
1255+
1256+
@pytest.mark.parametrize(
1257+
"subset, limit, offset, expected_count",
1258+
[
1259+
("unassigned", 1, 0, 1), # First page of unassigned
1260+
("unassigned", 1, 1, 1), # Second page of unassigned
1261+
("unassigned", 1, 2, 0), # Beyond available unassigned items
1262+
("training", 2, 0, 2), # First page of training
1263+
("training", 2, 2, 1), # Second page of training (only 1 left)
1264+
("validation", 10, 0, 2), # All validation items
1265+
("testing", 10, 0, 1), # All testing items
1266+
],
1267+
)
1268+
def test_list_dataset_items_with_subset_pagination(
1269+
self,
1270+
fxt_dataset_service: DatasetService,
1271+
fxt_project_with_subset_items: tuple[ProjectView, list[DatasetItemDB]],
1272+
subset: str | None,
1273+
limit: int,
1274+
offset: int,
1275+
expected_count: int,
1276+
) -> None:
1277+
"""Test listing dataset items with subset filter and pagination."""
1278+
project, db_dataset_items = fxt_project_with_subset_items
1279+
1280+
dataset_items = fxt_dataset_service.list_dataset_items(
1281+
project=project,
1282+
limit=limit,
1283+
offset=offset,
1284+
subset=subset,
1285+
)
1286+
1287+
assert len(dataset_items) == expected_count
1288+
1289+
def test_subset_filter_verifies_data_correctness(
1290+
self,
1291+
fxt_dataset_service: DatasetService,
1292+
fxt_project_with_subset_items: tuple[ProjectView, list[DatasetItemDB]],
1293+
) -> None:
1294+
"""Test that subset filter returns items with correct subset values."""
1295+
project, db_dataset_items = fxt_project_with_subset_items
1296+
1297+
# Unassigned items should have subset=unassigned
1298+
unassigned_items = fxt_dataset_service.list_dataset_items(
1299+
project=project,
1300+
limit=20,
1301+
offset=0,
1302+
subset="unassigned",
1303+
)
1304+
assert len(unassigned_items) == 2
1305+
for item in unassigned_items:
1306+
assert item.subset == DatasetItemSubset.UNASSIGNED
1307+
1308+
# Training items should have subset=training
1309+
training_items = fxt_dataset_service.list_dataset_items(
1310+
project=project,
1311+
limit=20,
1312+
offset=0,
1313+
subset="training",
1314+
)
1315+
assert len(training_items) == 3
1316+
for item in training_items:
1317+
assert item.subset == DatasetItemSubset.TRAINING
1318+
1319+
# Validation items should have subset=validation
1320+
validation_items = fxt_dataset_service.list_dataset_items(
1321+
project=project,
1322+
limit=20,
1323+
offset=0,
1324+
subset="validation",
1325+
)
1326+
assert len(validation_items) == 2
1327+
for item in validation_items:
1328+
assert item.subset == DatasetItemSubset.VALIDATION
1329+
1330+
# Testing items should have subset=testing
1331+
testing_items = fxt_dataset_service.list_dataset_items(
1332+
project=project,
1333+
limit=20,
1334+
offset=0,
1335+
subset="testing",
1336+
)
1337+
assert len(testing_items) == 1
1338+
for item in testing_items:
1339+
assert item.subset == DatasetItemSubset.TESTING

0 commit comments

Comments
 (0)