Skip to content

Commit c1d956d

Browse files
authored
Merge pull request #2063 from oracle-devrel/matsliwins-patch-13
Citations update
2 parents 1613259 + 8040e6d commit c1d956d

File tree

1 file changed

+194
-0
lines changed
  • ai/gen-ai-agents/agentsOCI-OpenAI-gateway/api/routers

1 file changed

+194
-0
lines changed

ai/gen-ai-agents/agentsOCI-OpenAI-gateway/api/routers/chat.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import yaml
88
import logging
9+
from urllib.parse import urlparse
910
from typing import Annotated, Any, Dict, List, Optional, Tuple, Union
1011

1112
import oci
@@ -133,6 +134,185 @@ def _extract_user_text(messages: List[Dict[str, Any]] | List[Any]) -> str:
133134
)
134135
return ""
135136

137+
def _normalize_source_location(source_location: Any) -> dict:
138+
"""
139+
Returns a dict with display_name and url (when present).
140+
Handles:
141+
- OCI SDK objects with .url
142+
- dict-like with 'url'
143+
- JSON-stringified dicts
144+
- raw URLs
145+
- plain strings / paths
146+
"""
147+
display_name = None
148+
url_value = None
149+
150+
try:
151+
# 1) SDK object with attribute 'url'
152+
if hasattr(source_location, "url"):
153+
url_value = getattr(source_location, "url") or None
154+
155+
# 2) dict-like
156+
if url_value is None:
157+
if isinstance(source_location, dict):
158+
url_value = source_location.get("url")
159+
else:
160+
# 3) JSON-like string? try parse
161+
if isinstance(source_location, str) and source_location.strip().startswith("{"):
162+
try:
163+
parsed = json.loads(source_location)
164+
if isinstance(parsed, dict):
165+
url_value = parsed.get("url")
166+
source_location = parsed
167+
except Exception:
168+
pass
169+
170+
# 4) If it's a URL string
171+
if url_value is None and isinstance(source_location, str):
172+
if source_location.startswith("http://") or source_location.startswith("https://"):
173+
url_value = source_location
174+
175+
# Decide display_name
176+
candidate_for_name = url_value or (source_location if isinstance(source_location, str) else None)
177+
if candidate_for_name:
178+
if isinstance(candidate_for_name, str) and (
179+
candidate_for_name.startswith("http://") or candidate_for_name.startswith("https://")
180+
):
181+
path = urlparse(candidate_for_name).path or ""
182+
base = os.path.basename(path) or path.strip("/")
183+
display_name = base or candidate_for_name
184+
else:
185+
display_name = os.path.basename(candidate_for_name) or str(candidate_for_name)
186+
else:
187+
display_name = None
188+
189+
except Exception as e:
190+
logging.getLogger(__name__).warning(f"Failed to normalize source_location: {e}")
191+
display_name = None
192+
url_value = None
193+
194+
return {"display_name": display_name, "url": url_value}
195+
196+
def _extract_citations_from_response(result, agent_name: str = "OCI Agent") -> Optional[Dict[str, Any]]:
197+
try:
198+
if not result or not hasattr(result, 'message') or not result.message:
199+
return None
200+
201+
message = result.message
202+
if not hasattr(message, 'content') or not message.content:
203+
return None
204+
205+
content = message.content
206+
if not hasattr(content, 'paragraph_citations') or not content.paragraph_citations:
207+
return None
208+
209+
paragraph_citations = []
210+
for para_citation in content.paragraph_citations:
211+
if hasattr(para_citation, 'paragraph') and hasattr(para_citation, 'citations'):
212+
paragraph = para_citation.paragraph
213+
citations = para_citation.citations
214+
215+
citation_list = []
216+
for citation in citations:
217+
normalized_loc = _normalize_source_location(getattr(citation, 'source_location', None))
218+
citation_dict = {
219+
"source_text": getattr(citation, 'source_text', None),
220+
"title": getattr(citation, 'title', None),
221+
"doc_id": getattr(citation, 'doc_id', None),
222+
"page_numbers": getattr(citation, 'page_numbers', None),
223+
"metadata": getattr(citation, 'metadata', None),
224+
"location_display": normalized_loc.get("display_name"),
225+
"location_url": normalized_loc.get("url"),
226+
}
227+
citation_list.append(citation_dict)
228+
229+
paragraph_dict = {
230+
"paragraph": {
231+
"text": getattr(paragraph, 'text', '') or '',
232+
"start": getattr(paragraph, 'start', 0),
233+
"end": getattr(paragraph, 'end', 0)
234+
},
235+
"citations": citation_list
236+
}
237+
paragraph_citations.append(paragraph_dict)
238+
239+
if paragraph_citations:
240+
return {"paragraph_citations": paragraph_citations, "agent_name": agent_name}
241+
242+
return None
243+
except Exception as e:
244+
logging.getLogger(__name__).warning(f"Failed to extract citations: {e}")
245+
return None
246+
247+
def _format_citations_for_display(citations: Dict[str, Any], agent_name: str = "OCI Agent") -> str:
248+
"""
249+
Renders like:
250+
251+
--- Citations from [Agent Name] ---
252+
253+
1. Text: "..."
254+
Sources:
255+
1. Title: ...
256+
Location: document.pdf
257+
Document ID: ...
258+
Pages: [1, 2]
259+
Source: ...
260+
Metadata: {...}
261+
262+
--- End Citations ---
263+
"""
264+
if not citations or "paragraph_citations" not in citations:
265+
return ""
266+
267+
agent = citations.get("agent_name") or agent_name
268+
blocks = []
269+
blocks.append(f"\n\n--- Citations from [{agent}] ---\n")
270+
271+
for idx, para_citation in enumerate(citations["paragraph_citations"], start=1):
272+
p = para_citation.get("paragraph", {}) or {}
273+
text = (p.get("text") or "").strip()
274+
275+
line = []
276+
# Ensure quoted text; json.dumps gives safe quoting and escapes
277+
#line.append(f"{idx}. Text: {json.dumps(text) if text else '\"\"'}")
278+
line.append(" Sources:")
279+
280+
for jdx, c in enumerate(para_citation.get("citations", []) or [], start=1):
281+
title = c.get("title")
282+
loc_display = c.get("location_display")
283+
doc_id = c.get("doc_id")
284+
pages = c.get("page_numbers")
285+
source_text = c.get("source_text")
286+
metadata = c.get("metadata")
287+
288+
line.append(f" {jdx}. " + (f"Title: {title}" if title else "Title: (unknown)"))
289+
if loc_display:
290+
line.append(f" Location: {loc_display}")
291+
if doc_id:
292+
line.append(f" Document ID: {doc_id}")
293+
if pages:
294+
try:
295+
pages_str = json.dumps(pages, ensure_ascii=False)
296+
except Exception:
297+
pages_str = str(pages)
298+
line.append(f" Pages: {pages_str}")
299+
if source_text:
300+
st = (source_text or "").strip()
301+
if len(st) > 500:
302+
st = st[:500].rstrip() + "…"
303+
#line.append(f" Source: {st}")
304+
if metadata:
305+
try:
306+
md_str = json.dumps(metadata, ensure_ascii=False)
307+
line.append(f" Metadata: {md_str}")
308+
except Exception:
309+
pass
310+
311+
blocks.append("\n".join(line) + "\n")
312+
313+
blocks.append("--- End Citations ---")
314+
return "\n".join(blocks)
315+
136316
def _resolve_endpoint_ocid(region: str, endpoint_ocid: Optional[str], agent_ocid: Optional[str], compartment_ocid: Optional[str]) -> str:
137317
if endpoint_ocid:
138318
return endpoint_ocid
@@ -227,6 +407,13 @@ async def chat_completions(
227407
text = ""
228408
if getattr(result, "message", None) and getattr(result.message, "content", None):
229409
text = getattr(result.message.content, "text", "") or ""
410+
411+
agent_name = agent_cfg.get("name", "OCI Agent")
412+
citations = _extract_citations_from_response(result, agent_name)
413+
414+
if citations:
415+
citation_text = _format_citations_for_display(citations, agent_name)
416+
text += citation_text
230417
except oci.exceptions.ServiceError as se:
231418
raise HTTPException(status_code=502, detail=f"Agent chat failed ({se.status}): {getattr(se,'message',str(se))}")
232419

@@ -257,6 +444,13 @@ async def chat_completions(
257444
result = runtime.chat(agent_endpoint_id=endpoint_id, chat_details=chat_details).data
258445
text = getattr(getattr(result, "message", None), "content", None)
259446
text = getattr(text, "text", "") if text else ""
447+
448+
citations = _extract_citations_from_response(result, "OCI Agent")
449+
450+
if citations:
451+
citation_text = _format_citations_for_display(citations, "OCI Agent")
452+
text += citation_text
453+
260454
tag = f"oci:agentendpoint:{endpoint_id}"
261455
if getattr(chat_request, "stream", False):
262456
return StreamingResponse(_stream_one_chunk(text, tag), media_type="text/event-stream", headers={"x-oci-session-id": session_id})

0 commit comments

Comments
 (0)