diff --git a/src/hdmf/build/classgenerator.py b/src/hdmf/build/classgenerator.py index 73dc30a12..63e515e8d 100644 --- a/src/hdmf/build/classgenerator.py +++ b/src/hdmf/build/classgenerator.py @@ -154,7 +154,7 @@ def _get_container_type(cls, type_name, type_map): return container_type @classmethod - def _get_type(cls, spec, type_map): + def _get_type(cls, spec, type_map, parent_spec): """Get the type of a spec for use in docval. Returns a container class, a type, a tuple of types, ('array_data', 'data') for specs with non-scalar shape, or (Data, Container) when an attribute reference target has not been mapped to a container @@ -176,6 +176,10 @@ def _get_type(cls, spec, type_map): if isinstance(spec, LinkSpec): return cls._get_container_type(spec.target_type, type_map) if spec.data_type is not None: + if spec.data_type == parent_spec.data_type_def: # handle case where A contains A + return spec.data_type # docval handles class names as strings specially + # TODO handle the rare case where A contains the definition of B which contains A + # the workaround is to define B separately (avoid nested type definitions) return cls._get_container_type(spec.data_type, type_map) if spec.shape is None and spec.dims is None: return cls._get_type_from_spec_dtype(spec.dtype) @@ -217,7 +221,7 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i :param spec: The spec for the container class to generate. """ field_spec = not_inherited_fields[attr_name] - dtype = cls._get_type(field_spec, type_map) + dtype = cls._get_type(field_spec, type_map, spec) fields_conf = {'name': attr_name, 'doc': field_spec['doc']} if cls._ischild(dtype) and issubclass(parent_cls, Container) and not isinstance(field_spec, LinkSpec): @@ -229,7 +233,7 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i docval_arg = dict( name=attr_name, doc=field_spec.doc, - type=cls._get_type(field_spec, type_map) + type=cls._get_type(field_spec, type_map, spec) ) shape = getattr(field_spec, 'shape', None) if shape is not None: @@ -343,7 +347,7 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i field_spec = not_inherited_fields[attr_name] field_clsconf = dict( attr=attr_name, - type=cls._get_type(field_spec, type_map), + type=cls._get_type(field_spec, type_map, spec), add='add_{}'.format(attr_name), get='get_{}'.format(attr_name), create='create_{}'.format(attr_name) @@ -354,7 +358,7 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i docval_arg = dict( name=attr_name, doc=field_spec.doc, - type=(list, tuple, dict, cls._get_type(field_spec, type_map)) + type=(list, tuple, dict, cls._get_type(field_spec, type_map, spec)) ) if cls._check_spec_optional(field_spec, spec): docval_arg['default'] = getattr(field_spec, 'default_value', None) diff --git a/src/hdmf/build/manager.py b/src/hdmf/build/manager.py index 744e8ec9b..73ceed4dd 100644 --- a/src/hdmf/build/manager.py +++ b/src/hdmf/build/manager.py @@ -527,23 +527,26 @@ def get_dt_container_cls(self, **kwargs): def __check_dependent_types(self, spec, namespace): """Ensure that classes for all types used by this type exist in this namespace and generate them if not. """ - def __check_dependent_types_helper(spec, namespace): + def __check_dependent_types_helper(spec, namespace, higher_types): if isinstance(spec, (GroupSpec, DatasetSpec)): - if spec.data_type_inc is not None: - self.get_dt_container_cls(spec.data_type_inc, namespace) # TODO handle recursive definitions + if spec.data_type_inc is not None and spec.data_type_inc not in higher_types: + self.get_dt_container_cls(spec.data_type_inc, namespace) if spec.data_type_def is not None: # nested type definition self.get_dt_container_cls(spec.data_type_def, namespace) - else: # spec is a LinkSpec + elif spec.target_type not in higher_types: # spec is a LinkSpec self.get_dt_container_cls(spec.target_type, namespace) if isinstance(spec, GroupSpec): + if spec.data_type_def is not None: + higher_types.add(spec.data_type_def) + # NOTE the same higher_types set is used in all calls and may grow in execution branches for child_spec in (spec.groups + spec.datasets + spec.links): - __check_dependent_types_helper(child_spec, namespace) + __check_dependent_types_helper(child_spec, namespace, higher_types=higher_types) if spec.data_type_inc is not None: self.get_dt_container_cls(spec.data_type_inc, namespace) if isinstance(spec, GroupSpec): for child_spec in (spec.groups + spec.datasets + spec.links): - __check_dependent_types_helper(child_spec, namespace) + __check_dependent_types_helper(child_spec, namespace, higher_types=set([spec.data_type_def])) def __get_parent_cls(self, namespace, data_type, spec): dt_hier = self.__ns_catalog.get_hierarchy(namespace, data_type) diff --git a/src/hdmf/container.py b/src/hdmf/container.py index 752e98e48..3a714ac5a 100644 --- a/src/hdmf/container.py +++ b/src/hdmf/container.py @@ -885,6 +885,12 @@ def __build_conf_methods(cls, conf_dict, conf_index, multi): # get container type container_type = conf_dict.get('type') + if isinstance(container_type, str): # handle self-referencing/recursive type + for supercls in cls.__mro__: + if container_type == supercls.__name__: + container_type = supercls + break + if container_type is None: msg = "MultiContainerInterface subclass %s is missing 'type' key in __clsconf__" % cls.__name__ if multi: