Skip to content

Commit 7ae709f

Browse files
fix: FIT-1392: exported annotations cannot be reimported as predictions (#9396)
1 parent 4772544 commit 7ae709f

File tree

4 files changed

+102
-0
lines changed

4 files changed

+102
-0
lines changed

label_studio/data_import/api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from rest_framework.views import APIView
3434
from tasks.functions import update_tasks_counters
3535
from tasks.models import Prediction, Task
36+
from tasks.serializers import sanitize_prediction_import_payload
3637
from users.models import User
3738
from webhooks.models import WebhookAction
3839
from webhooks.utils import emit_webhooks_for_instance
@@ -537,6 +538,7 @@ def _create_memory_efficient(self, project):
537538
# Build predictions for this batch
538539
batch_predictions = []
539540
for item in batch_items:
541+
item = sanitize_prediction_import_payload(item)
540542
task_id = item.get('task')
541543

542544
if task_id not in existing_task_ids:
@@ -586,6 +588,7 @@ def _create_legacy(self, project):
586588
predictions = []
587589

588590
for i, item in enumerate(self.request.data):
591+
item = sanitize_prediction_import_payload(item)
589592
# Validate task ID
590593
if item.get('task') not in tasks_ids:
591594
if flag_set('fflag_feat_utc_210_prediction_validation_15082025', user='auto'):

label_studio/data_import/functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from projects.models import ProjectImport, ProjectReimport, ProjectSummary
1313
from rest_framework.exceptions import ValidationError
1414
from tasks.models import Task
15+
from tasks.serializers import sanitize_prediction_import_payload
1516
from users.models import User
1617
from webhooks.models import WebhookAction
1718
from webhooks.utils import emit_webhooks_for_instance
@@ -71,6 +72,7 @@ def async_import_background(
7172
if 'predictions' in task:
7273
for j, prediction in enumerate(task['predictions']):
7374
try:
75+
prediction = sanitize_prediction_import_payload(prediction)
7476
validation_errors_list = li.validate_prediction(prediction, return_errors=True)
7577
if validation_errors_list:
7678
for error in validation_errors_list:
@@ -451,6 +453,7 @@ def _async_import_background_streaming(project_import, user):
451453
if 'predictions' in task:
452454
for j, prediction in enumerate(task['predictions']):
453455
try:
456+
prediction = sanitize_prediction_import_payload(prediction)
454457
validation_errors_list = li.validate_prediction(prediction, return_errors=True)
455458
if validation_errors_list:
456459
for error in validation_errors_list:

label_studio/tasks/serializers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license."""
22

33
import logging
4+
from collections.abc import MutableMapping
45

56
import ujson as json
67
from core.current_request import CurrentContext, get_current_request
@@ -32,6 +33,14 @@
3233
logger = logging.getLogger(__name__)
3334

3435

36+
def sanitize_prediction_import_payload(prediction):
37+
"""Drop only FSM `state` from prediction import payloads."""
38+
if not isinstance(prediction, MutableMapping):
39+
return prediction
40+
prediction.pop('state', None)
41+
return prediction
42+
43+
3544
class PredictionQuerySerializer(serializers.Serializer):
3645
task = serializers.IntegerField(required=False, help_text='Task ID to filter predictions')
3746
project = serializers.IntegerField(required=False, help_text='Project ID to filter predictions')
@@ -515,6 +524,7 @@ def add_predictions(self, task_predictions):
515524
# Validate prediction only when project label config is not default
516525
if should_validate:
517526
try:
527+
prediction = sanitize_prediction_import_payload(prediction)
518528
li = LabelInterface(self.project.label_config) if should_validate else None
519529
validation_errors_list = li.validate_prediction(prediction, return_errors=True)
520530

label_studio/tests/test_prediction_validation.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
from unittest.mock import patch
1616

1717
import pytest
18+
from data_import.api import ImportPredictionsAPI
1819
from data_import.functions import reformat_predictions
1920
from data_import.serializers import ImportApiSerializer
2021
from django.contrib.auth import get_user_model
2122
from organizations.tests.factories import OrganizationFactory
2223
from projects.tests.factories import ProjectFactory
2324
from rest_framework.exceptions import ValidationError
25+
from rest_framework.test import APIRequestFactory, force_authenticate
2426
from tasks.models import Annotation, Prediction, Task
2527
from tasks.tests.factories import TaskFactory
2628
from users.tests.factories import UserFactory
@@ -95,6 +97,90 @@ def test_valid_prediction_creation(self):
9597
assert prediction.score == 0.95
9698
assert prediction.model_version == 'v1.0'
9799

100+
@patch('tasks.serializers.flag_set', return_value=True)
101+
@patch('tasks.serializers.LabelInterface')
102+
def test_import_tasks_sanitizes_prediction_before_validation(self, mock_li_cls, _mock_flag_set):
103+
"""ImportApiSerializer must strip export-only keys before validate_prediction()."""
104+
mock_li = mock_li_cls.return_value
105+
106+
def _validate_prediction(payload, return_errors=True):
107+
if 'state' in payload:
108+
return ['Unexpected field: state']
109+
return []
110+
111+
mock_li.validate_prediction.side_effect = _validate_prediction
112+
tasks = [
113+
{
114+
'data': {'text': 'Sanitize before validate'},
115+
'predictions': [
116+
{
117+
'state': 'CREATED',
118+
'id': 111,
119+
'result': [
120+
{
121+
'from_name': 'sentiment',
122+
'to_name': 'text',
123+
'type': 'choices',
124+
'value': {'choices': ['positive']},
125+
}
126+
],
127+
'score': 0.9,
128+
'model_version': 'mv-sanitize',
129+
}
130+
],
131+
}
132+
]
133+
134+
serializer = ImportApiSerializer(data=tasks, many=True, context={'project': self.project})
135+
assert serializer.is_valid(), serializer.errors
136+
created_tasks = serializer.save(project_id=self.project.id)
137+
assert len(created_tasks) == 1
138+
139+
@patch(
140+
'data_import.api.flag_set',
141+
side_effect=lambda flag_name, user='auto', **kwargs: (
142+
flag_name == 'fflag_feat_utc_210_prediction_validation_15082025'
143+
),
144+
)
145+
@patch('data_import.api.LabelInterface')
146+
def test_import_predictions_endpoint_sanitizes_payload_before_validation(self, mock_li_cls, _mock_flag_set):
147+
"""Bulk import API should sanitize payload before LabelInterface.validate_prediction()."""
148+
mock_li = mock_li_cls.return_value
149+
150+
def _validate_prediction(payload, return_errors=True):
151+
if 'state' in payload:
152+
return ['Unexpected field: state']
153+
return []
154+
155+
mock_li.validate_prediction.side_effect = _validate_prediction
156+
request_factory = APIRequestFactory()
157+
payload = [
158+
{
159+
'state': 'CREATED',
160+
'id': 222,
161+
'result': [
162+
{
163+
'from_name': 'sentiment',
164+
'to_name': 'text',
165+
'type': 'choices',
166+
'value': {'choices': ['neutral']},
167+
}
168+
],
169+
'score': 0.5,
170+
'model_version': 'mv-sanitize-endpoint',
171+
'task': self.task.id,
172+
}
173+
]
174+
request = request_factory.post(
175+
f'/api/projects/{self.project.id}/import/predictions',
176+
data=payload,
177+
format='json',
178+
)
179+
force_authenticate(request, user=self.user)
180+
response = ImportPredictionsAPI.as_view()(request, pk=self.project.id)
181+
assert response.status_code == 201
182+
assert response.data['created'] == 1
183+
98184
def test_invalid_prediction_missing_result(self):
99185
"""Test validation fails when prediction is missing result field."""
100186
tasks = [

0 commit comments

Comments
 (0)