Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions beanie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
before_event,
)
from beanie.odm.bulk import BulkWriter
from beanie.odm.client import ODMClient
from beanie.odm.custom_types import DecimalAnnotation
from beanie.odm.custom_types.bson.binary import BsonBinary
from beanie.odm.documents import (
Expand Down Expand Up @@ -66,6 +67,7 @@
"Update",
# Bulk Write
"BulkWriter",
"ODMClient",
# Migrations
"iterative_migration",
"free_fall_migration",
Expand Down
129 changes: 129 additions & 0 deletions beanie/odm/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import asyncio
import logging
import os
from typing import Any, Dict, List, Optional, Type

from pymongo import AsyncMongoClient
from pymongo.asynchronous.database import AsyncDatabase
from typing_extensions import Self

from beanie.executors.migrate import MigrationSettings, run_migrate
from beanie.odm.documents import Document
from beanie.odm.utils.init import init_beanie

logger = logging.getLogger(__name__)


class ODMClient:
"""
An asynchronous ODM client for managing MongoDB connections using PyMongo and Beanie.
"""

def __init__(self, uri: str, **kwargs: Any) -> None:
"""Initializes the ODM client."""
self.uri = uri
self.client: AsyncMongoClient = AsyncMongoClient(uri, **kwargs)
self.databases: Dict[str, AsyncDatabase] = {}
self._migration_lock = asyncio.Lock()

async def init_db(
self,
db_config: Dict[str, List[Type[Document]]],
migrations_path: Optional[str] = None,
allow_index_dropping: bool = False,
recreate_views: bool = False,
skip_indexes: bool = False,
):
"""
Initializes all specified databases and their models from a configuration.

Args:
db_config (Dict[str, List[Type[Document]]]): A dictionary where keys are
database names and values are lists
of Beanie Document classes.
migrations_path (Optional[str]): Path to the migrations directory.
allow_index_dropping (bool): Whether to allow index dropping.
recreate_views (bool): Whether to recreate views.
skip_indexes (bool): Whether to skip index creation.
"""
tasks = [
self.register_database(
db_name,
models,
migrations_path,
allow_index_dropping,
recreate_views,
skip_indexes,
)
for db_name, models in db_config.items()
]
await asyncio.gather(*tasks)

async def register_database(
self,
db_name: str,
models: List[Type[Document]],
migrations_path: Optional[str] = None,
allow_index_dropping: bool = False,
recreate_views: bool = False,
skip_indexes: bool = False,
):
"""
Initializes Beanie for a specific database with its document models.

NOTE: Beanie binds document models to a database globally. If the same model
is registered in multiple databases, the last registration will prevail
for that model class.
"""
if db_name in self.databases:
logger.info(f"Database {db_name} is already registered.")
return

logger.info(f"Initializing database: {db_name}")
db = self.client[db_name]

# Handle Migrations
if migrations_path and os.path.exists(migrations_path):
logger.info(
f"Running migrations for {db_name} from {migrations_path}"
)
settings = MigrationSettings(
connection_uri=self.uri,
database_name=db_name,
path=migrations_path,
)
async with self._migration_lock:
await run_migrate(settings)

await init_beanie(
database=db,
document_models=models,
allow_index_dropping=allow_index_dropping,
recreate_views=recreate_views,
skip_indexes=skip_indexes,
)

self.databases[db_name] = db
logger.info(f"Successfully initialized database: {db_name}")

def get_database(self, db_name: str) -> Optional[AsyncDatabase]:
"""Retrieves a registered database instance by its name."""
return self.databases.get(db_name)

async def close(self):
"""Closes the underlying MongoDB client connection."""
if self.client:
await self.client.close()
logger.info("MongoDB client connection closed.")

async def __aenter__(self) -> Self:
"""
Async context manager entry point. Returns the client instance.
"""
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""
Async context manager exit point. Ensures the connection is closed.
"""
await self.close()
84 changes: 84 additions & 0 deletions docs/tutorial/odm-client.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# ODM Client

The `ODMClient` is a high-level utility for managing MongoDB connections and initializing multiple databases with Beanie models in a centralized way. It is especially useful for applications that need to interact with multiple databases or want a structured way to handle migrations and initialization.

## Initialization

You can initialize the `ODMClient` by providing a MongoDB connection URI and optional keyword arguments for the underlying `AsyncMongoClient`.

```python
from beanie import ODMClient

uri = "mongodb://localhost:27017"
client = ODMClient(uri)
```

## Registering Databases

The `register_database` method allows you to initialize Beanie for a specific database with a list of document models.

```python
from beanie import Document

class User(Document):
name: str

async def init():
await client.register_database(
db_name="user_db",
models=[User],
allow_index_dropping=False
)
```

### Multiple Databases at Once

The `init_db` method allows you to initialize multiple databases from a configuration dictionary.

```python
class Product(Document):
title: str

db_config = {
"user_db": [User],
"product_db": [Product]
}

async def init_all():
await client.init_db(db_config)
```

## Migrations

`ODMClient` supports running migrations during database registration. If a `migrations_path` is provided, Beanie will run migrations for the specified database before initializing the models.

```python
await client.register_database(
db_name="app_db",
models=[User],
migrations_path="path/to/migrations"
)
```

## Async Context Manager

`ODMClient` can be used as an async context manager to ensure that the MongoDB connection is properly closed when the application exits.

```python
async with ODMClient(uri) as client:
await client.init_db(db_config)
# ... use the client ...
# Connection is closed automatically here
```

## Important Note: Global Model Binding

Beanie binds document models to a database **globally**. If you register the same model class in multiple databases using `ODMClient`, the **last** registration will prevail for that class across your entire application.

```python
# User model will be bound to "db_one"
await client.register_database("db_one", [User])

# User model will now be bound to "db_two" globally
await client.register_database("db_two", [User])
```
5 changes: 5 additions & 0 deletions pydoc-markdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ renderer:
source: docs/tutorial/defining-a-document.md
- title: Initialization
source: docs/tutorial/init.md
- title: ODM Client
source: docs/tutorial/odm-client.md
- title: Inserting into the database
source: docs/tutorial/insert.md
- title: Finding documents
Expand Down Expand Up @@ -125,6 +127,9 @@ renderer:
- title: Fields
contents:
- beanie.odm.fields.*
- title: Client
contents:
- beanie.odm.client.*
- title: Development
source: docs/development.md
- title: Code of conduct
Expand Down
91 changes: 91 additions & 0 deletions tests/odm/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
from pymongo import AsyncMongoClient
from pymongo.asynchronous.database import AsyncDatabase

from beanie.odm.client import ODMClient
from beanie.odm.documents import Document


class SampleDoc(Document):
name: str


class AnotherDoc(Document):
title: str


@pytest.fixture
async def odm_client():
# Use a test database URI.
uri = "mongodb://localhost:27017"
client = ODMClient(uri)
yield client
await client.close()


@pytest.mark.asyncio
async def test_odm_client_init(odm_client):
assert isinstance(odm_client.client, AsyncMongoClient)
assert odm_client.databases == {}


@pytest.mark.asyncio
async def test_odm_client_register_database(odm_client):
db_name = "test_odm_client_db"
models = [SampleDoc]

await odm_client.register_database(db_name, models)

assert db_name in odm_client.databases
assert isinstance(odm_client.get_database(db_name), AsyncDatabase)

# Verify beanie was initialized for the model
# We can check if the collection is set
assert SampleDoc.get_pymongo_collection() is not None
assert SampleDoc.get_pymongo_collection().name == "SampleDoc"

# Cleanup
await odm_client.client.drop_database(db_name)


@pytest.mark.asyncio
async def test_odm_client_init_db(odm_client):
db_config = {"db1": [SampleDoc], "db2": [AnotherDoc]}

await odm_client.init_db(db_config)

assert "db1" in odm_client.databases
assert "db2" in odm_client.databases

assert SampleDoc.get_pymongo_collection().database.name == "db1"
assert AnotherDoc.get_pymongo_collection().database.name == "db2"

await odm_client.client.drop_database("db1")
await odm_client.client.drop_database("db2")


@pytest.mark.asyncio
async def test_odm_client_multiple_db_same_model(odm_client):
# This test demonstrates the global binding limitation of Beanie
db1_name = "test_db_1"
db2_name = "test_db_2"

await odm_client.register_database(db1_name, [SampleDoc])
assert SampleDoc.get_pymongo_collection().database.name == db1_name

await odm_client.register_database(db2_name, [SampleDoc])
# Now it should be bound to db2
assert SampleDoc.get_pymongo_collection().database.name == db2_name

await odm_client.client.drop_database(db1_name)
await odm_client.client.drop_database(db2_name)


@pytest.mark.asyncio
async def test_odm_client_context_manager():
uri = "mongodb://localhost:27017"
async with ODMClient(uri) as client:
assert isinstance(client, ODMClient)
assert client.client is not None

# Client should be closed
Loading