diff --git a/ai/gen-ai-agents/agentsOCI-OpenAI-gateway/api/routers/chat.py b/ai/gen-ai-agents/agentsOCI-OpenAI-gateway/api/routers/chat.py index 546b335a8..3ff44a2ff 100644 --- a/ai/gen-ai-agents/agentsOCI-OpenAI-gateway/api/routers/chat.py +++ b/ai/gen-ai-agents/agentsOCI-OpenAI-gateway/api/routers/chat.py @@ -6,6 +6,7 @@ import json import yaml import logging +from urllib.parse import urlparse from typing import Annotated, Any, Dict, List, Optional, Tuple, Union import oci @@ -133,6 +134,185 @@ def _extract_user_text(messages: List[Dict[str, Any]] | List[Any]) -> str: ) return "" +def _normalize_source_location(source_location: Any) -> dict: + """ + Returns a dict with display_name and url (when present). + Handles: + - OCI SDK objects with .url + - dict-like with 'url' + - JSON-stringified dicts + - raw URLs + - plain strings / paths + """ + display_name = None + url_value = None + + try: + # 1) SDK object with attribute 'url' + if hasattr(source_location, "url"): + url_value = getattr(source_location, "url") or None + + # 2) dict-like + if url_value is None: + if isinstance(source_location, dict): + url_value = source_location.get("url") + else: + # 3) JSON-like string? try parse + if isinstance(source_location, str) and source_location.strip().startswith("{"): + try: + parsed = json.loads(source_location) + if isinstance(parsed, dict): + url_value = parsed.get("url") + source_location = parsed + except Exception: + pass + + # 4) If it's a URL string + if url_value is None and isinstance(source_location, str): + if source_location.startswith("http://") or source_location.startswith("https://"): + url_value = source_location + + # Decide display_name + candidate_for_name = url_value or (source_location if isinstance(source_location, str) else None) + if candidate_for_name: + if isinstance(candidate_for_name, str) and ( + candidate_for_name.startswith("http://") or candidate_for_name.startswith("https://") + ): + path = urlparse(candidate_for_name).path or "" + base = os.path.basename(path) or path.strip("/") + display_name = base or candidate_for_name + else: + display_name = os.path.basename(candidate_for_name) or str(candidate_for_name) + else: + display_name = None + + except Exception as e: + logging.getLogger(__name__).warning(f"Failed to normalize source_location: {e}") + display_name = None + url_value = None + + return {"display_name": display_name, "url": url_value} + +def _extract_citations_from_response(result, agent_name: str = "OCI Agent") -> Optional[Dict[str, Any]]: + try: + if not result or not hasattr(result, 'message') or not result.message: + return None + + message = result.message + if not hasattr(message, 'content') or not message.content: + return None + + content = message.content + if not hasattr(content, 'paragraph_citations') or not content.paragraph_citations: + return None + + paragraph_citations = [] + for para_citation in content.paragraph_citations: + if hasattr(para_citation, 'paragraph') and hasattr(para_citation, 'citations'): + paragraph = para_citation.paragraph + citations = para_citation.citations + + citation_list = [] + for citation in citations: + normalized_loc = _normalize_source_location(getattr(citation, 'source_location', None)) + citation_dict = { + "source_text": getattr(citation, 'source_text', None), + "title": getattr(citation, 'title', None), + "doc_id": getattr(citation, 'doc_id', None), + "page_numbers": getattr(citation, 'page_numbers', None), + "metadata": getattr(citation, 'metadata', None), + "location_display": normalized_loc.get("display_name"), + "location_url": normalized_loc.get("url"), + } + citation_list.append(citation_dict) + + paragraph_dict = { + "paragraph": { + "text": getattr(paragraph, 'text', '') or '', + "start": getattr(paragraph, 'start', 0), + "end": getattr(paragraph, 'end', 0) + }, + "citations": citation_list + } + paragraph_citations.append(paragraph_dict) + + if paragraph_citations: + return {"paragraph_citations": paragraph_citations, "agent_name": agent_name} + + return None + except Exception as e: + logging.getLogger(__name__).warning(f"Failed to extract citations: {e}") + return None + +def _format_citations_for_display(citations: Dict[str, Any], agent_name: str = "OCI Agent") -> str: + """ + Renders like: + + --- Citations from [Agent Name] --- + + 1. Text: "..." + Sources: + 1. Title: ... + Location: document.pdf + Document ID: ... + Pages: [1, 2] + Source: ... + Metadata: {...} + + --- End Citations --- + """ + if not citations or "paragraph_citations" not in citations: + return "" + + agent = citations.get("agent_name") or agent_name + blocks = [] + blocks.append(f"\n\n--- Citations from [{agent}] ---\n") + + for idx, para_citation in enumerate(citations["paragraph_citations"], start=1): + p = para_citation.get("paragraph", {}) or {} + text = (p.get("text") or "").strip() + + line = [] + # Ensure quoted text; json.dumps gives safe quoting and escapes + #line.append(f"{idx}. Text: {json.dumps(text) if text else '\"\"'}") + line.append(" Sources:") + + for jdx, c in enumerate(para_citation.get("citations", []) or [], start=1): + title = c.get("title") + loc_display = c.get("location_display") + doc_id = c.get("doc_id") + pages = c.get("page_numbers") + source_text = c.get("source_text") + metadata = c.get("metadata") + + line.append(f" {jdx}. " + (f"Title: {title}" if title else "Title: (unknown)")) + if loc_display: + line.append(f" Location: {loc_display}") + if doc_id: + line.append(f" Document ID: {doc_id}") + if pages: + try: + pages_str = json.dumps(pages, ensure_ascii=False) + except Exception: + pages_str = str(pages) + line.append(f" Pages: {pages_str}") + if source_text: + st = (source_text or "").strip() + if len(st) > 500: + st = st[:500].rstrip() + "…" + #line.append(f" Source: {st}") + if metadata: + try: + md_str = json.dumps(metadata, ensure_ascii=False) + line.append(f" Metadata: {md_str}") + except Exception: + pass + + blocks.append("\n".join(line) + "\n") + + blocks.append("--- End Citations ---") + return "\n".join(blocks) + def _resolve_endpoint_ocid(region: str, endpoint_ocid: Optional[str], agent_ocid: Optional[str], compartment_ocid: Optional[str]) -> str: if endpoint_ocid: return endpoint_ocid @@ -227,6 +407,13 @@ async def chat_completions( text = "" if getattr(result, "message", None) and getattr(result.message, "content", None): text = getattr(result.message.content, "text", "") or "" + + agent_name = agent_cfg.get("name", "OCI Agent") + citations = _extract_citations_from_response(result, agent_name) + + if citations: + citation_text = _format_citations_for_display(citations, agent_name) + text += citation_text except oci.exceptions.ServiceError as se: raise HTTPException(status_code=502, detail=f"Agent chat failed ({se.status}): {getattr(se,'message',str(se))}") @@ -257,6 +444,13 @@ async def chat_completions( result = runtime.chat(agent_endpoint_id=endpoint_id, chat_details=chat_details).data text = getattr(getattr(result, "message", None), "content", None) text = getattr(text, "text", "") if text else "" + + citations = _extract_citations_from_response(result, "OCI Agent") + + if citations: + citation_text = _format_citations_for_display(citations, "OCI Agent") + text += citation_text + tag = f"oci:agentendpoint:{endpoint_id}" if getattr(chat_request, "stream", False): return StreamingResponse(_stream_one_chunk(text, tag), media_type="text/event-stream", headers={"x-oci-session-id": session_id})