Skip to content

Commit a5c7ad3

Browse files
committed
[feat] Allow dictionary of datasets as a field in a Group
1 parent ca8920c commit a5c7ad3

File tree

3 files changed

+205
-28
lines changed

3 files changed

+205
-28
lines changed

examples/custom_group.ipynb

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,151 @@
291291
"parse = Datastore.model_validate_hdf5(filepath)\n",
292292
"pprint(parse)"
293293
]
294+
},
295+
{
296+
"cell_type": "code",
297+
"execution_count": 7,
298+
"metadata": {},
299+
"outputs": [],
300+
"source": [
301+
"from typing import Dict\n",
302+
"\n",
303+
"from oqd_dataschema.base import CastDataset\n",
304+
"\n",
305+
"\n",
306+
"class A(GroupBase):\n",
307+
" data: Dict[str, CastDataset]"
308+
]
309+
},
310+
{
311+
"cell_type": "code",
312+
"execution_count": 8,
313+
"metadata": {},
314+
"outputs": [
315+
{
316+
"data": {
317+
"text/html": [
318+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Datastore</span><span style=\"font-weight: bold\">(</span>\n",
319+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">groups</span>=<span style=\"font-weight: bold\">{</span>\n",
320+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'A'</span>: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">A</span><span style=\"font-weight: bold\">(</span>\n",
321+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">attrs</span>=<span style=\"font-weight: bold\">{}</span>,\n",
322+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">data</span>=<span style=\"font-weight: bold\">{</span>\n",
323+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'x'</span>: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Dataset</span><span style=\"font-weight: bold\">(</span>\n",
324+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'float64'</span>,\n",
325+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">shape</span>=<span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>,<span style=\"font-weight: bold\">)</span>,\n",
326+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">data</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">array</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.90326782</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.17363226</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.13827196</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8917397</span> , <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.68175954</span>,\n",
327+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.47647195</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.88443397</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.75703312</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.74991232</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.68161151</span><span style=\"font-weight: bold\">])</span>,\n",
328+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">attrs</span>=<span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'type'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'mytype'</span><span style=\"font-weight: bold\">}</span>\n",
329+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">)</span>\n",
330+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
331+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">class_</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'A'</span>\n",
332+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">)</span>\n",
333+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">}</span>,\n",
334+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">attrs</span>=<span style=\"font-weight: bold\">{}</span>\n",
335+
"<span style=\"font-weight: bold\">)</span>\n",
336+
"</pre>\n"
337+
],
338+
"text/plain": [
339+
"\u001b[1;35mDatastore\u001b[0m\u001b[1m(\u001b[0m\n",
340+
"\u001b[2;32m│ \u001b[0m\u001b[33mgroups\u001b[0m=\u001b[1m{\u001b[0m\n",
341+
"\u001b[2;32m│ │ \u001b[0m\u001b[32m'A'\u001b[0m: \u001b[1;35mA\u001b[0m\u001b[1m(\u001b[0m\n",
342+
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n",
343+
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mdata\u001b[0m=\u001b[1m{\u001b[0m\n",
344+
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[32m'x'\u001b[0m: \u001b[1;35mDataset\u001b[0m\u001b[1m(\u001b[0m\n",
345+
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mdtype\u001b[0m=\u001b[32m'float64'\u001b[0m,\n",
346+
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m10\u001b[0m,\u001b[1m)\u001b[0m,\n",
347+
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mdata\u001b[0m=\u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.90326782\u001b[0m, \u001b[1;36m0.17363226\u001b[0m, \u001b[1;36m0.13827196\u001b[0m, \u001b[1;36m0.8917397\u001b[0m , \u001b[1;36m0.68175954\u001b[0m,\n",
348+
"\u001b[2;32m│ \u001b[0m\u001b[1;36m0.47647195\u001b[0m, \u001b[1;36m0.88443397\u001b[0m, \u001b[1;36m0.75703312\u001b[0m, \u001b[1;36m0.74991232\u001b[0m, \u001b[1;36m0.68161151\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n",
349+
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'type'\u001b[0m: \u001b[32m'mytype'\u001b[0m\u001b[1m}\u001b[0m\n",
350+
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m)\u001b[0m\n",
351+
"\u001b[2;32m│ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
352+
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_\u001b[0m=\u001b[32m'A'\u001b[0m\n",
353+
"\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n",
354+
"\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n",
355+
"\u001b[2;32m│ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
356+
"\u001b[1m)\u001b[0m\n"
357+
]
358+
},
359+
"metadata": {},
360+
"output_type": "display_data"
361+
}
362+
],
363+
"source": [
364+
"filepath = pathlib.Path(\"test.h5\")\n",
365+
"\n",
366+
"datastore = Datastore(\n",
367+
" groups={\n",
368+
" \"A\": A(data={\"x\": Dataset(data=np.random.rand(10), attrs={\"type\": \"mytype\"})})\n",
369+
" }\n",
370+
")\n",
371+
"pprint(datastore)\n",
372+
"datastore.model_dump_hdf5(filepath)"
373+
]
374+
},
375+
{
376+
"cell_type": "code",
377+
"execution_count": 9,
378+
"metadata": {},
379+
"outputs": [
380+
{
381+
"data": {
382+
"text/html": [
383+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Datastore</span><span style=\"font-weight: bold\">(</span>\n",
384+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">groups</span>=<span style=\"font-weight: bold\">{</span>\n",
385+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'A'</span>: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">A</span><span style=\"font-weight: bold\">(</span>\n",
386+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">attrs</span>=<span style=\"font-weight: bold\">{}</span>,\n",
387+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">data</span>=<span style=\"font-weight: bold\">{</span>\n",
388+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'x'</span>: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Dataset</span><span style=\"font-weight: bold\">(</span>\n",
389+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'float64'</span>,\n",
390+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">shape</span>=<span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>,<span style=\"font-weight: bold\">)</span>,\n",
391+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">data</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">array</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.90326782</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.17363226</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.13827196</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8917397</span> , <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.68175954</span>,\n",
392+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.47647195</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.88443397</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.75703312</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.74991232</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.68161151</span><span style=\"font-weight: bold\">])</span>,\n",
393+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">attrs</span>=<span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'type'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'mytype'</span><span style=\"font-weight: bold\">}</span>\n",
394+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">)</span>\n",
395+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
396+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">class_</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'A'</span>\n",
397+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">)</span>\n",
398+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">}</span>,\n",
399+
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">attrs</span>=<span style=\"font-weight: bold\">{}</span>\n",
400+
"<span style=\"font-weight: bold\">)</span>\n",
401+
"</pre>\n"
402+
],
403+
"text/plain": [
404+
"\u001b[1;35mDatastore\u001b[0m\u001b[1m(\u001b[0m\n",
405+
"\u001b[2;32m│ \u001b[0m\u001b[33mgroups\u001b[0m=\u001b[1m{\u001b[0m\n",
406+
"\u001b[2;32m│ │ \u001b[0m\u001b[32m'A'\u001b[0m: \u001b[1;35mA\u001b[0m\u001b[1m(\u001b[0m\n",
407+
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n",
408+
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mdata\u001b[0m=\u001b[1m{\u001b[0m\n",
409+
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[32m'x'\u001b[0m: \u001b[1;35mDataset\u001b[0m\u001b[1m(\u001b[0m\n",
410+
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mdtype\u001b[0m=\u001b[32m'float64'\u001b[0m,\n",
411+
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m10\u001b[0m,\u001b[1m)\u001b[0m,\n",
412+
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mdata\u001b[0m=\u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.90326782\u001b[0m, \u001b[1;36m0.17363226\u001b[0m, \u001b[1;36m0.13827196\u001b[0m, \u001b[1;36m0.8917397\u001b[0m , \u001b[1;36m0.68175954\u001b[0m,\n",
413+
"\u001b[2;32m│ \u001b[0m\u001b[1;36m0.47647195\u001b[0m, \u001b[1;36m0.88443397\u001b[0m, \u001b[1;36m0.75703312\u001b[0m, \u001b[1;36m0.74991232\u001b[0m, \u001b[1;36m0.68161151\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n",
414+
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'type'\u001b[0m: \u001b[32m'mytype'\u001b[0m\u001b[1m}\u001b[0m\n",
415+
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m)\u001b[0m\n",
416+
"\u001b[2;32m│ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
417+
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_\u001b[0m=\u001b[32m'A'\u001b[0m\n",
418+
"\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n",
419+
"\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n",
420+
"\u001b[2;32m│ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
421+
"\u001b[1m)\u001b[0m\n"
422+
]
423+
},
424+
"metadata": {},
425+
"output_type": "display_data"
426+
}
427+
],
428+
"source": [
429+
"parse = Datastore.model_validate_hdf5(filepath)\n",
430+
"pprint(parse)"
431+
]
432+
},
433+
{
434+
"cell_type": "code",
435+
"execution_count": null,
436+
"metadata": {},
437+
"outputs": [],
438+
"source": []
294439
}
295440
],
296441
"metadata": {

src/oqd_dataschema/base.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,12 @@ def validate_data_matches_shape_dtype(self):
165165
def __getitem__(self, idx):
166166
return self.data[idx]
167167

168+
@classmethod
169+
def _is_dataset_type(cls, type_):
170+
return type_ == cls or (
171+
typing.get_origin(type_) is Annotated and type_.__origin__ is cls
172+
)
173+
168174

169175
def _constrain_dtype(dataset, *, dtype_constraint=None):
170176
if (not isinstance(dtype_constraint, str)) and isinstance(
@@ -261,8 +267,14 @@ def __init_subclass__(cls, **kwargs):
261267

262268
if (
263269
k not in ["class_", "attrs"]
264-
and v not in [Dataset, ClassVar]
265-
and not (typing.get_origin(v) == Annotated and v.__origin__ is Dataset)
270+
and v is not ClassVar
271+
and not Dataset._is_dataset_type(v)
272+
and not (typing.get_origin(v) is Annotated and v.__origin__ is Dataset)
273+
and not (
274+
typing.get_origin(v) is dict
275+
and v.__args__[0] is str
276+
and Dataset._is_dataset_type(v.__args__[1])
277+
)
266278
):
267279
raise TypeError(
268280
"All fields of `GroupBase` have to be of type `Dataset`."

0 commit comments

Comments
 (0)