Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 13 additions & 71 deletions label_studio/data_import/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from django.utils.decorators import method_decorator
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from label_studio_sdk.label_interface import LabelInterface
from projects.models import Project, ProjectImport, ProjectReimport
from ranged_fileresponse import RangedFileResponse
from rest_framework import generics, status
Expand Down Expand Up @@ -267,31 +266,7 @@ def sync_import(self, request, project, preannotated_from_fields, commit_to_proj

if preannotated_from_fields:
# turn flat task JSONs {"column1": value, "column2": value} into {"data": {"column1"..}, "predictions": [{..."column2"}]
parsed_data = reformat_predictions(parsed_data, preannotated_from_fields, project)

# Conditionally validate predictions: skip when label config is default during project creation
if project.label_config_is_not_default:
validation_errors = []
li = LabelInterface(project.label_config)

for i, task in enumerate(parsed_data):
if 'predictions' in task:
for j, prediction in enumerate(task['predictions']):
try:
validation_errors_list = li.validate_prediction(prediction, return_errors=True)
if validation_errors_list:
for error in validation_errors_list:
validation_errors.append(f'Task {i}, prediction {j}: {error}')
except Exception as e:
error_msg = f'Task {i}, prediction {j}: Error validating prediction - {str(e)}'
validation_errors.append(error_msg)

if validation_errors:
error_message = f'Prediction validation failed ({len(validation_errors)} errors):\n'
for error in validation_errors:
error_message += f'- {error}\n'

raise ValidationError({'predictions': [error_message]})
parsed_data = reformat_predictions(parsed_data, preannotated_from_fields)

if commit_to_project:
# Immediately create project tasks and update project states and counters
Expand Down Expand Up @@ -509,55 +484,22 @@ def _create_legacy(self, project):
logger.debug(
f'Importing {len(self.request.data)} predictions to project {project} with {len(tasks_ids)} tasks (legacy mode)'
)

li = LabelInterface(project.label_config)

# Validate all predictions before creating any
validation_errors = []
predictions = []

for i, item in enumerate(self.request.data):
# Validate task ID
for item in self.request.data:
if item.get('task') not in tasks_ids:
validation_errors.append(
f'Prediction {i}: Invalid task ID {item.get("task")} - task not found in project'
raise ValidationError(
f'{item} contains invalid "task" field: corresponding task ID couldn\'t be retrieved '
f'from project {project} tasks'
)
continue

# Validate prediction using LabelInterface only
try:
validation_errors_list = li.validate_prediction(item, return_errors=True)

# If prediction is invalid, add error to validation_errors list and continue to next prediction
if validation_errors_list:
# Format errors for better readability
for error in validation_errors_list:
validation_errors.append(f'Prediction {i}: {error}')
continue

except Exception as e:
validation_errors.append(f'Prediction {i}: Error validating prediction - {str(e)}')
continue

# If prediction is valid, add it to predictions list to be created
try:
predictions.append(
Prediction(
task_id=item['task'],
project_id=project.id,
result=Prediction.prepare_prediction_result(item.get('result'), project),
score=item.get('score'),
model_version=item.get('model_version', 'undefined'),
)
predictions.append(
Prediction(
task_id=item['task'],
project_id=project.id,
result=Prediction.prepare_prediction_result(item.get('result'), project),
score=item.get('score'),
model_version=item.get('model_version', 'undefined'),
)
except Exception as e:
validation_errors.append(f'Prediction {i}: Failed to create prediction - {str(e)}')
continue

# If there are validation errors, raise them before creating any predictions
if validation_errors:
raise ValidationError(validation_errors)

)
predictions_obj = Prediction.objects.bulk_create(predictions, batch_size=settings.BATCH_SIZE)
start_job_async_or_sync(update_tasks_counters, Task.objects.filter(id__in=tasks_ids))
return Response({'created': len(predictions_obj)}, status=status.HTTP_201_CREATED)
Expand Down
191 changes: 31 additions & 160 deletions label_studio/data_import/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from core.utils.common import load_func
from django.conf import settings
from django.db import transaction
from label_studio_sdk.label_interface import LabelInterface
from projects.models import ProjectImport, ProjectReimport, ProjectSummary
from rest_framework.exceptions import ValidationError
from tasks.models import Task
from users.models import User
from webhooks.models import WebhookAction
Expand Down Expand Up @@ -48,35 +46,7 @@ def async_import_background(

if project_import.preannotated_from_fields:
# turn flat task JSONs {"column1": value, "column2": value} into {"data": {"column1"..}, "predictions": [{..."column2"}]
tasks = reformat_predictions(tasks, project_import.preannotated_from_fields, project)

# Always validate predictions regardless of commit_to_project setting
if project.label_config_is_not_default:
validation_errors = []
li = LabelInterface(project.label_config)

for i, task in enumerate(tasks):
if 'predictions' in task:
for j, prediction in enumerate(task['predictions']):
try:
validation_errors_list = li.validate_prediction(prediction, return_errors=True)
if validation_errors_list:
for error in validation_errors_list:
validation_errors.append(f'Task {i}, prediction {j}: {error}')
except Exception as e:
error_msg = f'Task {i}, prediction {j}: Error validating prediction - {str(e)}'
validation_errors.append(error_msg)
logger.error(f'Exception during validation: {error_msg}')

if validation_errors:
error_message = f'Prediction validation failed ({len(validation_errors)} errors):\n'
for error in validation_errors:
error_message += f'- {error}\n'

project_import.error = error_message
project_import.status = ProjectImport.Status.FAILED
project_import.save()
return
tasks = reformat_predictions(tasks, project_import.preannotated_from_fields)

if project_import.commit_to_project:
with transaction.atomic():
Expand All @@ -86,41 +56,32 @@ def async_import_background(
# Immediately create project tasks and update project states and counters
serializer = ImportApiSerializer(data=tasks, many=True, context={'project': project})
serializer.is_valid(raise_exception=True)
tasks = serializer.save(project_id=project.id)
emit_webhooks_for_instance(user.active_organization, project, WebhookAction.TASKS_CREATED, tasks)

try:
tasks = serializer.save(project_id=project.id)
emit_webhooks_for_instance(user.active_organization, project, WebhookAction.TASKS_CREATED, tasks)

task_count = len(tasks)
annotation_count = len(serializer.db_annotations)
prediction_count = len(serializer.db_predictions)
# Update counters (like total_annotations) for new tasks and after bulk update tasks stats. It should be a
# single operation as counters affect bulk is_labeled update

recalculate_stats_counts = {
'task_count': task_count,
'annotation_count': annotation_count,
'prediction_count': prediction_count,
}

project.update_tasks_counters_and_task_states(
tasks_queryset=tasks,
maximum_annotations_changed=False,
overlap_cohort_percentage_changed=False,
tasks_number_changed=True,
recalculate_stats_counts=recalculate_stats_counts,
)
logger.info('Tasks bulk_update finished (async import)')

summary.update_data_columns(tasks)
# TODO: summary.update_created_annotations_and_labels
except Exception as e:
# Handle any other unexpected errors during task creation
error_message = f'Error creating tasks: {str(e)}'
project_import.error = error_message
project_import.status = ProjectImport.Status.FAILED
project_import.save()
return
task_count = len(tasks)
annotation_count = len(serializer.db_annotations)
prediction_count = len(serializer.db_predictions)
# Update counters (like total_annotations) for new tasks and after bulk update tasks stats. It should be a
# single operation as counters affect bulk is_labeled update

recalculate_stats_counts = {
'task_count': task_count,
'annotation_count': annotation_count,
'prediction_count': prediction_count,
}

project.update_tasks_counters_and_task_states(
tasks_queryset=tasks,
maximum_annotations_changed=False,
overlap_cohort_percentage_changed=False,
tasks_number_changed=True,
recalculate_stats_counts=recalculate_stats_counts,
)
logger.info('Tasks bulk_update finished (async import)')

summary.update_data_columns(tasks)
# TODO: summary.update_created_annotations_and_labels
else:
# Do nothing - just output file upload ids for further use
task_count = len(tasks)
Expand Down Expand Up @@ -159,103 +120,13 @@ def set_reimport_background_failure(job, connection, type, value, _):
)


def reformat_predictions(tasks, preannotated_from_fields, project=None):
"""
Transform flat task JSON objects into proper format with separate data and predictions fields.
Also validates the predictions to ensure they are properly formatted using LabelInterface.

Args:
tasks: List of task data
preannotated_from_fields: List of field names to convert to predictions
project: Optional project instance to determine correct to_name and type from label config
"""
def reformat_predictions(tasks, preannotated_from_fields):
new_tasks = []
validation_errors = []

# If project is provided, create LabelInterface to determine correct mappings
li = None
if project:
try:
li = LabelInterface(project.label_config)
except Exception as e:
logger.warning(f'Could not create LabelInterface for project {project.id}: {e}')

for task_index, task in enumerate(tasks):
for task in tasks:
if 'data' in task:
task_data = task['data']
else:
task_data = task

predictions = []
for field in preannotated_from_fields:
if field not in task_data:
validation_errors.append(f"Task {task_index}: Preannotated field '{field}' not found in task data")
continue

value = task_data[field]
if value is not None:
# Try to determine correct to_name and type from project configuration
to_name = 'text' # Default fallback
prediction_type = 'choices' # Default fallback

if li:
# Find a control tag that matches the field name
try:
control_tag = li.get_control(field)
# Use the control's to_name and determine type
if hasattr(control_tag, 'to_name') and control_tag.to_name:
to_name = (
control_tag.to_name[0]
if isinstance(control_tag.to_name, list)
else control_tag.to_name
)
prediction_type = control_tag.tag.lower()
except Exception:
# Control not found, use defaults
pass

# Create prediction from preannotated field
# Handle different types of values
if isinstance(value, dict):
# For complex structures like bounding boxes, use the value directly
prediction_value = value
else:
# For simple values, use the prediction_type as the key
# Handle cases where the type doesn't match the expected key
value_key = prediction_type
if prediction_type == 'textarea':
value_key = 'text'

# Most types expect lists, but some expect single values
if prediction_type in ['rating', 'number', 'datetime']:
prediction_value = {value_key: value}
else:
# Wrap in list for most types
prediction_value = {value_key: [value] if not isinstance(value, list) else value}

prediction = {
'result': [
{
'from_name': field,
'to_name': to_name,
'type': prediction_type,
'value': prediction_value,
}
],
'score': 1.0,
'model_version': 'preannotated',
}

predictions.append(prediction)

# Create new task structure
new_task = {'data': task_data, 'predictions': predictions}
new_tasks.append(new_task)

# If there are validation errors, raise them
if validation_errors:
raise ValidationError({'preannotated_fields': validation_errors})

task = task['data']
predictions = [{'result': task.pop(field)} for field in preannotated_from_fields]
new_tasks.append({'data': task, 'predictions': predictions})
return new_tasks


Expand Down
Loading
Loading