11# -*- coding: utf-8 -*-
22from __future__ import absolute_import
33import inspect
4- import mongoengine
54
6- from flask import abort , current_app
5+ from flask import Flask , abort , current_app
6+ import mongoengine
7+ from mongoengine .base import ValidationError
78from mongoengine .base .fields import BaseField
8- from mongoengine .queryset import (MultipleObjectsReturned ,
9- DoesNotExist , QuerySet )
9+ from mongoengine .queryset import (DoesNotExist , MultipleObjectsReturned ,
10+ QuerySet )
1011
11- from mongoengine .base import ValidationError
12- from .sessions import *
13- from .pagination import *
14- from .metadata import *
12+ from .connection import *
1513from .json import override_json_encoder
14+ from .metadata import *
15+ from .pagination import *
16+ from .sessions import *
1617from .wtf import WtfBaseField
17- from . connection import *
18+
1819
1920def redirect_connection_calls (cls ):
2021 """
21- Redirect mongonengine.connection
22- calls via flask_mongoengine.connection
22+ Monkey-patch mongoengine's connection methods so that they use
23+ Flask-MongoEngine's equivalents.
24+
25+ Given a random mongoengine class (`cls`), get the module it's in,
26+ and iterate through all of that module's members to find the
27+ particular methods we want to monkey-patch.
2328 """
29+ # TODO this is so whack... Why don't we pass particular connection
30+ # settings down to mongoengine and just use their original implementation?
2431
25- # Proxy all ' mongoengine.connection'
26- # specific attr via 'flask_mongoengine'
32+ # Map of mongoengine method/variable names and flask-mongoengine
33+ # methods they should point to
2734 connection_methods = {
28- 'get_db' : get_db ,
29- 'DEFAULT_CONNECTION_NAME' : DEFAULT_CONNECTION_NAME ,
30- 'get_connection' : get_connection
35+ 'get_db' : get_db ,
36+ 'DEFAULT_CONNECTION_NAME' : DEFAULT_CONNECTION_NAME ,
37+ 'get_connection' : get_connection
3138 }
32-
3339 cls_module = inspect .getmodule (cls )
3440 if cls_module != mongoengine .connection :
3541 for attr in inspect .getmembers (cls_module ):
3642 n = attr [0 ]
37- if connection_methods .get (n , None ):
38- setattr (cls_module , n , connection_methods .get (n , None ))
43+ if n in connection_methods :
44+ setattr (cls_module , n , connection_methods [n ])
45+
3946
4047def _patch_base_field (obj , name ):
4148 """
@@ -50,55 +57,61 @@ def _patch_base_field(obj, name):
5057 @see: flask_mongoengine.wtf.base.WtfBaseField.
5158 @see: model_form in flask_mongoengine.wtf.orm
5259
53- @param obj: The object whose footprint to locate the class.
54- @param name: Name of the class to locate .
60+ @param obj: MongoEngine instance in which we should locate the class.
61+ @param name: Name of an attribute which may or may not be a BaseField .
5562 """
5663
57- # locate class
64+ # get an attribute of the MongoEngine class and return if it's not
65+ # a class
5866 cls = getattr (obj , name )
5967 if not inspect .isclass (cls ):
6068 return
6169
62- # fetch class base classes
70+ # if it is a class, inspect all of its parent classes
6371 cls_bases = list (cls .__bases__ )
6472
65- # replace BaseField with WtfBaseField
73+ # if any of them is a BaseField, replace it with WtfBaseField
6674 for index , base in enumerate (cls_bases ):
6775 if base == BaseField :
6876 cls_bases [index ] = WtfBaseField
6977 cls .__bases__ = tuple (cls_bases )
7078 break
7179
72- # re-assign class back to
73- # object footprint
80+ # re-assign the class back to the MongoEngine instance
7481 delattr (obj , name )
7582 setattr (obj , name , cls )
7683 redirect_connection_calls (cls )
7784
85+
7886def _include_mongoengine (obj ):
79- for module in mongoengine , mongoengine .fields :
80- for key in module .__all__ :
81- if not hasattr (obj , key ):
82- setattr (obj , key , getattr (module , key ))
87+ """
88+ Copy all of the attributes from mongoengine and mongoengine.fields
89+ onto obj (which should be an instance of the MongoEngine class).
90+ """
91+ # TODO why do we need this? What's wrong with importing from the
92+ # original modules?
93+ for module in (mongoengine , mongoengine .fields ):
94+ for attr_name in module .__all__ :
95+ if not hasattr (obj , attr_name ):
96+ setattr (obj , attr_name , getattr (module , attr_name ))
8397
8498 # patch BaseField if available
85- _patch_base_field (obj , key )
99+ _patch_base_field (obj , attr_name )
100+
86101
87102def current_mongoengine_instance ():
88103 """
89104 Obtain instance of MongoEngine in the
90105 current working app instance.
91106 """
92- me = current_app .extensions .get ('mongoengine' , None )
93- if current_app and me :
94- instance_dict = me .items ()\
95- if (sys .version_info >= (3 , 0 )) else me .iteritems ()
96- for k , v in instance_dict :
97- if isinstance (k , MongoEngine ):
98- return k
99- return None
107+ me = current_app .extensions .get ('mongoengine' , {})
108+ for k , v in me .items ():
109+ if isinstance (k , MongoEngine ):
110+ return k
111+
100112
101113class MongoEngine (object ):
114+ """Main class used for initialization of Flask-MongoEngine."""
102115
103116 def __init__ (self , app = None , config = None ):
104117 _include_mongoengine (self )
@@ -110,7 +123,6 @@ def __init__(self, app=None, config=None):
110123 self .init_app (app , config )
111124
112125 def init_app (self , app , config = None ):
113- from flask import Flask
114126 if not app or not isinstance (app , Flask ):
115127 raise Exception ('Invalid Flask application instance' )
116128
@@ -143,6 +155,7 @@ def init_app(self, app, config=None):
143155 app .extensions ['mongoengine' ][self ] = s
144156
145157 def disconnect (self ):
158+ """Close all connections to MongoDB."""
146159 conn_settings = fetch_connection_settings (current_app .config )
147160 if isinstance (conn_settings , list ):
148161 for setting in conn_settings :
@@ -155,46 +168,64 @@ def disconnect(self):
155168
156169 @property
157170 def connection (self ):
171+ """
172+ Return MongoDB connection associated with this MongoEngine
173+ instance.
174+ """
158175 return current_app .extensions ['mongoengine' ][self ]['conn' ]
159176
160177
161178class BaseQuerySet (QuerySet ):
162- """
163- A base queryset with handy extras
164- """
179+ """Mongoengine's queryset extended with handy extras."""
165180
166181 def get_or_404 (self , * args , ** kwargs ):
182+ """
183+ Get a document and raise a 404 Not Found error if it doesn't
184+ exist.
185+ """
167186 try :
168187 return self .get (* args , ** kwargs )
169188 except (MultipleObjectsReturned , DoesNotExist , ValidationError ):
189+ # TODO probably only DoesNotExist should raise a 404
170190 abort (404 )
171191
172192 def first_or_404 (self ):
173-
193+ """Same as get_or_404, but uses .first, not .get."""
174194 obj = self .first ()
175195 if obj is None :
176196 abort (404 )
177197
178198 return obj
179199
180- def paginate (self , page , per_page , error_out = True ):
200+ def paginate (self , page , per_page , ** kwargs ):
201+ """
202+ Paginate the QuerySet with a certain number of docs per page
203+ and return docs for a given page.
204+ """
181205 return Pagination (self , page , per_page )
182206
183- def paginate_field (self , field_name , doc_id , page , per_page ,
184- total = None ):
207+ def paginate_field (self , field_name , doc_id , page , per_page , total = None ):
208+ """
209+ Paginate items within a list field from one document in the
210+ QuerySet.
211+ """
212+ # TODO this doesn't sound useful at all - remove in next release?
185213 item = self .get (id = doc_id )
186214 count = getattr (item , field_name + "_count" , '' )
187215 total = total or count or len (getattr (item , field_name ))
188216 return ListFieldPagination (self , doc_id , field_name , page , per_page ,
189217 total = total )
190218
219+
191220class Document (mongoengine .Document ):
192221 """Abstract document with extra helpers in the queryset class"""
193222
194223 meta = {'abstract' : True ,
195224 'queryset_class' : BaseQuerySet }
196225
197226 def paginate_field (self , field_name , page , per_page , total = None ):
227+ """Paginate items within a list field."""
228+ # TODO this doesn't sound useful at all - remove in next release?
198229 count = getattr (self , field_name + "_count" , '' )
199230 total = total or count or len (getattr (self , field_name ))
200231 return ListFieldPagination (self .__class__ .objects , self .pk , field_name ,
0 commit comments