Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions src/hdmf/build/classgenerator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from copy import deepcopy
from datetime import datetime, date
from collections.abc import Callable
import warnings

import numpy as np

from ..container import Container, Data, MultiContainerInterface
from ..spec import AttributeSpec, LinkSpec, RefSpec, GroupSpec
from ..spec import AttributeSpec, LinkSpec, RefSpec, GroupSpec, DatasetSpec
from ..spec.spec import BaseStorageSpec, ZERO_OR_MANY, ONE_OR_MANY
from ..utils import docval, getargs, ExtenderMeta, get_docval, popargs, AllowPositional

Expand Down Expand Up @@ -79,7 +80,7 @@ def generate_class(self, **kwargs):
break # each field_spec should be processed by only one generator

for class_generator in self.__custom_generators:
class_generator.post_process(classdict, bases, docval_args, spec)
class_generator.post_process(classdict, bases, docval_args, spec, type_map)

for class_generator in reversed(self.__custom_generators):
# go in reverse order so that base init is added first and
Expand Down Expand Up @@ -252,7 +253,7 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
docval_arg = dict(
name=attr_name,
doc=field_spec.doc,
type=cls._get_type(field_spec, type_map)
type=dtype,
)
shape = getattr(field_spec, 'shape', None)
if shape is not None:
Expand Down Expand Up @@ -285,12 +286,13 @@ def _add_to_docval_args(cls, docval_args, arg, err_if_present=False):
docval_args.append(arg)

@classmethod
def post_process(cls, classdict, bases, docval_args, spec):
def post_process(cls, classdict, bases, docval_args, spec, type_map):
"""Convert classdict['__fields__'] to tuple and update docval args for a fixed name and default name.
:param classdict: The class dictionary to convert with '__fields__' key (or a different bases[0]._fieldsname)
:param bases: The list of base classes.
:param docval_args: The dict of docval arguments.
:param spec: The spec for the container class to generate.
:param type_map: The type map to use.
"""
# convert classdict['__fields__'] from list to tuple if present
for b in bases:
Expand All @@ -308,6 +310,33 @@ def post_process(cls, classdict, bases, docval_args, spec):
# set default name in docval args if provided
cls._set_default_name(docval_args, spec.default_name)

if isinstance(spec, DatasetSpec):
# handle the data field specially
# fixed and default values are not supported for datasets
if getattr(spec, 'value', None) is not None:
warnings.warn(
"Generating a class for a dataset with a fixed value is not supported. "
"The fixed value will be ignored."
)
if getattr(spec, 'default_value', None) is not None:
warnings.warn(
"Generating a class for a dataset with a default value is not supported. "
"The default value will be ignored."
)

data_docval_arg = dict(name='data', doc=spec.doc)
shape = spec.shape
if shape is None and spec.dims is None:
if spec.dtype is not None:
dtype = cls._get_type_from_spec_dtype(spec.dtype)
else:
dtype = ('scalar_data', 'array_data', 'data')
else:
dtype = ('array_data', 'data')
data_docval_arg['shape'] = shape
data_docval_arg['type'] = dtype
cls._add_to_docval_args(docval_args, data_docval_arg)

@classmethod
def _get_attrs_not_to_set_init(cls, classdict, parent_docval_args):
return parent_docval_args
Expand Down Expand Up @@ -413,12 +442,13 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
cls._add_to_docval_args(docval_args, docval_arg)

@classmethod
def post_process(cls, classdict, bases, docval_args, spec):
def post_process(cls, classdict, bases, docval_args, spec, type_map):
"""Add MultiContainerInterface to the list of base classes.
:param classdict: The class dictionary.
:param bases: The list of base classes.
:param docval_args: The dict of docval arguments.
:param spec: The spec for the container class to generate.
:param type_map: The type map to use.
"""
if '__clsconf__' in classdict:
# do not add MCI as a base if a base is already a subclass of MultiContainerInterface
Expand Down
36 changes: 26 additions & 10 deletions src/hdmf/build/objectmapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import logging
import re
import warnings
Expand Down Expand Up @@ -92,6 +93,17 @@ def _ascii(s):
raise ValueError("Expected unicode or ascii string, got %s" % type(s))


def _isoformat(s):
"""
A helper function for converting to ISO format
"""
if isinstance(s, (datetime.datetime, datetime.date)):
return s.isoformat()
elif isinstance(s, str): # probably already converted to isoformat
return s
else:
raise ValueError("Expected datetime, got %s" % type(s))

class ObjectMapper(metaclass=ExtenderMeta):
'''A class for mapping between Spec objects and AbstractContainer attributes

Expand Down Expand Up @@ -125,8 +137,8 @@ class ObjectMapper(metaclass=ExtenderMeta):
"utf-8": _unicode,
"ascii": _ascii,
"bytes": _ascii,
"isodatetime": _ascii,
"datetime": _ascii,
"isodatetime": _isoformat,
"datetime": _isoformat,
}

__no_convert = set()
Expand Down Expand Up @@ -230,8 +242,8 @@ def convert_dtype(cls, spec, value, spec_dtype=None) -> tuple: # noqa: C901
else:
ret = value.astype('U')
ret_dtype = "utf8"
elif spec_dtype_type is _ascii:
ret = value.astype('S')
elif spec_dtype_type in (_ascii, _isoformat):
ret = value.astype('S') # this works for datetime objects
ret_dtype = "ascii"
else:
dtype_func, warning_msg = cls.__resolve_numeric_dtype(value.dtype, spec_dtype_type)
Expand All @@ -245,7 +257,7 @@ def convert_dtype(cls, spec, value, spec_dtype=None) -> tuple: # noqa: C901
if len(value) == 0:
if spec_dtype_type is _unicode:
ret_dtype = 'utf8'
elif spec_dtype_type is _ascii:
elif spec_dtype_type in (_ascii, _isoformat):
ret_dtype = 'ascii'
else:
ret_dtype = spec_dtype_type
Expand All @@ -261,15 +273,16 @@ def convert_dtype(cls, spec, value, spec_dtype=None) -> tuple: # noqa: C901
ret = value
if spec_dtype_type is _unicode:
ret_dtype = "utf8"
elif spec_dtype_type is _ascii:
elif spec_dtype_type in (_ascii, _isoformat):
ret_dtype = "ascii"
else:
ret_dtype, warning_msg = cls.__resolve_numeric_dtype(value.dtype, spec_dtype_type)
else:
if spec_dtype_type in (_unicode, _ascii):
ret_dtype = 'ascii'
if spec_dtype_type in (_unicode, _ascii, _isoformat):
if spec_dtype_type is _unicode:
ret_dtype = 'utf8'
else:
ret_dtype = 'ascii'
ret = spec_dtype_type(value)
else:
dtype_func, warning_msg = cls.__resolve_numeric_dtype(type(value), spec_dtype_type)
Expand Down Expand Up @@ -343,6 +356,8 @@ def __check_edgecases(cls, spec, value, spec_dtype): # noqa: C901
elif np.issubdtype(value.dtype, np.dtype('O')):
# Only variable-length strings should ever appear as generic objects.
# Everything else should have a well-defined type
# NOTE: a datetime object would be converted to a string by this check
# but users should not provide arrays of datetime objects to an untyped/generic spec
ret_dtype = 'utf8'
else:
ret_dtype = value.dtype.type
Expand All @@ -357,7 +372,7 @@ def __check_edgecases(cls, spec, value, spec_dtype): # noqa: C901
cls.__check_convert_numeric(ret_dtype)
if ret_dtype is str:
ret_dtype = 'utf8'
elif ret_dtype is bytes:
elif ret_dtype in (bytes, datetime.datetime, datetime.date):
ret_dtype = 'ascii'
return value, ret_dtype
if isinstance(spec_dtype, RefSpec):
Expand Down Expand Up @@ -636,6 +651,7 @@ def __get_data_type(cls, spec):

def __convert_string(self, value, spec):
"""Convert string types to the specified dtype."""
# TODO: combine this with the logic in convert_dtype
def __apply_string_type(value, string_type):
# NOTE: if a user passes a h5py.Dataset that is not wrapped with a hdmf.utils.StrDataset,
# then this conversion may not be correct. Users should unpack their string h5py.Datasets
Expand All @@ -660,7 +676,7 @@ def __apply_string_type(value, string_type):
string_type = str
elif 'ascii' in spec.dtype:
string_type = bytes
elif 'isodatetime' in spec.dtype:
elif 'datetime' in spec.dtype:
def string_type(x):
return x.isoformat() # method works for both date and datetime
if string_type is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/hdmf/common/io/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,13 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
# do not add DynamicTable columns to init docval

@classmethod
def post_process(cls, classdict, bases, docval_args, spec):
def post_process(cls, classdict, bases, docval_args, spec, type_map):
"""Convert classdict['__columns__'] to tuple.
:param classdict: The class dictionary.
:param bases: The list of base classes.
:param docval_args: The dict of docval arguments.
:param spec: The spec for the container class to generate.
:param type_map: The type map to use.
"""
# convert classdict['__columns__'] from list to tuple if present
columns = classdict.get('__columns__')
Expand Down
3 changes: 2 additions & 1 deletion src/hdmf/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import copy as _copy
import datetime
import re
import types
import warnings
Expand All @@ -12,7 +13,7 @@

__macros = {
'array_data': [np.ndarray, list, tuple, h5py.Dataset],
'scalar_data': [str, int, float, bytes, bool],
'scalar_data': [str, int, float, bytes, bool, datetime.datetime, datetime.date],
'data': []
}

Expand Down
43 changes: 37 additions & 6 deletions tests/unit/build_tests/test_classgenerator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import numpy as np
import os
import shutil
Expand Down Expand Up @@ -35,7 +36,7 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
classdict.setdefault('process_field_spec', list()).append(attr_name)

@classmethod
def post_process(cls, classdict, bases, docval_args, spec):
def post_process(cls, classdict, bases, docval_args, spec, type_map):
classdict['post_process'] = True

spec = GroupSpec(
Expand Down Expand Up @@ -510,6 +511,36 @@ def test_multi_container_spec_one_or_more_ok(self):
)
assert len(multi.bars) == 1

def test_get_class_include_scalar_datetime_attribute(self):
"""Test that get_class resolves a scalar datetime attribute."""
goo_spec = GroupSpec(
doc='A test group that has a scalar datetime attribute',
data_type_def='Goo',
attributes=[
AttributeSpec(
name='attr1',
doc='a scalar datetime attribute',
dtype='datetime',
),
]
)
self.spec_catalog.register_spec(goo_spec, 'extension.yaml')
goo_cls = self.type_map.get_dt_container_cls('Goo', CORE_NAMESPACE)
goo = goo_cls(name='my_goo', attr1=datetime.datetime(2020, 1, 1, 0, 0, 0))
self.assertEqual(goo.attr1, datetime.datetime(2020, 1, 1, 0, 0, 0))

def test_get_class_include_scalar_datetime_dataset(self):
"""Test that get_class resolves a scalar datetime dataset."""
goo_spec = DatasetSpec(
doc='A test dataset with dtype datetime',
data_type_def='Goo',
dtype='datetime',
)
self.spec_catalog.register_spec(goo_spec, 'extension.yaml')
goo_cls = self.type_map.get_dt_container_cls('Goo', CORE_NAMESPACE)
goo = goo_cls(name='my_goo', data=datetime.datetime(2020, 1, 1, 0, 0, 0))
self.assertEqual(goo.data, datetime.datetime(2020, 1, 1, 0, 0, 0))


class TestDynamicContainerFixedValue(TestCase):

Expand Down Expand Up @@ -1321,7 +1352,7 @@ def test_post_process_fixed_name(self):
docval_args = [{'name': 'name', 'type': str, 'doc': 'name'},
{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute',
'shape': [None]}]
CustomClassGenerator.post_process(classdict, bases, docval_args, spec)
CustomClassGenerator.post_process(classdict, bases, docval_args, spec, self.type_map)

expected = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute',
'shape': [None]}]
Expand All @@ -1348,7 +1379,7 @@ def test_post_process_default_name(self):
docval_args = [{'name': 'name', 'type': str, 'doc': 'name'},
{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute',
'shape': [None]}]
CustomClassGenerator.post_process(classdict, bases, docval_args, spec)
CustomClassGenerator.post_process(classdict, bases, docval_args, spec, self.type_map)

expected = [{'name': 'name', 'type': str, 'doc': 'name', 'default': 'MyBaz'},
{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute',
Expand Down Expand Up @@ -1450,7 +1481,7 @@ def test_post_process(self):
)
bases = [Bar]
docval_args = []
MCIClassGenerator.post_process(classdict, bases, docval_args, multi_spec)
MCIClassGenerator.post_process(classdict, bases, docval_args, multi_spec, self.type_map)
self.assertEqual(bases, [Bar, MultiContainerInterface])

def test_post_process_already_multi(self):
Expand Down Expand Up @@ -1478,7 +1509,7 @@ class Multi1(MultiContainerInterface):
)
bases = [Multi1]
docval_args = []
MCIClassGenerator.post_process(classdict, bases, docval_args, multi_spec)
MCIClassGenerator.post_process(classdict, bases, docval_args, multi_spec, self.type_map)
self.assertEqual(bases, [Multi1])

def test_post_process_container(self):
Expand All @@ -1505,5 +1536,5 @@ class Multi1(MultiContainerInterface):
)
bases = [Container]
docval_args = []
MCIClassGenerator.post_process(classdict, bases, docval_args, multi_spec)
MCIClassGenerator.post_process(classdict, bases, docval_args, multi_spec, self.type_map)
self.assertEqual(bases, [MultiContainerInterface, Container])
Loading
Loading