Skip to content

Commit 27d361f

Browse files
committed
merge master into use-flaskform
2 parents 6ec4cdc + 3777f74 commit 27d361f

24 files changed

+228
-173
lines changed

flask_mongoengine/__init__.py

Lines changed: 78 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,48 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import absolute_import
33
import inspect
4-
import mongoengine
54

6-
from flask import abort, current_app
5+
from flask import Flask, abort, current_app
6+
import mongoengine
7+
from mongoengine.base import ValidationError
78
from mongoengine.base.fields import BaseField
8-
from mongoengine.queryset import (MultipleObjectsReturned,
9-
DoesNotExist, QuerySet)
9+
from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned,
10+
QuerySet)
1011

11-
from mongoengine.base import ValidationError
12-
from .sessions import *
13-
from .pagination import *
14-
from .metadata import *
12+
from .connection import *
1513
from .json import override_json_encoder
14+
from .metadata import *
15+
from .pagination import *
16+
from .sessions import *
1617
from .wtf import WtfBaseField
17-
from .connection import *
18+
1819

1920
def redirect_connection_calls(cls):
2021
"""
21-
Redirect mongonengine.connection
22-
calls via flask_mongoengine.connection
22+
Monkey-patch mongoengine's connection methods so that they use
23+
Flask-MongoEngine's equivalents.
24+
25+
Given a random mongoengine class (`cls`), get the module it's in,
26+
and iterate through all of that module's members to find the
27+
particular methods we want to monkey-patch.
2328
"""
29+
# TODO this is so whack... Why don't we pass particular connection
30+
# settings down to mongoengine and just use their original implementation?
2431

25-
# Proxy all 'mongoengine.connection'
26-
# specific attr via 'flask_mongoengine'
32+
# Map of mongoengine method/variable names and flask-mongoengine
33+
# methods they should point to
2734
connection_methods = {
28-
'get_db' : get_db,
29-
'DEFAULT_CONNECTION_NAME' : DEFAULT_CONNECTION_NAME,
30-
'get_connection' : get_connection
35+
'get_db': get_db,
36+
'DEFAULT_CONNECTION_NAME': DEFAULT_CONNECTION_NAME,
37+
'get_connection': get_connection
3138
}
32-
3339
cls_module = inspect.getmodule(cls)
3440
if cls_module != mongoengine.connection:
3541
for attr in inspect.getmembers(cls_module):
3642
n = attr[0]
37-
if connection_methods.get(n, None):
38-
setattr(cls_module, n, connection_methods.get(n, None))
43+
if n in connection_methods:
44+
setattr(cls_module, n, connection_methods[n])
45+
3946

4047
def _patch_base_field(obj, name):
4148
"""
@@ -50,55 +57,61 @@ def _patch_base_field(obj, name):
5057
@see: flask_mongoengine.wtf.base.WtfBaseField.
5158
@see: model_form in flask_mongoengine.wtf.orm
5259
53-
@param obj: The object whose footprint to locate the class.
54-
@param name: Name of the class to locate.
60+
@param obj: MongoEngine instance in which we should locate the class.
61+
@param name: Name of an attribute which may or may not be a BaseField.
5562
"""
5663

57-
# locate class
64+
# get an attribute of the MongoEngine class and return if it's not
65+
# a class
5866
cls = getattr(obj, name)
5967
if not inspect.isclass(cls):
6068
return
6169

62-
# fetch class base classes
70+
# if it is a class, inspect all of its parent classes
6371
cls_bases = list(cls.__bases__)
6472

65-
# replace BaseField with WtfBaseField
73+
# if any of them is a BaseField, replace it with WtfBaseField
6674
for index, base in enumerate(cls_bases):
6775
if base == BaseField:
6876
cls_bases[index] = WtfBaseField
6977
cls.__bases__ = tuple(cls_bases)
7078
break
7179

72-
# re-assign class back to
73-
# object footprint
80+
# re-assign the class back to the MongoEngine instance
7481
delattr(obj, name)
7582
setattr(obj, name, cls)
7683
redirect_connection_calls(cls)
7784

85+
7886
def _include_mongoengine(obj):
79-
for module in mongoengine, mongoengine.fields:
80-
for key in module.__all__:
81-
if not hasattr(obj, key):
82-
setattr(obj, key, getattr(module, key))
87+
"""
88+
Copy all of the attributes from mongoengine and mongoengine.fields
89+
onto obj (which should be an instance of the MongoEngine class).
90+
"""
91+
# TODO why do we need this? What's wrong with importing from the
92+
# original modules?
93+
for module in (mongoengine, mongoengine.fields):
94+
for attr_name in module.__all__:
95+
if not hasattr(obj, attr_name):
96+
setattr(obj, attr_name, getattr(module, attr_name))
8397

8498
# patch BaseField if available
85-
_patch_base_field(obj, key)
99+
_patch_base_field(obj, attr_name)
100+
86101

87102
def current_mongoengine_instance():
88103
"""
89104
Obtain instance of MongoEngine in the
90105
current working app instance.
91106
"""
92-
me = current_app.extensions.get('mongoengine', None)
93-
if current_app and me:
94-
instance_dict = me.items()\
95-
if (sys.version_info >= (3, 0)) else me.iteritems()
96-
for k, v in instance_dict:
97-
if isinstance(k, MongoEngine):
98-
return k
99-
return None
107+
me = current_app.extensions.get('mongoengine', {})
108+
for k, v in me.items():
109+
if isinstance(k, MongoEngine):
110+
return k
111+
100112

101113
class MongoEngine(object):
114+
"""Main class used for initialization of Flask-MongoEngine."""
102115

103116
def __init__(self, app=None, config=None):
104117
_include_mongoengine(self)
@@ -110,7 +123,6 @@ def __init__(self, app=None, config=None):
110123
self.init_app(app, config)
111124

112125
def init_app(self, app, config=None):
113-
from flask import Flask
114126
if not app or not isinstance(app, Flask):
115127
raise Exception('Invalid Flask application instance')
116128

@@ -143,6 +155,7 @@ def init_app(self, app, config=None):
143155
app.extensions['mongoengine'][self] = s
144156

145157
def disconnect(self):
158+
"""Close all connections to MongoDB."""
146159
conn_settings = fetch_connection_settings(current_app.config)
147160
if isinstance(conn_settings, list):
148161
for setting in conn_settings:
@@ -155,46 +168,64 @@ def disconnect(self):
155168

156169
@property
157170
def connection(self):
171+
"""
172+
Return MongoDB connection associated with this MongoEngine
173+
instance.
174+
"""
158175
return current_app.extensions['mongoengine'][self]['conn']
159176

160177

161178
class BaseQuerySet(QuerySet):
162-
"""
163-
A base queryset with handy extras
164-
"""
179+
"""Mongoengine's queryset extended with handy extras."""
165180

166181
def get_or_404(self, *args, **kwargs):
182+
"""
183+
Get a document and raise a 404 Not Found error if it doesn't
184+
exist.
185+
"""
167186
try:
168187
return self.get(*args, **kwargs)
169188
except (MultipleObjectsReturned, DoesNotExist, ValidationError):
189+
# TODO probably only DoesNotExist should raise a 404
170190
abort(404)
171191

172192
def first_or_404(self):
173-
193+
"""Same as get_or_404, but uses .first, not .get."""
174194
obj = self.first()
175195
if obj is None:
176196
abort(404)
177197

178198
return obj
179199

180-
def paginate(self, page, per_page, error_out=True):
200+
def paginate(self, page, per_page, **kwargs):
201+
"""
202+
Paginate the QuerySet with a certain number of docs per page
203+
and return docs for a given page.
204+
"""
181205
return Pagination(self, page, per_page)
182206

183-
def paginate_field(self, field_name, doc_id, page, per_page,
184-
total=None):
207+
def paginate_field(self, field_name, doc_id, page, per_page, total=None):
208+
"""
209+
Paginate items within a list field from one document in the
210+
QuerySet.
211+
"""
212+
# TODO this doesn't sound useful at all - remove in next release?
185213
item = self.get(id=doc_id)
186214
count = getattr(item, field_name + "_count", '')
187215
total = total or count or len(getattr(item, field_name))
188216
return ListFieldPagination(self, doc_id, field_name, page, per_page,
189217
total=total)
190218

219+
191220
class Document(mongoengine.Document):
192221
"""Abstract document with extra helpers in the queryset class"""
193222

194223
meta = {'abstract': True,
195224
'queryset_class': BaseQuerySet}
196225

197226
def paginate_field(self, field_name, page, per_page, total=None):
227+
"""Paginate items within a list field."""
228+
# TODO this doesn't sound useful at all - remove in next release?
198229
count = getattr(self, field_name + "_count", '')
199230
total = total or count or len(getattr(self, field_name))
200231
return ListFieldPagination(self.__class__.objects, self.pk, field_name,

0 commit comments

Comments
 (0)