Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions commcare_connect/labs/analysis/backends/csv_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import io
import json
import logging
from collections.abc import Generator

import pandas as pd

Expand Down Expand Up @@ -210,3 +211,33 @@ def parse_csv_bytes(
logger.info(f"Parsed {len(visits)} visits (full mode)")

return visits


def parse_csv_file_chunks(
csv_path: str,
opportunity_id: int,
chunksize: int = 1000,
) -> Generator[list[dict], None, None]:
"""
Parse CSV from file path in chunks. Memory-efficient: no BytesIO copy.

Reads directly from a file path using pandas C parser, avoiding the
BytesIO copy that doubles memory usage with parse_csv_bytes().

Args:
csv_path: Path to CSV file on disk
opportunity_id: Opportunity ID (fallback if not in CSV)
chunksize: Number of rows per chunk (default 1000)

Yields:
Lists of visit dicts (with form_json), one list per chunk
"""
total_parsed = 0
for chunk in pd.read_csv(csv_path, chunksize=chunksize, on_bad_lines="warn"):
batch = []
for _, row in chunk.iterrows():
batch.append(_row_to_visit_dict(row, opportunity_id, include_form_json=True))
total_parsed += len(batch)
yield batch

logger.info(f"Parsed {total_parsed} visits from file (chunked)")
164 changes: 115 additions & 49 deletions commcare_connect/labs/analysis/backends/sql/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import json
import logging
import os
import tempfile
from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
Expand All @@ -18,7 +20,7 @@
from django.http import HttpRequest
from django.utils.dateparse import parse_date

from commcare_connect.labs.analysis.backends.csv_parsing import parse_csv_bytes
from commcare_connect.labs.analysis.backends.csv_parsing import parse_csv_bytes, parse_csv_file_chunks
from commcare_connect.labs.analysis.backends.sql.cache import SQLCacheManager
from commcare_connect.labs.analysis.backends.sql.query_builder import execute_flw_aggregation, execute_visit_extraction
from commcare_connect.labs.analysis.config import AnalysisPipelineConfig, CacheStage
Expand Down Expand Up @@ -150,20 +152,30 @@ def stream_raw_visits(
Stream raw visit data with progress events.

SQL backend checks RawVisitCache first. If hit, yields immediately.
Otherwise streams download from API with progress.
Otherwise streams download to temp file, then parses and stores in
memory-efficient batches (1000 rows at a time).

On cache hit, yields slim dicts (no form_json) since the raw data
is already in the database for SQL extraction.

On cache miss, downloads to temp file (0 bytes in Python memory),
parses CSV in chunks, stores each chunk to DB with form_json,
and yields slim dicts (form_json stripped after DB storage).
Peak memory: ~50 MB instead of ~2 GB.
"""
cache_manager = SQLCacheManager(opportunity_id, config=None)

# Check SQL cache first
if not force_refresh and expected_visit_count:
if cache_manager.has_valid_raw_cache(expected_visit_count, tolerance_pct=tolerance_pct):
logger.info(f"[SQL] Raw cache HIT for opp {opportunity_id}")
visit_dicts = self._load_from_cache(cache_manager, skip_form_json=False, filter_visit_ids=None)
# Load slim dicts (no form_json) — SQL extraction reads from DB directly
visit_dicts = self._load_from_cache(cache_manager, skip_form_json=True, filter_visit_ids=None)
yield ("cached", visit_dicts)
return

# Cache miss - stream download from API
logger.info(f"[SQL] Raw cache MISS for opp {opportunity_id}, streaming from API")
# Cache miss - stream download to temp file (0 bytes in Python memory)
logger.info(f"[SQL] Raw cache MISS for opp {opportunity_id}, streaming to temp file")

url = f"{settings.CONNECT_PRODUCTION_URL}/export/opportunity/{opportunity_id}/user_visits/"
headers = {
Expand All @@ -174,58 +186,103 @@ def stream_raw_visits(
# Use shared progress interval from SSE streaming module
from commcare_connect.labs.analysis.sse_streaming import DOWNLOAD_PROGRESS_INTERVAL_BYTES

chunks = []
bytes_downloaded = 0
progress_interval = DOWNLOAD_PROGRESS_INTERVAL_BYTES # 5MB progress intervals
csv_tmpfile = None

try:
with httpx.stream("GET", url, headers=headers, timeout=580.0) as response:
response.raise_for_status()
total_bytes = int(response.headers.get("content-length", 0))
last_progress_at = 0
# Download directly to temp file — never hold CSV bytes in memory
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as f:
csv_tmpfile = f.name
raw_line_count = 0
bytes_downloaded = 0

try:
with httpx.stream("GET", url, headers=headers, timeout=580.0) as response:
response.raise_for_status()
total_bytes = int(response.headers.get("content-length", 0))
last_progress_at = 0

for chunk in response.iter_bytes(chunk_size=65536):
f.write(chunk)
raw_line_count += chunk.count(b"\n")
bytes_downloaded = response.num_bytes_downloaded

if bytes_downloaded - last_progress_at >= progress_interval:
yield ("progress", bytes_downloaded, total_bytes)
last_progress_at = bytes_downloaded

# Always yield final progress to ensure UI shows 100%
if bytes_downloaded > last_progress_at:
yield ("progress", bytes_downloaded, total_bytes)

except httpx.TimeoutException as e:
logger.error(f"[SQL] Timeout downloading for opp {opportunity_id}: {e}")
sentry_sdk.capture_exception(e)
raise RuntimeError("Connect API timeout") from e

csv_size = os.path.getsize(csv_tmpfile)
logger.info(
f"[SQL] Download complete: {csv_size} bytes on disk, "
f"{raw_line_count} raw lines (expect ~{expected_visit_count}+1 if complete)"
)

for chunk in response.iter_bytes(chunk_size=65536):
chunks.append(chunk)
# Use num_bytes_downloaded to track actual network traffic (compressed bytes)
bytes_downloaded = response.num_bytes_downloaded
# Yield status before slow CSV parsing so frontend can show progress
yield ("parsing", csv_size, raw_line_count)

# Yield progress every 5MB for real-time UI updates
if bytes_downloaded - last_progress_at >= progress_interval:
yield ("progress", bytes_downloaded, total_bytes)
last_progress_at = bytes_downloaded
# Parse and store in streaming batches (memory-efficient)
visit_count, slim_dicts = self._parse_and_store_streaming(
csv_tmpfile, opportunity_id, raw_line_count
)

# Always yield final progress to ensure UI shows 100%
if bytes_downloaded > last_progress_at:
yield ("progress", bytes_downloaded, total_bytes)
if visit_count != raw_line_count - 1: # -1 for header
logger.warning(
f"[SQL] CSV parsing dropped rows: {raw_line_count - 1} raw data lines "
f"but only {visit_count} parsed. Delta: {raw_line_count - 1 - visit_count} lost"
)

except httpx.TimeoutException as e:
logger.error(f"[SQL] Timeout downloading for opp {opportunity_id}: {e}")
sentry_sdk.capture_exception(e)
raise RuntimeError("Connect API timeout") from e
yield ("complete", slim_dicts)

csv_bytes = b"".join(chunks)
finally:
if csv_tmpfile and os.path.exists(csv_tmpfile):
os.unlink(csv_tmpfile)

# Count raw CSV lines to diagnose truncation vs parsing issues
raw_line_count = csv_bytes.count(b"\n")
logger.info(
f"[SQL] Download complete: {len(csv_bytes)} bytes, "
f"{raw_line_count} raw lines (expect ~{expected_visit_count}+1 if complete)"
)
def _parse_and_store_streaming(
self, csv_path: str, opportunity_id: int, raw_line_count: int
) -> tuple[int, list[dict]]:
"""
Parse CSV from file and store to DB in streaming batches.

# Yield status before slow CSV parsing so frontend can show progress
yield ("parsing", len(csv_bytes), raw_line_count)
For each chunk of 1000 rows:
1. Parse rows from CSV (with form_json) — ~15 MB per chunk
2. Store to RawVisitCache via bulk_create
3. Strip form_json from dicts — frees ~15 MB
4. Append slim dicts to result list — ~200 bytes per dict

visit_dicts = parse_csv_bytes(csv_bytes, opportunity_id, skip_form_json=False)
if len(visit_dicts) != raw_line_count - 1: # -1 for header
logger.warning(
f"[SQL] CSV parsing dropped rows: {raw_line_count - 1} raw data lines "
f"but only {len(visit_dicts)} parsed. Delta: {raw_line_count - 1 - len(visit_dicts)} lost"
)
Returns:
(visit_count, slim_dicts) where slim_dicts have form_json={}
Peak memory: ~50 MB instead of ~2 GB
"""
cache_manager = SQLCacheManager(opportunity_id, config=None)
estimated_count = max(0, raw_line_count - 1)

# Clear existing cache and prepare for batched inserts
cache_manager.store_raw_visits_start(estimated_count)

slim_dicts = []
actual_count = 0

# NOTE: Don't store to SQL here - let process_and_cache handle it.
# Storage happens in process_and_cache after pipeline yields "Processing X visits..."
for batch in parse_csv_file_chunks(csv_path, opportunity_id, chunksize=1000):
# Store full dicts (with form_json) to DB
cache_manager.store_raw_visits_batch(batch)
actual_count += len(batch)

yield ("complete", visit_dicts)
# Strip form_json to save memory, keep slim versions for pipeline
for v in batch:
v["form_json"] = {}
slim_dicts.extend(batch)

logger.info(f"[SQL] Streamed {actual_count} visits to DB, keeping {len(slim_dicts)} slim dicts")
return actual_count, slim_dicts

def has_valid_raw_cache(self, opportunity_id: int, expected_visit_count: int, tolerance_pct: int = 100) -> bool:
"""Check if valid raw cache exists in SQL."""
Expand Down Expand Up @@ -395,26 +452,35 @@ def process_and_cache(
config: AnalysisPipelineConfig,
opportunity_id: int,
visit_dicts: list[dict],
skip_raw_store: bool = False,
) -> FLWAnalysisResult | VisitAnalysisResult:
"""
Process visits using SQL and cache results.

For VISIT_LEVEL:
1. Store raw visits in SQL
1. Store raw visits in SQL (unless skip_raw_store=True)
2. Execute visit extraction query (no aggregation)
3. Cache computed visits and return VisitAnalysisResult

For AGGREGATED:
1. Store raw visits in SQL
1. Store raw visits in SQL (unless skip_raw_store=True)
2. Execute FLW aggregation query
3. Cache and return FLWAnalysisResult

Args:
skip_raw_store: If True, skip storing raw visits (already stored
during streaming parse or already in cache from a cache hit).
visit_dicts are only used for len() when this is True.
"""
cache_manager = SQLCacheManager(opportunity_id, config)
visit_count = len(visit_dicts)

# Step 1: Store raw visits to SQL (idempotent - replaces existing)
logger.info(f"[SQL] Storing {visit_count} raw visits to SQL")
cache_manager.store_raw_visits(visit_dicts, visit_count)
# Step 1: Store raw visits to SQL (skip if already stored during streaming)
if not skip_raw_store:
logger.info(f"[SQL] Storing {visit_count} raw visits to SQL")
cache_manager.store_raw_visits(visit_dicts, visit_count)
else:
logger.info(f"[SQL] Skipping raw store ({visit_count} visits already in DB)")

# Branch based on terminal stage
if config.terminal_stage == CacheStage.VISIT_LEVEL:
Expand Down
52 changes: 52 additions & 0 deletions commcare_connect/labs/analysis/backends/sql/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,58 @@ def store_raw_visits(self, visit_dicts: list[dict], visit_count: int):

logger.info(f"[SQLCache] Stored {len(rows)} raw visits for opp {self.opportunity_id}")

def store_raw_visits_start(self, visit_count: int):
"""
Delete existing raw cache and prepare for batched inserts.

Must be called before store_raw_visits_batch(). Stores metadata
(visit_count, expires_at) for subsequent batch inserts.
"""
self._pending_visit_count = visit_count
self._pending_expires_at = self._get_expires_at()
RawVisitCache.objects.filter(opportunity_id=self.opportunity_id).delete()
logger.info(f"[SQLCache] Cleared raw cache for opp {self.opportunity_id}, preparing for {visit_count} visits")

def store_raw_visits_batch(self, visit_dicts: list[dict]) -> int:
"""
Insert a batch of raw visits. Call store_raw_visits_start() first.

Returns:
Number of rows inserted
"""
rows = []
for v in visit_dicts:
rows.append(
RawVisitCache(
opportunity_id=self.opportunity_id,
visit_count=self._pending_visit_count,
expires_at=self._pending_expires_at,
visit_id=v.get("id", 0),
username=v.get("username") or "",
deliver_unit=v.get("deliver_unit") or "",
deliver_unit_id=v.get("deliver_unit_id"),
entity_id=v.get("entity_id") or "",
entity_name=v.get("entity_name") or "",
visit_date=_parse_date(v.get("visit_date")),
status=v.get("status") or "",
reason=v.get("reason") or "",
location=v.get("location") or "",
flagged=v.get("flagged") or False,
flag_reason=v.get("flag_reason") or {},
form_json=v.get("form_json") or {},
completed_work=v.get("completed_work") or {},
status_modified_date=_parse_datetime(v.get("status_modified_date")),
review_status=v.get("review_status") or "",
review_created_on=_parse_datetime(v.get("review_created_on")),
justification=v.get("justification") or "",
date_created=_parse_datetime(v.get("date_created")),
completed_work_id=v.get("completed_work_id"),
images=v.get("images") or [],
)
)
RawVisitCache.objects.bulk_create(rows, batch_size=1000)
return len(rows)

def get_raw_visits_queryset(self):
"""Get queryset of cached raw visits."""
return RawVisitCache.objects.filter(
Expand Down
Loading