Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sources = pydantic_extra_types tests

.PHONY: install ## Install the package, dependencies, and pre-commit for local development
install: .uv
uv sync --frozen --group all --all-extras
uv sync --frozen --all-groups --all-extras
uv pip install pre-commit
pre-commit install --install-hooks

Expand Down
71 changes: 71 additions & 0 deletions pydantic_extra_types/mongo_object_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Validation for MongoDB ObjectId fields.

Ref: https://github.com/pydantic/pydantic-extra-types/issues/133
"""

from typing import Any

from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

try:
from bson import ObjectId
except ModuleNotFoundError as e: # pragma: no cover
raise RuntimeError(
'The `mongo_object_id` module requires "pymongo" to be installed. You can install it with "pip install '
'pymongo".'
) from e


class MongoObjectId(str):
"""MongoObjectId parses and validates MongoDB bson.ObjectId.

```py
from pydantic import BaseModel

from pydantic_extra_types.mongo_object_id import MongoObjectId


class MongoDocument(BaseModel):
id: MongoObjectId


doc = MongoDocument(id='5f9f2f4b9d3c5a7b4c7e6c1d')
print(doc)
# > id='5f9f2f4b9d3c5a7b4c7e6c1d'
```

Raises:
PydanticCustomError: If the provided value is not a valid MongoDB ObjectId.
"""

OBJECT_ID_LENGTH = 24

@classmethod
def __get_pydantic_core_schema__(cls, _: Any, __: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(min_length=cls.OBJECT_ID_LENGTH, max_length=cls.OBJECT_ID_LENGTH),
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(ObjectId),
core_schema.chain_schema(
[
core_schema.str_schema(min_length=cls.OBJECT_ID_LENGTH, max_length=cls.OBJECT_ID_LENGTH),
core_schema.no_info_plain_validator_function(cls.validate),
]
),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: str(x)),
)

@classmethod
def validate(cls, value: str) -> ObjectId:
"""Validate the MongoObjectId str is a valid ObjectId instance."""
if not ObjectId.is_valid(value):
raise ValueError(
f"Invalid ObjectId {value} has to be 24 characters long and in the format '5f9f2f4b9d3c5a7b4c7e6c1d'."
)

return ObjectId(value)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ all = [
'python-ulid>=1,<2; python_version<"3.9"',
'python-ulid>=1,<4; python_version>="3.9"',
'pendulum>=3.0.0,<4.0.0',
'pymongo>=4.0.0,<5.0.0',
'pytz>=2024.1',
'semver~=3.0.2',
'tzdata>=2024.1',
Expand Down
23 changes: 21 additions & 2 deletions tests/test_json_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Any, Dict, Union

import pycountry
import pytest
Expand All @@ -15,6 +15,7 @@
from pydantic_extra_types.isbn import ISBN
from pydantic_extra_types.language_code import ISO639_3, ISO639_5, LanguageAlpha2, LanguageName
from pydantic_extra_types.mac_address import MacAddress
from pydantic_extra_types.mongo_object_id import MongoObjectId
from pydantic_extra_types.payment import PaymentCardNumber
from pydantic_extra_types.pendulum_dt import DateTime
from pydantic_extra_types.phone_numbers import PhoneNumber, PhoneNumberValidator
Expand Down Expand Up @@ -494,9 +495,27 @@
],
},
),
(
MongoObjectId,
{
'title': 'Model',
'type': 'object',
'properties': {
'x': {
'maxLength': MongoObjectId.OBJECT_ID_LENGTH,
'minLength': MongoObjectId.OBJECT_ID_LENGTH,
'title': 'X',
'type': 'string',
},
},
'required': ['x'],
},
),
],
)
def test_json_schema(cls, expected):
def test_json_schema(cls: Any, expected: Dict[str, Any]) -> None:
"""Test the model_json_schema implementation for all extra types."""

class Model(BaseModel):
x: cls

Expand Down
71 changes: 71 additions & 0 deletions tests/test_mongo_object_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Tests for the mongo_object_id module."""

import pytest
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
from pydantic.json_schema import JsonSchemaMode

from pydantic_extra_types.mongo_object_id import MongoObjectId


class MongoDocument(BaseModel):
object_id: MongoObjectId


@pytest.mark.parametrize(
'object_id, result, valid',
[
# Valid ObjectId for str format
('611827f2878b88b49ebb69fc', '611827f2878b88b49ebb69fc', True),
('611827f2878b88b49ebb69fd', '611827f2878b88b49ebb69fd', True),
# Invalid ObjectId for str format
('611827f2878b88b49ebb69f', None, False), # Invalid ObjectId (short length)
('611827f2878b88b49ebb69fca', None, False), # Invalid ObjectId (long length)
# Valid ObjectId for bytes format
],
)
def test_format_for_object_id(object_id: str, result: str, valid: bool) -> None:
"""Test the MongoObjectId validation."""
if valid:
assert str(MongoDocument(object_id=object_id).object_id) == result
else:
with pytest.raises(ValidationError):
MongoDocument(object_id=object_id)
with pytest.raises(
ValueError,
match=f"Invalid ObjectId {object_id} has to be 24 characters long and in the format '5f9f2f4b9d3c5a7b4c7e6c1d'.",
):
MongoObjectId.validate(object_id)


@pytest.mark.parametrize(
'schema_mode',
[
'validation',
'serialization',
],
)
def test_json_schema(schema_mode: JsonSchemaMode) -> None:
"""Test the MongoObjectId model_json_schema implementation."""
expected_json_schema = {
'properties': {
'object_id': {
'maxLength': MongoObjectId.OBJECT_ID_LENGTH,
'minLength': MongoObjectId.OBJECT_ID_LENGTH,
'title': 'Object Id',
'type': 'string',
}
},
'required': ['object_id'],
'title': 'MongoDocument',
'type': 'object',
}
assert MongoDocument.model_json_schema(mode=schema_mode) == expected_json_schema


def test_get_pydantic_core_schema() -> None:
"""Test the __get_pydantic_core_schema__ method override."""
schema = MongoObjectId.__get_pydantic_core_schema__(MongoObjectId, GetCoreSchemaHandler())
assert isinstance(schema, dict)
assert 'json_schema' in schema
assert 'python_schema' in schema
assert schema['json_schema']['type'] == 'str'
Loading