Skip to content

Commit 53ba5aa

Browse files
authored
modify built-in chat eval flow: (#34511)
* modify built-in chat eval flow: * modify rai service request format * fix aggregation logic * add input validation for rai service * fix bug in validate_conversation of chat eval flow
1 parent b71f012 commit 53ba5aa

File tree

7 files changed

+139
-65
lines changed

7 files changed

+139
-65
lines changed

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/pf_templates/built_in_metrics/chat/aggregate.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,22 @@
77

88
@tool
99
def aggregate_variants_results(results: List[dict], metrics: List[dict],
10-
thresholds = None):
11-
rag_metrics = filter_metrics(metrics[0]["rag_metrics"])
12-
non_rag_metrics = filter_metrics(metrics[0]["non_rag_metrics"])
13-
safety_metrics = filter_metrics(metrics[0]["safety_metrics"])
10+
thresholds = None):
11+
if metrics:
12+
rag_metrics = filter_metrics(metrics[0]["rag_metrics"])
13+
non_rag_metrics = filter_metrics(metrics[0]["non_rag_metrics"])
14+
safety_metrics = filter_metrics(metrics[0]["safety_metrics"])
15+
else:
16+
rag_metrics = []
17+
non_rag_metrics = []
18+
safety_metrics = []
1419
selected_quality_metrics = rag_metrics + non_rag_metrics
1520

1621
# set threshold for safety metric defect_rate
1722
if thresholds != [] and thresholds is not None:
1823
threshold = np.float16(thresholds[0])
1924
else:
20-
threshold = RAIService.HARM_SEVERITY_THRESHOLD
25+
threshold = np.float16(RAIService.HARM_SEVERITY_THRESHOLD)
2126

2227
aggregate_results = {}
2328
for result in results:
@@ -32,18 +37,22 @@ def aggregate_variants_results(results: List[dict], metrics: List[dict],
3237
float_val = float(metric_value)
3338
except Exception:
3439
float_val = np.nan
35-
aggregate_results[name].append(float_val)
40+
if float_val >= 0:
41+
aggregate_results[name].append(float_val)
3642
aggregate_output = {}
3743

38-
for name, value in aggregate_results.items():
44+
for name, values in aggregate_results.items():
3945
metric_name = name
40-
if name in selected_quality_metrics:
41-
aggregate_output[metric_name] = np.nanmean(value)
42-
elif name in safety_metrics:
43-
values = aggregate_results[name]
46+
if name in safety_metrics:
4447
metric_name = name+'_defect_rate'
45-
aggregate_output[metric_name] = np.sum(values >= threshold) / len(values)
46-
47-
aggregate_output[metric_name] = round(aggregate_output[metric_name], 2)
48+
if len(values) == 0:
49+
aggregate_output[metric_name] = np.nan
50+
else:
51+
if name in selected_quality_metrics:
52+
aggregate_output[metric_name] = round(np.nanmean(values), 2)
53+
elif name in safety_metrics:
54+
aggregate_output[metric_name] = round(np.sum(values >= threshold) / len(values), 2)
55+
else:
56+
aggregate_output[metric_name] = np.nan
4857
log_metric(metric_name, aggregate_output[metric_name])
4958
return aggregate_output

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/pf_templates/built_in_metrics/chat/concatenate_metrics.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@ def format_rag_results(rag_results: dict, supported_metrics):
77
result_per_turn = {}
88
if rag_results:
99
for metric, value in rag_results['artifacts'].items():
10-
result_per_chat[metric] = rag_results['metrics']["mean_" + metric]
11-
result_per_turn[metric] = {"reason": value['reason'], "score": value['score_per_turn']}
10+
try:
11+
result_per_chat[metric] = rag_results['metrics']["mean_" + metric]
12+
result_per_turn[metric] = {"reason": value['reason'], "score": value['score_per_turn']}
13+
except KeyError:
14+
result_per_chat[metric] = np.nan
15+
result_per_turn[metric] = np.nan
1216
for metric in supported_metrics:
1317
if metric not in result_per_turn:
1418
result_per_chat[metric] = np.nan
@@ -21,7 +25,10 @@ def format_non_rag_results(non_rag_results: dict, supported_metrics):
2125
result_per_turn = {}
2226
if non_rag_results:
2327
for metric in non_rag_results['artifacts']:
24-
result_per_chat[metric] = non_rag_results['metrics']['mean_' + metric]
28+
try:
29+
result_per_chat[metric] = non_rag_results['metrics']['mean_' + metric]
30+
except:
31+
result_per_chat[metric] = np.nan
2532
result_per_turn = non_rag_results['artifacts']
2633
for metric in supported_metrics:
2734
if metric not in result_per_turn:

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/pf_templates/built_in_metrics/chat/construct_service_request.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@ def parse_chat(user_text: list):
77
try:
88
role = turn["role"]
99
content = turn["content"]
10-
content_str = "<" + role + ">" + content + "</>\n"
10+
if role == "user":
11+
content_str = "<Human>" + content + "</>\n"
12+
elif role == "assistant":
13+
content_str = "<System>" + content + "</>\n"
14+
else:
15+
content_str = "<" + role + ">" + content + "</>\n"
1116
except KeyError:
1217
content_str = json.dumps(turn) + "\n"
1318
parsed_chat.append(content_str)
14-
return "{\"conversation\": \"" + "".join(parsed_chat) + "\"}"
19+
return "".join(parsed_chat)
1520

1621
def normalize_user_text(user_text):
1722
return user_text.replace("'", "\\\"")

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/pf_templates/built_in_metrics/chat/flow.dag.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ nodes:
137137
deployment_name: ${inputs.deployment_name}
138138
selected_metrics: ${select_metrics.output}
139139
activate:
140-
when: ${validate_conversation.output}
140+
when: ${validate_conversation.output.rag_metrics}
141141
is: true
142142
use_variants: false
143143
- name: evaluate_coherence_fluency
@@ -151,7 +151,7 @@ nodes:
151151
parsed_qa: ${parse_chat.output}
152152
selected_metrics: ${select_metrics.output}
153153
activate:
154-
when: ${validate_conversation.output}
154+
when: ${validate_conversation.output.non_rag_metrics}
155155
is: true
156156
use_variants: false
157157
- name: parse_chat
@@ -162,7 +162,7 @@ nodes:
162162
inputs:
163163
chat: ${inputs.messages}
164164
activate:
165-
when: ${validate_conversation.output}
165+
when: ${validate_conversation.output.non_rag_metrics}
166166
is: true
167167
use_variants: false
168168
- name: concatenate_metrics
@@ -191,6 +191,7 @@ nodes:
191191
type: code
192192
path: validate_service.py
193193
inputs:
194+
chat: ${inputs.messages}
194195
selected_metrics: ${select_metrics.output}
195196
use_variants: false
196197
- name: construct_service_request

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/pf_templates/built_in_metrics/chat/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from promptflow.connections import AzureOpenAIConnection
22
import constants
33
import numpy as np
4+
from azureml.metrics.common import _validation
5+
from azureml.metrics.common.contract import Contract
6+
from azureml.metrics.common.exceptions import InvalidOperationException
47

58
def get_openai_parameters(connection: AzureOpenAIConnection, deployment_name: str) -> dict:
69
openai_params = {
@@ -55,3 +58,42 @@ def get_harm_severity_level(harm_score: int) -> str:
5558
if harm_score >= harm_score_range[0] and harm_score <= harm_score_range[1]:
5659
return harm_level.name
5760
return np.nan
61+
62+
def is_conversation_valid(chat: []) -> bool:
63+
reference_code = "validate_conversation"
64+
name = "chat_format"
65+
# check if role and content keys exist in every turn
66+
_validation._check_chat_conversation([chat], name, reference_code=reference_code)
67+
return True
68+
69+
def is_conversation_valid_with_context(chat: []) -> bool:
70+
reference_code = "validate_conversation"
71+
name = "chat_context_format"
72+
73+
# check if context/documents keys exist for rag evaluation
74+
for turn_num, each_turn in enumerate(chat):
75+
# to accept legacy rag_evaluation format:
76+
# [{"user": {"content": "<user_content>"},
77+
# "assistant": {"content": "<assistang_content>"},
78+
# "retrieved_documents": "<retrieved_documents>"}]
79+
if "user" in each_turn and "assistant" in each_turn: # legancy rag_evaluation format
80+
Contract.assert_true("retrieved_documents" in each_turn,
81+
message = "Please ensure to have retrieved_documents key in each turn for rag_evaluation."
82+
+ " Please check turn_number: {}".format(turn_num),
83+
target=name, log_safe=True,
84+
reference_code = reference_code)
85+
elif "role" in each_turn and each_turn["role"] == "assistant":
86+
#if "context" not in each_turn:
87+
Contract.assert_true("context" in each_turn,
88+
message = "Please ensure to have context key in assistant turn for rag_evaluation."
89+
+ " Please check turn_number: {}".format(turn_num),
90+
target=name, log_safe=True,
91+
reference_code = reference_code)
92+
if "context" in each_turn: #and "citations" not in each_turn["context"]:
93+
Contract.assert_true("citations" in each_turn["context"],
94+
message = "Please ensure to have citations key in assistant turn context for rag_evaluation."
95+
+ " Please check turn_number: {}".format(turn_num),
96+
target=name, log_safe=True,
97+
reference_code = reference_code)
98+
99+
return True
Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,53 @@
11
from promptflow import tool
2-
from azureml.metrics.common import _validation
3-
from azureml.metrics.common.contract import Contract
4-
from azureml.metrics.common.exceptions import InvalidOperationException
5-
from utils import filter_metrics
2+
#from azureml.metrics.common import _validation
3+
#from azureml.metrics.common.contract import Contract
4+
#from azureml.metrics.common.exceptions import InvalidOperationException
5+
from utils import filter_metrics, is_conversation_valid, is_conversation_valid_with_context
66

7-
def is_conversation_valid(chat: [], selected_metrics: dict) -> bool:
8-
reference_code = "validate_conversation"
9-
name = "chat_format"
10-
# check if role and content keys exist in every turn
11-
_validation._check_chat_conversation([chat], name, reference_code=reference_code)
7+
def is_metric_group_selected(selected_metrics: dict) -> dict:
8+
group_selected = {}
9+
for metric_group in selected_metrics:
10+
group_selected[metric_group] = False
11+
for metric in selected_metrics[metric_group]:
12+
if selected_metrics[metric_group][metric]:
13+
group_selected[metric_group] = True
14+
break
15+
return group_selected
1216

13-
# check if context/documents keys exist for rag evaluation
14-
rag_metrics = filter_metrics(selected_metrics["rag_metrics"])
15-
if len(rag_metrics) > 0:
16-
for turn_num, each_turn in enumerate(chat):
17-
# to accept legacy rag_evaluation format:
18-
# [{"user": {"content": "<user_content>"},
19-
# "assistant": {"content": "<assistang_content>"},
20-
# "retrieved_documents": "<retrieved_documents>"}]
21-
if "user" in each_turn and "assistant" in each_turn: # legancy rag_evaluation format
22-
Contract.assert_true("retrieved_documents" in each_turn,
23-
message = "Please ensure to have retrieved_documents key in each turn for rag_evaluation."
24-
+ " Please check turn_number: {}".format(turn_num),
25-
target=name, log_safe=True,
26-
reference_code = reference_code)
27-
elif "role" in each_turn and each_turn["role"] == "assistant":
28-
#if "context" not in each_turn:
29-
Contract.assert_true("context" in each_turn,
30-
message = "Please ensure to have context key in assistant turn for rag_evaluation."
31-
+ " Please check turn_number: {}".format(turn_num),
32-
target=name, log_safe=True,
33-
reference_code = reference_code)
34-
if "context" in each_turn: #and "citations" not in each_turn["context"]:
35-
Contract.assert_true("citations" in each_turn["context"],
36-
message = "Please ensure to have citations key in assistant turn context for rag_evaluation."
37-
+ " Please check turn_number: {}".format(turn_num),
38-
target=name, log_safe=True,
39-
reference_code = reference_code)
40-
return True
4117

4218
# The inputs section will change based on the arguments of the tool function, after you save the code
4319
# Adding type to arguments and return value will help the system show the types properly
4420
# Please update the function name/signature per need
4521
@tool
4622
def validate_conversation(chat: [], selected_metrics: dict) -> bool:
23+
is_group_selected = is_metric_group_selected(selected_metrics)
24+
25+
# no quality metrics are selected
26+
if (not is_group_selected['rag_metrics']) and (not is_group_selected['non_rag_metrics']):
27+
print("no quality metrics selected. ")
28+
return {"non_rag_metrics": False,
29+
"rag_metrics": False}
30+
31+
# check if chat format is valid
32+
#is_valid_chat = is_conversation_valid(chat)
4733
try:
48-
is_valid_chat = is_conversation_valid(chat, selected_metrics)
49-
except Exception:
34+
is_valid_chat = is_conversation_valid(chat)
35+
except:
5036
is_valid_chat = False
51-
return is_valid_chat
37+
38+
# chat format is not valid
39+
if not is_valid_chat:
40+
print("chat format is not valid")
41+
return {"non_rag_metrics": False,
42+
"rag_metrics": False}
43+
44+
non_rag_node = is_group_selected['non_rag_metrics'] and is_valid_chat
45+
rag_node = False
46+
if is_group_selected['rag_metrics'] and is_valid_chat:
47+
try:
48+
rag_node = is_conversation_valid_with_context(chat)
49+
except:
50+
rag_node = False
51+
print("non_rag_metrics:", non_rag_node, "rag_metrics:", rag_node)
52+
53+
return {"non_rag_metrics": non_rag_node, "rag_metrics": rag_node}

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/pf_templates/built_in_metrics/chat/validate_service.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from promptflow import tool
22
import mlflow
33
from mlflow.utils.rest_utils import http_request
4-
from utils import get_cred
4+
from utils import get_cred, is_conversation_valid
55

66
def is_service_available():
77
try:
@@ -35,13 +35,21 @@ def is_safety_metrics_selected(selected_metrics):
3535
print("No safety metrics are selected.")
3636
return False
3737

38+
def is_chat_valid(chat) -> bool:
39+
try:
40+
is_valid_chat_format = is_conversation_valid(chat)
41+
except:
42+
print("The chat format is not valid for safety metrics")
43+
is_valid_chat_format = False
44+
return is_valid_chat_format
45+
3846

3947
# check if RAI service is avilable in this region. If not, return False.
4048
# check if tracking_uri is set. If not, return False
4149
# if tracking_rui is set, check if any safety metric is selected.
4250
# if no safety metric is selected, return False
4351
@tool
44-
def validate_safety_metric_input(selected_metrics: dict) -> dict:
52+
def validate_safety_metric_input(selected_metrics: dict, chat: [dict]) -> dict:
4553
return is_safety_metrics_selected(selected_metrics) and \
4654
is_service_available() and \
47-
is_tracking_uri_set()
55+
is_tracking_uri_set() and is_chat_valid(chat)

0 commit comments

Comments
 (0)