Skip to content

Commit ce74658

Browse files
authored
Add tool call handling TrustworthyRAG (#103)
1 parent cfc6b82 commit ce74658

File tree

6 files changed

+540
-0
lines changed

6 files changed

+540
-0
lines changed

src/cleanlab_tlm/internal/rag.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from __future__ import annotations
2+
3+
from functools import wraps
4+
from typing import TYPE_CHECKING, Any, Callable, TypeVar
5+
6+
from cleanlab_tlm.utils.chat import _TOOL_CALL_TAG_END, _TOOL_CALL_TAG_START
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Coroutine
10+
11+
# Define type variables for the response types
12+
ResponseT = TypeVar("ResponseT")
13+
14+
15+
def _is_tool_call_response(response_string: str) -> bool:
16+
"""Check if response string represents a tool call."""
17+
stripped = response_string.strip()
18+
19+
# If response doesn't contain tool call tags, it's not a tool call
20+
if _TOOL_CALL_TAG_START not in stripped or _TOOL_CALL_TAG_END not in stripped:
21+
return False
22+
23+
# Find all tool call sections and remove them
24+
remaining_content = stripped
25+
while _TOOL_CALL_TAG_START in remaining_content and _TOOL_CALL_TAG_END in remaining_content:
26+
start_pos = remaining_content.find(_TOOL_CALL_TAG_START)
27+
end_pos = remaining_content.find(_TOOL_CALL_TAG_END, start_pos)
28+
29+
# If we can't find a matching closing tag, break
30+
if end_pos == -1:
31+
break
32+
33+
# Remove this tool call section (including the tags)
34+
end_pos += len(_TOOL_CALL_TAG_END)
35+
remaining_content = remaining_content[:start_pos] + remaining_content[end_pos:]
36+
37+
# If there's any non-whitespace content left after removing all tool calls,
38+
# then this response contains regular text and is not a pure tool call response
39+
return not remaining_content.strip()
40+
41+
42+
def _handle_tool_call_filtering(
43+
func: Callable[..., Coroutine[Any, Any, ResponseT]],
44+
) -> Callable[..., Coroutine[Any, Any, ResponseT]]:
45+
"""
46+
Decorator to handle tool call filtering for scoring methods.
47+
48+
When tool call handling is enabled and a tool call is detected:
49+
- Filters out evals that have response_identifier (these would get None scores)
50+
- Calls the original method with filtered evals via a context wrapper
51+
- Adds None scores for the filtered evals
52+
- Returns the combined result
53+
54+
This implementation avoids modifying the original instance state to prevent
55+
race conditions in concurrent async operations.
56+
"""
57+
58+
@wraps(func)
59+
async def wrapper(self: Any, **kwargs: Any) -> ResponseT:
60+
response = kwargs.get("response", {})
61+
response_text = response.get("response", "")
62+
is_tool_call = _is_tool_call_response(str(response_text))
63+
64+
# If not a tool call, just call the original method
65+
if not is_tool_call:
66+
return await func(self, **kwargs)
67+
68+
# It's a tool call - determine which evals to process vs. filter
69+
# Default behavior:
70+
# - Evals with response_identifier are filtered out (score None)
71+
# - Evals without response_identifier are still evaluated normally
72+
# Optional per-eval overrides via instance-level include/exclude name sets:
73+
# - If name in exclude set, filter (score None)
74+
75+
exclude_names = set(getattr(self, "_tool_call_eval_exclude_names", set()) or set())
76+
77+
evals_to_process = []
78+
tool_call_filtered_evals = []
79+
80+
for eval_obj in self._evals:
81+
# Start from default filtering decision
82+
is_filtered = eval_obj.response_identifier is not None and eval_obj.name in exclude_names
83+
84+
if is_filtered:
85+
tool_call_filtered_evals.append(eval_obj)
86+
else:
87+
evals_to_process.append(eval_obj)
88+
89+
# Create a context wrapper that temporarily provides filtered evals
90+
# without modifying the original instance
91+
class _EvalsContextWrapper:
92+
def __init__(self, original_instance: Any, filtered_evals: list[Any]):
93+
self._original = original_instance
94+
self._filtered_evals = filtered_evals
95+
96+
def __getattr__(self, name: str) -> Any:
97+
if name == "_evals":
98+
return self._filtered_evals
99+
return getattr(self._original, name)
100+
101+
def __repr__(self) -> str:
102+
return repr(self._original)
103+
104+
def __str__(self) -> str:
105+
return str(self._original)
106+
107+
# Use the wrapper instance to call the original method
108+
wrapper_instance = _EvalsContextWrapper(self, evals_to_process)
109+
backend_response: ResponseT = await func(wrapper_instance, **kwargs)
110+
return _rebuild_response(backend_response, self._evals)
111+
112+
return wrapper
113+
114+
115+
def _rebuild_response(backend_response: ResponseT, evals: list[Any]) -> ResponseT:
116+
eval_names = [e.name for e in evals]
117+
ordered = {}
118+
119+
for k, v in backend_response.items(): # type: ignore
120+
if k not in eval_names:
121+
ordered[k] = v
122+
123+
for e in evals:
124+
name = e.name
125+
if name in backend_response: # type: ignore
126+
ordered[name] = backend_response[name] # type: ignore
127+
else:
128+
ordered[name] = {"score": None} # filtered or missing
129+
130+
return ordered # type: ignore

src/cleanlab_tlm/utils/rag.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_VALID_TLM_QUALITY_PRESETS,
4242
)
4343
from cleanlab_tlm.internal.exception_handling import handle_tlm_exceptions
44+
from cleanlab_tlm.internal.rag import _handle_tool_call_filtering
4445
from cleanlab_tlm.internal.validation import (
4546
_validate_trustworthy_rag_options,
4647
tlm_score_process_response_and_kwargs,
@@ -87,6 +88,10 @@ class TrustworthyRAG(BaseTLM):
8788
To come up with your custom `evals`, we recommend you first run [get_default_evals()](#function-get_default_evals) and then add/remove/modify the returned list.
8889
Each [Eval](#class-eval) in this list provides real-time detection of specific issues in your RAG application based on the user query, retrieved context (documents), and/or LLM-generated response.
8990
Set this to an empty list to only score response trustworthiness without additional evaluations.
91+
92+
Tool call handling: by default, when a tool call response is detected, evaluations that analyze the response content
93+
(those with a `response_identifier`) are assigned `score=None`. You can override this behavior for specific evals via
94+
`_configure_tool_call_eval_overrides()`.
9095
"""
9196

9297
def __init__(
@@ -135,6 +140,35 @@ def __init__(
135140

136141
_validate_trustworthy_rag_options(options=options, initialized_evals=self._evals)
137142

143+
# Optional per-eval tool call overrides
144+
# These are name-based include/exclude sets used only in the _handle_tool_call_filtering decorator
145+
self._configure_tool_call_eval_overrides(exclude_names=[k.name for k in self._evals if k.response_identifier])
146+
147+
def _configure_tool_call_eval_overrides(
148+
self,
149+
*,
150+
exclude_names: Optional[list[str]] = None,
151+
) -> None:
152+
"""Validates and stores tool-call exclusion names.
153+
154+
Only evals that read from the model response (have a non-None `response_identifier`)
155+
are eligible for tool-call filtering. We validate here (configuration boundary) so the
156+
decorator `_handle_tool_call_filtering` can assume a correct set and remain simple.
157+
158+
- If an eval name is in exclude_names, it will be filtered (score=None) during tool call handling.
159+
160+
Args:
161+
exclude_names (list[str] | None): Evaluation names to always filter during tool calls.
162+
"""
163+
names = exclude_names or []
164+
eligible = {e.name for e in self._evals if e.response_identifier is not None}
165+
invalid = [n for n in names if n not in eligible]
166+
if invalid:
167+
raise ValidationError(
168+
f"Invalid eval name(s) for tool-call exclusion (must exist and have response_identifier): {', '.join(invalid)}"
169+
)
170+
self._tool_call_eval_exclude_names = set(names) # membership filter; order/dupes irrelevant
171+
138172
def score(
139173
self,
140174
*,
@@ -434,6 +468,7 @@ async def _batch_async(
434468
await gather_task,
435469
)
436470

471+
@_handle_tool_call_filtering
437472
@handle_tlm_exceptions("TrustworthyRAGResponse")
438473
async def _generate_async(
439474
self,
@@ -480,6 +515,7 @@ async def _generate_async(
480515
),
481516
)
482517

518+
@_handle_tool_call_filtering
483519
@handle_tlm_exceptions("TrustworthyRAGScore")
484520
async def _score_async(
485521
self,

tests/internal/__init__.py

Whitespace-only changes.

tests/internal/test_rag.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
from typing import Any
2+
from unittest import mock
3+
4+
from cleanlab_tlm.utils.rag import TrustworthyRAG
5+
from tests.test_tlm_rag import (
6+
test_context,
7+
test_prompt,
8+
test_query,
9+
test_response,
10+
trustworthy_rag, # noqa: F401
11+
trustworthy_rag_api_key, # noqa: F401
12+
)
13+
14+
15+
def test_decorator_skips_bulk_logic_for_non_tool_calls(trustworthy_rag: TrustworthyRAG) -> None: # noqa: F811
16+
"""Tests that the _handle_tool_call_filtering decorator skips the bulk of its logic for non-tool calls.
17+
18+
Expected:
19+
- When _is_tool_call_response returns False, the decorator should skip eval filtering logic
20+
- The original _evals should not be modified during execution
21+
- No None scores should be added for tool call filtered evals
22+
"""
23+
# Store original evals for comparison
24+
original_evals = trustworthy_rag._evals.copy()
25+
original_evals_id = id(trustworthy_rag._evals)
26+
27+
# Mock to track if the bulk logic is executed
28+
with mock.patch("cleanlab_tlm.internal.rag._is_tool_call_response", return_value=False) as mock_is_tool_call:
29+
# Track if evals are temporarily modified (which shouldn't happen for non-tool calls)
30+
evals_modifications = []
31+
original_setattr = object.__setattr__
32+
33+
def track_evals_setattr(self: Any, name: str, value: Any) -> Any:
34+
if name == "_evals" and hasattr(self, "_evals"):
35+
evals_modifications.append((name, value, id(value)))
36+
return original_setattr(self, name, value)
37+
38+
with mock.patch.object(type(trustworthy_rag), "__setattr__", track_evals_setattr):
39+
response = trustworthy_rag.score(
40+
query=test_query,
41+
context=test_context,
42+
response=test_response,
43+
prompt=test_prompt,
44+
)
45+
46+
# Verify _is_tool_call_response was called (decorator logic was entered)
47+
assert mock_is_tool_call.call_count > 0
48+
49+
# Verify that evals were not temporarily modified (bulk logic was skipped)
50+
# The only modifications should be the initial assignment during init, not temporary changes
51+
evals_temp_modifications = [mod for mod in evals_modifications if mod[2] != original_evals_id]
52+
assert len(evals_temp_modifications) == 0, f"Evals were temporarily modified: {evals_temp_modifications}"
53+
54+
# Verify evals are unchanged after the call
55+
assert trustworthy_rag._evals == original_evals
56+
assert id(trustworthy_rag._evals) == original_evals_id
57+
58+
# Verify we got a normal response with actual scores (not None scores from tool call filtering)
59+
assert isinstance(response, dict)
60+
for eval_name, eval_data in response.items():
61+
if eval_name != "trustworthiness": # trustworthiness might have None score if disabled
62+
# Non-tool calls should have actual scores, not None scores from tool call filtering
63+
assert eval_data["score"] is not None or eval_name == "trustworthiness"
64+
65+
66+
def test_decorator_calls_api_with_full_evals_for_non_tool_calls(trustworthy_rag_api_key: str) -> None: # noqa: F811
67+
"""Decorator should pass full evals to API for non-tool-call responses.
68+
69+
Expected:
70+
- When _is_tool_call_response returns False, the decorator should call the underlying API
71+
with the complete _evals parameter (no filtering applied).
72+
"""
73+
# Create TrustworthyRAG instance
74+
tlm_rag = TrustworthyRAG(api_key=trustworthy_rag_api_key)
75+
76+
# Store the original evals to verify they're passed through
77+
original_evals = tlm_rag._evals.copy()
78+
79+
# Mock _is_tool_call_response to return False (non-tool call)
80+
with (
81+
mock.patch("cleanlab_tlm.internal.rag._is_tool_call_response", return_value=False),
82+
mock.patch("cleanlab_tlm.internal.api.api.tlm_rag_score") as mock_api_score,
83+
):
84+
# Configure the mock to return a valid response
85+
mock_api_score.return_value = {eval_name: {"score": 0.8, "reason": "test"} for eval_name in original_evals}
86+
87+
response = tlm_rag.score(
88+
query=test_query,
89+
context=test_context,
90+
response=test_response,
91+
)
92+
93+
# Verify the API was called with the full evals (no filtering)
94+
assert mock_api_score.call_count == 1
95+
call_args = mock_api_score.call_args
96+
assert call_args is not None
97+
98+
# Check that evals parameter matches the original evals
99+
called_evals = call_args.kwargs.get("evals")
100+
assert called_evals == original_evals
101+
102+
# Should get a normal response
103+
assert isinstance(response, dict)
104+
for eval_dict in response.values():
105+
assert isinstance(eval_dict["score"], float)
106+
107+
108+
def test_ordering_preserved_for_non_tool_calls(trustworthy_rag_api_key: str) -> None: # noqa: F811
109+
"""When not a tool call, ordering should match exactly what the mocked api.tlm_rag_score returns."""
110+
tlm_rag = TrustworthyRAG(api_key=trustworthy_rag_api_key)
111+
112+
# Construct a mocked backend result with a specific insertion order
113+
mocked_backend = {
114+
"trustworthiness": {"score": 0.91},
115+
# Put evals in a custom order to ensure we preserve this order
116+
"query_ease": {"score": 0.11},
117+
"context_sufficiency": {"score": 0.22},
118+
"response_helpfulness": {"score": 0.33},
119+
"response_groundedness": {"score": 0.44},
120+
}
121+
122+
with (
123+
mock.patch("cleanlab_tlm.internal.rag._is_tool_call_response", return_value=False),
124+
mock.patch("cleanlab_tlm.internal.api.api.tlm_rag_score", return_value=mocked_backend),
125+
):
126+
result = tlm_rag.score(
127+
query=test_query,
128+
context=test_context,
129+
response=test_response,
130+
)
131+
132+
assert isinstance(result, dict)
133+
assert list(result.keys()) == list(mocked_backend.keys())
134+
135+
136+
def test_ordering_rebuilt_for_tool_calls(trustworthy_rag_api_key: str) -> None: # noqa: F811
137+
"""For tool calls, non-eval keys keep backend order, then all evals in self._evals order with filtered as None."""
138+
tlm_rag = TrustworthyRAG(api_key=trustworthy_rag_api_key)
139+
140+
# Default eval order from TrustworthyRAG
141+
eval_order = [e.name for e in tlm_rag._evals]
142+
assert eval_order == [
143+
"context_sufficiency",
144+
"response_groundedness",
145+
"response_helpfulness",
146+
"query_ease",
147+
]
148+
149+
# Backend only processes non-response evals during tool-calls (decorator filters response-based evals)
150+
# Intentionally put processed evals in a non-evals order to ensure rebuild will override to eval_order
151+
mocked_backend_processed = {
152+
"trustworthiness": {"score": 0.9},
153+
"query_ease": {"score": 0.5},
154+
"context_sufficiency": {"score": 0.8},
155+
}
156+
157+
with (
158+
mock.patch("cleanlab_tlm.internal.rag._is_tool_call_response", return_value=True),
159+
mock.patch("cleanlab_tlm.internal.api.api.tlm_rag_score", return_value=mocked_backend_processed),
160+
):
161+
result = tlm_rag.score(
162+
query=test_query,
163+
context=test_context,
164+
response=test_response,
165+
prompt=test_prompt,
166+
)
167+
168+
assert isinstance(result, dict)
169+
170+
# Non-eval keys (trustworthiness) should appear first preserving backend order
171+
expected_keys = ["trustworthiness", *eval_order]
172+
assert list(result.keys()) == expected_keys
173+
174+
# Filtered response-based evals should be present with None score
175+
assert result["response_groundedness"]["score"] is None
176+
assert result["response_helpfulness"]["score"] is None
177+
assert result["query_ease"]["score"] == mocked_backend_processed["query_ease"]["score"]
178+
assert result["context_sufficiency"]["score"] == mocked_backend_processed["context_sufficiency"]["score"]

0 commit comments

Comments
 (0)