Skip to content

Commit 67fb422

Browse files
authored
Merge pull request #17 from digitalocean/enable-langgraph-retriever-spans
Add network interception for KBaaS and retriever spans
2 parents 883cc68 + 6b52489 commit 67fb422

File tree

4 files changed

+632
-22
lines changed

4 files changed

+632
-22
lines changed

gradient_adk/runtime/langgraph/langgraph_instrumentor.py

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313
from ..interfaces import NodeExecution
1414
from ..digitalocean_tracker import DigitalOceanTracesTracker
15-
from ..network_interceptor import get_network_interceptor
15+
from ..network_interceptor import (
16+
get_network_interceptor,
17+
is_inference_url,
18+
is_kbaas_url,
19+
)
1620

1721

1822
WRAPPED_FLAG = "__do_wrapped__"
@@ -232,17 +236,68 @@ def _had_hits_since(intr, token) -> bool:
232236
return False
233237

234238

235-
def _get_captured_payloads(intr, token) -> tuple:
236-
"""Get captured API request/response payloads if available (e.g., for LLM calls)."""
239+
def _get_captured_payloads_with_type(intr, token) -> tuple:
240+
"""Get captured API request/response payloads and classify the call type.
241+
242+
Returns:
243+
(request_payload, response_payload, is_llm, is_retriever)
244+
"""
237245
try:
238246
captured = intr.get_captured_requests_since(token)
239247
if captured:
240248
# Use the first captured request (most common case)
241249
call = captured[0]
242-
return call.request_payload, call.response_payload
250+
url = call.url
251+
is_llm = is_inference_url(url)
252+
is_retriever = is_kbaas_url(url)
253+
return call.request_payload, call.response_payload, is_llm, is_retriever
243254
except Exception:
244255
pass
245-
return None, None
256+
return None, None, False, False
257+
258+
259+
def _transform_kbaas_response(response: Optional[Dict[str, Any]]) -> Optional[list]:
260+
"""Transform KBaaS response to standard retriever format.
261+
262+
Extracts results and maps content fields to 'page_content'.
263+
264+
For hierarchical KB (parent retrieval):
265+
- Uses 'parent_chunk_text' as 'page_content' (the context users typically want)
266+
- Preserves 'text_content' as 'embedded_content' for reference
267+
268+
For standard KB:
269+
- Uses 'text_content' as 'page_content'
270+
271+
Returns a list of dicts as expected for retriever spans.
272+
"""
273+
if not isinstance(response, dict):
274+
return response
275+
276+
results = response.get("results", [])
277+
if not isinstance(results, list):
278+
return response
279+
280+
transformed_results = []
281+
for item in results:
282+
if isinstance(item, dict):
283+
new_item = dict(item)
284+
285+
# For hierarchical KB: prefer parent_chunk_text as page_content
286+
if "parent_chunk_text" in new_item:
287+
new_item["page_content"] = new_item.pop("parent_chunk_text")
288+
# Preserve embedded text as embedded_content for reference
289+
if "text_content" in new_item:
290+
new_item["embedded_content"] = new_item.pop("text_content")
291+
elif "text_content" in new_item:
292+
# Standard KB: use text_content as page_content
293+
new_item["page_content"] = new_item.pop("text_content")
294+
295+
transformed_results.append(new_item)
296+
else:
297+
transformed_results.append(item)
298+
299+
# Return just the array of results
300+
return transformed_results
246301

247302

248303
class LangGraphInstrumentor:
@@ -280,20 +335,33 @@ def _finish_ok(
280335
# (_wrap_async_func, _wrap_sync_func, etc.) BEFORE calling _finish_ok.
281336
# The wrappers collect streamed content and pass {"content": "..."} here.
282337

283-
# Check if this node made any tracked API calls (e.g., LLM inference)
338+
# Check if this node made any tracked API calls (e.g., LLM inference or KBaaS retrieval)
284339
if _had_hits_since(intr, tok):
285-
_ensure_meta(rec)["is_llm_call"] = True
340+
# Get captured payloads and classify the call type
341+
api_request, api_response, is_llm, is_retriever = (
342+
_get_captured_payloads_with_type(intr, tok)
343+
)
286344

287-
# Try to get actual API request/response payloads (for LLM calls)
288-
api_request, api_response = _get_captured_payloads(intr, tok)
345+
# Set metadata based on call type
346+
meta = _ensure_meta(rec)
347+
if is_llm:
348+
meta["is_llm_call"] = True
349+
elif is_retriever:
350+
meta["is_retriever_call"] = True
351+
else:
352+
# Fallback: assume LLM call for backward compatibility
353+
meta["is_llm_call"] = True
289354

290355
if api_request or api_response:
291356
# Use actual API payloads instead of function args
292357
if api_request:
293358
rec.inputs = _freeze(api_request)
294359

295-
# Use actual API response as output (e.g., LLM completion)
360+
# Use actual API response as output
296361
if api_response:
362+
# Transform KBaaS response to standard retriever format
363+
if is_retriever:
364+
api_response = _transform_kbaas_response(api_response)
297365
out_payload = _freeze(api_response)
298366
else:
299367
out_payload = _canonical_output(inputs_snapshot, a, kw, ret)
@@ -306,10 +374,21 @@ def _finish_ok(
306374

307375
def _finish_err(rec: NodeExecution, intr, tok, e: BaseException):
308376
if _had_hits_since(intr, tok):
309-
_ensure_meta(rec)["is_llm_call"] = True
377+
# Get captured payloads and classify the call type
378+
api_request, _, is_llm, is_retriever = _get_captured_payloads_with_type(
379+
intr, tok
380+
)
381+
382+
# Set metadata based on call type
383+
meta = _ensure_meta(rec)
384+
if is_llm:
385+
meta["is_llm_call"] = True
386+
elif is_retriever:
387+
meta["is_retriever_call"] = True
388+
else:
389+
# Fallback: assume LLM call for backward compatibility
390+
meta["is_llm_call"] = True
310391

311-
# Try to get actual API request payload even on error
312-
api_request, _ = _get_captured_payloads(intr, tok)
313392
if api_request:
314393
rec.inputs = _freeze(api_request)
315394

@@ -623,4 +702,4 @@ def wrapped_add_node(graph_self, *args, **kwargs):
623702
return original_add_node(graph_self, *args, **kwargs)
624703

625704
StateGraph.add_node = wrapped_add_node
626-
self._installed = True
705+
self._installed = True

gradient_adk/runtime/network_interceptor.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ class CapturedRequest:
2424

2525
def __init__(
2626
self,
27+
url: Optional[str] = None,
2728
request_payload: Optional[Dict[str, Any]] = None,
2829
response_payload: Optional[Dict[str, Any]] = None,
2930
):
31+
self.url = url
3032
self.request_payload = request_payload
3133
self.response_payload = response_payload
3234

@@ -287,9 +289,9 @@ def _record_request(
287289
with self._lock:
288290
if self._is_tracked_url(url):
289291
self._hit_count += 1
290-
# Create a new captured request record
292+
# Create a new captured request record with URL
291293
self._captured_requests.append(
292-
CapturedRequest(request_payload=request_payload)
294+
CapturedRequest(url=url, request_payload=request_payload)
293295
)
294296

295297
def _record_response(
@@ -401,6 +403,25 @@ def hook(url: str, headers: Dict[str, str]) -> Dict[str, str]:
401403
return hook
402404

403405

406+
# URL classification helpers for different DigitalOcean services
407+
INFERENCE_URL_PATTERNS = ["inference.do-ai.run", "inference.do-ai-test.run"]
408+
KBAAS_URL_PATTERNS = ["kbaas.do-ai.run", "kbaas.do-ai-test.run"]
409+
410+
411+
def is_inference_url(url: Optional[str]) -> bool:
412+
"""Check if URL matches DigitalOcean inference (LLM) endpoints."""
413+
if not url:
414+
return False
415+
return any(pattern in url for pattern in INFERENCE_URL_PATTERNS)
416+
417+
418+
def is_kbaas_url(url: Optional[str]) -> bool:
419+
"""Check if URL matches DigitalOcean KBaaS (Knowledge Base) endpoints."""
420+
if not url:
421+
return False
422+
return any(pattern in url for pattern in KBAAS_URL_PATTERNS)
423+
424+
404425
# Global instance
405426
_global_interceptor = NetworkInterceptor()
406427

@@ -411,14 +432,21 @@ def get_network_interceptor() -> NetworkInterceptor:
411432

412433
def setup_digitalocean_interception() -> None:
413434
intr = get_network_interceptor()
414-
intr.add_endpoint_pattern("inference.do-ai.run")
415-
intr.add_endpoint_pattern("inference.do-ai-test.run")
416435

417-
# Register User-Agent hook for ADK identification
436+
# Add inference (LLM) endpoint patterns
437+
for pattern in INFERENCE_URL_PATTERNS:
438+
intr.add_endpoint_pattern(pattern)
439+
440+
# Add KBaaS (Knowledge Base) endpoint patterns
441+
for pattern in KBAAS_URL_PATTERNS:
442+
intr.add_endpoint_pattern(pattern)
443+
444+
# Register User-Agent hook for ADK identification (all DO endpoints)
445+
all_patterns = INFERENCE_URL_PATTERNS + KBAAS_URL_PATTERNS
418446
ua_hook = create_adk_user_agent_hook(
419447
version=_get_adk_version(),
420-
url_patterns=["inference.do-ai.run", "inference.do-ai-test.run"],
448+
url_patterns=all_patterns,
421449
)
422450
intr.add_request_hook(ua_hook)
423451

424-
intr.start_intercepting()
452+
intr.start_intercepting()

0 commit comments

Comments
 (0)