diff --git a/src/nipanel/_convert.py b/src/nipanel/_convert.py index 8ea2b68..3b934fe 100644 --- a/src/nipanel/_convert.py +++ b/src/nipanel/_convert.py @@ -2,11 +2,13 @@ from __future__ import annotations +import enum import logging from collections.abc import Collection from typing import Any, Iterable from google.protobuf import any_pb2 +from nitypes.vector import Vector from nitypes.waveform import AnalogWaveform, ComplexWaveform from nipanel.converters import Converter @@ -73,17 +75,18 @@ VectorConverter(), ] -_CONVERTIBLE_COLLECTION_TYPES = { - frozenset, - list, - set, - tuple, -} - _CONVERTER_FOR_PYTHON_TYPE = {entry.python_typename: entry for entry in _CONVERTIBLE_TYPES} _CONVERTER_FOR_GRPC_TYPE = {entry.protobuf_typename: entry for entry in _CONVERTIBLE_TYPES} _SUPPORTED_PYTHON_TYPES = _CONVERTER_FOR_PYTHON_TYPE.keys() +_SKIPPED_COLLECTIONS = ( + str, # Handled by StrConverter + bytes, # Handled by BytesConverter + dict, # Unsupported data type + enum.Enum, # Handled by IntConverter + Vector, # Handled by VectorConverter +) + def to_any(python_value: object) -> any_pb2.Any: """Convert a Python object to a protobuf Any.""" @@ -97,7 +100,7 @@ def _get_best_matching_type(python_value: object) -> str: additional_info_string = _get_additional_type_info_string(python_value) container_types = [] - value_is_collection = any(_CONVERTIBLE_COLLECTION_TYPES.intersection(underlying_parents)) + value_is_collection = _is_collection_for_convert(python_value) # Variable to use when traversing down through collection types. working_python_value = python_value while value_is_collection: @@ -119,9 +122,7 @@ def _get_best_matching_type(python_value: object) -> str: # If this element is a collection, we want to continue traversing. Once we find a # non-collection, underlying_parents will refer to the candidates for the non- # collection type. - value_is_collection = any( - _CONVERTIBLE_COLLECTION_TYPES.intersection(underlying_parents) - ) + value_is_collection = _is_collection_for_convert(working_python_value) container_types.append(Collection) best_matching_type = None @@ -192,3 +193,9 @@ def _get_additional_type_info_string(python_value: object) -> str: return str(python_value.dtype) else: return "" + + +def _is_collection_for_convert(python_value: object) -> bool: + return isinstance(python_value, Collection) and not isinstance( + python_value, _SKIPPED_COLLECTIONS + )