|
5 | 5 | import json |
6 | 6 |
|
7 | 7 | from collections import defaultdict |
8 | | -from dataclasses import dataclass, field, is_dataclass, asdict |
| 8 | +from dataclasses import dataclass, field, fields, is_dataclass, asdict |
9 | 9 | from functools import reduce |
10 | 10 | from pathlib import Path |
11 | 11 | from typing import Dict, Iterable, Optional, Set, Tuple, List, Any |
@@ -90,6 +90,24 @@ def languages(self) -> Set[str]: |
90 | 90 | def expand_entities(self, text: str) -> Tuple[str, EntityErrors]: |
91 | 91 | return expand_all_entities(text, self.entities) |
92 | 92 |
|
| 93 | + def expand_entity_fields(self, obj: object): |
| 94 | + if isinstance(obj, list): |
| 95 | + for o in obj: |
| 96 | + self.expand_entity_fields(o) |
| 97 | + if isinstance(obj, dict): |
| 98 | + for val in obj.values(): |
| 99 | + self.expand_entity_fields(val) |
| 100 | + if is_dataclass(obj) and not isinstance(obj, type): |
| 101 | + for f in fields(obj): |
| 102 | + val = getattr(obj, f.name) |
| 103 | + if isinstance(val, str): |
| 104 | + [expanded, errs] = self.expand_entities(val) |
| 105 | + if errs: |
| 106 | + self.errors.extend(errs) |
| 107 | + else: |
| 108 | + setattr(obj, f.name, expanded) |
| 109 | + self.expand_entity_fields(val) |
| 110 | + |
93 | 111 | def merge(self, other: "DocGen") -> MetadataErrors: |
94 | 112 | """Merge fields from other into self, prioritizing self fields.""" |
95 | 113 | warnings = MetadataErrors() |
@@ -332,7 +350,7 @@ def count_genai(d: Dict[str, int], e: Example): |
332 | 350 | # and arguably not useful either. |
333 | 351 | class DocGenEncoder(json.JSONEncoder): |
334 | 352 | def default(self, obj): |
335 | | - if is_dataclass(obj): |
| 353 | + if is_dataclass(obj) and not isinstance(obj, type): |
336 | 354 | return asdict(obj) |
337 | 355 |
|
338 | 356 | if isinstance(obj, Path): |
|
0 commit comments