|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -try: |
4 | | - import aind_data_schema |
5 | | -except ImportError as _e: |
6 | | - _e.add_note( |
7 | | - "The 'aind-data-schema' package is required to use this module. \ |
8 | | - Install the optional dependencies defined in `project.toml' \ |
9 | | - by running `pip install .[aind-services]`" |
| 3 | +import importlib.util |
| 4 | + |
| 5 | +if importlib.util.find_spec("aind_data_schema") is None: |
| 6 | + raise ImportError( |
| 7 | + "The 'aind-data-schema' package is required to use this module. " |
| 8 | + "Install the optional dependencies defined in `project.toml` " |
| 9 | + "by running `pip install .[aind-services]`" |
10 | 10 | ) |
11 | | - raise _e |
12 | 11 |
|
13 | 12 | import abc |
14 | 13 | import logging |
15 | | -from typing import Generic, TypeVar, Union |
16 | | - |
17 | | -import aind_data_schema.components.devices |
18 | | -import aind_data_schema.core.rig |
19 | | -import aind_data_schema.core.session |
| 14 | +from typing import Any, Generic, Type, TypeVar, Union |
20 | 15 |
|
21 | | -from . import _base |
| 16 | +from aind_data_schema.core import rig as ads_rig |
| 17 | +from aind_data_schema.core import session as ads_session |
| 18 | +from pydantic import BaseModel, create_model, model_validator |
22 | 19 |
|
23 | | -TAdsObject = TypeVar("TAdsObject", bound=Union[aind_data_schema.core.session.Session, aind_data_schema.core.rig.Rig]) |
| 20 | +from aind_behavior_experiment_launcher.data_mapper import _base |
24 | 21 |
|
25 | 22 | logger = logging.getLogger(__name__) |
26 | 23 |
|
| 24 | +_TAdsObject = TypeVar("_TAdsObject", bound=Union[ads_session.Session, ads_rig.Rig]) |
27 | 25 |
|
28 | | -class AindDataSchemaDataMapper(_base.DataMapper[TAdsObject], abc.ABC, Generic[TAdsObject]): |
| 26 | + |
| 27 | +class AindDataSchemaDataMapper(_base.DataMapper[_TAdsObject], abc.ABC, Generic[_TAdsObject]): |
29 | 28 | @property |
30 | 29 | @abc.abstractmethod |
31 | 30 | def session_name(self) -> str: ... |
32 | 31 |
|
33 | 32 |
|
34 | | -class AindDataSchemaSessionDataMapper(AindDataSchemaDataMapper[aind_data_schema.core.session.Session], abc.ABC): ... |
| 33 | +class AindDataSchemaSessionDataMapper(AindDataSchemaDataMapper[ads_session.Session], abc.ABC): ... |
| 34 | + |
| 35 | + |
| 36 | +class AindDataSchemaRigDataMapper(AindDataSchemaDataMapper[ads_rig.Rig], abc.ABC): ... |
| 37 | + |
| 38 | + |
| 39 | +_TModel = TypeVar("_TModel", bound=BaseModel) |
35 | 40 |
|
36 | 41 |
|
37 | | -class AindDataSchemaRigDataMapper(AindDataSchemaDataMapper[aind_data_schema.core.rig.Rig], abc.ABC): ... |
| 42 | +def create_encoding_model(model: Type[_TModel]) -> Type[_TModel]: |
| 43 | + """Creates a new BaseModel by wrapping the incoming model and adding a Before |
| 44 | + ModelValidator to replace _SPECIAL_CHARACTERS with the unicode, escaped, |
| 45 | + representation""" |
| 46 | + |
| 47 | + _SPECIAL_CHARACTERS = ".$" |
| 48 | + |
| 49 | + def _to_unicode_repr(character: str): |
| 50 | + if len(character) != 1: |
| 51 | + raise ValueError(f"Expected a single character, got {character}") |
| 52 | + return f"\\u{ord(character):04x}" |
| 53 | + |
| 54 | + def _aind_data_schema_encoder(cls, data: Any) -> Any: |
| 55 | + if isinstance(data, dict): |
| 56 | + return _sanitize_dict(data) |
| 57 | + return data |
| 58 | + |
| 59 | + def _sanitize_dict(value: dict) -> dict: |
| 60 | + if isinstance(value, dict): |
| 61 | + _keys = list(value.keys()) |
| 62 | + for key in _keys: |
| 63 | + if isinstance(value[key], dict): |
| 64 | + value[key] = _sanitize_dict(value[key]) |
| 65 | + if isinstance(sanitized_key := key, str): |
| 66 | + for char in _SPECIAL_CHARACTERS: |
| 67 | + if char in sanitized_key: |
| 68 | + sanitized_key = sanitized_key.replace(char, _to_unicode_repr(char)) |
| 69 | + value[sanitized_key] = value.pop(key) |
| 70 | + return value |
| 71 | + |
| 72 | + return create_model( |
| 73 | + f"_Wrapped{model.__class__.__name__}", |
| 74 | + __base__=model, |
| 75 | + __validators__={ |
| 76 | + "encoder": model_validator(mode="before")(_aind_data_schema_encoder) # type: ignore |
| 77 | + }, |
| 78 | + ) |
0 commit comments