|
6 | 6 | import json |
7 | 7 | import yaml |
8 | 8 | import logging |
| 9 | +from urllib.parse import urlparse |
9 | 10 | from typing import Annotated, Any, Dict, List, Optional, Tuple, Union |
10 | 11 |
|
11 | 12 | import oci |
@@ -133,6 +134,185 @@ def _extract_user_text(messages: List[Dict[str, Any]] | List[Any]) -> str: |
133 | 134 | ) |
134 | 135 | return "" |
135 | 136 |
|
| 137 | +def _normalize_source_location(source_location: Any) -> dict: |
| 138 | + """ |
| 139 | + Returns a dict with display_name and url (when present). |
| 140 | + Handles: |
| 141 | + - OCI SDK objects with .url |
| 142 | + - dict-like with 'url' |
| 143 | + - JSON-stringified dicts |
| 144 | + - raw URLs |
| 145 | + - plain strings / paths |
| 146 | + """ |
| 147 | + display_name = None |
| 148 | + url_value = None |
| 149 | + |
| 150 | + try: |
| 151 | + # 1) SDK object with attribute 'url' |
| 152 | + if hasattr(source_location, "url"): |
| 153 | + url_value = getattr(source_location, "url") or None |
| 154 | + |
| 155 | + # 2) dict-like |
| 156 | + if url_value is None: |
| 157 | + if isinstance(source_location, dict): |
| 158 | + url_value = source_location.get("url") |
| 159 | + else: |
| 160 | + # 3) JSON-like string? try parse |
| 161 | + if isinstance(source_location, str) and source_location.strip().startswith("{"): |
| 162 | + try: |
| 163 | + parsed = json.loads(source_location) |
| 164 | + if isinstance(parsed, dict): |
| 165 | + url_value = parsed.get("url") |
| 166 | + source_location = parsed |
| 167 | + except Exception: |
| 168 | + pass |
| 169 | + |
| 170 | + # 4) If it's a URL string |
| 171 | + if url_value is None and isinstance(source_location, str): |
| 172 | + if source_location.startswith("http://") or source_location.startswith("https://"): |
| 173 | + url_value = source_location |
| 174 | + |
| 175 | + # Decide display_name |
| 176 | + candidate_for_name = url_value or (source_location if isinstance(source_location, str) else None) |
| 177 | + if candidate_for_name: |
| 178 | + if isinstance(candidate_for_name, str) and ( |
| 179 | + candidate_for_name.startswith("http://") or candidate_for_name.startswith("https://") |
| 180 | + ): |
| 181 | + path = urlparse(candidate_for_name).path or "" |
| 182 | + base = os.path.basename(path) or path.strip("/") |
| 183 | + display_name = base or candidate_for_name |
| 184 | + else: |
| 185 | + display_name = os.path.basename(candidate_for_name) or str(candidate_for_name) |
| 186 | + else: |
| 187 | + display_name = None |
| 188 | + |
| 189 | + except Exception as e: |
| 190 | + logging.getLogger(__name__).warning(f"Failed to normalize source_location: {e}") |
| 191 | + display_name = None |
| 192 | + url_value = None |
| 193 | + |
| 194 | + return {"display_name": display_name, "url": url_value} |
| 195 | + |
| 196 | +def _extract_citations_from_response(result, agent_name: str = "OCI Agent") -> Optional[Dict[str, Any]]: |
| 197 | + try: |
| 198 | + if not result or not hasattr(result, 'message') or not result.message: |
| 199 | + return None |
| 200 | + |
| 201 | + message = result.message |
| 202 | + if not hasattr(message, 'content') or not message.content: |
| 203 | + return None |
| 204 | + |
| 205 | + content = message.content |
| 206 | + if not hasattr(content, 'paragraph_citations') or not content.paragraph_citations: |
| 207 | + return None |
| 208 | + |
| 209 | + paragraph_citations = [] |
| 210 | + for para_citation in content.paragraph_citations: |
| 211 | + if hasattr(para_citation, 'paragraph') and hasattr(para_citation, 'citations'): |
| 212 | + paragraph = para_citation.paragraph |
| 213 | + citations = para_citation.citations |
| 214 | + |
| 215 | + citation_list = [] |
| 216 | + for citation in citations: |
| 217 | + normalized_loc = _normalize_source_location(getattr(citation, 'source_location', None)) |
| 218 | + citation_dict = { |
| 219 | + "source_text": getattr(citation, 'source_text', None), |
| 220 | + "title": getattr(citation, 'title', None), |
| 221 | + "doc_id": getattr(citation, 'doc_id', None), |
| 222 | + "page_numbers": getattr(citation, 'page_numbers', None), |
| 223 | + "metadata": getattr(citation, 'metadata', None), |
| 224 | + "location_display": normalized_loc.get("display_name"), |
| 225 | + "location_url": normalized_loc.get("url"), |
| 226 | + } |
| 227 | + citation_list.append(citation_dict) |
| 228 | + |
| 229 | + paragraph_dict = { |
| 230 | + "paragraph": { |
| 231 | + "text": getattr(paragraph, 'text', '') or '', |
| 232 | + "start": getattr(paragraph, 'start', 0), |
| 233 | + "end": getattr(paragraph, 'end', 0) |
| 234 | + }, |
| 235 | + "citations": citation_list |
| 236 | + } |
| 237 | + paragraph_citations.append(paragraph_dict) |
| 238 | + |
| 239 | + if paragraph_citations: |
| 240 | + return {"paragraph_citations": paragraph_citations, "agent_name": agent_name} |
| 241 | + |
| 242 | + return None |
| 243 | + except Exception as e: |
| 244 | + logging.getLogger(__name__).warning(f"Failed to extract citations: {e}") |
| 245 | + return None |
| 246 | + |
| 247 | +def _format_citations_for_display(citations: Dict[str, Any], agent_name: str = "OCI Agent") -> str: |
| 248 | + """ |
| 249 | + Renders like: |
| 250 | +
|
| 251 | + --- Citations from [Agent Name] --- |
| 252 | +
|
| 253 | + 1. Text: "..." |
| 254 | + Sources: |
| 255 | + 1. Title: ... |
| 256 | + Location: document.pdf |
| 257 | + Document ID: ... |
| 258 | + Pages: [1, 2] |
| 259 | + Source: ... |
| 260 | + Metadata: {...} |
| 261 | +
|
| 262 | + --- End Citations --- |
| 263 | + """ |
| 264 | + if not citations or "paragraph_citations" not in citations: |
| 265 | + return "" |
| 266 | + |
| 267 | + agent = citations.get("agent_name") or agent_name |
| 268 | + blocks = [] |
| 269 | + blocks.append(f"\n\n--- Citations from [{agent}] ---\n") |
| 270 | + |
| 271 | + for idx, para_citation in enumerate(citations["paragraph_citations"], start=1): |
| 272 | + p = para_citation.get("paragraph", {}) or {} |
| 273 | + text = (p.get("text") or "").strip() |
| 274 | + |
| 275 | + line = [] |
| 276 | + # Ensure quoted text; json.dumps gives safe quoting and escapes |
| 277 | + #line.append(f"{idx}. Text: {json.dumps(text) if text else '\"\"'}") |
| 278 | + line.append(" Sources:") |
| 279 | + |
| 280 | + for jdx, c in enumerate(para_citation.get("citations", []) or [], start=1): |
| 281 | + title = c.get("title") |
| 282 | + loc_display = c.get("location_display") |
| 283 | + doc_id = c.get("doc_id") |
| 284 | + pages = c.get("page_numbers") |
| 285 | + source_text = c.get("source_text") |
| 286 | + metadata = c.get("metadata") |
| 287 | + |
| 288 | + line.append(f" {jdx}. " + (f"Title: {title}" if title else "Title: (unknown)")) |
| 289 | + if loc_display: |
| 290 | + line.append(f" Location: {loc_display}") |
| 291 | + if doc_id: |
| 292 | + line.append(f" Document ID: {doc_id}") |
| 293 | + if pages: |
| 294 | + try: |
| 295 | + pages_str = json.dumps(pages, ensure_ascii=False) |
| 296 | + except Exception: |
| 297 | + pages_str = str(pages) |
| 298 | + line.append(f" Pages: {pages_str}") |
| 299 | + if source_text: |
| 300 | + st = (source_text or "").strip() |
| 301 | + if len(st) > 500: |
| 302 | + st = st[:500].rstrip() + "…" |
| 303 | + #line.append(f" Source: {st}") |
| 304 | + if metadata: |
| 305 | + try: |
| 306 | + md_str = json.dumps(metadata, ensure_ascii=False) |
| 307 | + line.append(f" Metadata: {md_str}") |
| 308 | + except Exception: |
| 309 | + pass |
| 310 | + |
| 311 | + blocks.append("\n".join(line) + "\n") |
| 312 | + |
| 313 | + blocks.append("--- End Citations ---") |
| 314 | + return "\n".join(blocks) |
| 315 | + |
136 | 316 | def _resolve_endpoint_ocid(region: str, endpoint_ocid: Optional[str], agent_ocid: Optional[str], compartment_ocid: Optional[str]) -> str: |
137 | 317 | if endpoint_ocid: |
138 | 318 | return endpoint_ocid |
@@ -227,6 +407,13 @@ async def chat_completions( |
227 | 407 | text = "" |
228 | 408 | if getattr(result, "message", None) and getattr(result.message, "content", None): |
229 | 409 | text = getattr(result.message.content, "text", "") or "" |
| 410 | + |
| 411 | + agent_name = agent_cfg.get("name", "OCI Agent") |
| 412 | + citations = _extract_citations_from_response(result, agent_name) |
| 413 | + |
| 414 | + if citations: |
| 415 | + citation_text = _format_citations_for_display(citations, agent_name) |
| 416 | + text += citation_text |
230 | 417 | except oci.exceptions.ServiceError as se: |
231 | 418 | raise HTTPException(status_code=502, detail=f"Agent chat failed ({se.status}): {getattr(se,'message',str(se))}") |
232 | 419 |
|
@@ -257,6 +444,13 @@ async def chat_completions( |
257 | 444 | result = runtime.chat(agent_endpoint_id=endpoint_id, chat_details=chat_details).data |
258 | 445 | text = getattr(getattr(result, "message", None), "content", None) |
259 | 446 | text = getattr(text, "text", "") if text else "" |
| 447 | + |
| 448 | + citations = _extract_citations_from_response(result, "OCI Agent") |
| 449 | + |
| 450 | + if citations: |
| 451 | + citation_text = _format_citations_for_display(citations, "OCI Agent") |
| 452 | + text += citation_text |
| 453 | + |
260 | 454 | tag = f"oci:agentendpoint:{endpoint_id}" |
261 | 455 | if getattr(chat_request, "stream", False): |
262 | 456 | return StreamingResponse(_stream_one_chunk(text, tag), media_type="text/event-stream", headers={"x-oci-session-id": session_id}) |
|
0 commit comments