22# Copyright (c) Microsoft Corporation. All rights reserved.
33# ---------------------------------------------------------
44from concurrent .futures import as_completed
5- from typing import Callable , Dict , List , Union
5+ from typing import Callable , Dict , List , Union , Optional
6+ from typing_extensions import override
67
78from promptflow .tracing import ThreadPoolExecutorWithContext as ThreadPoolExecutor
89
10+ from azure .ai .evaluation ._evaluators ._common import EvaluatorBase
911from azure .ai .evaluation ._common ._experimental import experimental
1012
1113from ._hate_unfairness import HateUnfairnessEvaluator
1517
1618
1719@experimental
18- class ContentSafetyEvaluator :
20+ class ContentSafetyEvaluator ( EvaluatorBase ) :
1921 """
2022 Initialize a content safety evaluator configured to evaluate content safetry metrics for QA scenario.
2123
@@ -24,8 +26,12 @@ class ContentSafetyEvaluator:
2426 :param azure_ai_project: The scope of the Azure AI project.
2527 It contains subscription id, resource group, and project name.
2628 :type azure_ai_project: ~azure.ai.evaluation.AzureAIProject
27- :param parallel: If True, use parallel execution for evaluators. Else, use sequential execution.
28- Default is True.
29+ :param eval_last_turn: Whether to evaluate the last turn of a conversation. Default is False.
30+ :type eval_last_turn: bool
31+ :param kwargs: Additional arguments to pass to the evaluator.
32+ :type kwargs: Any
33+ :return: A function that evaluates content-safety metrics for "question-answering" scenario.
34+ :rtype: Callable
2935
3036 **Usage**
3137
@@ -62,41 +68,69 @@ class ContentSafetyEvaluator:
6268 }
6369 """
6470
65- def __init__ (self , credential , azure_ai_project , parallel : bool = True ):
66- self ._parallel = parallel
71+ def __init__ (self , credential , azure_ai_project , eval_last_turn : bool = False , ** kwargs ):
72+ super ().__init__ (eval_last_turn = eval_last_turn )
73+ self ._parallel = kwargs .pop ("parallel" , True )
6774 self ._evaluators : List [Callable [..., Dict [str , Union [str , float ]]]] = [
6875 ViolenceEvaluator (credential , azure_ai_project ),
6976 SexualEvaluator (credential , azure_ai_project ),
7077 SelfHarmEvaluator (credential , azure_ai_project ),
7178 HateUnfairnessEvaluator (credential , azure_ai_project ),
7279 ]
7380
74- def __call__ (self , * , query : str , response : str , ** kwargs ):
81+ @override
82+ def __call__ (
83+ self ,
84+ * ,
85+ query : Optional [str ] = None ,
86+ response : Optional [str ] = None ,
87+ conversation = None ,
88+ ** kwargs ,
89+ ):
90+ """Evaluate a collection of content safety metrics for the given query/response pair or conversation.
91+ This inputs must supply either a query AND response, or a conversation, but not both.
92+
93+ :keyword query: The query to evaluate.
94+ :paramtype query: Optional[str]
95+ :keyword response: The response to evaluate.
96+ :paramtype response: Optional[str]
97+ :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
98+ key "messages", and potentially a global context under the key "context". Conversation turns are expected
99+ to be dictionaries with keys "content", "role", and possibly "context".
100+ :paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
101+ :return: The evaluation result.
102+ :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]]
75103 """
76- Evaluates content-safety metrics for "question-answering" scenario.
77-
78- :keyword query: The query to be evaluated.
79- :paramtype query: str
80- :keyword response: The response to be evaluated.
81- :paramtype response: str
82- :keyword parallel: Whether to evaluate in parallel.
83- :paramtype parallel: bool
84- :return: The scores for content-safety.
85- :rtype: Dict[str, Union[str, float]]
104+ return super ().__call__ (query = query , response = response , conversation = conversation , ** kwargs )
105+
106+ @override
107+ async def _do_eval (self , eval_input : Dict ) -> Dict [str , Union [str , float ]]:
108+ """Perform the evaluation using the Azure AI RAI service.
109+ The exact evaluation performed is determined by the evaluation metric supplied
110+ by the child class initializer.
111+
112+ :param eval_input: The input to the evaluation function.
113+ :type eval_input: Dict
114+ :return: The evaluation result.
115+ :rtype: Dict
86116 """
117+ query = eval_input .get ("query" , None )
118+ response = eval_input .get ("response" , None )
119+ conversation = eval_input .get ("conversation" , None )
87120 results : Dict [str , Union [str , float ]] = {}
88121 if self ._parallel :
89122 with ThreadPoolExecutor () as executor :
123+ # pylint: disable=no-value-for-parameter
90124 futures = {
91- executor .submit (evaluator , query = query , response = response , ** kwargs ): evaluator
125+ executor .submit (query = query , response = response , conversation = conversation ): evaluator
92126 for evaluator in self ._evaluators
93127 }
94128
95129 for future in as_completed (futures ):
96130 results .update (future .result ())
97131 else :
98132 for evaluator in self ._evaluators :
99- result = evaluator (query = query , response = response , ** kwargs )
133+ result = evaluator (query = query , response = response , conversation = conversation )
100134 results .update (result )
101135
102136 return results
0 commit comments