Skip to content

Commit 84d802b

Browse files
authored
Merge pull request #27 from akkaouim/labs-mbw-v3
fix: stream-parse-and-store to prevent OOM on large CSV downloads
2 parents bad5a13 + 01edacc commit 84d802b

File tree

6 files changed

+484
-196
lines changed

6 files changed

+484
-196
lines changed

commcare_connect/labs/analysis/backends/csv_parsing.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import io
1010
import json
1111
import logging
12+
from collections.abc import Generator
1213

1314
import pandas as pd
1415

@@ -210,3 +211,33 @@ def parse_csv_bytes(
210211
logger.info(f"Parsed {len(visits)} visits (full mode)")
211212

212213
return visits
214+
215+
216+
def parse_csv_file_chunks(
217+
csv_path: str,
218+
opportunity_id: int,
219+
chunksize: int = 1000,
220+
) -> Generator[list[dict], None, None]:
221+
"""
222+
Parse CSV from file path in chunks. Memory-efficient: no BytesIO copy.
223+
224+
Reads directly from a file path using pandas C parser, avoiding the
225+
BytesIO copy that doubles memory usage with parse_csv_bytes().
226+
227+
Args:
228+
csv_path: Path to CSV file on disk
229+
opportunity_id: Opportunity ID (fallback if not in CSV)
230+
chunksize: Number of rows per chunk (default 1000)
231+
232+
Yields:
233+
Lists of visit dicts (with form_json), one list per chunk
234+
"""
235+
total_parsed = 0
236+
for chunk in pd.read_csv(csv_path, chunksize=chunksize, on_bad_lines="warn"):
237+
batch = []
238+
for _, row in chunk.iterrows():
239+
batch.append(_row_to_visit_dict(row, opportunity_id, include_form_json=True))
240+
total_parsed += len(batch)
241+
yield batch
242+
243+
logger.info(f"Parsed {total_parsed} visits from file (chunked)")

commcare_connect/labs/analysis/backends/sql/backend.py

Lines changed: 115 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import json
99
import logging
10+
import os
11+
import tempfile
1012
from collections.abc import Generator
1113
from datetime import date, datetime
1214
from decimal import Decimal
@@ -18,7 +20,7 @@
1820
from django.http import HttpRequest
1921
from django.utils.dateparse import parse_date
2022

21-
from commcare_connect.labs.analysis.backends.csv_parsing import parse_csv_bytes
23+
from commcare_connect.labs.analysis.backends.csv_parsing import parse_csv_bytes, parse_csv_file_chunks
2224
from commcare_connect.labs.analysis.backends.sql.cache import SQLCacheManager
2325
from commcare_connect.labs.analysis.backends.sql.query_builder import execute_flw_aggregation, execute_visit_extraction
2426
from commcare_connect.labs.analysis.config import AnalysisPipelineConfig, CacheStage
@@ -27,7 +29,7 @@
2729
logger = logging.getLogger(__name__)
2830

2931

30-
def _model_to_visit_dict(row) -> dict:
32+
def _model_to_visit_dict(row, skip_form_json=False) -> dict:
3133
"""Convert RawVisitCache model instance to visit dict."""
3234
return {
3335
"id": row.visit_id,
@@ -43,7 +45,7 @@ def _model_to_visit_dict(row) -> dict:
4345
"location": row.location,
4446
"flagged": row.flagged,
4547
"flag_reason": row.flag_reason,
46-
"form_json": row.form_json,
48+
"form_json": {} if skip_form_json else row.form_json,
4749
"completed_work": row.completed_work,
4850
"status_modified_date": row.status_modified_date.isoformat() if row.status_modified_date else None,
4951
"review_status": row.review_status,
@@ -150,20 +152,30 @@ def stream_raw_visits(
150152
Stream raw visit data with progress events.
151153
152154
SQL backend checks RawVisitCache first. If hit, yields immediately.
153-
Otherwise streams download from API with progress.
155+
Otherwise streams download to temp file, then parses and stores in
156+
memory-efficient batches (1000 rows at a time).
157+
158+
On cache hit, yields slim dicts (no form_json) since the raw data
159+
is already in the database for SQL extraction.
160+
161+
On cache miss, downloads to temp file (0 bytes in Python memory),
162+
parses CSV in chunks, stores each chunk to DB with form_json,
163+
and yields slim dicts (form_json stripped after DB storage).
164+
Peak memory: ~50 MB instead of ~2 GB.
154165
"""
155166
cache_manager = SQLCacheManager(opportunity_id, config=None)
156167

157168
# Check SQL cache first
158169
if not force_refresh and expected_visit_count:
159170
if cache_manager.has_valid_raw_cache(expected_visit_count, tolerance_pct=tolerance_pct):
160171
logger.info(f"[SQL] Raw cache HIT for opp {opportunity_id}")
161-
visit_dicts = self._load_from_cache(cache_manager, skip_form_json=False, filter_visit_ids=None)
172+
# Load slim dicts (no form_json) — SQL extraction reads from DB directly
173+
visit_dicts = self._load_from_cache(cache_manager, skip_form_json=True, filter_visit_ids=None)
162174
yield ("cached", visit_dicts)
163175
return
164176

165-
# Cache miss - stream download from API
166-
logger.info(f"[SQL] Raw cache MISS for opp {opportunity_id}, streaming from API")
177+
# Cache miss - stream download to temp file (0 bytes in Python memory)
178+
logger.info(f"[SQL] Raw cache MISS for opp {opportunity_id}, streaming to temp file")
167179

168180
url = f"{settings.CONNECT_PRODUCTION_URL}/export/opportunity/{opportunity_id}/user_visits/"
169181
headers = {
@@ -174,58 +186,100 @@ def stream_raw_visits(
174186
# Use shared progress interval from SSE streaming module
175187
from commcare_connect.labs.analysis.sse_streaming import DOWNLOAD_PROGRESS_INTERVAL_BYTES
176188

177-
chunks = []
178-
bytes_downloaded = 0
179189
progress_interval = DOWNLOAD_PROGRESS_INTERVAL_BYTES # 5MB progress intervals
190+
csv_tmpfile = None
180191

181192
try:
182-
with httpx.stream("GET", url, headers=headers, timeout=580.0) as response:
183-
response.raise_for_status()
184-
total_bytes = int(response.headers.get("content-length", 0))
185-
last_progress_at = 0
193+
# Download directly to temp file — never hold CSV bytes in memory
194+
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as f:
195+
csv_tmpfile = f.name
196+
raw_line_count = 0
197+
bytes_downloaded = 0
198+
199+
try:
200+
with httpx.stream("GET", url, headers=headers, timeout=580.0) as response:
201+
response.raise_for_status()
202+
total_bytes = int(response.headers.get("content-length", 0))
203+
last_progress_at = 0
204+
205+
for chunk in response.iter_bytes(chunk_size=65536):
206+
f.write(chunk)
207+
raw_line_count += chunk.count(b"\n")
208+
bytes_downloaded = response.num_bytes_downloaded
209+
210+
if bytes_downloaded - last_progress_at >= progress_interval:
211+
yield ("progress", bytes_downloaded, total_bytes)
212+
last_progress_at = bytes_downloaded
213+
214+
# Always yield final progress to ensure UI shows 100%
215+
if bytes_downloaded > last_progress_at:
216+
yield ("progress", bytes_downloaded, total_bytes)
217+
218+
except httpx.TimeoutException as e:
219+
logger.error(f"[SQL] Timeout downloading for opp {opportunity_id}: {e}")
220+
sentry_sdk.capture_exception(e)
221+
raise RuntimeError("Connect API timeout") from e
222+
223+
csv_size = os.path.getsize(csv_tmpfile)
224+
logger.info(
225+
f"[SQL] Download complete: {csv_size} bytes on disk, "
226+
f"{raw_line_count} raw lines (expect ~{expected_visit_count}+1 if complete)"
227+
)
186228

187-
for chunk in response.iter_bytes(chunk_size=65536):
188-
chunks.append(chunk)
189-
# Use num_bytes_downloaded to track actual network traffic (compressed bytes)
190-
bytes_downloaded = response.num_bytes_downloaded
229+
# Yield status before slow CSV parsing so frontend can show progress
230+
yield ("parsing", csv_size, raw_line_count)
191231

192-
# Yield progress every 5MB for real-time UI updates
193-
if bytes_downloaded - last_progress_at >= progress_interval:
194-
yield ("progress", bytes_downloaded, total_bytes)
195-
last_progress_at = bytes_downloaded
232+
# Parse and store in streaming batches (memory-efficient)
233+
_, slim_dicts = self._parse_and_store_streaming(
234+
csv_tmpfile, opportunity_id, raw_line_count
235+
)
196236

197-
# Always yield final progress to ensure UI shows 100%
198-
if bytes_downloaded > last_progress_at:
199-
yield ("progress", bytes_downloaded, total_bytes)
237+
yield ("complete", slim_dicts)
200238

201-
except httpx.TimeoutException as e:
202-
logger.error(f"[SQL] Timeout downloading for opp {opportunity_id}: {e}")
203-
sentry_sdk.capture_exception(e)
204-
raise RuntimeError("Connect API timeout") from e
239+
finally:
240+
if csv_tmpfile and os.path.exists(csv_tmpfile):
241+
os.unlink(csv_tmpfile)
205242

206-
csv_bytes = b"".join(chunks)
243+
def _parse_and_store_streaming(
244+
self, csv_path: str, opportunity_id: int, raw_line_count: int
245+
) -> tuple[int, list[dict]]:
246+
"""
247+
Parse CSV from file and store to DB in streaming batches.
207248
208-
# Count raw CSV lines to diagnose truncation vs parsing issues
209-
raw_line_count = csv_bytes.count(b"\n")
210-
logger.info(
211-
f"[SQL] Download complete: {len(csv_bytes)} bytes, "
212-
f"{raw_line_count} raw lines (expect ~{expected_visit_count}+1 if complete)"
213-
)
249+
For each chunk of 1000 rows:
250+
1. Parse rows from CSV (with form_json) — ~15 MB per chunk
251+
2. Store to RawVisitCache via bulk_create
252+
3. Strip form_json from dicts — frees ~15 MB
253+
4. Append slim dicts to result list — ~200 bytes per dict
214254
215-
# Yield status before slow CSV parsing so frontend can show progress
216-
yield ("parsing", len(csv_bytes), raw_line_count)
255+
Returns:
256+
(visit_count, slim_dicts) where slim_dicts have form_json={}
257+
Peak memory: ~50 MB instead of ~2 GB
258+
"""
259+
cache_manager = SQLCacheManager(opportunity_id, config=None)
260+
estimated_count = max(0, raw_line_count - 1)
217261

218-
visit_dicts = parse_csv_bytes(csv_bytes, opportunity_id, skip_form_json=False)
219-
if len(visit_dicts) != raw_line_count - 1: # -1 for header
220-
logger.warning(
221-
f"[SQL] CSV parsing dropped rows: {raw_line_count - 1} raw data lines "
222-
f"but only {len(visit_dicts)} parsed. Delta: {raw_line_count - 1 - len(visit_dicts)} lost"
223-
)
262+
# Clear existing cache and prepare for batched inserts
263+
cache_manager.store_raw_visits_start(estimated_count)
264+
265+
slim_dicts = []
266+
actual_count = 0
267+
268+
for batch in parse_csv_file_chunks(csv_path, opportunity_id, chunksize=1000):
269+
# Store full dicts (with form_json) to DB
270+
cache_manager.store_raw_visits_batch(batch)
271+
actual_count += len(batch)
224272

225-
# NOTE: Don't store to SQL here - let process_and_cache handle it.
226-
# Storage happens in process_and_cache after pipeline yields "Processing X visits..."
273+
# Strip form_json to save memory, keep slim versions for pipeline
274+
for v in batch:
275+
v["form_json"] = {}
276+
slim_dicts.extend(batch)
277+
278+
# Atomically make rows visible with accurate count
279+
cache_manager.store_raw_visits_finalize(actual_count)
227280

228-
yield ("complete", visit_dicts)
281+
logger.info(f"[SQL] Streamed {actual_count} visits to DB, keeping {len(slim_dicts)} slim dicts")
282+
return actual_count, slim_dicts
229283

230284
def has_valid_raw_cache(self, opportunity_id: int, expected_visit_count: int, tolerance_pct: int = 100) -> bool:
231285
"""Check if valid raw cache exists in SQL."""
@@ -250,10 +304,7 @@ def _load_from_cache(
250304

251305
visits = []
252306
for row in qs.iterator():
253-
visit = _model_to_visit_dict(row)
254-
if skip_form_json:
255-
visit["form_json"] = {}
256-
visits.append(visit)
307+
visits.append(_model_to_visit_dict(row, skip_form_json=skip_form_json))
257308

258309
logger.info(f"[SQL] Loaded {len(visits)} visits from RawVisitCache")
259310
return visits
@@ -395,26 +446,35 @@ def process_and_cache(
395446
config: AnalysisPipelineConfig,
396447
opportunity_id: int,
397448
visit_dicts: list[dict],
449+
skip_raw_store: bool = False,
398450
) -> FLWAnalysisResult | VisitAnalysisResult:
399451
"""
400452
Process visits using SQL and cache results.
401453
402454
For VISIT_LEVEL:
403-
1. Store raw visits in SQL
455+
1. Store raw visits in SQL (unless skip_raw_store=True)
404456
2. Execute visit extraction query (no aggregation)
405457
3. Cache computed visits and return VisitAnalysisResult
406458
407459
For AGGREGATED:
408-
1. Store raw visits in SQL
460+
1. Store raw visits in SQL (unless skip_raw_store=True)
409461
2. Execute FLW aggregation query
410462
3. Cache and return FLWAnalysisResult
463+
464+
Args:
465+
skip_raw_store: If True, skip storing raw visits (already stored
466+
during streaming parse or already in cache from a cache hit).
467+
visit_dicts are only used for len() when this is True.
411468
"""
412469
cache_manager = SQLCacheManager(opportunity_id, config)
413470
visit_count = len(visit_dicts)
414471

415-
# Step 1: Store raw visits to SQL (idempotent - replaces existing)
416-
logger.info(f"[SQL] Storing {visit_count} raw visits to SQL")
417-
cache_manager.store_raw_visits(visit_dicts, visit_count)
472+
# Step 1: Store raw visits to SQL (skip if already stored during streaming)
473+
if not skip_raw_store:
474+
logger.info(f"[SQL] Storing {visit_count} raw visits to SQL")
475+
cache_manager.store_raw_visits(visit_dicts, visit_count)
476+
else:
477+
logger.info(f"[SQL] Skipping raw store ({visit_count} visits already in DB)")
418478

419479
# Branch based on terminal stage
420480
if config.terminal_stage == CacheStage.VISIT_LEVEL:

0 commit comments

Comments
 (0)