Skip to content

Commit 8dc732f

Browse files
committed
fix wrapper usage
1 parent 22d83dc commit 8dc732f

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

flask_pymongo/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@
3333
import pymongo
3434
from flask import Flask, Response, abort, current_app, request
3535
from gridfs import GridFS, NoFile
36-
from pymongo import MongoClient, uri_parser
37-
from pymongo.database import Database
36+
from pymongo import uri_parser
3837
from pymongo.driver_info import DriverInfo
3938
from werkzeug.wsgi import wrap_file
4039

4140
from flask_pymongo._version import __version__
4241
from flask_pymongo.helpers import BSONObjectIdConverter, BSONProvider
42+
from flask_pymongo.wrappers import Database, MongoClient
4343

4444
DESCENDING = pymongo.DESCENDING
4545
"""Descending sort order."""
@@ -64,8 +64,8 @@ class PyMongo:
6464
def __init__(
6565
self, app: Flask | None = None, uri: str | None = None, *args: Any, **kwargs: Any
6666
) -> None:
67-
self.cx: MongoClient[dict[str, Any]] | None = None
68-
self.db: Database[dict[str, Any]] | None = None
67+
self.cx: MongoClient | None = None
68+
self.db: Database | None = None
6969

7070
if app is not None:
7171
self.init_app(app, uri, *args, **kwargs)

flask_pymongo/wrappers.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,70 @@
2727
from typing import Any
2828

2929
from flask import abort
30-
from pymongo import collection
30+
from pymongo import collection, database, mongo_client
31+
32+
33+
class MongoClient(mongo_client.MongoClient[dict[str, Any]]):
34+
"""Wrapper for :class:`~pymongo.mongo_client.MongoClient`.
35+
36+
Returns instances of Flask-PyMongo
37+
:class:`~flask_pymongo.wrappers.Database` instead of native PyMongo
38+
:class:`~pymongo.database.Database` when accessed with dot notation.
39+
40+
"""
41+
42+
def __getattr__(self, name: str) -> Any:
43+
attr = super().__getattr__(name)
44+
if isinstance(attr, database.Database):
45+
return Database(self, name)
46+
return attr
47+
48+
def __getitem__(self, name: str) -> Any:
49+
attr = super().__getitem__(name)
50+
if isinstance(attr, database.Database):
51+
return Database(self, name)
52+
return attr
53+
54+
55+
class Database(database.Database[dict[str, Any]]):
56+
"""Wrapper for :class:`~pymongo.database.Database`.
57+
58+
Returns instances of Flask-PyMongo
59+
:class:`~flask_pymongo.wrappers.Collection` instead of native PyMongo
60+
:class:`~pymongo.collection.Collection` when accessed with dot notation.
61+
62+
"""
63+
64+
def __getattr__(self, name: str) -> Any:
65+
attr = super().__getattr__(name)
66+
if isinstance(attr, collection.Collection):
67+
return Collection(self, name)
68+
return attr
69+
70+
def __getitem__(self, name: str) -> Any:
71+
item_ = super().__getitem__(name)
72+
if isinstance(item_, collection.Collection):
73+
return Collection(self, name)
74+
return item_
3175

3276

3377
class Collection(collection.Collection[dict[str, Any]]):
3478
"""Sub-class of PyMongo :class:`~pymongo.collection.Collection` with helpers."""
3579

80+
def __getattr__(self, name: str) -> Any:
81+
attr = super().__getattr__(name)
82+
if isinstance(attr, collection.Collection):
83+
db = self._Collection__database
84+
return Collection(db, attr.name)
85+
return attr
86+
87+
def __getitem__(self, name: str) -> Any:
88+
item = super().__getitem__(name)
89+
if isinstance(item, collection.Collection):
90+
db = self._Collection__database
91+
return Collection(db, item.name)
92+
return item
93+
3694
def find_one_or_404(self, *args: Any, **kwargs: Any) -> Any:
3795
"""Find a single document or raise a 404.
3896

0 commit comments

Comments
 (0)