|
| 1 | +from typing import Any |
| 2 | +from unittest import mock |
| 3 | + |
| 4 | +from cleanlab_tlm.utils.rag import TrustworthyRAG |
| 5 | +from tests.test_tlm_rag import ( |
| 6 | + test_context, |
| 7 | + test_prompt, |
| 8 | + test_query, |
| 9 | + test_response, |
| 10 | + trustworthy_rag, # noqa: F401 |
| 11 | + trustworthy_rag_api_key, # noqa: F401 |
| 12 | +) |
| 13 | + |
| 14 | + |
| 15 | +def test_decorator_skips_bulk_logic_for_non_tool_calls(trustworthy_rag: TrustworthyRAG) -> None: # noqa: F811 |
| 16 | + """Tests that the _handle_tool_call_filtering decorator skips the bulk of its logic for non-tool calls. |
| 17 | +
|
| 18 | + Expected: |
| 19 | + - When _is_tool_call_response returns False, the decorator should skip eval filtering logic |
| 20 | + - The original _evals should not be modified during execution |
| 21 | + - No None scores should be added for tool call filtered evals |
| 22 | + """ |
| 23 | + # Store original evals for comparison |
| 24 | + original_evals = trustworthy_rag._evals.copy() |
| 25 | + original_evals_id = id(trustworthy_rag._evals) |
| 26 | + |
| 27 | + # Mock to track if the bulk logic is executed |
| 28 | + with mock.patch("cleanlab_tlm.internal.rag._is_tool_call_response", return_value=False) as mock_is_tool_call: |
| 29 | + # Track if evals are temporarily modified (which shouldn't happen for non-tool calls) |
| 30 | + evals_modifications = [] |
| 31 | + original_setattr = object.__setattr__ |
| 32 | + |
| 33 | + def track_evals_setattr(self: Any, name: str, value: Any) -> Any: |
| 34 | + if name == "_evals" and hasattr(self, "_evals"): |
| 35 | + evals_modifications.append((name, value, id(value))) |
| 36 | + return original_setattr(self, name, value) |
| 37 | + |
| 38 | + with mock.patch.object(type(trustworthy_rag), "__setattr__", track_evals_setattr): |
| 39 | + response = trustworthy_rag.score( |
| 40 | + query=test_query, |
| 41 | + context=test_context, |
| 42 | + response=test_response, |
| 43 | + prompt=test_prompt, |
| 44 | + ) |
| 45 | + |
| 46 | + # Verify _is_tool_call_response was called (decorator logic was entered) |
| 47 | + assert mock_is_tool_call.call_count > 0 |
| 48 | + |
| 49 | + # Verify that evals were not temporarily modified (bulk logic was skipped) |
| 50 | + # The only modifications should be the initial assignment during init, not temporary changes |
| 51 | + evals_temp_modifications = [mod for mod in evals_modifications if mod[2] != original_evals_id] |
| 52 | + assert len(evals_temp_modifications) == 0, f"Evals were temporarily modified: {evals_temp_modifications}" |
| 53 | + |
| 54 | + # Verify evals are unchanged after the call |
| 55 | + assert trustworthy_rag._evals == original_evals |
| 56 | + assert id(trustworthy_rag._evals) == original_evals_id |
| 57 | + |
| 58 | + # Verify we got a normal response with actual scores (not None scores from tool call filtering) |
| 59 | + assert isinstance(response, dict) |
| 60 | + for eval_name, eval_data in response.items(): |
| 61 | + if eval_name != "trustworthiness": # trustworthiness might have None score if disabled |
| 62 | + # Non-tool calls should have actual scores, not None scores from tool call filtering |
| 63 | + assert eval_data["score"] is not None or eval_name == "trustworthiness" |
| 64 | + |
| 65 | + |
| 66 | +def test_decorator_calls_api_with_full_evals_for_non_tool_calls(trustworthy_rag_api_key: str) -> None: # noqa: F811 |
| 67 | + """Decorator should pass full evals to API for non-tool-call responses. |
| 68 | +
|
| 69 | + Expected: |
| 70 | + - When _is_tool_call_response returns False, the decorator should call the underlying API |
| 71 | + with the complete _evals parameter (no filtering applied). |
| 72 | + """ |
| 73 | + # Create TrustworthyRAG instance |
| 74 | + tlm_rag = TrustworthyRAG(api_key=trustworthy_rag_api_key) |
| 75 | + |
| 76 | + # Store the original evals to verify they're passed through |
| 77 | + original_evals = tlm_rag._evals.copy() |
| 78 | + |
| 79 | + # Mock _is_tool_call_response to return False (non-tool call) |
| 80 | + with ( |
| 81 | + mock.patch("cleanlab_tlm.internal.rag._is_tool_call_response", return_value=False), |
| 82 | + mock.patch("cleanlab_tlm.internal.api.api.tlm_rag_score") as mock_api_score, |
| 83 | + ): |
| 84 | + # Configure the mock to return a valid response |
| 85 | + mock_api_score.return_value = {eval_name: {"score": 0.8, "reason": "test"} for eval_name in original_evals} |
| 86 | + |
| 87 | + response = tlm_rag.score( |
| 88 | + query=test_query, |
| 89 | + context=test_context, |
| 90 | + response=test_response, |
| 91 | + ) |
| 92 | + |
| 93 | + # Verify the API was called with the full evals (no filtering) |
| 94 | + assert mock_api_score.call_count == 1 |
| 95 | + call_args = mock_api_score.call_args |
| 96 | + assert call_args is not None |
| 97 | + |
| 98 | + # Check that evals parameter matches the original evals |
| 99 | + called_evals = call_args.kwargs.get("evals") |
| 100 | + assert called_evals == original_evals |
| 101 | + |
| 102 | + # Should get a normal response |
| 103 | + assert isinstance(response, dict) |
| 104 | + for eval_dict in response.values(): |
| 105 | + assert isinstance(eval_dict["score"], float) |
| 106 | + |
| 107 | + |
| 108 | +def test_ordering_preserved_for_non_tool_calls(trustworthy_rag_api_key: str) -> None: # noqa: F811 |
| 109 | + """When not a tool call, ordering should match exactly what the mocked api.tlm_rag_score returns.""" |
| 110 | + tlm_rag = TrustworthyRAG(api_key=trustworthy_rag_api_key) |
| 111 | + |
| 112 | + # Construct a mocked backend result with a specific insertion order |
| 113 | + mocked_backend = { |
| 114 | + "trustworthiness": {"score": 0.91}, |
| 115 | + # Put evals in a custom order to ensure we preserve this order |
| 116 | + "query_ease": {"score": 0.11}, |
| 117 | + "context_sufficiency": {"score": 0.22}, |
| 118 | + "response_helpfulness": {"score": 0.33}, |
| 119 | + "response_groundedness": {"score": 0.44}, |
| 120 | + } |
| 121 | + |
| 122 | + with ( |
| 123 | + mock.patch("cleanlab_tlm.internal.rag._is_tool_call_response", return_value=False), |
| 124 | + mock.patch("cleanlab_tlm.internal.api.api.tlm_rag_score", return_value=mocked_backend), |
| 125 | + ): |
| 126 | + result = tlm_rag.score( |
| 127 | + query=test_query, |
| 128 | + context=test_context, |
| 129 | + response=test_response, |
| 130 | + ) |
| 131 | + |
| 132 | + assert isinstance(result, dict) |
| 133 | + assert list(result.keys()) == list(mocked_backend.keys()) |
| 134 | + |
| 135 | + |
| 136 | +def test_ordering_rebuilt_for_tool_calls(trustworthy_rag_api_key: str) -> None: # noqa: F811 |
| 137 | + """For tool calls, non-eval keys keep backend order, then all evals in self._evals order with filtered as None.""" |
| 138 | + tlm_rag = TrustworthyRAG(api_key=trustworthy_rag_api_key) |
| 139 | + |
| 140 | + # Default eval order from TrustworthyRAG |
| 141 | + eval_order = [e.name for e in tlm_rag._evals] |
| 142 | + assert eval_order == [ |
| 143 | + "context_sufficiency", |
| 144 | + "response_groundedness", |
| 145 | + "response_helpfulness", |
| 146 | + "query_ease", |
| 147 | + ] |
| 148 | + |
| 149 | + # Backend only processes non-response evals during tool-calls (decorator filters response-based evals) |
| 150 | + # Intentionally put processed evals in a non-evals order to ensure rebuild will override to eval_order |
| 151 | + mocked_backend_processed = { |
| 152 | + "trustworthiness": {"score": 0.9}, |
| 153 | + "query_ease": {"score": 0.5}, |
| 154 | + "context_sufficiency": {"score": 0.8}, |
| 155 | + } |
| 156 | + |
| 157 | + with ( |
| 158 | + mock.patch("cleanlab_tlm.internal.rag._is_tool_call_response", return_value=True), |
| 159 | + mock.patch("cleanlab_tlm.internal.api.api.tlm_rag_score", return_value=mocked_backend_processed), |
| 160 | + ): |
| 161 | + result = tlm_rag.score( |
| 162 | + query=test_query, |
| 163 | + context=test_context, |
| 164 | + response=test_response, |
| 165 | + prompt=test_prompt, |
| 166 | + ) |
| 167 | + |
| 168 | + assert isinstance(result, dict) |
| 169 | + |
| 170 | + # Non-eval keys (trustworthiness) should appear first preserving backend order |
| 171 | + expected_keys = ["trustworthiness", *eval_order] |
| 172 | + assert list(result.keys()) == expected_keys |
| 173 | + |
| 174 | + # Filtered response-based evals should be present with None score |
| 175 | + assert result["response_groundedness"]["score"] is None |
| 176 | + assert result["response_helpfulness"]["score"] is None |
| 177 | + assert result["query_ease"]["score"] == mocked_backend_processed["query_ease"]["score"] |
| 178 | + assert result["context_sufficiency"]["score"] == mocked_backend_processed["context_sufficiency"]["score"] |
0 commit comments