Skip to content

Commit 8f37cec

Browse files
authored
Merge pull request #25 from akkaouim/labs-mbw-v3
Continued work on mbw v1 using updated pipeline work Reviewed and tested
2 parents 4fb7998 + 2d73077 commit 8f37cec

File tree

17 files changed

+2164
-217
lines changed

17 files changed

+2164
-217
lines changed

commcare_connect/labs/analysis/backends/sql/backend.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
All analysis is done via SQL queries, not Python/pandas.
66
"""
77

8+
import json
89
import logging
910
from collections.abc import Generator
1011
from datetime import date, datetime
@@ -54,6 +55,29 @@ def _model_to_visit_dict(row) -> dict:
5455
}
5556

5657

58+
def _build_visit_dict(row: dict) -> dict:
59+
"""Build a visit context dict from a raw SQL row for transform/extractor post-processing."""
60+
form_json = row.get("form_json", {})
61+
if isinstance(form_json, str):
62+
try:
63+
form_json = json.loads(form_json) if form_json else {}
64+
except (ValueError, json.JSONDecodeError):
65+
form_json = {}
66+
images = row.get("images", [])
67+
if isinstance(images, str):
68+
try:
69+
images = json.loads(images) if images else []
70+
except (ValueError, json.JSONDecodeError):
71+
images = []
72+
return {
73+
"form_json": form_json,
74+
"images": images,
75+
"username": row.get("username"),
76+
"visit_date": row.get("visit_date"),
77+
"entity_name": row.get("entity_name"),
78+
}
79+
80+
5781
class SQLBackend:
5882
"""
5983
SQL backend for analysis.
@@ -431,6 +455,8 @@ def _process_visit_level(
431455

432456
# Apply post-processing transforms that need full visit context
433457
# (e.g., extract_images_with_question_ids needs both form_json and images)
458+
# Build visit_dict once per row (lazy); transforms must not mutate it.
459+
visit_dict = None
434460
for field in config.fields:
435461
if field.name not in computed_field_names:
436462
continue
@@ -445,30 +471,22 @@ def _process_visit_level(
445471
# If transform takes 'visit_data' param, it needs full context
446472
if "visit_data" in params or len(params) == 0:
447473
try:
448-
import json
449-
450-
# Build full visit dict for transform
451-
# Note: form_json and images come back as JSON strings from SQL
452-
form_json = row.get("form_json", {})
453-
if isinstance(form_json, str):
454-
form_json = json.loads(form_json) if form_json else {}
455-
456-
images = row.get("images", [])
457-
if isinstance(images, str):
458-
images = json.loads(images) if images else []
459-
460-
visit_dict = {
461-
"form_json": form_json,
462-
"images": images,
463-
"username": row.get("username"),
464-
"visit_date": row.get("visit_date"),
465-
"entity_name": row.get("entity_name"),
466-
}
474+
if visit_dict is None:
475+
visit_dict = _build_visit_dict(row)
467476
computed[field.name] = field.transform(visit_dict)
468477
except Exception as e:
469478
logger.warning(f"Transform for {field.name} failed: {e}")
470479
computed[field.name] = None
471480

481+
elif field.extractor and callable(field.extractor):
482+
try:
483+
if visit_dict is None:
484+
visit_dict = _build_visit_dict(row)
485+
computed[field.name] = field.extractor(visit_dict)
486+
except Exception as e:
487+
logger.warning(f"Extractor for {field.name} failed: {e}")
488+
computed[field.name] = None
489+
472490
# Parse visit_date
473491
visit_date_val = row.get("visit_date")
474492
if visit_date_val and isinstance(visit_date_val, date):
@@ -495,7 +513,7 @@ def _process_visit_level(
495513
# Cache computed visits (store base fields as columns to avoid joins later)
496514
computed_cache_data = [
497515
{
498-
"visit_id": int(row.id),
516+
"visit_id": row.id,
499517
"username": row.username,
500518
# Handle both date and datetime objects
501519
"visit_date": row.visit_date.date()

commcare_connect/labs/analysis/backends/sql/cchq_fetcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def normalize_cchq_form_to_visit_dict(form: dict, index: int) -> dict:
4646
"entity_id": "",
4747
"entity_name": "",
4848
"deliver_unit": "",
49-
"deliver_unit_id": "",
49+
"deliver_unit_id": None,
5050
"location": "",
5151
"flagged": False,
5252
"flag_reason": "",
@@ -58,7 +58,7 @@ def normalize_cchq_form_to_visit_dict(form: dict, index: int) -> dict:
5858
"review_created_on": None,
5959
"justification": "",
6060
"date_created": received_on,
61-
"completed_work_id": "",
61+
"completed_work_id": None,
6262
"images": [],
6363
}
6464

commcare_connect/labs/analysis/backends/sql/query_builder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,9 @@ def build_visit_extraction_query(
360360
# Check if any field needs full visit context (form_json, images)
361361
needs_full_context = False
362362
for field in config.fields:
363+
if field.extractor and callable(field.extractor):
364+
needs_full_context = True
365+
break
363366
if field.transform and callable(field.transform):
364367
import inspect
365368

@@ -378,6 +381,12 @@ def build_visit_extraction_query(
378381

379382
# Add computed fields from config (no aggregation, just extraction + transform)
380383
for field in config.fields:
384+
# Handle extractor fields — need post-processing with full visit context
385+
if field.extractor and callable(field.extractor):
386+
select_parts.append(f"NULL as {field.name}")
387+
computed_field_names.append(field.name)
388+
continue
389+
381390
# Skip fields that will be computed from full visit context (special markers like __images__)
382391
if field.transform and callable(field.transform):
383392
import inspect

commcare_connect/labs/analysis/data_access.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21-
def fetch_flw_names(access_token: str, opportunity_id: int, use_cache: bool = True) -> dict[str, str]:
21+
def fetch_flw_names(
22+
access_token: str,
23+
opportunity_id: int,
24+
use_cache: bool = True,
25+
last_active_out: dict | None = None,
26+
) -> dict[str, str]:
2227
"""
2328
Fetch FLW display names for an opportunity from Connect API.
2429
@@ -29,6 +34,9 @@ def fetch_flw_names(access_token: str, opportunity_id: int, use_cache: bool = Tr
2934
access_token: OAuth Bearer token for Connect API
3035
opportunity_id: Opportunity ID to fetch FLW names for
3136
use_cache: Whether to use Django cache (default True)
37+
last_active_out: Optional dict to populate with {username: last_active_str}.
38+
When provided, last_active data is written directly into this dict,
39+
avoiding reliance on Django cache for in-process data sharing.
3240
3341
Returns:
3442
Dictionary mapping username to display name.
@@ -44,8 +52,16 @@ def fetch_flw_names(access_token: str, opportunity_id: int, use_cache: bool = Tr
4452
try:
4553
cached = cache.get(cache_key)
4654
if cached is not None:
47-
logger.debug(f"FLW names loaded from cache for opp {opportunity_id}")
48-
return cached
55+
# If last_active was requested, only use cache if la data is also cached
56+
if last_active_out is not None:
57+
la_cached = cache.get(f"flw_last_active_{opportunity_id}")
58+
if la_cached is not None:
59+
last_active_out.update(la_cached)
60+
else:
61+
cached = None # Force fresh fetch to populate last_active
62+
if cached is not None:
63+
logger.debug(f"FLW names loaded from cache for opp {opportunity_id}")
64+
return cached
4965
except Exception as e:
5066
logger.warning(f"Cache get failed for {cache_key}: {e}")
5167

@@ -69,20 +85,30 @@ def fetch_flw_names(access_token: str, opportunity_id: int, use_cache: bool = Tr
6985

7086
# Parse CSV response
7187
df = pd.read_csv(StringIO(response.text))
72-
logger.info(f"Fetched {len(df)} FLWs from Connect")
88+
logger.info(f"Fetched {len(df)} FLWs from Connect. CSV columns: {list(df.columns)}")
7389

7490
# Build mapping: username -> name (fallback to username if name is empty)
91+
# Also extract last_active for the "Last Active" dashboard column
7592
flw_names = {}
93+
flw_last_active = {}
7694
for _, row in df.iterrows():
7795
username = row.get("username")
7896
name = row.get("name")
7997
if username:
8098
flw_names[username] = name if name else username
99+
last_active = row.get("last_active")
100+
if pd.notna(last_active):
101+
flw_last_active[username] = str(last_active)
81102

82-
# Cache the result
103+
# Populate caller's dict directly (avoids cache dependency)
104+
if last_active_out is not None:
105+
last_active_out.update(flw_last_active)
106+
107+
# Cache the results
83108
if use_cache:
84109
try:
85110
cache.set(cache_key, flw_names, DJANGO_CACHE_TTL)
111+
cache.set(f"flw_last_active_{opportunity_id}", flw_last_active, DJANGO_CACHE_TTL)
86112
logger.debug(f"FLW names cached for opp {opportunity_id}")
87113
except Exception as e:
88114
logger.warning(f"Cache set failed for {cache_key}: {e}")

commcare_connect/labs/analysis/sse_streaming.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import json
99
import logging
10+
import queue
11+
import threading
1012
import time
1113
from collections.abc import Callable, Generator
1214

@@ -69,6 +71,9 @@ def stream_data(self, request) -> Generator[str, None, None]:
6971
yield send_sse_event("Complete!", data={"result": 123})
7072
"""
7173

74+
heartbeat_enabled = True
75+
heartbeat_interval = 20 # seconds between heartbeat comments
76+
7277
def get(self, request):
7378
"""
7479
Handle GET request and return streaming response.
@@ -79,14 +84,82 @@ def get(self, request):
7984
if not request.user.is_authenticated:
8085
return JsonResponse({"error": "Not authenticated"}, status=401)
8186

87+
generator = self.stream_data(request)
88+
if self.heartbeat_enabled:
89+
generator = self._with_heartbeat(generator)
90+
8291
response = StreamingHttpResponse(
83-
self.stream_data(request),
92+
generator,
8493
content_type="text/event-stream",
8594
)
8695
response["Cache-Control"] = "no-cache"
8796
response["X-Accel-Buffering"] = "no" # Disable nginx buffering
8897
return response
8998

99+
def _with_heartbeat(self, generator, interval=None):
100+
"""Wrap a generator with periodic SSE heartbeat comments.
101+
102+
Prevents ALB/browser timeouts during long-running blocking operations
103+
(CSV parsing, data processing) by sending SSE comment lines every
104+
``interval`` seconds when the generator isn't yielding real data.
105+
106+
SSE comment format ``: heartbeat\\n\\n`` keeps the TCP connection
107+
alive but does not trigger EventSource.onmessage on the frontend.
108+
109+
Set ``heartbeat_enabled = False`` on a subclass to disable.
110+
"""
111+
if interval is None:
112+
interval = self.heartbeat_interval
113+
114+
data_queue: queue.Queue = queue.Queue(maxsize=100)
115+
stop_event = threading.Event()
116+
117+
def _producer():
118+
try:
119+
for item in generator:
120+
if stop_event.is_set():
121+
break
122+
while not stop_event.is_set():
123+
try:
124+
data_queue.put(("data", item), timeout=1)
125+
break
126+
except queue.Full:
127+
continue
128+
except Exception as e: # noqa: BLE001
129+
try:
130+
data_queue.put(("error", e), timeout=1)
131+
except queue.Full:
132+
pass
133+
finally:
134+
try:
135+
data_queue.put(("done", None), timeout=1)
136+
except queue.Full:
137+
pass
138+
try:
139+
generator.close()
140+
except (GeneratorExit, RuntimeError):
141+
pass
142+
143+
thread = threading.Thread(target=_producer, daemon=True)
144+
thread.start()
145+
146+
try:
147+
while True:
148+
try:
149+
msg_type, value = data_queue.get(timeout=interval)
150+
if msg_type == "data":
151+
yield value
152+
elif msg_type == "done":
153+
break
154+
elif msg_type == "error":
155+
raise value
156+
except queue.Empty:
157+
# No data for `interval` seconds — send SSE comment to keep alive
158+
yield ": heartbeat\n\n"
159+
finally:
160+
stop_event.set()
161+
thread.join(timeout=2)
162+
90163
def stream_data(self, request) -> Generator[str, None, None]:
91164
"""
92165
Generator that yields SSE events.

commcare_connect/static/js/workflow-runner.tsx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,16 @@ function createActionHandlers(csrfToken: string): ActionHandlers {
545545
}
546546
},
547547

548+
getAISessions: async (taskId: number): Promise<Record<string, unknown>> => {
549+
try {
550+
const response = await fetch(`/tasks/${taskId}/ai/sessions/`);
551+
const data = await response.json();
552+
return data;
553+
} catch (e) {
554+
return { success: false, error: e instanceof Error ? e.message : 'Failed to fetch AI sessions' };
555+
}
556+
},
557+
548558
updateTask: async (taskId: number, data: Record<string, unknown>): Promise<Record<string, unknown>> => {
549559
try {
550560
const response = await fetch(`/tasks/api/${taskId}/update/`, {

0 commit comments

Comments
 (0)