Skip to content

Commit b1ad899

Browse files
committed
Add support for pymongo bson ObjectId (#133)
1 parent b7ddcfa commit b1ad899

File tree

5 files changed

+206
-2
lines changed

5 files changed

+206
-2
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ sources = pydantic_extra_types tests
77

88
.PHONY: install ## Install the package, dependencies, and pre-commit for local development
99
install: .uv
10-
uv sync --frozen --group all --all-extras
10+
uv sync --frozen --all-groups --all-extras
1111
uv pip install pre-commit
1212
pre-commit install --install-hooks
1313

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""
2+
Validation for MongoDB ObjectId fields.
3+
4+
Ref: https://github.com/pydantic/pydantic-extra-types/issues/133
5+
"""
6+
7+
from typing import Any
8+
9+
from pydantic import GetCoreSchemaHandler
10+
from pydantic_core import core_schema
11+
12+
try:
13+
from bson import ObjectId
14+
except ModuleNotFoundError as e:
15+
raise RuntimeError(
16+
'The `mongo_object_id` module requires "pymongo" to be installed. You can install it with "pip install '
17+
'pymongo".'
18+
) from e
19+
20+
21+
class MongoObjectId(str):
22+
"""MongoObjectId parses and validates MongoDB bson.ObjectId.
23+
24+
```py
25+
from pydantic import BaseModel
26+
27+
from pydantic_extra_types.mongo_object_id import MongoObjectId
28+
29+
30+
class MongoDocument(BaseModel):
31+
id: MongoObjectId
32+
33+
34+
doc = MongoDocument(id='5f9f2f4b9d3c5a7b4c7e6c1d')
35+
print(doc)
36+
# > id='5f9f2f4b9d3c5a7b4c7e6c1d'
37+
```
38+
39+
Raises:
40+
PydanticCustomError: If the provided value is not a valid MongoDB ObjectId.
41+
"""
42+
43+
OBJECT_ID_LENGTH = 24
44+
45+
@classmethod
46+
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
47+
return core_schema.json_or_python_schema(
48+
json_schema=core_schema.str_schema(min_length=cls.OBJECT_ID_LENGTH, max_length=cls.OBJECT_ID_LENGTH),
49+
python_schema=core_schema.union_schema(
50+
[
51+
core_schema.is_instance_schema(ObjectId),
52+
core_schema.chain_schema(
53+
[
54+
core_schema.str_schema(min_length=cls.OBJECT_ID_LENGTH, max_length=cls.OBJECT_ID_LENGTH),
55+
core_schema.no_info_plain_validator_function(cls.validate),
56+
]
57+
),
58+
]
59+
),
60+
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: str(x)),
61+
)
62+
63+
@classmethod
64+
def validate(cls, value: str) -> ObjectId:
65+
"""Validate the MongoObjectId str is a valid ObjectId instance."""
66+
if not ObjectId.is_valid(value):
67+
raise ValueError(
68+
f"Invalid ObjectId {value} has to be 24 characters long and in the format '5f9f2f4b9d3c5a7b4c7e6c1d'."
69+
)
70+
71+
return ObjectId(value)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ all = [
5050
'python-ulid>=1,<2; python_version<"3.9"',
5151
'python-ulid>=1,<4; python_version>="3.9"',
5252
'pendulum>=3.0.0,<4.0.0',
53+
'pymongo>=4.0.0,<5.0.0',
5354
'pytz>=2024.1',
5455
'semver~=3.0.2',
5556
'tzdata>=2024.1',

tests/test_mongo_object_id.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Tests for the mongo_object_id module."""
2+
3+
import pytest
4+
from pydantic import BaseModel, ValidationError
5+
6+
from pydantic_extra_types.mongo_object_id import MongoObjectId
7+
8+
9+
class MongoDocument(BaseModel):
10+
object_id: MongoObjectId
11+
12+
13+
@pytest.mark.parametrize(
14+
'object_id, result, valid',
15+
[
16+
# Valid ObjectId for str format
17+
('611827f2878b88b49ebb69fc', '611827f2878b88b49ebb69fc', True),
18+
('611827f2878b88b49ebb69fd', '611827f2878b88b49ebb69fd', True),
19+
# Invalid ObjectId for str format
20+
('611827f2878b88b49ebb69f', None, False), # Invalid ObjectId (short length)
21+
('611827f2878b88b49ebb69fca', None, False), # Invalid ObjectId (long length)
22+
# Valid ObjectId for bytes format
23+
],
24+
)
25+
def test_format_for_object_id(object_id: str, result: str, valid: bool) -> None:
26+
"""Test the MongoObjectId validation."""
27+
if valid:
28+
assert str(MongoDocument(object_id=object_id).object_id) == result
29+
else:
30+
with pytest.raises(ValidationError):
31+
MongoDocument(object_id=object_id)
32+
with pytest.raises(
33+
ValueError,
34+
match=f"Invalid ObjectId {object_id} has to be 24 characters long and in the format '5f9f2f4b9d3c5a7b4c7e6c1d'.",
35+
):
36+
MongoObjectId.validate(object_id)
37+
38+
39+
def test_json_schema() -> None:
40+
"""Test the MongoObjectId model_json_schema implementation."""
41+
assert MongoDocument.model_json_schema(mode='validation') == {
42+
'properties': {'object_id': {'maxLength': 24, 'minLength': 24, 'title': 'Object Id', 'type': 'string'}},
43+
'required': ['object_id'],
44+
'title': 'MongoDocument',
45+
'type': 'object',
46+
}
47+
assert MongoDocument.model_json_schema(mode='serialization') == {
48+
'properties': {'object_id': {'maxLength': 24, 'minLength': 24, 'title': 'Object Id', 'type': 'string'}},
49+
'required': ['object_id'],
50+
'title': 'MongoDocument',
51+
'type': 'object',
52+
}

0 commit comments

Comments
 (0)