Skip to content

Commit 2e3d453

Browse files
authored
[evaluation] feat: Forward input columns that aren't named in evaluator signature to **kwargs (#42893)
* feat: Forward input columns that aren't named in evaluator signature to **kwargs * tests: Add tests for **kwargs behavior * test,fix: Check correct row in test
1 parent be04205 commit 2e3d453

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_legacy/_batch_engine/_engine.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,13 @@ def __preprocess_inputs(self, inputs: Mapping[str, Any]) -> Mapping[str, Any]:
344344

345345
func_params = inspect.signature(self._func).parameters
346346

347-
filtered_params = {key: value for key, value in inputs.items() if key in func_params}
348-
return filtered_params
347+
has_kwargs = any(p.kind == p.VAR_KEYWORD for p in func_params.values())
348+
349+
if has_kwargs:
350+
return inputs
351+
else:
352+
filtered_params = {key: value for key, value in inputs.items() if key in func_params}
353+
return filtered_params
349354

350355
async def _exec_line_async(
351356
self,

sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluate.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,82 @@ def test_name_map_conversion(self):
975975
assert result[EvaluationRunProperties.NAME_MAP_LENGTH] == -1
976976
assert len(result) == 1
977977

978+
def test_evaluate_evaluator_only_kwargs_param(self, evaluate_test_data_jsonl_file):
979+
"""Validate that an evaluator with only an **kwargs param receives all input in kwargs."""
980+
981+
def evaluator(**kwargs):
982+
return locals()
983+
984+
result = evaluate(data=evaluate_test_data_jsonl_file, evaluators={"test": evaluator})
985+
986+
assert len(result["rows"]) == 3
987+
988+
assert {"query", "response", "ground_truth", "context"}.issubset(result["rows"][0]["outputs.test.kwargs"])
989+
assert {"query", "response", "ground_truth", "context"}.issubset(result["rows"][1]["outputs.test.kwargs"])
990+
assert {"query", "response", "ground_truth", "context"}.issubset(result["rows"][2]["outputs.test.kwargs"])
991+
992+
def test_evaluate_evaluator_kwargs_param(self, evaluate_test_data_jsonl_file):
993+
"""Validate that an evaluator with named parameters and **kwargs obeys python function call semantics."""
994+
995+
def evaluator(query, response, *, bar=None, **kwargs):
996+
return locals()
997+
998+
result = evaluate(data=evaluate_test_data_jsonl_file, evaluators={"test": evaluator})
999+
1000+
assert len(result["rows"]) == 3
1001+
1002+
row1_kwargs = result["rows"][0]["outputs.test.kwargs"]
1003+
row2_kwargs = result["rows"][1]["outputs.test.kwargs"]
1004+
row3_kwargs = result["rows"][2]["outputs.test.kwargs"]
1005+
1006+
assert {"ground_truth", "context"}.issubset(row1_kwargs), "Unnamed parameters should be in kwargs"
1007+
assert {"query", "response", "bar"}.isdisjoint(row1_kwargs), "Named parameters should not be in kwargs"
1008+
1009+
assert {"ground_truth", "context"}.issubset(row2_kwargs), "Unnamed parameters should be in kwargs"
1010+
assert {"query", "response", "bar"}.isdisjoint(row2_kwargs), "Named parameters should not be in kwargs"
1011+
1012+
assert {"ground_truth", "context"}.issubset(row3_kwargs), "Unnamed parameters should be in kwargs"
1013+
assert {"query", "response", "bar"}.isdisjoint(row3_kwargs), "Named parameters should not be in kwargs"
1014+
1015+
def test_evaluate_evaluator_kwargs_param_column_mapping(self, evaluate_test_data_jsonl_file):
1016+
"""Validate that an evaluator with kwargs can receive column mapped values."""
1017+
1018+
def evaluator(query, response, *, bar=None, **kwargs):
1019+
return locals()
1020+
1021+
result = evaluate(
1022+
data=evaluate_test_data_jsonl_file,
1023+
evaluators={"test": evaluator},
1024+
evaluator_config={
1025+
"default": {
1026+
"column_mapping": {
1027+
"query": "${data.query}",
1028+
"response": "${data.response}",
1029+
"foo": "${data.context}",
1030+
"bar": "${data.ground_truth}",
1031+
}
1032+
}
1033+
},
1034+
)
1035+
1036+
assert len(result["rows"]) == 3
1037+
1038+
row1_kwargs = result["rows"][0]["outputs.test.kwargs"]
1039+
row2_kwargs = result["rows"][1]["outputs.test.kwargs"]
1040+
row3_kwargs = result["rows"][2]["outputs.test.kwargs"]
1041+
1042+
assert {"ground_truth", "context"}.issubset(row1_kwargs), "Unnamed parameters should be in kwargs"
1043+
assert "foo" in row1_kwargs, "Making a column mapping to an unnamed parameter should appear in kwargs"
1044+
assert {"query", "response", "bar"}.isdisjoint(row1_kwargs), "Named parameters should not be in kwargs"
1045+
1046+
assert {"ground_truth", "context"}.issubset(row2_kwargs), "Unnamed parameters should be in kwargs"
1047+
assert "foo" in row2_kwargs, "Making a column mapping to an unnamed parameter should appear in kwargs"
1048+
assert {"query", "response", "bar"}.isdisjoint(row2_kwargs), "Named parameters should not be in kwargs"
1049+
1050+
assert {"ground_truth", "context"}.issubset(row3_kwargs), "Unnamed parameters should be in kwargs"
1051+
assert "foo" in row3_kwargs, "Making a column mapping to an unnamed parameter should appear in kwargs"
1052+
assert {"query", "response", "bar"}.isdisjoint(row3_kwargs), "Named parameters should not be in kwargs"
1053+
9781054

9791055
@pytest.mark.unittest
9801056
class TestTagsInLoggingFunctions:

0 commit comments

Comments
 (0)