Skip to content

Commit 15523ce

Browse files
authored
fix: UTC-459: Handling breadth first with evaluation enabled (#9076)
Co-authored-by: mcanu <mcanu@users.noreply.github.com>
1 parent 0fb9c95 commit 15523ce

File tree

3 files changed

+120
-29
lines changed

3 files changed

+120
-29
lines changed

label_studio/projects/functions/next_task.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from core.utils.common import conditional_atomic, db_is_not_sqlite, load_func
77
from core.utils.db import fast_first
88
from django.conf import settings
9-
from django.db.models import BooleanField, Case, Count, Exists, F, Max, OuterRef, Q, QuerySet, Value, When
9+
from django.db.models import Case, Count, Exists, F, Max, OuterRef, Q, QuerySet, When
1010
from django.db.models.fields import DecimalField
1111
from projects.functions.stream_history import add_stream_history
1212
from projects.models import Project
@@ -75,37 +75,32 @@ def _try_tasks_with_overlap(tasks: QuerySet[Task]) -> Tuple[Union[Task, None], Q
7575
return None, tasks.filter(overlap=1)
7676

7777

78-
def _try_breadth_first(
79-
tasks: QuerySet[Task], user: User, project: Project, attempt_gt_first: bool = False
80-
) -> Union[Task, None]:
78+
def _try_breadth_first(tasks: QuerySet[Task], user: User, project: Project) -> Union[Task, None]:
8179
"""Try to find tasks with maximum amount of annotations, since we are trying to label tasks as fast as possible"""
8280

83-
# Exclude ground truth annotations from the count when not in onboarding window
84-
# to prevent GT tasks from being prioritized via breadth-first logic
85-
annotation_filter = ~Q(annotations__completed_by=user)
86-
if not attempt_gt_first:
87-
annotation_filter &= ~Q(annotations__ground_truth=True)
81+
if project.annotator_evaluation_enabled:
82+
# When annotator evaluation is enabled, ground truth tasks accumulate overlap regardless of the maximum annotations setting.
83+
# If we include them, they will eventually be front-loaded by the breadth first logic.
84+
# So we exclude them from the candidates.
85+
# Onboarding tasks are served by _try_ground_truth.
86+
# When no in progress tasks are found by breadth first, the next step in the pipeline will serve the remaining GT tasks.
87+
tasks = _annotate_has_ground_truths(tasks)
88+
tasks = tasks.filter(has_ground_truths=False)
8889

89-
tasks = tasks.annotate(annotations_count=Count('annotations', filter=annotation_filter))
90+
tasks = tasks.annotate(annotations_count=Count('annotations', filter=~Q(annotations__completed_by=user)))
9091
max_annotations_count = tasks.aggregate(Max('annotations_count'))['annotations_count__max']
91-
if max_annotations_count == 0:
92-
# there is no any labeled tasks found
93-
return
94-
95-
# find any task with maximal amount of created annotations
96-
not_solved_tasks_labeling_started = tasks.annotate(
97-
reach_max_annotations_count=Case(
98-
When(annotations_count=max_annotations_count, then=Value(True)),
99-
default=Value(False),
100-
output_field=BooleanField(),
101-
)
102-
)
103-
not_solved_tasks_labeling_with_max_annotations = not_solved_tasks_labeling_started.filter(
104-
reach_max_annotations_count=True
105-
)
106-
if not_solved_tasks_labeling_with_max_annotations.exists():
107-
# try to complete tasks that are already in progress
108-
return _get_random_unlocked(not_solved_tasks_labeling_with_max_annotations, user)
92+
93+
if max_annotations_count == 0 or max_annotations_count is None:
94+
# No tasks with annotations, let the next step in the pipeline handle it
95+
return None
96+
97+
# Find tasks at the maximum amount of annotations
98+
candidates = tasks.filter(annotations_count=max_annotations_count)
99+
if candidates.exists():
100+
# Select randomly from candidates
101+
result = _get_random_unlocked(candidates, user)
102+
return result
103+
return None
109104

110105

111106
def _try_uncertainty_sampling(
@@ -289,7 +284,7 @@ def get_next_task_without_dm_queue(
289284
if not next_task and project.maximum_annotations > 1:
290285
# if there are already labeled tasks, but task.overlap still < project.maximum_annotations, randomly sampling from them
291286
logger.debug(f'User={user} tries depth first from prepared tasks')
292-
next_task = _try_breadth_first(not_solved_tasks, user, project, attempt_gt_first)
287+
next_task = _try_breadth_first(not_solved_tasks, user, project)
293288
if next_task:
294289
queue_info += (' & ' if queue_info else '') + 'Breadth first queue'
295290

label_studio/tasks/tests/factories.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class TaskFactory(factory.django.DjangoModelFactory):
1515
}
1616
)
1717
project = factory.SubFactory(load_func(settings.PROJECT_FACTORY))
18+
overlap = factory.LazyAttribute(lambda obj: obj.project.maximum_annotations if obj.project else 1)
1819

1920
class Meta:
2021
model = Task

label_studio/tests/test_next_task.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
from core.redis import redis_healthcheck
99
from django.apps import apps
1010
from django.db.models import Q
11+
from django.test import TestCase
12+
from projects.functions.next_task import _try_breadth_first
1113
from projects.models import Project
14+
from projects.tests.factories import ProjectFactory
1215
from tasks.models import Annotation, Prediction, Task
16+
from tasks.tests.factories import AnnotationFactory, TaskFactory
17+
from users.tests.factories import UserFactory
1318

1419
from .utils import (
1520
_client_is_annotator,
@@ -1327,3 +1332,93 @@ def complete_task(annotator):
13271332
else:
13281333
assert not all_tasks_with_overlap_are_labeled
13291334
assert not all_tasks_without_overlap_are_not_labeled
1335+
1336+
1337+
class TestTryBreadthFirst(TestCase):
1338+
@classmethod
1339+
def setUpTestData(cls):
1340+
cls.user = UserFactory()
1341+
cls.other_user = UserFactory()
1342+
1343+
# Project with evaluation enabled
1344+
cls.project_with_eval = ProjectFactory(
1345+
maximum_annotations=3,
1346+
annotator_evaluation_enabled=True,
1347+
)
1348+
1349+
# Project without evaluation
1350+
cls.project_without_eval = ProjectFactory(
1351+
maximum_annotations=3,
1352+
annotator_evaluation_enabled=False,
1353+
)
1354+
1355+
def test_excludes_ground_truth_tasks_when_evaluation_enabled(self):
1356+
"""
1357+
Test that _try_breadth_first excludes GT tasks when annotator_evaluation_enabled=True.
1358+
"""
1359+
# Create tasks with varying annotation counts
1360+
task_1 = TaskFactory(project=self.project_with_eval) # 2 regular annotations (max)
1361+
task_2 = TaskFactory(project=self.project_with_eval) # 1 regular annotation
1362+
task_3_gt = TaskFactory(project=self.project_with_eval) # 3 annotations BUT has GT
1363+
1364+
# Add regular annotations to task_1 (should be selected)
1365+
AnnotationFactory.create_batch(2, task=task_1, ground_truth=False)
1366+
1367+
# Add regular annotation to task_2
1368+
AnnotationFactory(task=task_2, ground_truth=False)
1369+
1370+
# Add GT annotation to task_3_gt plus a regular one
1371+
AnnotationFactory(task=task_3_gt, ground_truth=True)
1372+
AnnotationFactory(task=task_3_gt, ground_truth=False)
1373+
AnnotationFactory(task=task_3_gt, ground_truth=False)
1374+
1375+
# Get all tasks
1376+
tasks = Task.objects.filter(project=self.project_with_eval)
1377+
1378+
# Execute
1379+
result = _try_breadth_first(tasks, self.user, self.project_with_eval)
1380+
1381+
# Assert: should return task_1 (max annotations, not GT), not task_3_gt
1382+
assert result == task_1
1383+
1384+
def test_includes_ground_truth_tasks_when_evaluation_disabled(self):
1385+
"""
1386+
Test that _try_breadth_first includes GT tasks when annotator_evaluation_enabled=False.
1387+
"""
1388+
# Create tasks with varying annotation counts
1389+
task_1 = TaskFactory(project=self.project_without_eval) # 2 regular annotations (max)
1390+
task_2 = TaskFactory(project=self.project_without_eval) # 1 regular annotation
1391+
task_3_gt = TaskFactory(project=self.project_without_eval) # 3 annotations BUT has GT
1392+
1393+
# Add regular annotations to task_1 (should be selected)
1394+
AnnotationFactory.create_batch(2, task=task_1, ground_truth=False)
1395+
1396+
# Add regular annotation to task_2
1397+
AnnotationFactory(task=task_2, ground_truth=False)
1398+
1399+
# Add GT annotation to task_3_gt plus a regular one
1400+
AnnotationFactory(task=task_3_gt, ground_truth=True)
1401+
AnnotationFactory(task=task_3_gt, ground_truth=False)
1402+
AnnotationFactory(task=task_3_gt, ground_truth=False)
1403+
1404+
# Get all tasks
1405+
tasks = Task.objects.filter(project=self.project_without_eval)
1406+
1407+
# Execute
1408+
result = _try_breadth_first(tasks, self.user, self.project_without_eval)
1409+
1410+
# Assert: should return task_3_gt (max annotations, GT), not task_1 or task_2
1411+
assert result == task_3_gt
1412+
1413+
def test_returns_none_when_no_tasks_with_annotations_and_evaluation_enabled(self):
1414+
1415+
task_gt = TaskFactory(project=self.project_with_eval)
1416+
AnnotationFactory(task=task_gt, ground_truth=True)
1417+
1418+
tasks = Task.objects.filter(project=self.project_with_eval)
1419+
1420+
# Execute
1421+
result = _try_breadth_first(tasks, self.user, self.project_with_eval)
1422+
1423+
# Assert: should return None
1424+
assert result is None

0 commit comments

Comments
 (0)