Skip to content

Commit 322e092

Browse files
authored
Fix dataset serialization when inputs have discriminators with defaults (#3079)
1 parent 5406124 commit 322e092

File tree

3 files changed

+47
-6
lines changed

3 files changed

+47
-6
lines changed

docs/evals.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,12 @@ async def main():
653653
print(output_file.read_text())
654654
"""
655655
# yaml-language-server: $schema=questions_cases_schema.json
656+
name: null
656657
cases:
657658
- name: Easy Capital Question
658659
inputs:
659660
question: What is the capital of France?
661+
context: null
660662
metadata:
661663
difficulty: easy
662664
category: Geography
@@ -668,6 +670,7 @@ async def main():
668670
- name: Challenging Landmark Question
669671
inputs:
670672
question: Which world-famous landmark is located on the banks of the Seine River?
673+
context: null
671674
metadata:
672675
difficulty: hard
673676
category: Landmarks
@@ -676,6 +679,7 @@ async def main():
676679
confidence: 0.9
677680
evaluators:
678681
- EqualsExpected
682+
evaluators: []
679683
"""
680684
```
681685

@@ -713,11 +717,13 @@ async def main():
713717
"""
714718
{
715719
"$schema": "questions_cases_schema.json",
720+
"name": null,
716721
"cases": [
717722
{
718723
"name": "Easy Capital Question",
719724
"inputs": {
720-
"question": "What is the capital of France?"
725+
"question": "What is the capital of France?",
726+
"context": null
721727
},
722728
"metadata": {
723729
"difficulty": "easy",
@@ -734,7 +740,8 @@ async def main():
734740
{
735741
"name": "Challenging Landmark Question",
736742
"inputs": {
737-
"question": "Which world-famous landmark is located on the banks of the Seine River?"
743+
"question": "Which world-famous landmark is located on the banks of the Seine River?",
744+
"context": null
738745
},
739746
"metadata": {
740747
"difficulty": "hard",
@@ -748,7 +755,8 @@ async def main():
748755
"EqualsExpected"
749756
]
750757
}
751-
]
758+
],
759+
"evaluators": []
752760
}
753761
"""
754762
```

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,15 +646,15 @@ def to_file(
646646

647647
context: dict[str, Any] = {'use_short_form': True}
648648
if fmt == 'yaml':
649-
dumped_data = self.model_dump(mode='json', by_alias=True, exclude_defaults=True, context=context)
649+
dumped_data = self.model_dump(mode='json', by_alias=True, context=context)
650650
content = yaml.dump(dumped_data, sort_keys=False)
651651
if schema_ref: # pragma: no branch
652652
yaml_language_server_line = f'{_YAML_SCHEMA_LINE_PREFIX}{schema_ref}'
653653
content = f'{yaml_language_server_line}\n{content}'
654654
path.write_text(content)
655655
else:
656656
context['$schema'] = schema_ref
657-
json_data = self.model_dump_json(indent=2, by_alias=True, exclude_defaults=True, context=context)
657+
json_data = self.model_dump_json(indent=2, by_alias=True, context=context)
658658
path.write_text(json_data + '\n')
659659

660660
@classmethod
@@ -724,6 +724,7 @@ class Case(BaseModel, extra='forbid'): # pyright: ignore[reportUnusedClass] #
724724
evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007
725725

726726
class Dataset(BaseModel, extra='forbid'):
727+
name: str | None = None
727728
cases: list[Case]
728729
if evaluator_schema_types: # pragma: no branch
729730
evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007

tests/evals/test_dataset.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sys
55
from dataclasses import dataclass, field
66
from pathlib import Path
7-
from typing import Any
7+
from typing import Any, Literal
88

99
import pytest
1010
import yaml
@@ -863,6 +863,38 @@ async def test_serialization_to_json(example_dataset: Dataset[TaskInput, TaskOut
863863
assert (tmp_path / schema).exists()
864864

865865

866+
def test_serializing_parts_with_discriminators(tmp_path: Path):
867+
class Foo(BaseModel):
868+
foo: str
869+
kind: Literal['foo'] = 'foo'
870+
871+
class Bar(BaseModel):
872+
bar: str
873+
kind: Literal['bar'] = 'bar'
874+
875+
items = [Foo(foo='foo'), Bar(bar='bar')]
876+
877+
dataset = Dataset[list[Foo | Bar]](cases=[Case(inputs=items)])
878+
yaml_path = tmp_path / 'test_cases.yaml'
879+
dataset.to_file(yaml_path)
880+
881+
loaded_dataset = Dataset[list[Foo | Bar]].from_file(yaml_path)
882+
assert loaded_dataset == snapshot(
883+
Dataset(
884+
name='test_cases',
885+
cases=[
886+
Case(
887+
name=None,
888+
inputs=[
889+
Foo(foo='foo'),
890+
Bar(bar='bar'),
891+
],
892+
)
893+
],
894+
)
895+
)
896+
897+
866898
def test_serialization_errors(tmp_path: Path):
867899
with pytest.raises(ValueError) as exc_info:
868900
Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(tmp_path / 'test_cases.abc')

0 commit comments

Comments
 (0)