Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions app/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,38 @@ async def start_conversation(
initial_state=session_data.initial_state,
)

# Resolve school name from school_wriveted_id if not already set
session_state = conversation_session.state or {}
ctx = session_state.get("context", {})
school_wriveted_id = ctx.get("school_wriveted_id")
if school_wriveted_id and not ctx.get("school_name"):
try:
from app.repositories.school_repository import school_repository

school_obj = await school_repository.aget_by_wriveted_id_or_404(
db=session, wriveted_id=school_wriveted_id
)
ctx["school_name"] = school_obj.name
session_state["context"] = ctx
await chat_repo.update_session_state(
session,
session_id=conversation_session.id,
state_updates=session_state,
expected_revision=conversation_session.revision,
)
# Refresh session to pick up updated state
refreshed = await chat_repo.get_session_by_id(
session, conversation_session.id
)
if refreshed:
conversation_session = refreshed
except Exception as e:
logger.warning(
"Could not resolve school name",
school_wriveted_id=school_wriveted_id,
error=str(e),
)

# Get initial node
initial_node = await chat_runtime.get_initial_node(
session, session_data.flow_id, conversation_session
Expand Down
2 changes: 1 addition & 1 deletion app/repositories/cms_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ async def get_random_content(

json_filter = json.dumps({key: value})
conditions.append(
text("info @> :json_filter::jsonb").bindparams(
text("info @> cast(:json_filter as jsonb)").bindparams(
json_filter=json_filter
)
)
Expand Down
107 changes: 93 additions & 14 deletions app/services/action_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
- aggregate: Aggregate values from a list using various operations
"""

from datetime import datetime
import datetime
from typing import Any, Dict

from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -29,6 +29,21 @@
logger = get_logger()


def _extract_nested(data: Dict[str, Any], key_path: str) -> Any:
"""Extract a value from a nested dict using dot notation."""
keys = key_path.split(".")
value = data
try:
for key in keys:
if isinstance(value, dict):
value = value.get(key)
else:
return None
return value
except (KeyError, TypeError):
return None


class ActionNodeProcessor(NodeProcessor):
"""Processor for ACTION nodes with support for api_call actions."""

Expand Down Expand Up @@ -145,6 +160,11 @@ async def _execute_actions_sync(

current_state = session.state or {}

# Inject db session and user ID for internal API calls
context = {**context, "db": db}
if session.user_id:
context["session_user_id"] = str(session.user_id)

for i, action in enumerate(actions):
action_type = action.get("type")
action_id = f"{node_id}_action_{i}"
Expand Down Expand Up @@ -194,7 +214,7 @@ async def _execute_actions_sync(

# Update session state if variables were modified
if variables_updated:
current_state.update(variables_updated)
self._deep_merge(current_state, variables_updated)
await chat_repo.update_session_state(
db,
session_id=session.id,
Expand All @@ -214,7 +234,7 @@ async def _execute_actions_sync(
"variables_updated": list(variables_updated.keys()),
"success": success,
"errors": errors,
"timestamp": datetime.utcnow().isoformat(),
"timestamp": datetime.datetime.utcnow().isoformat(),
"processed_async": False,
},
)
Expand All @@ -233,9 +253,10 @@ async def _handle_set_variable(
value = action.get("value")

if variable and value is not None:
# Substitute variables in value if it's a string
if isinstance(value, str):
value = self.runtime.substitute_variables(value, state)
# Recursively substitute variables in the value.
# substitute_object handles strings, lists, and dicts, and preserves
# typed values when the entire string is a single {{var}} reference.
value = self.runtime.substitute_object(value, state)

self._set_nested_value(updates, variable, value)
logger.debug(f"Set variable {variable} = {value}")
Expand Down Expand Up @@ -285,6 +306,51 @@ async def _handle_api_call(
) -> None:
"""Handle api_call action."""
api_config_data = action.get("config", {})
auth_type = api_config_data.get("auth_type", "internal")

# For internal endpoints, try direct service call (bypasses HTTP + auth)
if auth_type == "internal":
from app.services.internal_api_handlers import INTERNAL_HANDLERS

endpoint = api_config_data.get("endpoint", "")
db = context.get("db")

if endpoint in INTERNAL_HANDLERS and db is not None:
resolved_body = self.runtime.substitute_object(
api_config_data.get("body", {}), state
)
resolved_params = self.runtime.substitute_object(
api_config_data.get("query_params", {}), state
)

result_data = await INTERNAL_HANDLERS[endpoint](
db, resolved_body, resolved_params
)

response_mapping = api_config_data.get("response_mapping", {})
for response_path, variable_name in response_mapping.items():
value = _extract_nested(result_data, response_path)
if value is not None:
self._set_nested_value(updates, variable_name, value)

logger.info(
"Internal API call via direct service call",
endpoint=endpoint,
variables_updated=list(response_mapping.values()),
)
return

# For authenticated sessions, generate a short-lived JWT
if auth_type == "internal" and context.get("session_user_id"):
from app.services.security import create_access_token

token = create_access_token(
subject=f"Wriveted:User-Account:{context['session_user_id']}",
expires_delta=datetime.timedelta(minutes=5),
)
api_config_data = {**api_config_data}
api_config_data["auth_type"] = "bearer"
api_config_data["auth_config"] = {"token": token}

# Create API call configuration
api_config = ApiCallConfig(**api_config_data)
Expand All @@ -297,8 +363,9 @@ async def _handle_api_call(
result = await api_client.execute_api_call(api_config, state, composite_scopes)

if result.success:
# Update variables with API response
updates.update(result.variables_updated)
# Update variables with API response using nested paths
for var_name, var_value in result.variables_updated.items():
self._set_nested_value(updates, var_name, var_value)
logger.info(
"API call successful",
endpoint=api_config.endpoint,
Expand All @@ -307,12 +374,16 @@ async def _handle_api_call(
else:
# Store error information
error_var = api_config_data.get("error_variable", "api_error")
updates[error_var] = {
"error": result.error_message,
"status_code": result.status_code,
"timestamp": datetime.utcnow().isoformat(),
"fallback_used": result.fallback_used,
}
self._set_nested_value(
updates,
error_var,
{
"error": result.error_message,
"status_code": result.status_code,
"timestamp": datetime.datetime.utcnow().isoformat(),
"fallback_used": result.fallback_used,
},
)
logger.error(
"API call failed",
endpoint=api_config.endpoint,
Expand Down Expand Up @@ -450,6 +521,14 @@ def _build_cel_expression(
else:
raise ValueError(f"Unknown aggregate operation: {operation}")

def _deep_merge(self, base: Dict[str, Any], updates: Dict[str, Any]) -> None:
"""Recursively merge updates into base dict, preserving existing nested keys."""
for key, value in updates.items():
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
self._deep_merge(base[key], value)
else:
base[key] = value

def _get_nested_value(self, data: Dict[str, Any], key_path: str) -> Any:
"""Get nested value from dictionary using dot notation."""
keys = key_path.split(".")
Expand Down
10 changes: 7 additions & 3 deletions app/services/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
with proper authentication, error handling, and response processing.
"""

import logging
from typing import Any, Dict, Optional

import httpx
from pydantic import BaseModel
from structlog import get_logger

from app.config import get_settings
from app.services.circuit_breaker import (
Expand All @@ -18,7 +18,7 @@
get_circuit_breaker,
)

logger = logging.getLogger(__name__)
logger = get_logger()


class ApiCallConfig(BaseModel):
Expand Down Expand Up @@ -68,7 +68,11 @@ class InternalApiClient:

def __init__(self):
self.settings = get_settings()
self.base_url = self.settings.WRIVETED_INTERNAL_API
self.base_url = (
str(self.settings.WRIVETED_INTERNAL_API)
if self.settings.WRIVETED_INTERNAL_API
else None
)
Comment on lines +71 to +75
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InternalApiClient now allows base_url to be None, but initialize() still passes self.base_url into httpx.AsyncClient(base_url=...). If WRIVETED_INTERNAL_API is unset, this will fail at runtime with a low-signal error. Consider validating in init/initialize and raising a clear exception (or providing a safe default) when base_url is missing.

Copilot uses AI. Check for mistakes.
self.session: Optional[httpx.AsyncClient] = None

async def initialize(self) -> None:
Expand Down
15 changes: 15 additions & 0 deletions app/services/cel_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@ def _cel_collect(items: List[Any]) -> List[Any]:
return _cel_flatten(items)


def _cel_top_keys(d: Dict[str, Any], n: int = 5) -> List[str]:
"""Return the top N keys from a dict sorted by value descending.
Useful for converting a hue_profile dict (hue→weight) into a ranked
list of hue keys for the recommendation API.
"""
if not isinstance(d, dict):
return []
numeric_items = [(k, v) for k, v in d.items() if isinstance(v, (int, float))]
numeric_items.sort(key=lambda pair: pair[1], reverse=True)
return [k for k, _ in numeric_items[:n]]


# Registry of custom functions available in CEL expressions
CUSTOM_CEL_FUNCTIONS: Dict[str, Callable] = {
"sum": _cel_sum,
Expand All @@ -108,6 +121,7 @@ def _cel_collect(items: List[Any]) -> List[Any]:
"merge_last": _cel_merge_last,
"flatten": _cel_flatten,
"collect": _cel_collect,
"top_keys": _cel_top_keys,
}


Expand Down Expand Up @@ -174,6 +188,7 @@ def evaluate_cel_expression(
- min(temp.quiz_results.map(x, x.time))
- merge(selections.map(x, x.preferences))
- flatten(items.map(x, x.tags))
- top_keys(user.hue_profile, 5)
"""
try:
if include_aggregation_functions:
Expand Down
Loading
Loading