11# -*- coding: utf-8 -*-
22from __future__ import absolute_import
3- import inspect
3+ import mongoengine , inspect
44
55from flask import abort , current_app
6+ from mongoengine .base .fields import BaseField
7+ from mongoengine .queryset import (MultipleObjectsReturned ,
8+ DoesNotExist , QuerySet )
69
7- import mongoengine
8-
9- if mongoengine .__version__ == '0.7.10' :
10- from mongoengine .base import BaseField
11- else :
12- from mongoengine .base .fields import BaseField
13-
14-
15- from mongoengine .queryset import MultipleObjectsReturned , DoesNotExist , QuerySet
1610from mongoengine .base import ValidationError
17-
1811from pymongo import uri_parser
19-
2012from .sessions import *
2113from .pagination import *
2214from .metadata import *
23- from .json import overide_json_encoder
15+ from .json import override_json_encoder
2416from .wtf import WtfBaseField
17+ from .connection import *
18+ import flask_mongoengine
2519
26- def _patch_base_field (object , name ):
20+ def redirect_connection_calls (cls ):
21+ """
22+ Redirect mongonengine.connection
23+ calls via flask_mongoengine.connection
24+ """
25+
26+ # Proxy all 'mongoengine.connection'
27+ # specific attr via 'flask_mongoengine'
28+ connection_methods = {
29+ 'get_db' : get_db ,
30+ 'DEFAULT_CONNECTION_NAME' : DEFAULT_CONNECTION_NAME ,
31+ 'get_connection' : get_connection
32+ }
33+
34+ cls_module = inspect .getmodule (cls )
35+ if cls_module != mongoengine .connection :
36+ for attr in inspect .getmembers (cls_module ):
37+ n = attr [0 ]
38+ if connection_methods .get (n , None ):
39+ setattr (cls_module , n , connection_methods .get (n , None ))
40+
41+ def _patch_base_field (obj , name ):
2742 """
2843 If the object submitted has a class whose base class is
2944 mongoengine.base.fields.BaseField, then monkey patch to
@@ -36,12 +51,12 @@ def _patch_base_field(object, name):
3651 @see: flask_mongoengine.wtf.base.WtfBaseField.
3752 @see: model_form in flask_mongoengine.wtf.orm
3853
39- @param object: The object whose footprint to locate the class.
54+ @param obj: The object whose footprint to locate the class.
4055 @param name: Name of the class to locate.
4156
4257 """
4358 # locate class
44- cls = getattr (object , name )
59+ cls = getattr (obj , name )
4560 if not inspect .isclass (cls ):
4661 return
4762
@@ -57,9 +72,9 @@ def _patch_base_field(object, name):
5772
5873 # re-assign class back to
5974 # object footprint
60- delattr (object , name )
61- setattr (object , name , cls )
62-
75+ delattr (obj , name )
76+ setattr (obj , name , cls )
77+ redirect_connection_calls ( cls )
6378
6479def _include_mongoengine (obj ):
6580 for module in mongoengine , mongoengine .fields :
@@ -70,30 +85,19 @@ def _include_mongoengine(obj):
7085 # patch BaseField if available
7186 _patch_base_field (obj , key )
7287
73-
74- def _create_connection (conn_settings ):
75-
76- # Handle multiple connections recursively
77- if isinstance (conn_settings , list ):
78- connections = {}
79- for conn in conn_settings :
80- connections [conn .get ('alias' )] = _create_connection (conn )
81- return connections
82-
83- # Ugly dict comprehention in order to support python 2.6
84- conn = dict ((k .lower (), v ) for k , v in conn_settings .items () if v is not None )
85-
86- if 'replicaset' in conn :
87- conn ['replicaSet' ] = conn .pop ('replicaset' )
88-
89- # Handle uri style connections
90- if "://" in conn .get ('host' , '' ):
91- uri_dict = uri_parser .parse_uri (conn ['host' ])
92- conn ['db' ] = uri_dict ['database' ]
93-
94- return mongoengine .connect (conn .pop ('db' , 'test' ), ** conn )
95-
96-
88+ def current_mongoengine_instance ():
89+ """
90+ Obtain instance of MongoEngine in the
91+ current working app instance.
92+ """
93+ me = current_app .extensions .get ('mongoengine' , None )
94+ if current_app and me :
95+ instance_dict = me .items ()\
96+ if (sys .version_info >= (3 , 0 )) else me .iteritems ()
97+ for k , v in instance_dict :
98+ if isinstance (k , MongoEngine ):
99+ return k
100+ return None
97101
98102class MongoEngine (object ):
99103
@@ -107,11 +111,10 @@ def __init__(self, app=None, config=None):
107111 self .init_app (app , config )
108112
109113 def init_app (self , app , config = None ):
110-
111114 app .extensions = getattr (app , 'extensions' , {})
112115
113116 # Make documents JSON serializable
114- overide_json_encoder (app )
117+ override_json_encoder (app )
115118
116119 if not 'mongoengine' in app .extensions :
117120 app .extensions ['mongoengine' ] = {}
@@ -122,27 +125,30 @@ def init_app(self, app, config=None):
122125 raise Exception ('Extension already initialized' )
123126
124127 if not config :
125- # If not passed a config then we read the connection settings
126- # from the app config.
128+ # If not passed a config then we
129+ # read the connection settings from
130+ # the app config.
127131 config = app .config
128132
129- if 'MONGODB_SETTINGS' in config :
130- # Connection settings provided as a dictionary.
131- connection = _create_connection (config ['MONGODB_SETTINGS' ])
133+ # Obtain db connection
134+ connection = create_connection (config )
135+
136+ # Store objects in application instance
137+ # so that multiple apps do not end up
138+ # accessing the same objects.
139+ s = {'app' : app , 'conn' : connection }
140+ app .extensions ['mongoengine' ][self ] = s
141+
142+ def disconnect (self ):
143+ conn_settings = fetch_connection_settings (current_app .config )
144+ if isinstance (conn_settings , list ):
145+ for setting in conn_settings :
146+ alias = setting .get ('alias' , DEFAULT_CONNECTION_NAME )
147+ disconnect (alias , setting .get ('preserve_temp_db' , False ))
132148 else :
133- # Connection settings provided in standard format.
134- settings = {'alias' : config .get ('MONGODB_ALIAS' , None ),
135- 'db' : config .get ('MONGODB_DB' , None ),
136- 'host' : config .get ('MONGODB_HOST' , None ),
137- 'password' : config .get ('MONGODB_PASSWORD' , None ),
138- 'port' : config .get ('MONGODB_PORT' , None ),
139- 'username' : config .get ('MONGODB_USERNAME' , None )}
140- connection = _create_connection (settings )
141-
142- # Store objects in application instance so that multiple apps do
143- # not end up accessing the same objects.
144- app .extensions ['mongoengine' ][self ] = {'app' : app ,
145- 'conn' : connection }
149+ alias = conn_settings .get ('alias' , DEFAULT_CONNECTION_NAME )
150+ disconnect (alias , conn_settings .get ('preserve_temp_db' , False ))
151+ return True
146152
147153 @property
148154 def connection (self ):
@@ -179,7 +185,6 @@ def paginate_field(self, field_name, doc_id, page, per_page,
179185 return ListFieldPagination (self , doc_id , field_name , page , per_page ,
180186 total = total )
181187
182-
183188class Document (mongoengine .Document ):
184189 """Abstract document with extra helpers in the queryset class"""
185190
0 commit comments