|
15 | 15 | from unittest.mock import patch |
16 | 16 |
|
17 | 17 | import pytest |
| 18 | +from data_import.api import ImportPredictionsAPI |
18 | 19 | from data_import.functions import reformat_predictions |
19 | 20 | from data_import.serializers import ImportApiSerializer |
20 | 21 | from django.contrib.auth import get_user_model |
21 | 22 | from organizations.tests.factories import OrganizationFactory |
22 | 23 | from projects.tests.factories import ProjectFactory |
23 | 24 | from rest_framework.exceptions import ValidationError |
| 25 | +from rest_framework.test import APIRequestFactory, force_authenticate |
24 | 26 | from tasks.models import Annotation, Prediction, Task |
25 | 27 | from tasks.tests.factories import TaskFactory |
26 | 28 | from users.tests.factories import UserFactory |
@@ -95,6 +97,90 @@ def test_valid_prediction_creation(self): |
95 | 97 | assert prediction.score == 0.95 |
96 | 98 | assert prediction.model_version == 'v1.0' |
97 | 99 |
|
| 100 | + @patch('tasks.serializers.flag_set', return_value=True) |
| 101 | + @patch('tasks.serializers.LabelInterface') |
| 102 | + def test_import_tasks_sanitizes_prediction_before_validation(self, mock_li_cls, _mock_flag_set): |
| 103 | + """ImportApiSerializer must strip export-only keys before validate_prediction().""" |
| 104 | + mock_li = mock_li_cls.return_value |
| 105 | + |
| 106 | + def _validate_prediction(payload, return_errors=True): |
| 107 | + if 'state' in payload: |
| 108 | + return ['Unexpected field: state'] |
| 109 | + return [] |
| 110 | + |
| 111 | + mock_li.validate_prediction.side_effect = _validate_prediction |
| 112 | + tasks = [ |
| 113 | + { |
| 114 | + 'data': {'text': 'Sanitize before validate'}, |
| 115 | + 'predictions': [ |
| 116 | + { |
| 117 | + 'state': 'CREATED', |
| 118 | + 'id': 111, |
| 119 | + 'result': [ |
| 120 | + { |
| 121 | + 'from_name': 'sentiment', |
| 122 | + 'to_name': 'text', |
| 123 | + 'type': 'choices', |
| 124 | + 'value': {'choices': ['positive']}, |
| 125 | + } |
| 126 | + ], |
| 127 | + 'score': 0.9, |
| 128 | + 'model_version': 'mv-sanitize', |
| 129 | + } |
| 130 | + ], |
| 131 | + } |
| 132 | + ] |
| 133 | + |
| 134 | + serializer = ImportApiSerializer(data=tasks, many=True, context={'project': self.project}) |
| 135 | + assert serializer.is_valid(), serializer.errors |
| 136 | + created_tasks = serializer.save(project_id=self.project.id) |
| 137 | + assert len(created_tasks) == 1 |
| 138 | + |
| 139 | + @patch( |
| 140 | + 'data_import.api.flag_set', |
| 141 | + side_effect=lambda flag_name, user='auto', **kwargs: ( |
| 142 | + flag_name == 'fflag_feat_utc_210_prediction_validation_15082025' |
| 143 | + ), |
| 144 | + ) |
| 145 | + @patch('data_import.api.LabelInterface') |
| 146 | + def test_import_predictions_endpoint_sanitizes_payload_before_validation(self, mock_li_cls, _mock_flag_set): |
| 147 | + """Bulk import API should sanitize payload before LabelInterface.validate_prediction().""" |
| 148 | + mock_li = mock_li_cls.return_value |
| 149 | + |
| 150 | + def _validate_prediction(payload, return_errors=True): |
| 151 | + if 'state' in payload: |
| 152 | + return ['Unexpected field: state'] |
| 153 | + return [] |
| 154 | + |
| 155 | + mock_li.validate_prediction.side_effect = _validate_prediction |
| 156 | + request_factory = APIRequestFactory() |
| 157 | + payload = [ |
| 158 | + { |
| 159 | + 'state': 'CREATED', |
| 160 | + 'id': 222, |
| 161 | + 'result': [ |
| 162 | + { |
| 163 | + 'from_name': 'sentiment', |
| 164 | + 'to_name': 'text', |
| 165 | + 'type': 'choices', |
| 166 | + 'value': {'choices': ['neutral']}, |
| 167 | + } |
| 168 | + ], |
| 169 | + 'score': 0.5, |
| 170 | + 'model_version': 'mv-sanitize-endpoint', |
| 171 | + 'task': self.task.id, |
| 172 | + } |
| 173 | + ] |
| 174 | + request = request_factory.post( |
| 175 | + f'/api/projects/{self.project.id}/import/predictions', |
| 176 | + data=payload, |
| 177 | + format='json', |
| 178 | + ) |
| 179 | + force_authenticate(request, user=self.user) |
| 180 | + response = ImportPredictionsAPI.as_view()(request, pk=self.project.id) |
| 181 | + assert response.status_code == 201 |
| 182 | + assert response.data['created'] == 1 |
| 183 | + |
98 | 184 | def test_invalid_prediction_missing_result(self): |
99 | 185 | """Test validation fails when prediction is missing result field.""" |
100 | 186 | tasks = [ |
|
0 commit comments