Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions ai/gen-ai-agents/agentsOCI-OpenAI-gateway/api/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import yaml
import logging
from urllib.parse import urlparse
from typing import Annotated, Any, Dict, List, Optional, Tuple, Union

import oci
Expand Down Expand Up @@ -133,6 +134,185 @@ def _extract_user_text(messages: List[Dict[str, Any]] | List[Any]) -> str:
)
return ""

def _normalize_source_location(source_location: Any) -> dict:
"""
Returns a dict with display_name and url (when present).
Handles:
- OCI SDK objects with .url
- dict-like with 'url'
- JSON-stringified dicts
- raw URLs
- plain strings / paths
"""
display_name = None
url_value = None

try:
# 1) SDK object with attribute 'url'
if hasattr(source_location, "url"):
url_value = getattr(source_location, "url") or None

# 2) dict-like
if url_value is None:
if isinstance(source_location, dict):
url_value = source_location.get("url")
else:
# 3) JSON-like string? try parse
if isinstance(source_location, str) and source_location.strip().startswith("{"):
try:
parsed = json.loads(source_location)
if isinstance(parsed, dict):
url_value = parsed.get("url")
source_location = parsed
except Exception:
pass

# 4) If it's a URL string
if url_value is None and isinstance(source_location, str):
if source_location.startswith("http://") or source_location.startswith("https://"):
url_value = source_location

# Decide display_name
candidate_for_name = url_value or (source_location if isinstance(source_location, str) else None)
if candidate_for_name:
if isinstance(candidate_for_name, str) and (
candidate_for_name.startswith("http://") or candidate_for_name.startswith("https://")
):
path = urlparse(candidate_for_name).path or ""
base = os.path.basename(path) or path.strip("/")
display_name = base or candidate_for_name
else:
display_name = os.path.basename(candidate_for_name) or str(candidate_for_name)
else:
display_name = None

except Exception as e:
logging.getLogger(__name__).warning(f"Failed to normalize source_location: {e}")
display_name = None
url_value = None

return {"display_name": display_name, "url": url_value}

def _extract_citations_from_response(result, agent_name: str = "OCI Agent") -> Optional[Dict[str, Any]]:
try:
if not result or not hasattr(result, 'message') or not result.message:
return None

message = result.message
if not hasattr(message, 'content') or not message.content:
return None

content = message.content
if not hasattr(content, 'paragraph_citations') or not content.paragraph_citations:
return None

paragraph_citations = []
for para_citation in content.paragraph_citations:
if hasattr(para_citation, 'paragraph') and hasattr(para_citation, 'citations'):
paragraph = para_citation.paragraph
citations = para_citation.citations

citation_list = []
for citation in citations:
normalized_loc = _normalize_source_location(getattr(citation, 'source_location', None))
citation_dict = {
"source_text": getattr(citation, 'source_text', None),
"title": getattr(citation, 'title', None),
"doc_id": getattr(citation, 'doc_id', None),
"page_numbers": getattr(citation, 'page_numbers', None),
"metadata": getattr(citation, 'metadata', None),
"location_display": normalized_loc.get("display_name"),
"location_url": normalized_loc.get("url"),
}
citation_list.append(citation_dict)

paragraph_dict = {
"paragraph": {
"text": getattr(paragraph, 'text', '') or '',
"start": getattr(paragraph, 'start', 0),
"end": getattr(paragraph, 'end', 0)
},
"citations": citation_list
}
paragraph_citations.append(paragraph_dict)

if paragraph_citations:
return {"paragraph_citations": paragraph_citations, "agent_name": agent_name}

return None
except Exception as e:
logging.getLogger(__name__).warning(f"Failed to extract citations: {e}")
return None

def _format_citations_for_display(citations: Dict[str, Any], agent_name: str = "OCI Agent") -> str:
"""
Renders like:
--- Citations from [Agent Name] ---
1. Text: "..."
Sources:
1. Title: ...
Location: document.pdf
Document ID: ...
Pages: [1, 2]
Source: ...
Metadata: {...}
--- End Citations ---
"""
if not citations or "paragraph_citations" not in citations:
return ""

agent = citations.get("agent_name") or agent_name
blocks = []
blocks.append(f"\n\n--- Citations from [{agent}] ---\n")

for idx, para_citation in enumerate(citations["paragraph_citations"], start=1):
p = para_citation.get("paragraph", {}) or {}
text = (p.get("text") or "").strip()

line = []
# Ensure quoted text; json.dumps gives safe quoting and escapes
#line.append(f"{idx}. Text: {json.dumps(text) if text else '\"\"'}")
line.append(" Sources:")

for jdx, c in enumerate(para_citation.get("citations", []) or [], start=1):
title = c.get("title")
loc_display = c.get("location_display")
doc_id = c.get("doc_id")
pages = c.get("page_numbers")
source_text = c.get("source_text")
metadata = c.get("metadata")

line.append(f" {jdx}. " + (f"Title: {title}" if title else "Title: (unknown)"))
if loc_display:
line.append(f" Location: {loc_display}")
if doc_id:
line.append(f" Document ID: {doc_id}")
if pages:
try:
pages_str = json.dumps(pages, ensure_ascii=False)
except Exception:
pages_str = str(pages)
line.append(f" Pages: {pages_str}")
if source_text:
st = (source_text or "").strip()
if len(st) > 500:
st = st[:500].rstrip() + "…"
#line.append(f" Source: {st}")
if metadata:
try:
md_str = json.dumps(metadata, ensure_ascii=False)
line.append(f" Metadata: {md_str}")
except Exception:
pass

blocks.append("\n".join(line) + "\n")

blocks.append("--- End Citations ---")
return "\n".join(blocks)

def _resolve_endpoint_ocid(region: str, endpoint_ocid: Optional[str], agent_ocid: Optional[str], compartment_ocid: Optional[str]) -> str:
if endpoint_ocid:
return endpoint_ocid
Expand Down Expand Up @@ -227,6 +407,13 @@ async def chat_completions(
text = ""
if getattr(result, "message", None) and getattr(result.message, "content", None):
text = getattr(result.message.content, "text", "") or ""

agent_name = agent_cfg.get("name", "OCI Agent")
citations = _extract_citations_from_response(result, agent_name)

if citations:
citation_text = _format_citations_for_display(citations, agent_name)
text += citation_text
except oci.exceptions.ServiceError as se:
raise HTTPException(status_code=502, detail=f"Agent chat failed ({se.status}): {getattr(se,'message',str(se))}")

Expand Down Expand Up @@ -257,6 +444,13 @@ async def chat_completions(
result = runtime.chat(agent_endpoint_id=endpoint_id, chat_details=chat_details).data
text = getattr(getattr(result, "message", None), "content", None)
text = getattr(text, "text", "") if text else ""

citations = _extract_citations_from_response(result, "OCI Agent")

if citations:
citation_text = _format_citations_for_display(citations, "OCI Agent")
text += citation_text

tag = f"oci:agentendpoint:{endpoint_id}"
if getattr(chat_request, "stream", False):
return StreamingResponse(_stream_one_chunk(text, tag), media_type="text/event-stream", headers={"x-oci-session-id": session_id})
Expand Down