|
| 1 | +"""Dialogue state tracking metrics for task-oriented dialogue evaluation. |
| 2 | +
|
| 3 | +This module provides general-purpose evaluation metrics for dialogue systems: |
| 4 | +- Joint Goal Accuracy (JGA): Whether all slots match the ground truth |
| 5 | +- Slot Accuracy: Per-slot accuracy across all samples |
| 6 | +- Slot F1: F1 score for slot value extraction |
| 7 | +
|
| 8 | +These metrics can be used by any task-oriented dialogue benchmark. |
| 9 | +""" |
| 10 | + |
| 11 | +import json |
| 12 | +import logging |
| 13 | +import re |
| 14 | +from typing import Dict, List, Any |
| 15 | + |
| 16 | +from tqdm import tqdm |
| 17 | + |
| 18 | +from metrics.metrics import Metrics |
| 19 | +from utils import util |
| 20 | +from utils.custom_logging import write_record_log, append_final_score |
| 21 | + |
| 22 | +logger = logging.getLogger(__name__) |
| 23 | + |
| 24 | + |
| 25 | +class _BaseDialogueMetric(Metrics): |
| 26 | + """Base class for dialogue state tracking metrics.""" |
| 27 | + |
| 28 | + def __init__(self): |
| 29 | + super().__init__() |
| 30 | + self.instructions = None |
| 31 | + self.model_responses = [] |
| 32 | + self.ground_truth_slots = [] |
| 33 | + |
| 34 | + def __call__(self, candidates, references, instructions=None, *, |
| 35 | + task_name: str | None = None, model_name: str | None = None, |
| 36 | + model_responses=None, ground_truth_slots=None): |
| 37 | + """Evaluate model predictions against ground truth. |
| 38 | + |
| 39 | + Args: |
| 40 | + candidates: Model generated responses |
| 41 | + references: Ground truth agent responses |
| 42 | + instructions: Input instructions (contains dialogue context) |
| 43 | + task_name: Name of the task |
| 44 | + model_name: Name of the model |
| 45 | + model_responses: Raw model response objects |
| 46 | + ground_truth_slots: List of ground truth slot dictionaries |
| 47 | + """ |
| 48 | + self.instructions = instructions |
| 49 | + self.model_responses = model_responses if model_responses else [] |
| 50 | + self.ground_truth_slots = ground_truth_slots if ground_truth_slots else [] |
| 51 | + |
| 52 | + # Compute scores |
| 53 | + overall = self.get_score(candidates, references, task_name, model_name) |
| 54 | + |
| 55 | + if task_name and model_name: |
| 56 | + scores = self.record_level_scores |
| 57 | + write_record_log(self, references, candidates, scores, task_name, model_name, |
| 58 | + instructions=self.instructions, model_responses=self.model_responses) |
| 59 | + append_final_score(self, overall, task_name, model_name, self.model_responses) |
| 60 | + |
| 61 | + return overall |
| 62 | + |
| 63 | + def get_score(self, candidates: list, references: list, |
| 64 | + task_name: str = None, model_name: str = None) -> Dict[str, float]: |
| 65 | + """Compute overall scores for the dataset.""" |
| 66 | + if not self.record_level_scores: |
| 67 | + self.record_level_scores = self.compute_record_level_scores( |
| 68 | + candidates, references, task_name, model_name |
| 69 | + ) |
| 70 | + |
| 71 | + results = {} |
| 72 | + for metric_name, scores in self.record_level_scores.items(): |
| 73 | + valid_scores = [s for s in scores if s is not None] |
| 74 | + if valid_scores: |
| 75 | + results[metric_name] = util.smart_round(sum(valid_scores) / len(valid_scores) * 100, 2) |
| 76 | + else: |
| 77 | + results[metric_name] = 0.0 |
| 78 | + |
| 79 | + return results |
| 80 | + |
| 81 | + def _extract_slots_from_response(self, response: str) -> Dict[str, Dict[str, str]]: |
| 82 | + """Extract slot values from model's natural language response. |
| 83 | + |
| 84 | + Uses pattern matching to identify slot-value pairs mentioned in the response. |
| 85 | + """ |
| 86 | + if not response: |
| 87 | + return {} |
| 88 | + |
| 89 | + response_lower = response.lower() |
| 90 | + extracted_slots = {} |
| 91 | + |
| 92 | + # Pattern-based extraction for common slot types |
| 93 | + patterns = { |
| 94 | + 'restaurant': { |
| 95 | + 'area': r'(?:in|at|around|near)\s+(?:the\s+)?(\w+)\s+(?:area|part|side)', |
| 96 | + 'food': r'(?:serving|serves?|offering?|type of food[:\s]+)(\w+(?:\s+\w+)?)\s+(?:food|cuisine)?', |
| 97 | + 'pricerange': r'(?:price[d]?\s*range|budget|cost)[:\s]+(\w+)|(\w+)\s+(?:price[d]?\s*range|priced)', |
| 98 | + 'name': r'(?:called|named|restaurant[:\s]+)([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)', |
| 99 | + }, |
| 100 | + 'hotel': { |
| 101 | + 'area': r'(?:in|at|around|near)\s+(?:the\s+)?(\w+)\s+(?:area|part|side)', |
| 102 | + 'pricerange': r'(?:price[d]?\s*range|budget|cost)[:\s]+(\w+)|(\w+)\s+(?:price[d]?\s*range|priced)', |
| 103 | + 'stars': r'(\d+)\s*(?:star|stars)', |
| 104 | + 'parking': r'(?:parking)[:\s]*(yes|no|free)', |
| 105 | + 'internet': r'(?:internet|wifi)[:\s]*(yes|no|free)', |
| 106 | + 'type': r'(?:type)[:\s]+(\w+)|(\w+)\s+(?:type)', |
| 107 | + 'name': r'(?:called|named|hotel[:\s]+)([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)', |
| 108 | + }, |
| 109 | + 'train': { |
| 110 | + 'departure': r'(?:from|depart(?:ing|ure)?(?:\s+from)?)[:\s]+([A-Za-z]+)', |
| 111 | + 'destination': r'(?:to|arrive?(?:\s+at)?|destination)[:\s]+([A-Za-z]+)', |
| 112 | + 'day': r'(?:on|day)[:\s]+(\w+day)', |
| 113 | + 'leaveat': r'(?:leav(?:e|ing)\s+(?:at)?|departure\s+time)[:\s]+(\d{1,2}[:\.]?\d{2})', |
| 114 | + 'arriveby': r'(?:arriv(?:e|ing)\s+(?:by)?|arrival\s+time)[:\s]+(\d{1,2}[:\.]?\d{2})', |
| 115 | + }, |
| 116 | + 'taxi': { |
| 117 | + 'departure': r'(?:from|pick\s*up)[:\s]+([A-Za-z\s]+?)(?:\s+to|\s*$)', |
| 118 | + 'destination': r'(?:to|drop\s*off|destination)[:\s]+([A-Za-z\s]+)', |
| 119 | + 'leaveat': r'(?:leav(?:e|ing)\s+(?:at)?)[:\s]+(\d{1,2}[:\.]?\d{2})', |
| 120 | + 'arriveby': r'(?:arriv(?:e|ing)\s+(?:by)?)[:\s]+(\d{1,2}[:\.]?\d{2})', |
| 121 | + }, |
| 122 | + 'attraction': { |
| 123 | + 'area': r'(?:in|at|around|near)\s+(?:the\s+)?(\w+)\s+(?:area|part|side)', |
| 124 | + 'type': r'(?:type)[:\s]+(\w+)|(\w+)\s+(?:attraction|place)', |
| 125 | + 'name': r'(?:called|named|attraction[:\s]+)([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)', |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + for domain, slot_patterns in patterns.items(): |
| 130 | + for slot, pattern in slot_patterns.items(): |
| 131 | + matches = re.findall(pattern, response_lower, re.IGNORECASE) |
| 132 | + if matches: |
| 133 | + value = None |
| 134 | + for match in matches: |
| 135 | + if isinstance(match, tuple): |
| 136 | + value = next((m for m in match if m), None) |
| 137 | + else: |
| 138 | + value = match |
| 139 | + if value: |
| 140 | + break |
| 141 | + |
| 142 | + if value: |
| 143 | + if domain not in extracted_slots: |
| 144 | + extracted_slots[domain] = {} |
| 145 | + extracted_slots[domain][slot] = value.strip() |
| 146 | + |
| 147 | + return extracted_slots |
| 148 | + |
| 149 | + def _normalize_value(self, value: str) -> str: |
| 150 | + """Normalize slot value for comparison.""" |
| 151 | + if not value: |
| 152 | + return '' |
| 153 | + return str(value).lower().strip().replace(' ', '').replace('-', '').replace('_', '') |
| 154 | + |
| 155 | + def _parse_ground_truth_slots(self, gt_slots) -> Dict: |
| 156 | + """Parse ground truth slots from various formats.""" |
| 157 | + if not gt_slots: |
| 158 | + return {} |
| 159 | + if isinstance(gt_slots, str): |
| 160 | + try: |
| 161 | + return json.loads(gt_slots) |
| 162 | + except json.JSONDecodeError: |
| 163 | + return {} |
| 164 | + return gt_slots |
| 165 | + |
| 166 | + |
| 167 | +class JointGoalAccuracy(_BaseDialogueMetric): |
| 168 | + """Joint Goal Accuracy metric for dialogue state tracking. |
| 169 | + |
| 170 | + JGA is 1 if all predicted slots exactly match ground truth, 0 otherwise. |
| 171 | + """ |
| 172 | + |
| 173 | + def __init__(self): |
| 174 | + super().__init__() |
| 175 | + self.name = "joint_goal_accuracy" |
| 176 | + |
| 177 | + def compute_record_level_scores(self, candidates: list, references: list, |
| 178 | + task_name: str = None, model_name: str = None) -> Dict[str, List]: |
| 179 | + """Compute JGA for each sample.""" |
| 180 | + scores = [] |
| 181 | + |
| 182 | + desc = f"JGA Eval" |
| 183 | + if task_name and model_name: |
| 184 | + desc = f"JGA [{task_name}] [{model_name}]" |
| 185 | + |
| 186 | + for i, (candidate, reference) in enumerate(tqdm( |
| 187 | + zip(candidates, references), desc=desc, total=len(candidates) |
| 188 | + )): |
| 189 | + gt_slots = {} |
| 190 | + if i < len(self.ground_truth_slots): |
| 191 | + gt_slots = self._parse_ground_truth_slots(self.ground_truth_slots[i]) |
| 192 | + |
| 193 | + pred_slots = self._extract_slots_from_response(candidate) |
| 194 | + jga = self._compute_jga(pred_slots, gt_slots) |
| 195 | + scores.append(jga) |
| 196 | + |
| 197 | + return {self.name: scores} |
| 198 | + |
| 199 | + def _compute_jga(self, pred_slots: Dict, gt_slots: Dict) -> float: |
| 200 | + """Compute Joint Goal Accuracy.""" |
| 201 | + if not gt_slots: |
| 202 | + return 1.0 if not pred_slots else 0.0 |
| 203 | + |
| 204 | + for domain, slots in gt_slots.items(): |
| 205 | + if domain not in pred_slots: |
| 206 | + return 0.0 |
| 207 | + for slot, value in slots.items(): |
| 208 | + pred_value = pred_slots.get(domain, {}).get(slot, '') |
| 209 | + if self._normalize_value(pred_value) != self._normalize_value(value): |
| 210 | + return 0.0 |
| 211 | + |
| 212 | + return 1.0 |
| 213 | + |
| 214 | + |
| 215 | +class SlotAccuracy(_BaseDialogueMetric): |
| 216 | + """Slot Accuracy metric for dialogue state tracking. |
| 217 | + |
| 218 | + Computes the proportion of individual slots correctly predicted. |
| 219 | + """ |
| 220 | + |
| 221 | + def __init__(self): |
| 222 | + super().__init__() |
| 223 | + self.name = "slot_accuracy" |
| 224 | + |
| 225 | + def compute_record_level_scores(self, candidates: list, references: list, |
| 226 | + task_name: str = None, model_name: str = None) -> Dict[str, List]: |
| 227 | + """Compute slot accuracy for each sample.""" |
| 228 | + scores = [] |
| 229 | + |
| 230 | + desc = f"Slot Accuracy Eval" |
| 231 | + if task_name and model_name: |
| 232 | + desc = f"Slot Accuracy [{task_name}] [{model_name}]" |
| 233 | + |
| 234 | + for i, (candidate, reference) in enumerate(tqdm( |
| 235 | + zip(candidates, references), desc=desc, total=len(candidates) |
| 236 | + )): |
| 237 | + gt_slots = {} |
| 238 | + if i < len(self.ground_truth_slots): |
| 239 | + gt_slots = self._parse_ground_truth_slots(self.ground_truth_slots[i]) |
| 240 | + |
| 241 | + pred_slots = self._extract_slots_from_response(candidate) |
| 242 | + slot_acc = self._compute_slot_accuracy(pred_slots, gt_slots) |
| 243 | + scores.append(slot_acc) |
| 244 | + |
| 245 | + return {self.name: scores} |
| 246 | + |
| 247 | + def _compute_slot_accuracy(self, pred_slots: Dict, gt_slots: Dict) -> float: |
| 248 | + """Compute per-slot accuracy.""" |
| 249 | + if not gt_slots: |
| 250 | + return 1.0 if not pred_slots else 0.0 |
| 251 | + |
| 252 | + total_slots = 0 |
| 253 | + correct_slots = 0 |
| 254 | + |
| 255 | + for domain, slots in gt_slots.items(): |
| 256 | + for slot, value in slots.items(): |
| 257 | + total_slots += 1 |
| 258 | + pred_value = pred_slots.get(domain, {}).get(slot, '') |
| 259 | + if self._normalize_value(pred_value) == self._normalize_value(value): |
| 260 | + correct_slots += 1 |
| 261 | + |
| 262 | + return correct_slots / total_slots if total_slots > 0 else 1.0 |
| 263 | + |
| 264 | + |
| 265 | +class SlotF1(_BaseDialogueMetric): |
| 266 | + """Slot F1 metric for dialogue state tracking. |
| 267 | + |
| 268 | + Computes F1 score for slot value extraction. |
| 269 | + """ |
| 270 | + |
| 271 | + def __init__(self): |
| 272 | + super().__init__() |
| 273 | + self.name = "slot_f1" |
| 274 | + |
| 275 | + def compute_record_level_scores(self, candidates: list, references: list, |
| 276 | + task_name: str = None, model_name: str = None) -> Dict[str, List]: |
| 277 | + """Compute slot F1 for each sample.""" |
| 278 | + scores = [] |
| 279 | + |
| 280 | + desc = f"Slot F1 Eval" |
| 281 | + if task_name and model_name: |
| 282 | + desc = f"Slot F1 [{task_name}] [{model_name}]" |
| 283 | + |
| 284 | + for i, (candidate, reference) in enumerate(tqdm( |
| 285 | + zip(candidates, references), desc=desc, total=len(candidates) |
| 286 | + )): |
| 287 | + gt_slots = {} |
| 288 | + if i < len(self.ground_truth_slots): |
| 289 | + gt_slots = self._parse_ground_truth_slots(self.ground_truth_slots[i]) |
| 290 | + |
| 291 | + pred_slots = self._extract_slots_from_response(candidate) |
| 292 | + slot_f1 = self._compute_slot_f1(pred_slots, gt_slots) |
| 293 | + scores.append(slot_f1) |
| 294 | + |
| 295 | + return {self.name: scores} |
| 296 | + |
| 297 | + def _compute_slot_f1(self, pred_slots: Dict, gt_slots: Dict) -> float: |
| 298 | + """Compute slot F1 score.""" |
| 299 | + gt_set = set() |
| 300 | + pred_set = set() |
| 301 | + |
| 302 | + for domain, slots in gt_slots.items(): |
| 303 | + for slot, value in slots.items(): |
| 304 | + gt_set.add((domain, slot, self._normalize_value(value))) |
| 305 | + |
| 306 | + for domain, slots in pred_slots.items(): |
| 307 | + for slot, value in slots.items(): |
| 308 | + pred_set.add((domain, slot, self._normalize_value(value))) |
| 309 | + |
| 310 | + if not gt_set and not pred_set: |
| 311 | + return 1.0 |
| 312 | + if not gt_set or not pred_set: |
| 313 | + return 0.0 |
| 314 | + |
| 315 | + tp = len(gt_set & pred_set) |
| 316 | + precision = tp / len(pred_set) if pred_set else 0.0 |
| 317 | + recall = tp / len(gt_set) if gt_set else 0.0 |
| 318 | + |
| 319 | + if precision + recall == 0: |
| 320 | + return 0.0 |
| 321 | + |
| 322 | + return 2 * precision * recall / (precision + recall) |
0 commit comments