Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sources = pydantic_extra_types tests

.PHONY: install ## Install the package, dependencies, and pre-commit for local development
install: .uv
uv sync --frozen --group all --all-extras
uv sync --frozen --all-groups --all-extras
uv pip install pre-commit
pre-commit install --install-hooks

Expand Down
71 changes: 71 additions & 0 deletions pydantic_extra_types/mongo_object_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Validation for MongoDB ObjectId fields.

Ref: https://github.com/pydantic/pydantic-extra-types/issues/133
"""

from typing import Any

from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

try:
from bson import ObjectId
except ModuleNotFoundError as e:
raise RuntimeError(
'The `mongo_object_id` module requires "pymongo" to be installed. You can install it with "pip install '
'pymongo".'
) from e


class MongoObjectId(str):
"""MongoObjectId parses and validates MongoDB bson.ObjectId.

```py
from pydantic import BaseModel

from pydantic_extra_types.mongo_object_id import MongoObjectId


class MongoDocument(BaseModel):
id: MongoObjectId


doc = MongoDocument(id='5f9f2f4b9d3c5a7b4c7e6c1d')
print(doc)
# > id='5f9f2f4b9d3c5a7b4c7e6c1d'
```

Raises:
PydanticCustomError: If the provided value is not a valid MongoDB ObjectId.
"""

OBJECT_ID_LENGTH = 24

@classmethod
def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(min_length=cls.OBJECT_ID_LENGTH, max_length=cls.OBJECT_ID_LENGTH),
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(ObjectId),
core_schema.chain_schema(
[
core_schema.str_schema(min_length=cls.OBJECT_ID_LENGTH, max_length=cls.OBJECT_ID_LENGTH),
core_schema.no_info_plain_validator_function(cls.validate),
]
),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: str(x)),
)

@classmethod
def validate(cls, value: str) -> ObjectId:
"""Validate the MongoObjectId str is a valid ObjectId instance."""
if not ObjectId.is_valid(value):
raise ValueError(
f"Invalid ObjectId {value} has to be 24 characters long and in the format '5f9f2f4b9d3c5a7b4c7e6c1d'."
)

return ObjectId(value)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ all = [
'python-ulid>=1,<2; python_version<"3.9"',
'python-ulid>=1,<4; python_version>="3.9"',
'pendulum>=3.0.0,<4.0.0',
'pymongo>=4.0.0,<5.0.0',
'pytz>=2024.1',
'semver~=3.0.2',
'tzdata>=2024.1',
Expand Down
52 changes: 52 additions & 0 deletions tests/test_mongo_object_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Tests for the mongo_object_id module."""

import pytest
from pydantic import BaseModel, ValidationError

from pydantic_extra_types.mongo_object_id import MongoObjectId


class MongoDocument(BaseModel):
object_id: MongoObjectId


@pytest.mark.parametrize(
'object_id, result, valid',
[
# Valid ObjectId for str format
('611827f2878b88b49ebb69fc', '611827f2878b88b49ebb69fc', True),
('611827f2878b88b49ebb69fd', '611827f2878b88b49ebb69fd', True),
# Invalid ObjectId for str format
('611827f2878b88b49ebb69f', None, False), # Invalid ObjectId (short length)
('611827f2878b88b49ebb69fca', None, False), # Invalid ObjectId (long length)
# Valid ObjectId for bytes format
],
)
def test_format_for_object_id(object_id: str, result: str, valid: bool) -> None:
"""Test the MongoObjectId validation."""
if valid:
assert str(MongoDocument(object_id=object_id).object_id) == result
else:
with pytest.raises(ValidationError):
MongoDocument(object_id=object_id)
with pytest.raises(
ValueError,
match=f"Invalid ObjectId {object_id} has to be 24 characters long and in the format '5f9f2f4b9d3c5a7b4c7e6c1d'.",
):
MongoObjectId.validate(object_id)


def test_json_schema() -> None:
"""Test the MongoObjectId model_json_schema implementation."""
assert MongoDocument.model_json_schema(mode='validation') == {
'properties': {'object_id': {'maxLength': 24, 'minLength': 24, 'title': 'Object Id', 'type': 'string'}},
'required': ['object_id'],
'title': 'MongoDocument',
'type': 'object',
}
assert MongoDocument.model_json_schema(mode='serialization') == {
'properties': {'object_id': {'maxLength': 24, 'minLength': 24, 'title': 'Object Id', 'type': 'string'}},
'required': ['object_id'],
'title': 'MongoDocument',
'type': 'object',
}
Loading