Skip to content

Commit 144fd7c

Browse files
triklozoidfarioas
andauthored
fix: PLT-838: Reimport memory usage optimization (#8105)
Co-authored-by: Sergey Zhuk <[email protected]>
1 parent e65787a commit 144fd7c

File tree

4 files changed

+302
-41
lines changed

4 files changed

+302
-41
lines changed

label_studio/core/settings/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,8 @@
555555
# Total size of task data (in bytes) to process per batch - used to calculate dynamic batch sizes
556556
# For example: if task data is 10MB, batch will be ~5 tasks to stay under 50MB limit
557557
TASK_DATA_PER_BATCH = int(get_env('TASK_DATA_PER_BATCH', 50 * 1024 * 1024)) # 50 MB in bytes
558+
# Batch size for streaming reimport operations to reduce memory usage
559+
REIMPORT_BATCH_SIZE = int(get_env('REIMPORT_BATCH_SIZE', 1000))
558560
# Batch size for processing prediction imports to avoid memory issues with large datasets
559561
PREDICTION_IMPORT_BATCH_SIZE = int(get_env('PREDICTION_IMPORT_BATCH_SIZE', 500))
560562
PROJECT_TITLE_MIN_LEN = 3

label_studio/data_import/functions.py

Lines changed: 168 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import traceback
44
from typing import Callable, Optional
55

6+
from core.feature_flags import flag_set
67
from core.utils.common import load_func
78
from django.conf import settings
89
from django.db import transaction
910
from projects.models import ProjectImport, ProjectReimport, ProjectSummary
11+
from tasks.models import Task
1012
from users.models import User
1113
from webhooks.models import WebhookAction
1214
from webhooks.utils import emit_webhooks_for_instance
@@ -131,6 +133,125 @@ def reformat_predictions(tasks, preannotated_from_fields):
131133
post_process_reimport = load_func(settings.POST_PROCESS_REIMPORT)
132134

133135

136+
def _async_reimport_background_streaming(reimport, project, organization_id, user):
137+
"""Streaming version of reimport that processes tasks in batches to reduce memory usage"""
138+
try:
139+
# Get batch size from settings or use default
140+
batch_size = settings.REIMPORT_BATCH_SIZE
141+
142+
# Initialize counters
143+
total_task_count = 0
144+
total_annotation_count = 0
145+
total_prediction_count = 0
146+
all_found_formats = {}
147+
all_data_columns = set()
148+
all_created_task_ids = []
149+
150+
# Remove old tasks once before starting
151+
with transaction.atomic():
152+
project.remove_tasks_by_file_uploads(reimport.file_upload_ids)
153+
154+
# Process tasks in batches
155+
batch_number = 0
156+
for batch_tasks, batch_formats, batch_columns in FileUpload.load_tasks_from_uploaded_files_streaming(
157+
project, reimport.file_upload_ids, files_as_tasks_list=reimport.files_as_tasks_list, batch_size=batch_size
158+
):
159+
if not batch_tasks:
160+
logger.info(f'Empty batch received for reimport {reimport.id}')
161+
continue
162+
163+
batch_number += 1
164+
logger.info(f'Processing batch {batch_number} with {len(batch_tasks)} tasks for reimport {reimport.id}')
165+
166+
# Process batch in transaction
167+
with transaction.atomic():
168+
# Lock summary for update to avoid race conditions
169+
summary = ProjectSummary.objects.select_for_update().get(project=project)
170+
171+
# Serialize and save batch
172+
serializer = ImportApiSerializer(
173+
data=batch_tasks, many=True, context={'project': project, 'user': user}
174+
)
175+
serializer.is_valid(raise_exception=True)
176+
batch_db_tasks = serializer.save(project_id=project.id)
177+
178+
# Collect task IDs for later use
179+
all_created_task_ids.extend([t.id for t in batch_db_tasks])
180+
181+
# Update batch counters
182+
batch_task_count = len(batch_db_tasks)
183+
batch_annotation_count = len(serializer.db_annotations)
184+
batch_prediction_count = len(serializer.db_predictions)
185+
186+
total_task_count += batch_task_count
187+
total_annotation_count += batch_annotation_count
188+
total_prediction_count += batch_prediction_count
189+
190+
# Update formats and columns
191+
all_found_formats.update(batch_formats)
192+
if batch_columns:
193+
if not all_data_columns:
194+
all_data_columns = batch_columns
195+
else:
196+
all_data_columns &= batch_columns
197+
198+
# Update data columns in summary
199+
summary.update_data_columns(batch_db_tasks)
200+
201+
logger.info(
202+
f'Batch {batch_number} processed successfully: {batch_task_count} tasks, '
203+
f'{batch_annotation_count} annotations, {batch_prediction_count} predictions'
204+
)
205+
206+
# After all batches are processed, emit webhooks and update task states once
207+
if all_created_task_ids:
208+
logger.info(
209+
f'Finalizing reimport: emitting webhooks and updating task states for {len(all_created_task_ids)} tasks'
210+
)
211+
212+
# Emit webhooks for all tasks at once (passing list of IDs)
213+
emit_webhooks_for_instance(organization_id, project, WebhookAction.TASKS_CREATED, all_created_task_ids)
214+
215+
# Update task states for all tasks at once
216+
all_tasks_queryset = Task.objects.filter(id__in=all_created_task_ids)
217+
recalculate_stats_counts = {
218+
'task_count': total_task_count,
219+
'annotation_count': total_annotation_count,
220+
'prediction_count': total_prediction_count,
221+
}
222+
223+
project.update_tasks_counters_and_task_states(
224+
tasks_queryset=all_tasks_queryset,
225+
maximum_annotations_changed=False,
226+
overlap_cohort_percentage_changed=False,
227+
tasks_number_changed=True,
228+
recalculate_stats_counts=recalculate_stats_counts,
229+
)
230+
logger.info('Tasks bulk_update finished (async streaming reimport)')
231+
232+
# Update reimport with final statistics
233+
reimport.task_count = total_task_count
234+
reimport.annotation_count = total_annotation_count
235+
reimport.prediction_count = total_prediction_count
236+
reimport.found_formats = all_found_formats
237+
reimport.data_columns = list(all_data_columns)
238+
reimport.status = ProjectReimport.Status.COMPLETED
239+
reimport.save()
240+
241+
logger.info(f'Streaming reimport {reimport.id} completed: {total_task_count} tasks imported')
242+
243+
# Run post-processing
244+
post_process_reimport(reimport)
245+
246+
except Exception as e:
247+
logger.error(f'Error in streaming reimport {reimport.id}: {str(e)}', exc_info=True)
248+
reimport.status = ProjectReimport.Status.FAILED
249+
reimport.traceback = traceback.format_exc()
250+
reimport.error = str(e)
251+
reimport.save()
252+
raise
253+
254+
134255
def async_reimport_background(reimport_id, organization_id, user, **kwargs):
135256

136257
with transaction.atomic():
@@ -147,50 +268,56 @@ def async_reimport_background(reimport_id, organization_id, user, **kwargs):
147268

148269
project = reimport.project
149270

150-
tasks, found_formats, data_columns = FileUpload.load_tasks_from_uploaded_files(
151-
reimport.project, reimport.file_upload_ids, files_as_tasks_list=reimport.files_as_tasks_list
152-
)
271+
# Check feature flag for memory improvement
272+
if flag_set('fflag_fix_back_plt_838_reimport_memory_improvement_05082025_short', user='auto'):
273+
logger.info(f'Using streaming reimport for project {project.id}')
274+
_async_reimport_background_streaming(reimport, project, organization_id, user)
275+
else:
276+
# Original implementation
277+
tasks, found_formats, data_columns = FileUpload.load_tasks_from_uploaded_files(
278+
reimport.project, reimport.file_upload_ids, files_as_tasks_list=reimport.files_as_tasks_list
279+
)
153280

154-
with transaction.atomic():
155-
# Lock summary for update to avoid race conditions
156-
summary = ProjectSummary.objects.select_for_update().get(project=project)
281+
with transaction.atomic():
282+
# Lock summary for update to avoid race conditions
283+
summary = ProjectSummary.objects.select_for_update().get(project=project)
157284

158-
project.remove_tasks_by_file_uploads(reimport.file_upload_ids)
159-
serializer = ImportApiSerializer(data=tasks, many=True, context={'project': project, 'user': user})
160-
serializer.is_valid(raise_exception=True)
161-
tasks = serializer.save(project_id=project.id)
162-
emit_webhooks_for_instance(organization_id, project, WebhookAction.TASKS_CREATED, tasks)
285+
project.remove_tasks_by_file_uploads(reimport.file_upload_ids)
286+
serializer = ImportApiSerializer(data=tasks, many=True, context={'project': project, 'user': user})
287+
serializer.is_valid(raise_exception=True)
288+
tasks = serializer.save(project_id=project.id)
289+
emit_webhooks_for_instance(organization_id, project, WebhookAction.TASKS_CREATED, tasks)
163290

164-
task_count = len(tasks)
165-
annotation_count = len(serializer.db_annotations)
166-
prediction_count = len(serializer.db_predictions)
167-
168-
recalculate_stats_counts = {
169-
'task_count': task_count,
170-
'annotation_count': annotation_count,
171-
'prediction_count': prediction_count,
172-
}
173-
174-
# Update counters (like total_annotations) for new tasks and after bulk update tasks stats. It should be a
175-
# single operation as counters affect bulk is_labeled update
176-
project.update_tasks_counters_and_task_states(
177-
tasks_queryset=tasks,
178-
maximum_annotations_changed=False,
179-
overlap_cohort_percentage_changed=False,
180-
tasks_number_changed=True,
181-
recalculate_stats_counts=recalculate_stats_counts,
182-
)
183-
logger.info('Tasks bulk_update finished (async reimport)')
291+
task_count = len(tasks)
292+
annotation_count = len(serializer.db_annotations)
293+
prediction_count = len(serializer.db_predictions)
184294

185-
summary.update_data_columns(tasks)
186-
# TODO: summary.update_created_annotations_and_labels
295+
recalculate_stats_counts = {
296+
'task_count': task_count,
297+
'annotation_count': annotation_count,
298+
'prediction_count': prediction_count,
299+
}
300+
301+
# Update counters (like total_annotations) for new tasks and after bulk update tasks stats. It should be a
302+
# single operation as counters affect bulk is_labeled update
303+
project.update_tasks_counters_and_task_states(
304+
tasks_queryset=tasks,
305+
maximum_annotations_changed=False,
306+
overlap_cohort_percentage_changed=False,
307+
tasks_number_changed=True,
308+
recalculate_stats_counts=recalculate_stats_counts,
309+
)
310+
logger.info('Tasks bulk_update finished (async reimport)')
311+
312+
summary.update_data_columns(tasks)
313+
# TODO: summary.update_created_annotations_and_labels
187314

188-
reimport.task_count = task_count
189-
reimport.annotation_count = annotation_count
190-
reimport.prediction_count = prediction_count
191-
reimport.found_formats = found_formats
192-
reimport.data_columns = list(data_columns)
193-
reimport.status = ProjectReimport.Status.COMPLETED
194-
reimport.save()
315+
reimport.task_count = task_count
316+
reimport.annotation_count = annotation_count
317+
reimport.prediction_count = prediction_count
318+
reimport.found_formats = found_formats
319+
reimport.data_columns = list(data_columns)
320+
reimport.status = ProjectReimport.Status.COMPLETED
321+
reimport.save()
195322

196-
post_process_reimport(reimport)
323+
post_process_reimport(reimport)

label_studio/data_import/models.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,63 @@ def load_tasks_from_uploaded_files(
197197

198198
return tasks, dict(Counter(fileformats)), common_data_fields
199199

200+
@classmethod
201+
def load_tasks_from_uploaded_files_streaming(
202+
cls, project, file_upload_ids=None, formats=None, files_as_tasks_list=True, batch_size=5000
203+
):
204+
"""Stream tasks from uploaded files in batches to reduce memory usage"""
205+
fileformats = []
206+
common_data_fields = set()
207+
batch = []
208+
total_yielded = 0
209+
210+
# scan all files
211+
file_uploads = FileUpload.objects.filter(project=project)
212+
if file_upload_ids:
213+
file_uploads = file_uploads.filter(id__in=file_upload_ids)
214+
215+
for file_upload in file_uploads:
216+
file_format = file_upload.format
217+
if formats and file_format not in formats:
218+
continue
219+
220+
new_tasks = file_upload.read_tasks(files_as_tasks_list)
221+
fileformats.append(file_format)
222+
223+
# Validate data fields consistency
224+
if new_tasks:
225+
new_data_fields = set(new_tasks[0]['data'].keys())
226+
if not common_data_fields:
227+
common_data_fields = new_data_fields
228+
elif not common_data_fields.intersection(new_data_fields):
229+
raise ValidationError(
230+
_old_vs_new_data_keys_inconsistency_message(
231+
new_data_fields, common_data_fields, file_upload.file.name
232+
)
233+
)
234+
else:
235+
common_data_fields &= new_data_fields
236+
237+
# Add file_upload_id to tasks and batch them
238+
for task in new_tasks:
239+
task['file_upload_id'] = file_upload.id
240+
batch.append(task)
241+
242+
# Yield batch when it reaches the size limit
243+
if len(batch) >= batch_size:
244+
yield batch, dict(Counter(fileformats)), common_data_fields
245+
total_yielded += len(batch)
246+
batch = []
247+
248+
# Yield remaining tasks if any
249+
if batch:
250+
yield batch, dict(Counter(fileformats)), common_data_fields
251+
total_yielded += len(batch)
252+
253+
# If no tasks were yielded, return empty batch with metadata
254+
if total_yielded == 0:
255+
yield [], dict(Counter(fileformats)), common_data_fields
256+
200257

201258
def _old_vs_new_data_keys_inconsistency_message(new_data_keys, old_data_keys, current_file):
202259
new_data_keys_list = ','.join(new_data_keys)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Test streaming import functionality for memory optimization"""
2+
from unittest.mock import MagicMock, patch
3+
4+
import pytest
5+
from data_import.models import FileUpload
6+
from organizations.tests.factories import OrganizationFactory
7+
from projects.tests.factories import ProjectFactory
8+
from users.tests.factories import UserFactory
9+
10+
pytestmark = pytest.mark.django_db
11+
12+
13+
class TestStreamingImport:
14+
@pytest.fixture
15+
def user(self):
16+
return UserFactory()
17+
18+
@pytest.fixture
19+
def organization(self):
20+
return OrganizationFactory()
21+
22+
@pytest.fixture
23+
def project(self, user, organization):
24+
return ProjectFactory(
25+
created_by=user,
26+
organization=organization,
27+
label_config='<View><Text name="text" value="$text"/><Choices name="label" toName="text"><Choice value="pos"/><Choice value="neg"/></Choices></View>',
28+
)
29+
30+
def test_load_tasks_from_uploaded_files_streaming_basic(self, user, project):
31+
"""Test basic streaming functionality with small batches"""
32+
# Mock file upload objects
33+
with patch.object(FileUpload.objects, 'filter') as mock_filter:
34+
mock_file_upload1 = MagicMock()
35+
mock_file_upload1.format = '.json'
36+
mock_file_upload1.id = 1
37+
mock_file_upload1.read_tasks.return_value = [{'data': {'text': f'Task {i}'}} for i in range(10)]
38+
39+
mock_file_upload2 = MagicMock()
40+
mock_file_upload2.format = '.json'
41+
mock_file_upload2.id = 2
42+
mock_file_upload2.read_tasks.return_value = [{'data': {'text': f'Task {i+10}'}} for i in range(10)]
43+
44+
mock_filter.return_value = [mock_file_upload1, mock_file_upload2]
45+
46+
# Test streaming with batch size 5
47+
batches = list(FileUpload.load_tasks_from_uploaded_files_streaming(project, batch_size=5))
48+
49+
# Should have 4 batches (20 tasks / 5 per batch)
50+
assert len(batches) == 4
51+
52+
# Check batch sizes
53+
assert len(batches[0][0]) == 5 # First batch tasks
54+
assert len(batches[1][0]) == 5 # Second batch tasks
55+
assert len(batches[2][0]) == 5 # Third batch tasks
56+
assert len(batches[3][0]) == 5 # Fourth batch tasks
57+
58+
# Check that all tasks have file_upload_id
59+
for batch_tasks, _, _ in batches:
60+
for task in batch_tasks:
61+
assert 'file_upload_id' in task
62+
63+
def test_load_tasks_from_uploaded_files_streaming_empty(self, project):
64+
"""Test streaming with no file uploads"""
65+
# Mock empty file uploads
66+
with patch.object(FileUpload.objects, 'filter') as mock_filter:
67+
mock_filter.return_value = []
68+
69+
batches = list(FileUpload.load_tasks_from_uploaded_files_streaming(project, batch_size=5))
70+
71+
# Should have one empty batch with metadata
72+
assert len(batches) == 1
73+
assert len(batches[0][0]) == 0 # Empty tasks
74+
assert batches[0][1] == {} # Empty formats
75+
assert batches[0][2] == set() # Empty columns

0 commit comments

Comments
 (0)