Skip to content

Commit a572abe

Browse files
add spokenwoz speech and text (#24)
1 parent 8099646 commit a572abe

File tree

12 files changed

+694
-2
lines changed

12 files changed

+694
-2
lines changed

engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,13 @@ async def score_model_with_tokens(model_name, outs):
423423
metric, outs, model_targets, source_sentences,
424424
instructions=instructions, task_name=self.task_name, model_name=model_name, model_responses=model_responses
425425
)
426+
elif metric_name in ('joint_goal_accuracy', 'slot_accuracy', 'slot_f1'):
427+
ground_truth_slots = process_result.get("ground_truth_slots", [])
428+
result = await asyncio.to_thread(
429+
metric, outs, model_targets,
430+
instructions=instructions, task_name=self.task_name, model_name=model_name,
431+
model_responses=model_responses, ground_truth_slots=ground_truth_slots
432+
)
426433
else:
427434
result = await asyncio.to_thread(
428435
metric, outs, model_targets,

metrics/README.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ For more detailed documentation regarding which metrics can be used for which ta
2626
| `sql_score` (↑) | SQL correctness and execution match | text2sql_score |
2727
| `instruction_following` (↑) | LLM-judged instruction following capability | final |
2828
| `gsm8k_exact_match` (↑) | Exact-match accuracy of the final numerical answer. | gsm8k_exact_match |
29+
| `joint_goal_accuracy` (↑) | Dialogue state tracking - all slots match | joint_goal_accuracy |
30+
| `slot_accuracy` (↑) | Dialogue state tracking - per-slot accuracy | slot_accuracy |
31+
| `slot_f1` (↑) | Dialogue state tracking - slot extraction F1 | slot_f1 |
2932

3033
---
3134

@@ -156,4 +159,28 @@ For more detailed documentation regarding which metrics can be used for which ta
156159
- **Type**: Math correctness metric
157160
- **Description**: Measure the exact-match accuracy of the final numerical answer (expected within `\boxed{}`) with the reference numerical answer.
158161
- **Scoring (record-level)** Score between `0` and `100`, higher is better.
159-
- **Used In**: Math (`gsm8k`)
162+
- **Used In**: Math (`gsm8k`)
163+
164+
---
165+
166+
### `joint_goal_accuracy`
167+
- **Type**: Dialogue state tracking metric
168+
- **Description**: Evaluates whether all predicted slots exactly match the ground truth dialogue state. A sample scores 1 only if every slot-value pair is correct.
169+
- **Scoring (record-level)** Score `0` or `1`, higher is better.
170+
- **Used In**: Task-Oriented Dialogue (`spoken_dialogue`)
171+
172+
---
173+
174+
### `slot_accuracy`
175+
- **Type**: Dialogue state tracking metric
176+
- **Description**: Computes the proportion of individual slots correctly predicted across all samples.
177+
- **Scoring (record-level)** Score between `0` and `1`, higher is better.
178+
- **Used In**: Task-Oriented Dialogue (`spoken_dialogue`)
179+
180+
---
181+
182+
### `slot_f1`
183+
- **Type**: Dialogue state tracking metric
184+
- **Description**: Computes F1 score for slot value extraction, balancing precision and recall of predicted slot-value pairs.
185+
- **Scoring (record-level)** Score between `0` and `1`, higher is better.
186+
- **Used In**: Task-Oriented Dialogue (`spoken_dialogue`)

metrics/dialogue_metrics.py

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
"""Dialogue state tracking metrics for task-oriented dialogue evaluation.
2+
3+
This module provides general-purpose evaluation metrics for dialogue systems:
4+
- Joint Goal Accuracy (JGA): Whether all slots match the ground truth
5+
- Slot Accuracy: Per-slot accuracy across all samples
6+
- Slot F1: F1 score for slot value extraction
7+
8+
These metrics can be used by any task-oriented dialogue benchmark.
9+
"""
10+
11+
import json
12+
import logging
13+
import re
14+
from typing import Dict, List, Any
15+
16+
from tqdm import tqdm
17+
18+
from metrics.metrics import Metrics
19+
from utils import util
20+
from utils.custom_logging import write_record_log, append_final_score
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
class _BaseDialogueMetric(Metrics):
26+
"""Base class for dialogue state tracking metrics."""
27+
28+
def __init__(self):
29+
super().__init__()
30+
self.instructions = None
31+
self.model_responses = []
32+
self.ground_truth_slots = []
33+
34+
def __call__(self, candidates, references, instructions=None, *,
35+
task_name: str | None = None, model_name: str | None = None,
36+
model_responses=None, ground_truth_slots=None):
37+
"""Evaluate model predictions against ground truth.
38+
39+
Args:
40+
candidates: Model generated responses
41+
references: Ground truth agent responses
42+
instructions: Input instructions (contains dialogue context)
43+
task_name: Name of the task
44+
model_name: Name of the model
45+
model_responses: Raw model response objects
46+
ground_truth_slots: List of ground truth slot dictionaries
47+
"""
48+
self.instructions = instructions
49+
self.model_responses = model_responses if model_responses else []
50+
self.ground_truth_slots = ground_truth_slots if ground_truth_slots else []
51+
52+
# Compute scores
53+
overall = self.get_score(candidates, references, task_name, model_name)
54+
55+
if task_name and model_name:
56+
scores = self.record_level_scores
57+
write_record_log(self, references, candidates, scores, task_name, model_name,
58+
instructions=self.instructions, model_responses=self.model_responses)
59+
append_final_score(self, overall, task_name, model_name, self.model_responses)
60+
61+
return overall
62+
63+
def get_score(self, candidates: list, references: list,
64+
task_name: str = None, model_name: str = None) -> Dict[str, float]:
65+
"""Compute overall scores for the dataset."""
66+
if not self.record_level_scores:
67+
self.record_level_scores = self.compute_record_level_scores(
68+
candidates, references, task_name, model_name
69+
)
70+
71+
results = {}
72+
for metric_name, scores in self.record_level_scores.items():
73+
valid_scores = [s for s in scores if s is not None]
74+
if valid_scores:
75+
results[metric_name] = util.smart_round(sum(valid_scores) / len(valid_scores) * 100, 2)
76+
else:
77+
results[metric_name] = 0.0
78+
79+
return results
80+
81+
def _extract_slots_from_response(self, response: str) -> Dict[str, Dict[str, str]]:
82+
"""Extract slot values from model's natural language response.
83+
84+
Uses pattern matching to identify slot-value pairs mentioned in the response.
85+
"""
86+
if not response:
87+
return {}
88+
89+
response_lower = response.lower()
90+
extracted_slots = {}
91+
92+
# Pattern-based extraction for common slot types
93+
patterns = {
94+
'restaurant': {
95+
'area': r'(?:in|at|around|near)\s+(?:the\s+)?(\w+)\s+(?:area|part|side)',
96+
'food': r'(?:serving|serves?|offering?|type of food[:\s]+)(\w+(?:\s+\w+)?)\s+(?:food|cuisine)?',
97+
'pricerange': r'(?:price[d]?\s*range|budget|cost)[:\s]+(\w+)|(\w+)\s+(?:price[d]?\s*range|priced)',
98+
'name': r'(?:called|named|restaurant[:\s]+)([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
99+
},
100+
'hotel': {
101+
'area': r'(?:in|at|around|near)\s+(?:the\s+)?(\w+)\s+(?:area|part|side)',
102+
'pricerange': r'(?:price[d]?\s*range|budget|cost)[:\s]+(\w+)|(\w+)\s+(?:price[d]?\s*range|priced)',
103+
'stars': r'(\d+)\s*(?:star|stars)',
104+
'parking': r'(?:parking)[:\s]*(yes|no|free)',
105+
'internet': r'(?:internet|wifi)[:\s]*(yes|no|free)',
106+
'type': r'(?:type)[:\s]+(\w+)|(\w+)\s+(?:type)',
107+
'name': r'(?:called|named|hotel[:\s]+)([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
108+
},
109+
'train': {
110+
'departure': r'(?:from|depart(?:ing|ure)?(?:\s+from)?)[:\s]+([A-Za-z]+)',
111+
'destination': r'(?:to|arrive?(?:\s+at)?|destination)[:\s]+([A-Za-z]+)',
112+
'day': r'(?:on|day)[:\s]+(\w+day)',
113+
'leaveat': r'(?:leav(?:e|ing)\s+(?:at)?|departure\s+time)[:\s]+(\d{1,2}[:\.]?\d{2})',
114+
'arriveby': r'(?:arriv(?:e|ing)\s+(?:by)?|arrival\s+time)[:\s]+(\d{1,2}[:\.]?\d{2})',
115+
},
116+
'taxi': {
117+
'departure': r'(?:from|pick\s*up)[:\s]+([A-Za-z\s]+?)(?:\s+to|\s*$)',
118+
'destination': r'(?:to|drop\s*off|destination)[:\s]+([A-Za-z\s]+)',
119+
'leaveat': r'(?:leav(?:e|ing)\s+(?:at)?)[:\s]+(\d{1,2}[:\.]?\d{2})',
120+
'arriveby': r'(?:arriv(?:e|ing)\s+(?:by)?)[:\s]+(\d{1,2}[:\.]?\d{2})',
121+
},
122+
'attraction': {
123+
'area': r'(?:in|at|around|near)\s+(?:the\s+)?(\w+)\s+(?:area|part|side)',
124+
'type': r'(?:type)[:\s]+(\w+)|(\w+)\s+(?:attraction|place)',
125+
'name': r'(?:called|named|attraction[:\s]+)([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
126+
}
127+
}
128+
129+
for domain, slot_patterns in patterns.items():
130+
for slot, pattern in slot_patterns.items():
131+
matches = re.findall(pattern, response_lower, re.IGNORECASE)
132+
if matches:
133+
value = None
134+
for match in matches:
135+
if isinstance(match, tuple):
136+
value = next((m for m in match if m), None)
137+
else:
138+
value = match
139+
if value:
140+
break
141+
142+
if value:
143+
if domain not in extracted_slots:
144+
extracted_slots[domain] = {}
145+
extracted_slots[domain][slot] = value.strip()
146+
147+
return extracted_slots
148+
149+
def _normalize_value(self, value: str) -> str:
150+
"""Normalize slot value for comparison."""
151+
if not value:
152+
return ''
153+
return str(value).lower().strip().replace(' ', '').replace('-', '').replace('_', '')
154+
155+
def _parse_ground_truth_slots(self, gt_slots) -> Dict:
156+
"""Parse ground truth slots from various formats."""
157+
if not gt_slots:
158+
return {}
159+
if isinstance(gt_slots, str):
160+
try:
161+
return json.loads(gt_slots)
162+
except json.JSONDecodeError:
163+
return {}
164+
return gt_slots
165+
166+
167+
class JointGoalAccuracy(_BaseDialogueMetric):
168+
"""Joint Goal Accuracy metric for dialogue state tracking.
169+
170+
JGA is 1 if all predicted slots exactly match ground truth, 0 otherwise.
171+
"""
172+
173+
def __init__(self):
174+
super().__init__()
175+
self.name = "joint_goal_accuracy"
176+
177+
def compute_record_level_scores(self, candidates: list, references: list,
178+
task_name: str = None, model_name: str = None) -> Dict[str, List]:
179+
"""Compute JGA for each sample."""
180+
scores = []
181+
182+
desc = f"JGA Eval"
183+
if task_name and model_name:
184+
desc = f"JGA [{task_name}] [{model_name}]"
185+
186+
for i, (candidate, reference) in enumerate(tqdm(
187+
zip(candidates, references), desc=desc, total=len(candidates)
188+
)):
189+
gt_slots = {}
190+
if i < len(self.ground_truth_slots):
191+
gt_slots = self._parse_ground_truth_slots(self.ground_truth_slots[i])
192+
193+
pred_slots = self._extract_slots_from_response(candidate)
194+
jga = self._compute_jga(pred_slots, gt_slots)
195+
scores.append(jga)
196+
197+
return {self.name: scores}
198+
199+
def _compute_jga(self, pred_slots: Dict, gt_slots: Dict) -> float:
200+
"""Compute Joint Goal Accuracy."""
201+
if not gt_slots:
202+
return 1.0 if not pred_slots else 0.0
203+
204+
for domain, slots in gt_slots.items():
205+
if domain not in pred_slots:
206+
return 0.0
207+
for slot, value in slots.items():
208+
pred_value = pred_slots.get(domain, {}).get(slot, '')
209+
if self._normalize_value(pred_value) != self._normalize_value(value):
210+
return 0.0
211+
212+
return 1.0
213+
214+
215+
class SlotAccuracy(_BaseDialogueMetric):
216+
"""Slot Accuracy metric for dialogue state tracking.
217+
218+
Computes the proportion of individual slots correctly predicted.
219+
"""
220+
221+
def __init__(self):
222+
super().__init__()
223+
self.name = "slot_accuracy"
224+
225+
def compute_record_level_scores(self, candidates: list, references: list,
226+
task_name: str = None, model_name: str = None) -> Dict[str, List]:
227+
"""Compute slot accuracy for each sample."""
228+
scores = []
229+
230+
desc = f"Slot Accuracy Eval"
231+
if task_name and model_name:
232+
desc = f"Slot Accuracy [{task_name}] [{model_name}]"
233+
234+
for i, (candidate, reference) in enumerate(tqdm(
235+
zip(candidates, references), desc=desc, total=len(candidates)
236+
)):
237+
gt_slots = {}
238+
if i < len(self.ground_truth_slots):
239+
gt_slots = self._parse_ground_truth_slots(self.ground_truth_slots[i])
240+
241+
pred_slots = self._extract_slots_from_response(candidate)
242+
slot_acc = self._compute_slot_accuracy(pred_slots, gt_slots)
243+
scores.append(slot_acc)
244+
245+
return {self.name: scores}
246+
247+
def _compute_slot_accuracy(self, pred_slots: Dict, gt_slots: Dict) -> float:
248+
"""Compute per-slot accuracy."""
249+
if not gt_slots:
250+
return 1.0 if not pred_slots else 0.0
251+
252+
total_slots = 0
253+
correct_slots = 0
254+
255+
for domain, slots in gt_slots.items():
256+
for slot, value in slots.items():
257+
total_slots += 1
258+
pred_value = pred_slots.get(domain, {}).get(slot, '')
259+
if self._normalize_value(pred_value) == self._normalize_value(value):
260+
correct_slots += 1
261+
262+
return correct_slots / total_slots if total_slots > 0 else 1.0
263+
264+
265+
class SlotF1(_BaseDialogueMetric):
266+
"""Slot F1 metric for dialogue state tracking.
267+
268+
Computes F1 score for slot value extraction.
269+
"""
270+
271+
def __init__(self):
272+
super().__init__()
273+
self.name = "slot_f1"
274+
275+
def compute_record_level_scores(self, candidates: list, references: list,
276+
task_name: str = None, model_name: str = None) -> Dict[str, List]:
277+
"""Compute slot F1 for each sample."""
278+
scores = []
279+
280+
desc = f"Slot F1 Eval"
281+
if task_name and model_name:
282+
desc = f"Slot F1 [{task_name}] [{model_name}]"
283+
284+
for i, (candidate, reference) in enumerate(tqdm(
285+
zip(candidates, references), desc=desc, total=len(candidates)
286+
)):
287+
gt_slots = {}
288+
if i < len(self.ground_truth_slots):
289+
gt_slots = self._parse_ground_truth_slots(self.ground_truth_slots[i])
290+
291+
pred_slots = self._extract_slots_from_response(candidate)
292+
slot_f1 = self._compute_slot_f1(pred_slots, gt_slots)
293+
scores.append(slot_f1)
294+
295+
return {self.name: scores}
296+
297+
def _compute_slot_f1(self, pred_slots: Dict, gt_slots: Dict) -> float:
298+
"""Compute slot F1 score."""
299+
gt_set = set()
300+
pred_set = set()
301+
302+
for domain, slots in gt_slots.items():
303+
for slot, value in slots.items():
304+
gt_set.add((domain, slot, self._normalize_value(value)))
305+
306+
for domain, slots in pred_slots.items():
307+
for slot, value in slots.items():
308+
pred_set.add((domain, slot, self._normalize_value(value)))
309+
310+
if not gt_set and not pred_set:
311+
return 1.0
312+
if not gt_set or not pred_set:
313+
return 0.0
314+
315+
tp = len(gt_set & pred_set)
316+
precision = tp / len(pred_set) if pred_set else 0.0
317+
recall = tp / len(gt_set) if gt_set else 0.0
318+
319+
if precision + recall == 0:
320+
return 0.0
321+
322+
return 2 * precision * recall / (precision + recall)

0 commit comments

Comments
 (0)