|
3 | 3 | """ |
4 | 4 |
|
5 | 5 | import datetime |
6 | | -import json |
7 | 6 | import types |
| 7 | +from importlib import resources |
8 | 8 | from typing import Union, get_args, get_origin |
9 | 9 |
|
10 | 10 | import pyarrow as pa |
| 11 | +import yaml |
11 | 12 | from datasets import Features |
12 | 13 | from pydantic import BaseModel |
13 | 14 |
|
@@ -61,19 +62,34 @@ def _schema_from_pydantic(model: type[BaseModel]) -> list[pa.Field]: |
61 | 62 |
|
62 | 63 | def features_from_pydantic(model: type[BaseModel]) -> Features: |
63 | 64 | """ |
64 | | - Build a Hugging Face Features object from a Pydantic BaseModel using PyArrow schema. |
| 65 | + Build a HuggingFace Features object from a Pydantic BaseModel using PyArrow schema. |
65 | 66 | """ |
66 | 67 | pa_fields = _schema_from_pydantic(model) |
67 | 68 | pa_schema = pa.schema(pa_fields) |
68 | 69 | return Features.from_arrow_schema(pa_schema) |
69 | 70 |
|
70 | 71 |
|
71 | | -def generate_dataset_infos(output_path: str = "dataset_infos.json"): |
| 72 | +def write_dataset_features(output_path: str) -> None: |
72 | 73 | """ |
73 | | - Generate a dataset_infos.json file from the EvalResult schema. |
| 74 | + Write the HuggingFace Features data inferred from the EvalResult schema. |
74 | 75 | """ |
75 | 76 | features = features_from_pydantic(EvalResult) |
76 | | - infos = {"default": {"features": features.to_dict()}} |
77 | 77 | with open(output_path, "w", encoding="utf-8") as f: |
78 | | - json.dump(infos, f, indent=2) |
79 | | - print(f"Generated dataset_infos.json at {output_path}") |
| 78 | + yaml_values = features._to_yaml_list() |
| 79 | + yaml.safe_dump(yaml_values, f, indent=2, sort_keys=False) |
| 80 | + |
| 81 | + |
| 82 | +def load_dataset_features(input_path: str | None = None) -> Features: |
| 83 | + """ |
| 84 | + Load the HuggingFace Features data from a YAML file. |
| 85 | + """ |
| 86 | + if input_path is None: |
| 87 | + # load the shipped dataset_features.yml from the package |
| 88 | + with resources.open_text( |
| 89 | + "agenteval", "dataset_features.yml", encoding="utf-8" |
| 90 | + ) as f: |
| 91 | + yaml_values = yaml.safe_load(f) |
| 92 | + else: |
| 93 | + with open(input_path, "r", encoding="utf-8") as f: |
| 94 | + yaml_values = yaml.safe_load(f) |
| 95 | + return Features._from_yaml_list(yaml_values) |
0 commit comments