Skip to content

Commit 5e96236

Browse files
authored
Merge pull request #37 from AllenNeuralDynamics/feat-reencode-aind-data-schema-models
Add encoder for special characters for aind-data-schema models
2 parents 828066f + ce6cbab commit 5e96236

File tree

4 files changed

+79
-23
lines changed

4 files changed

+79
-23
lines changed

src/aind_behavior_experiment_launcher/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.3.1"
1+
__version__ = "0.3.2"
22

33
import logging
44
import logging.config
Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,78 @@
11
from __future__ import annotations
22

3-
try:
4-
import aind_data_schema
5-
except ImportError as _e:
6-
_e.add_note(
7-
"The 'aind-data-schema' package is required to use this module. \
8-
Install the optional dependencies defined in `project.toml' \
9-
by running `pip install .[aind-services]`"
3+
import importlib.util
4+
5+
if importlib.util.find_spec("aind_data_schema") is None:
6+
raise ImportError(
7+
"The 'aind-data-schema' package is required to use this module. "
8+
"Install the optional dependencies defined in `project.toml` "
9+
"by running `pip install .[aind-services]`"
1010
)
11-
raise _e
1211

1312
import abc
1413
import logging
15-
from typing import Generic, TypeVar, Union
16-
17-
import aind_data_schema.components.devices
18-
import aind_data_schema.core.rig
19-
import aind_data_schema.core.session
14+
from typing import Any, Generic, Type, TypeVar, Union
2015

21-
from . import _base
16+
from aind_data_schema.core import rig as ads_rig
17+
from aind_data_schema.core import session as ads_session
18+
from pydantic import BaseModel, create_model, model_validator
2219

23-
TAdsObject = TypeVar("TAdsObject", bound=Union[aind_data_schema.core.session.Session, aind_data_schema.core.rig.Rig])
20+
from aind_behavior_experiment_launcher.data_mapper import _base
2421

2522
logger = logging.getLogger(__name__)
2623

24+
_TAdsObject = TypeVar("_TAdsObject", bound=Union[ads_session.Session, ads_rig.Rig])
2725

28-
class AindDataSchemaDataMapper(_base.DataMapper[TAdsObject], abc.ABC, Generic[TAdsObject]):
26+
27+
class AindDataSchemaDataMapper(_base.DataMapper[_TAdsObject], abc.ABC, Generic[_TAdsObject]):
2928
@property
3029
@abc.abstractmethod
3130
def session_name(self) -> str: ...
3231

3332

34-
class AindDataSchemaSessionDataMapper(AindDataSchemaDataMapper[aind_data_schema.core.session.Session], abc.ABC): ...
33+
class AindDataSchemaSessionDataMapper(AindDataSchemaDataMapper[ads_session.Session], abc.ABC): ...
34+
35+
36+
class AindDataSchemaRigDataMapper(AindDataSchemaDataMapper[ads_rig.Rig], abc.ABC): ...
37+
38+
39+
_TModel = TypeVar("_TModel", bound=BaseModel)
3540

3641

37-
class AindDataSchemaRigDataMapper(AindDataSchemaDataMapper[aind_data_schema.core.rig.Rig], abc.ABC): ...
42+
def create_encoding_model(model: Type[_TModel]) -> Type[_TModel]:
43+
"""Creates a new BaseModel by wrapping the incoming model and adding a Before
44+
ModelValidator to replace _SPECIAL_CHARACTERS with the unicode, escaped,
45+
representation"""
46+
47+
_SPECIAL_CHARACTERS = ".$"
48+
49+
def _to_unicode_repr(character: str):
50+
if len(character) != 1:
51+
raise ValueError(f"Expected a single character, got {character}")
52+
return f"\\u{ord(character):04x}"
53+
54+
def _aind_data_schema_encoder(cls, data: Any) -> Any:
55+
if isinstance(data, dict):
56+
return _sanitize_dict(data)
57+
return data
58+
59+
def _sanitize_dict(value: dict) -> dict:
60+
if isinstance(value, dict):
61+
_keys = list(value.keys())
62+
for key in _keys:
63+
if isinstance(value[key], dict):
64+
value[key] = _sanitize_dict(value[key])
65+
if isinstance(sanitized_key := key, str):
66+
for char in _SPECIAL_CHARACTERS:
67+
if char in sanitized_key:
68+
sanitized_key = sanitized_key.replace(char, _to_unicode_repr(char))
69+
value[sanitized_key] = value.pop(key)
70+
return value
71+
72+
return create_model(
73+
f"_Wrapped{model.__class__.__name__}",
74+
__base__=model,
75+
__validators__={
76+
"encoder": model_validator(mode="before")(_aind_data_schema_encoder) # type: ignore
77+
},
78+
)

tests/test_data_mapper.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import unittest
22
from pathlib import Path
3-
from typing import Dict, List, Optional
3+
from typing import Any, Dict, List, Optional
44
from unittest.mock import patch
55

6+
from aind_data_schema.base import AindGeneric
67
from pydantic import BaseModel
78

9+
from aind_behavior_experiment_launcher.data_mapper.aind_data_schema import create_encoding_model
810
from aind_behavior_experiment_launcher.data_mapper.helpers import (
911
_sanity_snapshot_keys,
1012
snapshot_bonsai_environment,
@@ -65,5 +67,19 @@ def test_sanity_snapshot_keys_with_dots_and_dollars(self):
6567
self.assertEqual(result, expected)
6668

6769

70+
class TestAindDataMapper(unittest.TestCase):
71+
class MyMockModel(BaseModel):
72+
a_dict: Dict[str, Any]
73+
a_generic: AindGeneric
74+
75+
def test_encoding_with_illegal_characters(self):
76+
_input = {"key": "value", "$key.key": "value"}
77+
_expected = {"key": "value", "\\u0024key\\u002ekey": "value"}
78+
encoding_model = create_encoding_model(self.MyMockModel)
79+
test = encoding_model(a_dict=_input, a_generic=_input)
80+
self.assertEqual(test.a_dict, _expected)
81+
self.assertEqual(test.a_generic.model_dump(), _expected)
82+
83+
6884
if __name__ == "__main__":
6985
unittest.main()

uv.lock

Lines changed: 2 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)