Skip to content

Commit ede9092

Browse files
committed
Support List/Tuple/Dict of dataclasses via CLI, config files, and encode/decode
Add support for container-of-dataclass fields (List[DC], Tuple[DC, ...], Dict[str, DC]) as first-class citizens. These can now be set from the command line as YAML/JSON strings (e.g. --points '[{x: 1.0, y: 2.0}]'), from config files, and via encode/decode. Deeply nested structures are also supported. - Add is_dict_of_dataclasses utility and wire it into DataclassWrapper - Replace NotImplementedError with working FieldWrapper registration - Add 78 tests covering encode/decode, config file, CLI, and deep nesting - Document the feature in README, docs/step_by_step, and docs/api - Bump version to 0.3.2 - Drop Python 3.6 support (bump to >=3.7, remove dataclasses backport) - Update CI to Python 3.7-3.12, actions v4/v5 - Fix mutable default warnings in test fixtures
1 parent 1e0586f commit ede9092

File tree

11 files changed

+1322
-21
lines changed

11 files changed

+1322
-21
lines changed

.github/workflows/pytest.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,12 @@ jobs:
77
strategy:
88
matrix:
99
os: [ ubuntu-latest, macos-latest ]
10-
python-version: [ '3.6', '3.7', '3.8', '3.9', '3.10' ]
11-
exclude:
12-
- os: macos-latest
13-
python-version: '3.6'
10+
python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11', '3.12' ]
1411
runs-on: ${{ matrix.os }}
1512
steps:
16-
- uses: actions/checkout@v2
13+
- uses: actions/checkout@v4
1714
- name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }}
18-
uses: actions/setup-python@v1
15+
uses: actions/setup-python@v5
1916
with:
2017
python-version: ${{ matrix.python-version }}
2118
- name: Install dependencies

README.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,65 @@ Training my_third_exp...
243243
Saving to /share/experiments/my_third_exp
244244
```
245245

246+
#### Lists and Dicts of Dataclasses 🐲
247+
`pyrallis` also supports `List[Dataclass]`, `Tuple[Dataclass, ...]`, and `Dict[str, Dataclass]` fields. These can be set from a config file in the natural YAML structure, **or directly from the command line as a YAML/JSON string**:
248+
249+
```python
250+
from dataclasses import dataclass, field
251+
from typing import List, Dict
252+
import pyrallis
253+
254+
@dataclass
255+
class DatasetConfig:
256+
path: str = ''
257+
split: str = 'train'
258+
259+
@dataclass
260+
class TrainConfig:
261+
# A list of datasets to train on
262+
datasets: List[DatasetConfig] = field(default_factory=list)
263+
# Named dataset overrides per evaluation task
264+
eval_sets: Dict[str, DatasetConfig] = field(default_factory=dict)
265+
266+
@pyrallis.wrap()
267+
def main(cfg: TrainConfig):
268+
for ds in cfg.datasets:
269+
print(f'Training on: {ds.path} (split={ds.split})')
270+
for name, ds in cfg.eval_sets.items():
271+
print(f'Eval {name}: {ds.path} (split={ds.split})')
272+
```
273+
274+
Pass a YAML flow sequence or mapping from the command line:
275+
```console
276+
$ python train_model.py \
277+
--datasets '[{path: /data/imagenet, split: train}, {path: /data/coco, split: train}]' \
278+
--eval_sets '{imagenet_val: {path: /data/imagenet, split: val}, coco_val: {path: /data/coco, split: val}}'
279+
Training on: /data/imagenet (split=train)
280+
Training on: /data/coco (split=train)
281+
Eval imagenet_val: /data/imagenet (split=val)
282+
Eval coco_val: /data/coco (split=val)
283+
```
284+
285+
The equivalent YAML config file uses the standard nested structure:
286+
```yaml
287+
datasets:
288+
- path: /data/imagenet
289+
split: train
290+
- path: /data/coco
291+
split: train
292+
eval_sets:
293+
imagenet_val:
294+
path: /data/imagenet
295+
split: val
296+
coco_val:
297+
path: /data/coco
298+
split: val
299+
```
300+
301+
This also works with deeply nested structures -- for example a `List[Dataclass]` where each dataclass itself contains a `Dict[str, Dataclass]` field, and so on.
302+
303+
> CLI args always take priority over the config file, so you can mix and match: load a base config from a file and override specific container fields on the command line.
304+
246305
### 🐲 5/5 Easy Serialization with `pyrallis.dump` 🐲
247306
As your config get longer you will probably want to start working with configuration files. Pyrallis supports encoding a dataclass configuration into a `yaml` file 💾
248307

docs/api.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,15 @@ Reading sequences and dictionaries is done using the `yaml` syntax. Notice that
353353
$ python train_model.py --worker_inds=[2,18,42] --worker_names="{2: 'George', 18: 'Ben'}"
354354
```
355355

356+
Collections can also contain dataclass types. `#!python typing.List[SomeDataclass]`, `#!python typing.Tuple[SomeDataclass, ...]`, and `#!python typing.Dict[str, SomeDataclass]` are all supported. From the command line these are passed as YAML/JSON strings:
357+
358+
```console
359+
$ python train_model.py --points '[{x: 1.0, y: 2.0}, {x: 3.0, y: 4.0}]'
360+
$ python train_model.py --mapping '{origin: {x: 0.0, y: 0.0}, target: {x: 1.0, y: 2.0}}'
361+
```
362+
363+
Deeply nested structures (e.g. a `List[Dataclass]` where each dataclass contains a `Dict[str, Dataclass]` field) are also supported.
364+
356365
#### Typing Types
357366

358367
`typing.Any`

docs/step_by_step.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,65 @@ Training my_third_exp...
172172
Saving to /share/experiments/my_third_exp
173173
```
174174

175+
## Lists and Dicts of Dataclasses
176+
`pyrallis` also supports `List[Dataclass]`, `Tuple[Dataclass, ...]`, and `Dict[str, Dataclass]` fields. These can be set from a config file in the natural YAML structure, **or directly from the command line as a YAML/JSON string**:
177+
178+
```python
179+
from dataclasses import dataclass, field
180+
from typing import List, Dict
181+
import pyrallis
182+
183+
@dataclass
184+
class DatasetConfig:
185+
path: str = ''
186+
split: str = 'train'
187+
188+
@dataclass
189+
class TrainConfig:
190+
# A list of datasets to train on
191+
datasets: List[DatasetConfig] = field(default_factory=list)
192+
# Named dataset overrides per evaluation task
193+
eval_sets: Dict[str, DatasetConfig] = field(default_factory=dict)
194+
195+
@pyrallis.wrap()
196+
def main(cfg: TrainConfig):
197+
for ds in cfg.datasets:
198+
print(f'Training on: {ds.path} (split={ds.split})')
199+
for name, ds in cfg.eval_sets.items():
200+
print(f'Eval {name}: {ds.path} (split={ds.split})')
201+
```
202+
203+
Pass a YAML flow sequence or mapping from the command line:
204+
```console
205+
$ python train_model.py \
206+
--datasets '[{path: /data/imagenet, split: train}, {path: /data/coco, split: train}]' \
207+
--eval_sets '{imagenet_val: {path: /data/imagenet, split: val}, coco_val: {path: /data/coco, split: val}}'
208+
Training on: /data/imagenet (split=train)
209+
Training on: /data/coco (split=train)
210+
Eval imagenet_val: /data/imagenet (split=val)
211+
Eval coco_val: /data/coco (split=val)
212+
```
213+
214+
The equivalent YAML config file uses the standard nested structure:
215+
```yaml
216+
datasets:
217+
- path: /data/imagenet
218+
split: train
219+
- path: /data/coco
220+
split: train
221+
eval_sets:
222+
imagenet_val:
223+
path: /data/imagenet
224+
split: val
225+
coco_val:
226+
path: /data/coco
227+
split: val
228+
```
229+
230+
This also works with deeply nested structures -- for example a `List[Dataclass]` where each dataclass itself contains a `Dict[str, Dataclass]` field, and so on.
231+
232+
> CLI args always take priority over the config file, so you can mix and match: load a base config from a file and override specific container fields on the command line.
233+
175234
## Easy Serialization
176235
As your config get longer you will probably want to start working with configuration files. Pyrallis supports encoding a dataclass configuration into a `yaml` file 💾
177236

pyrallis/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.3.1"
1+
__version__ = "0.3.2"
22

33
from . import wrappers, utils
44
from .argparsing import wrap, parse

pyrallis/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,15 @@ def is_tuple_or_list_of_dataclasses(t: Type) -> bool:
173173
return is_tuple_or_list(t) and is_dataclass_type(get_item_type(t))
174174

175175

176+
def is_dict_of_dataclasses(t: Type) -> bool:
177+
if not is_dict(t):
178+
return False
179+
args = get_type_arguments(t)
180+
if args and len(args) == 2:
181+
return is_dataclass_type(args[1])
182+
return False
183+
184+
176185
def contains_dataclass_type_arg(t: Type) -> bool:
177186
if is_dataclass_type(t):
178187
return True

pyrallis/wrappers/dataclass_wrapper.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,15 @@ def __init__(
5151
if not field.init:
5252
continue
5353

54-
elif utils.is_tuple_or_list_of_dataclasses(field.type):
55-
raise NotImplementedError(
56-
f"Field {field.name} is of type {field.type}, which isn't "
57-
f"supported yet. (container of a dataclass type)"
54+
elif utils.is_tuple_or_list_of_dataclasses(field.type) or utils.is_dict_of_dataclasses(field.type):
55+
# Container of dataclasses (List/Tuple/Dict) - treat as a regular field.
56+
# Settable via config files, encode/decode, and via CLI as a YAML/JSON string
57+
# (e.g. --points '[{x: 1.0, y: 2.0}]' or --mapping '{key: {x: 1.0}}').
58+
field_wrapper = field_wrapper_class(field, parent=self, prefix=self.prefix)
59+
logger.debug(
60+
f"wrapped container-of-dataclass field at {field_wrapper.dest} has a default value of {field_wrapper.default}"
5861
)
62+
self.fields.append(field_wrapper)
5963

6064
elif dataclasses.is_dataclass(field.type):
6165
# handle a nested dataclass attribute

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ def find_version(*file_paths: str) -> str:
3333
"License :: OSI Approved :: MIT License",
3434
"Operating System :: OS Independent",
3535
],
36-
python_requires=">=3.6",
36+
python_requires=">=3.7",
3737
install_requires=[
3838
"typing_inspect",
39-
"dataclasses;python_version<'3.7'",
4039
'pyyaml'
4140
],
4241
setup_requires=["pre-commit"],

tests/conftest.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class HyperParameters(TestSetup):
181181
use_custom_likes: bool = True
182182

183183
# Gender model settings
184-
gender: TaskHyperParameters = TaskHyperParameters(
184+
gender: TaskHyperParameters = field(default_factory=lambda: TaskHyperParameters(
185185
"gender",
186186
num_layers=1,
187187
num_units=32,
@@ -190,10 +190,10 @@ class HyperParameters(TestSetup):
190190
dropout_rate=0.1,
191191
use_image_features=True,
192192
use_likes=True,
193-
)
193+
))
194194

195195
# Age Group Model settings
196-
age_group: TaskHyperParameters = TaskHyperParameters(
196+
age_group: TaskHyperParameters = field(default_factory=lambda: TaskHyperParameters(
197197
"age_group",
198198
num_layers=2,
199199
num_units=64,
@@ -202,10 +202,10 @@ class HyperParameters(TestSetup):
202202
dropout_rate=0.1,
203203
use_image_features=True,
204204
use_likes=True,
205-
)
205+
))
206206

207207
# Personality Model(s) settings:
208-
personality: TaskHyperParameters = TaskHyperParameters(
208+
personality: TaskHyperParameters = field(default_factory=lambda: TaskHyperParameters(
209209
"personality",
210210
num_layers=1,
211211
num_units=8,
@@ -214,6 +214,6 @@ class HyperParameters(TestSetup):
214214
dropout_rate=0.1,
215215
use_image_features=False,
216216
use_likes=False,
217-
)
217+
))
218218

219219
return HyperParameters

tests/test_inheritance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class ExtendedC(Base, TestSetup):
2222

2323
@dataclass
2424
class Inheritance(TestSetup):
25-
ext_b: ExtendedB = ExtendedB()
26-
ext_c: ExtendedC = ExtendedC()
25+
ext_b: ExtendedB = field(default_factory=ExtendedB)
26+
ext_c: ExtendedC = field(default_factory=ExtendedC)
2727

2828

2929
def test_simple_subclassing_no_args():

0 commit comments

Comments
 (0)