Skip to content

Commit f083534

Browse files
Refactoring to use pydantic-gubbins
1 parent ec18de9 commit f083534

File tree

13 files changed

+129
-144
lines changed

13 files changed

+129
-144
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# This workflow will upload a Python Package to PyPI when a release is created
2+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3+
4+
# This workflow uses actions that are not certified by GitHub.
5+
# They are provided by a third-party and are governed by
6+
# separate terms of service, privacy policy, and support
7+
# documentation.
8+
9+
name: Upload Python Package
10+
11+
on:
12+
release:
13+
types: [published]
14+
15+
permissions:
16+
contents: read
17+
18+
jobs:
19+
release-build:
20+
runs-on: ubuntu-latest
21+
22+
steps:
23+
- uses: actions/checkout@v4
24+
25+
- uses: actions/setup-python@v5
26+
with:
27+
python-version: "3.x"
28+
29+
- name: Build release distributions
30+
run: |
31+
# NOTE: put your own distribution build steps here.
32+
python -m pip install build
33+
python -m build
34+
35+
- name: Upload distributions
36+
uses: actions/upload-artifact@v4
37+
with:
38+
name: release-dists
39+
path: dist/
40+
41+
pypi-publish:
42+
runs-on: ubuntu-latest
43+
needs:
44+
- release-build
45+
permissions:
46+
# IMPORTANT: this permission is mandatory for trusted publishing
47+
id-token: write
48+
49+
# Dedicated environments with protections for publishing are strongly recommended.
50+
# For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules
51+
environment:
52+
name: pypi
53+
# OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
54+
# url: https://pypi.org/p/YOURPROJECT
55+
#
56+
# ALTERNATIVE: if your GitHub Release name is the PyPI project version string
57+
# ALTERNATIVE: exactly, uncomment the following line instead:
58+
# url: https://pypi.org/project/YOURPROJECT/${{ github.event.release.name }}
59+
60+
steps:
61+
- name: Retrieve release distributions
62+
uses: actions/download-artifact@v4
63+
with:
64+
name: release-dists
65+
path: dist/
66+
67+
- name: Publish release distributions to PyPI
68+
uses: pypa/gh-action-pypi-publish@release/v1
69+
with:
70+
packages-dir: dist/

object_model/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from ._dataclasses import Base, Immutable, NamedPersistable, Persistable
22
from ._json import dump, dumps, load, loads
33
from ._pydantic import BaseModel, ImmutableModel, NamedPersistableModel, PersistableModel
4-
from ._typing import Subclass
54
from ._descriptors import Id
65

76

object_model/_json.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from dataclasses import MISSING, is_dataclass, fields
2-
from orjson import loads as __loads, dumps as __dumps
2+
from orjson import dumps as __dumps
33
from pydantic import BaseModel, ConfigDict, TypeAdapter
44
from pydantic.alias_generators import to_camel
55
from typing import Any
66

77

8-
from ._type_registry import TYPE_KEY, get_type
8+
from ._type_registry import get_type
99

1010

11-
__type_adaptors: dict[str, TypeAdapter] = {}
11+
__type_adaptors: dict[type, TypeAdapter] = {}
1212

1313

1414
def get_type_adaptor(typ: type) -> TypeAdapter:
@@ -27,43 +27,32 @@ def get_type_adaptor(typ: type) -> TypeAdapter:
2727

2828
def dump(data: Any) -> dict[str, Any]:
2929
if isinstance(data, BaseModel):
30-
return data.model_dump(include={*data.model_fields_set, TYPE_KEY}, by_alias=True)
30+
return data.model_dump(exclude_unset=True, by_alias=True)
3131
elif is_dataclass(data):
32-
flds = [f.name for f in fields(data) if f.default is MISSING or f.default != getattr(data, f.name)]
32+
flds = set(f.name for f in fields(data) if f.default is MISSING or f.default != getattr(data, f.name))
3333
return get_type_adaptor(type(data)).dump_python(data, include=flds)
3434
else:
3535
raise RuntimeError("Unsupported type")
3636

3737

3838
def dumps(data: Any) -> bytes:
3939
if isinstance(data, BaseModel):
40-
return data.model_dump_json(include={*data.model_fields_set, TYPE_KEY}, by_alias=True).encode("utf-8")
40+
return data.model_dump_json(exclude_unset=True, by_alias=True).encode()
4141
elif is_dataclass(data):
42-
flds = [f.name for f in fields(data) if f.default is MISSING or f.default != getattr(data, f.name)]
42+
flds = set(f.name for f in fields(data) if f.default is MISSING or f.default != getattr(data, f.name))
4343
return get_type_adaptor(type(data)).dump_json(data, by_alias=True, include=flds)
4444
else:
4545
return __dumps(data)
4646

4747

48-
def load(data: dict[str, Any]) -> Any:
49-
type_name = data.get(TYPE_KEY)
50-
if type_name is None:
51-
raise RuntimeError("No type in data")
52-
53-
typ = get_type(type_name)
48+
def load(data: dict[str, Any], typ: type | str) -> Any:
49+
if isinstance(typ, str):
50+
typ = get_type(typ)
5451

5552
return typ.model_validate(data) if issubclass(typ, BaseModel) else get_type_adaptor(typ).validate_python(data)
5653

5754

58-
def loads(data: bytes | str, typ: type | str | None = None) -> Any:
59-
if typ is None:
60-
ret = __loads(data)
61-
# SUPER ineffecient ...
62-
if TYPE_KEY in ret:
63-
typ = ret[TYPE_KEY]
64-
else:
65-
return ret
66-
55+
def loads(data: bytes | str, typ: type | str) -> Any:
6756
if isinstance(typ, str):
6857
typ = get_type(typ)
6958

object_model/_type_checking.py

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
1-
from dataclasses import field, is_dataclass
2-
from datetime import date, datetime
3-
from pydantic import BaseModel
4-
from typing import Any, Callable, ClassVar, Literal, Union, get_origin, get_args
1+
from pydantic._internal._config import ConfigWrapper
2+
from pydantic_gubbins.typing import FrozenDict, Union
3+
from typing import Any, Callable, Union as _Union, get_origin, get_args
54

6-
from ._typing import DiscriminatedUnion, FrozenDict
7-
from ._type_registry import CLASS_TYPE_KEY, TYPE_KEY, register_type
5+
from ._type_registry import register_type
86

97

10-
__base_type_order = {datetime: -2, date: -1, int: -1}
11-
12-
13-
def check_type(fld: str, typ: Any) -> Any:
8+
def check_type(fld: str, typ: Any, immutable_collections: bool) -> Any:
149
# Check that we have no non-serialisable or ambiguously serialisable types
15-
# Also, rewrite a couple of types to avoid common problems
1610

1711
if typ in (object, Any, Callable):
1812
raise TypeError(f"{typ} is not a persistable type for {fld}")
@@ -22,50 +16,33 @@ def check_type(fld: str, typ: Any) -> Any:
2216

2317
if not args:
2418
if origin in (dict, list, set, tuple):
25-
raise TypeError(f"Cannot use untyped collection for {field}")
19+
raise TypeError(f"Cannot use untyped collection for {fld}")
2620

2721
for arg in args:
28-
check_type(fld, arg)
29-
30-
if origin is set:
31-
return frozenset[args]
32-
elif origin is list:
33-
return tuple[args + (...,)]
34-
elif origin is dict:
35-
return FrozenDict[args]
36-
elif origin is Union:
37-
# Re-order the args of unions so that e.g. datetime comes before str
38-
39-
object_types = tuple(t for t in args if issubclass(t, BaseModel) or is_dataclass(t))
40-
base_types = set(args).difference(object_types) if object_types else args
41-
base_types = tuple(sorted(base_types, key=lambda x: __base_type_order.get(x, 0)))
22+
check_type(fld, arg, immutable_collections)
4223

43-
# If we have e.g. Union[date, MyClass, MyOtherClass] we need to use a discriminated union for the
44-
# classes, so we need Union[date, Annotated[Union[MyClass, MyOtherClass], Field(discriminator=TYPE_KEY)]
24+
if origin is _Union:
25+
return Union[args]
4526

46-
if len(object_types) > 1:
47-
return Union[base_types + DiscriminatedUnion[object_types]] if base_types else\
48-
DiscriminatedUnion[object_types]
27+
if immutable_collections:
28+
if origin is set:
29+
return frozenset[args]
30+
elif origin is list:
31+
return tuple[args + (...,)]
32+
elif origin is dict:
33+
return FrozenDict[args]
4934

5035
return typ
5136

5237

5338
class TypeCheckMixin:
5439
def __new__(cls, cls_name: str, bases: tuple[type[Any], ...], namespace: dict[str, Any], **kwargs):
55-
annotations = namespace.setdefault("__annotations__", {})
56-
57-
for name, typ in annotations.items():
58-
annotations[name] = check_type(name, typ)
59-
60-
registered_name = annotations.get(TYPE_KEY)
61-
62-
if not registered_name:
63-
registered_name = cls_name
64-
annotations[TYPE_KEY] = Literal[registered_name]
65-
namespace[TYPE_KEY] = field(default_factory=lambda: registered_name, init=False)
40+
model_config = ConfigWrapper.for_model(bases, namespace, kwargs)
41+
immutable_collections = model_config.frozen if model_config else False
42+
annotations = namespace.get("__annotations__", {})
6643

67-
annotations[CLASS_TYPE_KEY] = ClassVar[str]
68-
namespace[CLASS_TYPE_KEY] = registered_name
44+
for name, typ in namespace.setdefault("__annotations__", {}).items():
45+
annotations[name] = check_type(name, typ, immutable_collections)
6946

7047
ret = super().__new__(cls, cls_name, bases, namespace, **kwargs)
7148
register_type(ret)

object_model/_type_registry.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import importlib.metadata as md
2-
3-
TYPE_KEY = "t_"
4-
CLASS_TYPE_KEY = "t__"
2+
from pydantic_gubbins.typing import get_type_name
53

64

75
class __TypeRegistry:
@@ -34,10 +32,7 @@ def is_temporary_type(self, type_name: str) -> bool:
3432
return is_temporary
3533

3634
def register_type(self, typ: type):
37-
type_name = getattr(typ, CLASS_TYPE_KEY, None)
38-
if type_name is None:
39-
raise RuntimeError(f"{typ} is missing attribute {TYPE_KEY}")
40-
35+
type_name = get_type_name(typ)
4136
if type_name in self.__object_store.names:
4237
return
4338

@@ -53,9 +48,7 @@ def get_type(type_name: str) -> type:
5348

5449
def is_temporary_type(typ: str | type) -> bool:
5550
if isinstance(typ, type):
56-
type_name = getattr(typ, CLASS_TYPE_KEY, None)
57-
if type_name is None:
58-
raise RuntimeError(f"{typ} is missing attribute {TYPE_KEY}")
51+
type_name = get_type_name(typ)
5952
else:
6053
type_name = typ
6154

object_model/_typing.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

object_model/store/exception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def message(cls) -> str:
3131

3232
class WrongStoreError(ObjectStoreError):
3333
def __init__(self, object_type: str, object_id: bytes):
34-
message = f"""Attempting to save object {object_type}:{object_id.decode("utf-8")}
34+
message = f"""Attempting to save object {object_type}:{object_id.decode()}
3535
in a different store to the one from which it was loaded"""
3636
super().__init__(message)
3737

object_model/store/object_store.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from orjson import loads
66
from platform import system, uname
77
from pydantic import BaseModel
8+
from pydantic_gubbins.typing import get_type_name
89
from typing import Iterable
910

1011
from . import ObjectResult
1112
from .exception import NotFoundError
1213
from .persistable import ImmutableMixin, ObjectRecord, PersistableMixin
1314
from .._json import schema
14-
from .._type_registry import CLASS_TYPE_KEY, is_temporary_type
15+
from .._type_registry import is_temporary_type
1516

1617

1718
def _get_user_name():
@@ -145,7 +146,7 @@ def write(self, obj: PersistableMixin, as_of_effective_time: bool = False) -> Fu
145146
return ret
146147

147148
def register_type(self, typ: type[PersistableMixin]):
148-
self.register_schema(RegisterSchemaRequest(name=getattr(typ, CLASS_TYPE_KEY), json_schema=schema(typ)))
149+
self.register_schema(RegisterSchemaRequest(name=get_type_name(typ), json_schema=schema(typ)))
149150

150151
def register_schema(self, request: RegisterSchemaRequest):
151152
defs = request.json_schema.pop("$defs", {})

object_model/store/persistable.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from datetime import datetime
55
from functools import cached_property
66
from hashlib import sha3_512
7+
from pydantic_gubbins.typing import get_type_name
78
from uuid import UUID
89

910
from .object_record import ObjectRecord
1011
from .._descriptors import Id
1112
from .._json import dumps, loads
12-
from .._type_registry import CLASS_TYPE_KEY, TYPE_KEY
1313

1414

1515
class UseDerived:
@@ -51,11 +51,11 @@ def id(self) -> tuple[type, tuple[str, ...]]:
5151

5252
@property
5353
def object_type(self) -> str:
54-
return getattr(self, TYPE_KEY)
54+
return get_type_name(type(self))
5555

5656
@property
5757
def object_id_type(self) -> str:
58-
return getattr(self.id[0], CLASS_TYPE_KEY)
58+
return get_type_name(self.id[0])
5959

6060
@property
6161
def object_id(self) -> bytes:
@@ -81,7 +81,7 @@ def make_id(cls, *args, **kwargs) -> tuple[str, bytes]:
8181
else:
8282
raise ValueError(f"Missing ID field {name}")
8383

84-
return getattr(id_type, CLASS_TYPE_KEY), dumps(ret)
84+
return get_type_name(id_type), dumps(ret)
8585

8686
@classmethod
8787
def _check_persistable_class(cls):

0 commit comments

Comments
 (0)