Skip to content

Commit 321c23d

Browse files
authored
Fix chat runtime option matching and add direct service calls (#577)
* Fix chat runtime option matching and add direct service calls - Store inline question options in _current_options (not just CMS-sourced), fixing cascading failures where age/reading answers stored as raw strings instead of full option dicts with typed fields like age_number - Add internal API handler registry for direct service-layer calls, bypassing HTTP auth for anonymous chatbot sessions (e.g. /v1/recommend) - Fix broken import in _find_matching_connection (app.crud.chat → chat_repo) - Resolve school name server-side from school_wriveted_id during session start - Add CEL functions for hue profile aggregation (merge, top_keys) - Expand seed fixtures with book catalog, themes, and flow_file loading * Update docs for internal service calls, option matching, and seed fixtures - Fix repository path reference (app/crud/chat_repo.py → app/repositories/) - Document direct service call architecture for anonymous chatbot sessions - Document choice option matching behavior (full option objects with typed fields) - Add top_keys CEL function to aggregation docs - Fix user scope as read/write (writable by action nodes) - Document flow_file key for external flow JSON in seed fixtures - Document server-side school name resolution and theme loading in /start * Harden seed script: guard missing flow files and fix theme lookup - _load_flow_config returns None with warning instead of raising FileNotFoundError when referenced flow JSON is absent - _ensure_theme uses (name, school_id) composite filter with .first() to avoid MultipleResultsFound and remove unused seed_key branch * Fix variable resolver tests to expect preserved types substitute_object returns raw typed values for single {{var}} refs (int 30, bool True) rather than stringified versions. * Add unit tests for top_keys CEL function Covers ranking, n-parameter limiting, non-numeric value filtering, empty/non-dict inputs, and the full merge-then-rank pipeline used in Huey hue profile aggregation. * Fix limit parsing, internal API URL, and stale option clearing - Safely parse and clamp limit param in recommend handler (1-50) - Point api service WRIVETED_INTERNAL_API at internal:8888, not itself - Always write _current_options (even empty) to clear stale options from previous questions, preventing incorrect option reuse
1 parent 0e888c3 commit 321c23d

17 files changed

+1009
-118
lines changed

app/api/chat.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,38 @@ async def start_conversation(
9797
initial_state=session_data.initial_state,
9898
)
9999

100+
# Resolve school name from school_wriveted_id if not already set
101+
session_state = conversation_session.state or {}
102+
ctx = session_state.get("context", {})
103+
school_wriveted_id = ctx.get("school_wriveted_id")
104+
if school_wriveted_id and not ctx.get("school_name"):
105+
try:
106+
from app.repositories.school_repository import school_repository
107+
108+
school_obj = await school_repository.aget_by_wriveted_id_or_404(
109+
db=session, wriveted_id=school_wriveted_id
110+
)
111+
ctx["school_name"] = school_obj.name
112+
session_state["context"] = ctx
113+
await chat_repo.update_session_state(
114+
session,
115+
session_id=conversation_session.id,
116+
state_updates=session_state,
117+
expected_revision=conversation_session.revision,
118+
)
119+
# Refresh session to pick up updated state
120+
refreshed = await chat_repo.get_session_by_id(
121+
session, conversation_session.id
122+
)
123+
if refreshed:
124+
conversation_session = refreshed
125+
except Exception as e:
126+
logger.warning(
127+
"Could not resolve school name",
128+
school_wriveted_id=school_wriveted_id,
129+
error=str(e),
130+
)
131+
100132
# Get initial node
101133
initial_node = await chat_runtime.get_initial_node(
102134
session, session_data.flow_id, conversation_session

app/repositories/cms_repository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ async def get_random_content(
374374

375375
json_filter = json.dumps({key: value})
376376
conditions.append(
377-
text("info @> :json_filter::jsonb").bindparams(
377+
text("info @> cast(:json_filter as jsonb)").bindparams(
378378
json_filter=json_filter
379379
)
380380
)

app/services/action_processor.py

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
- aggregate: Aggregate values from a list using various operations
1010
"""
1111

12-
from datetime import datetime
12+
import datetime
1313
from typing import Any, Dict
1414

1515
from sqlalchemy.ext.asyncio import AsyncSession
@@ -29,6 +29,21 @@
2929
logger = get_logger()
3030

3131

32+
def _extract_nested(data: Dict[str, Any], key_path: str) -> Any:
33+
"""Extract a value from a nested dict using dot notation."""
34+
keys = key_path.split(".")
35+
value = data
36+
try:
37+
for key in keys:
38+
if isinstance(value, dict):
39+
value = value.get(key)
40+
else:
41+
return None
42+
return value
43+
except (KeyError, TypeError):
44+
return None
45+
46+
3247
class ActionNodeProcessor(NodeProcessor):
3348
"""Processor for ACTION nodes with support for api_call actions."""
3449

@@ -145,6 +160,11 @@ async def _execute_actions_sync(
145160

146161
current_state = session.state or {}
147162

163+
# Inject db session and user ID for internal API calls
164+
context = {**context, "db": db}
165+
if session.user_id:
166+
context["session_user_id"] = str(session.user_id)
167+
148168
for i, action in enumerate(actions):
149169
action_type = action.get("type")
150170
action_id = f"{node_id}_action_{i}"
@@ -194,7 +214,7 @@ async def _execute_actions_sync(
194214

195215
# Update session state if variables were modified
196216
if variables_updated:
197-
current_state.update(variables_updated)
217+
self._deep_merge(current_state, variables_updated)
198218
await chat_repo.update_session_state(
199219
db,
200220
session_id=session.id,
@@ -214,7 +234,7 @@ async def _execute_actions_sync(
214234
"variables_updated": list(variables_updated.keys()),
215235
"success": success,
216236
"errors": errors,
217-
"timestamp": datetime.utcnow().isoformat(),
237+
"timestamp": datetime.datetime.utcnow().isoformat(),
218238
"processed_async": False,
219239
},
220240
)
@@ -233,9 +253,10 @@ async def _handle_set_variable(
233253
value = action.get("value")
234254

235255
if variable and value is not None:
236-
# Substitute variables in value if it's a string
237-
if isinstance(value, str):
238-
value = self.runtime.substitute_variables(value, state)
256+
# Recursively substitute variables in the value.
257+
# substitute_object handles strings, lists, and dicts, and preserves
258+
# typed values when the entire string is a single {{var}} reference.
259+
value = self.runtime.substitute_object(value, state)
239260

240261
self._set_nested_value(updates, variable, value)
241262
logger.debug(f"Set variable {variable} = {value}")
@@ -285,6 +306,51 @@ async def _handle_api_call(
285306
) -> None:
286307
"""Handle api_call action."""
287308
api_config_data = action.get("config", {})
309+
auth_type = api_config_data.get("auth_type", "internal")
310+
311+
# For internal endpoints, try direct service call (bypasses HTTP + auth)
312+
if auth_type == "internal":
313+
from app.services.internal_api_handlers import INTERNAL_HANDLERS
314+
315+
endpoint = api_config_data.get("endpoint", "")
316+
db = context.get("db")
317+
318+
if endpoint in INTERNAL_HANDLERS and db is not None:
319+
resolved_body = self.runtime.substitute_object(
320+
api_config_data.get("body", {}), state
321+
)
322+
resolved_params = self.runtime.substitute_object(
323+
api_config_data.get("query_params", {}), state
324+
)
325+
326+
result_data = await INTERNAL_HANDLERS[endpoint](
327+
db, resolved_body, resolved_params
328+
)
329+
330+
response_mapping = api_config_data.get("response_mapping", {})
331+
for response_path, variable_name in response_mapping.items():
332+
value = _extract_nested(result_data, response_path)
333+
if value is not None:
334+
self._set_nested_value(updates, variable_name, value)
335+
336+
logger.info(
337+
"Internal API call via direct service call",
338+
endpoint=endpoint,
339+
variables_updated=list(response_mapping.values()),
340+
)
341+
return
342+
343+
# For authenticated sessions, generate a short-lived JWT
344+
if auth_type == "internal" and context.get("session_user_id"):
345+
from app.services.security import create_access_token
346+
347+
token = create_access_token(
348+
subject=f"Wriveted:User-Account:{context['session_user_id']}",
349+
expires_delta=datetime.timedelta(minutes=5),
350+
)
351+
api_config_data = {**api_config_data}
352+
api_config_data["auth_type"] = "bearer"
353+
api_config_data["auth_config"] = {"token": token}
288354

289355
# Create API call configuration
290356
api_config = ApiCallConfig(**api_config_data)
@@ -297,8 +363,9 @@ async def _handle_api_call(
297363
result = await api_client.execute_api_call(api_config, state, composite_scopes)
298364

299365
if result.success:
300-
# Update variables with API response
301-
updates.update(result.variables_updated)
366+
# Update variables with API response using nested paths
367+
for var_name, var_value in result.variables_updated.items():
368+
self._set_nested_value(updates, var_name, var_value)
302369
logger.info(
303370
"API call successful",
304371
endpoint=api_config.endpoint,
@@ -307,12 +374,16 @@ async def _handle_api_call(
307374
else:
308375
# Store error information
309376
error_var = api_config_data.get("error_variable", "api_error")
310-
updates[error_var] = {
311-
"error": result.error_message,
312-
"status_code": result.status_code,
313-
"timestamp": datetime.utcnow().isoformat(),
314-
"fallback_used": result.fallback_used,
315-
}
377+
self._set_nested_value(
378+
updates,
379+
error_var,
380+
{
381+
"error": result.error_message,
382+
"status_code": result.status_code,
383+
"timestamp": datetime.datetime.utcnow().isoformat(),
384+
"fallback_used": result.fallback_used,
385+
},
386+
)
316387
logger.error(
317388
"API call failed",
318389
endpoint=api_config.endpoint,
@@ -450,6 +521,14 @@ def _build_cel_expression(
450521
else:
451522
raise ValueError(f"Unknown aggregate operation: {operation}")
452523

524+
def _deep_merge(self, base: Dict[str, Any], updates: Dict[str, Any]) -> None:
525+
"""Recursively merge updates into base dict, preserving existing nested keys."""
526+
for key, value in updates.items():
527+
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
528+
self._deep_merge(base[key], value)
529+
else:
530+
base[key] = value
531+
453532
def _get_nested_value(self, data: Dict[str, Any], key_path: str) -> Any:
454533
"""Get nested value from dictionary using dot notation."""
455534
keys = key_path.split(".")

app/services/api_client.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
with proper authentication, error handling, and response processing.
66
"""
77

8-
import logging
98
from typing import Any, Dict, Optional
109

1110
import httpx
1211
from pydantic import BaseModel
12+
from structlog import get_logger
1313

1414
from app.config import get_settings
1515
from app.services.circuit_breaker import (
@@ -18,7 +18,7 @@
1818
get_circuit_breaker,
1919
)
2020

21-
logger = logging.getLogger(__name__)
21+
logger = get_logger()
2222

2323

2424
class ApiCallConfig(BaseModel):
@@ -68,7 +68,11 @@ class InternalApiClient:
6868

6969
def __init__(self):
7070
self.settings = get_settings()
71-
self.base_url = self.settings.WRIVETED_INTERNAL_API
71+
self.base_url = (
72+
str(self.settings.WRIVETED_INTERNAL_API)
73+
if self.settings.WRIVETED_INTERNAL_API
74+
else None
75+
)
7276
self.session: Optional[httpx.AsyncClient] = None
7377

7478
async def initialize(self) -> None:

app/services/cel_evaluator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,19 @@ def _cel_collect(items: List[Any]) -> List[Any]:
9595
return _cel_flatten(items)
9696

9797

98+
def _cel_top_keys(d: Dict[str, Any], n: int = 5) -> List[str]:
99+
"""Return the top N keys from a dict sorted by value descending.
100+
101+
Useful for converting a hue_profile dict (hue→weight) into a ranked
102+
list of hue keys for the recommendation API.
103+
"""
104+
if not isinstance(d, dict):
105+
return []
106+
numeric_items = [(k, v) for k, v in d.items() if isinstance(v, (int, float))]
107+
numeric_items.sort(key=lambda pair: pair[1], reverse=True)
108+
return [k for k, _ in numeric_items[:n]]
109+
110+
98111
# Registry of custom functions available in CEL expressions
99112
CUSTOM_CEL_FUNCTIONS: Dict[str, Callable] = {
100113
"sum": _cel_sum,
@@ -108,6 +121,7 @@ def _cel_collect(items: List[Any]) -> List[Any]:
108121
"merge_last": _cel_merge_last,
109122
"flatten": _cel_flatten,
110123
"collect": _cel_collect,
124+
"top_keys": _cel_top_keys,
111125
}
112126

113127

@@ -174,6 +188,7 @@ def evaluate_cel_expression(
174188
- min(temp.quiz_results.map(x, x.time))
175189
- merge(selections.map(x, x.preferences))
176190
- flatten(items.map(x, x.tags))
191+
- top_keys(user.hue_profile, 5)
177192
"""
178193
try:
179194
if include_aggregation_functions:

0 commit comments

Comments
 (0)