Skip to content

Commit 20f5daf

Browse files
committed
Merge pull request #116 from mtsgrd/master
Fix dangerous use self.app in init_app and test refactoring.
2 parents 63ea2ae + a7dc87a commit 20f5daf

File tree

8 files changed

+140
-104
lines changed

8 files changed

+140
-104
lines changed

flask_mongoengine/__init__.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import absolute_import
33

4-
from flask import abort
4+
from flask import abort, current_app
55

66
import mongoengine
77

@@ -23,6 +23,14 @@ def _include_mongoengine(obj):
2323

2424

2525
def _create_connection(conn_settings):
26+
27+
# Handle multiple connections recursively
28+
if isinstance(conn_settings, list):
29+
connections = {}
30+
for conn in conn_settings:
31+
connections[conn.get('alias')] = _create_connection(conn)
32+
return connections
33+
2634
conn = dict([(k.lower(), v) for k, v in conn_settings.items() if v])
2735

2836
if 'replicaset' in conn:
@@ -38,40 +46,60 @@ def _create_connection(conn_settings):
3846

3947
class MongoEngine(object):
4048

41-
def __init__(self, app=None):
49+
def __init__(self, app=None, config=None):
4250

4351
_include_mongoengine(self)
4452

4553
self.Document = Document
4654
self.DynamicDocument = DynamicDocument
4755

4856
if app is not None:
49-
self.init_app(app)
57+
self.init_app(app, config)
5058

51-
def init_app(self, app):
59+
def init_app(self, app, config=None):
5260

53-
conn_settings = app.config.get('MONGODB_SETTINGS', None)
61+
app.extensions = getattr(app, 'extensions', {})
5462

55-
if not conn_settings:
56-
conn_settings = {
57-
'db': app.config.get('MONGODB_DB', None),
58-
'username': app.config.get('MONGODB_USERNAME', None),
59-
'password': app.config.get('MONGODB_PASSWORD', None),
60-
'host': app.config.get('MONGODB_HOST', None),
61-
'port': int(app.config.get('MONGODB_PORT', 0)) or None
62-
}
63+
# Make documents JSON serializable
64+
overide_json_encoder(app)
6365

64-
if isinstance(conn_settings, list):
65-
self.connection = {}
66-
for conn in conn_settings:
67-
self.connection[conn.get('alias')] = _create_connection(conn)
68-
else:
69-
self.connection = _create_connection(conn_settings)
66+
if not 'mongoengine' in app.extensions:
67+
app.extensions['mongoengine'] = {}
7068

71-
app.extensions = getattr(app, 'extensions', {})
72-
app.extensions['mongoengine'] = self
73-
self.app = app
74-
overide_json_encoder(app)
69+
if self in app.extensions['mongoengine']:
70+
# Raise an exception if extension already initialized as
71+
# potentially new configuration would not be loaded.
72+
raise Exception('Extension already initialized')
73+
74+
if config:
75+
# If passed an explicit config then we must make sure to ignore
76+
# anything set in the application config.
77+
connection = _create_connection(config)
78+
else:
79+
# Set default config
80+
config = {}
81+
config.setdefault('db', app.config.get('MONGODB_DB', None))
82+
config.setdefault('host', app.config.get('MONGODB_HOST', None))
83+
config.setdefault('port', app.config.get('MONGODB_PORT', None))
84+
config.setdefault('username',
85+
app.config.get('MONGODB_USERNAME', None))
86+
config.setdefault('password',
87+
app.config.get('MONGODB_PASSWORD', None))
88+
89+
# Before using default config we check for MONGODB_SETTINGS
90+
if 'MONGODB_SETTINGS' in app.config:
91+
connection = _create_connection(app.config['MONGODB_SETTINGS'])
92+
else:
93+
connection = _create_connection(config)
94+
95+
# Store objects in application instance so that multiple apps do
96+
# not end up accessing the same objects.
97+
app.extensions['mongoengine'] = {self: {'app': app,
98+
'conn': connection}}
99+
100+
@property
101+
def connection(self):
102+
return current_app.extensions['mongoengine'][self]['conn']
75103

76104

77105
class BaseQuerySet(QuerySet):

tests/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import flask
2+
import unittest
3+
4+
class FlaskMongoEngineTestCase(unittest.TestCase):
5+
"""Parent class of all test cases"""
6+
7+
def setUp(self):
8+
self.app = flask.Flask(__name__)
9+
self.app.config['MONGODB_DB'] = 'testing'
10+
self.app.config['TESTING'] = True
11+
self.ctx = self.app.app_context()
12+
self.ctx.push()
13+
14+
def tearDown(self):
15+
self.ctx.pop()

tests/test_basic_app.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
import sys
2-
sys.path[0:0] = [""]
32

43
import unittest
54
import datetime
65
import flask
76

87
from flask.ext.mongoengine import MongoEngine
8+
from . import FlaskMongoEngineTestCase
99

1010

11-
class BasicAppTestCase(unittest.TestCase):
11+
class BasicAppTestCase(FlaskMongoEngineTestCase):
1212

1313
def setUp(self):
14-
app = flask.Flask(__name__)
15-
app.config['MONGODB_DB'] = 'testing'
16-
app.config['TESTING'] = True
14+
super(BasicAppTestCase, self).setUp()
1715
db = MongoEngine()
1816

1917
class Todo(db.Document):
@@ -22,53 +20,54 @@ class Todo(db.Document):
2220
done = db.BooleanField(default=False)
2321
pub_date = db.DateTimeField(default=datetime.datetime.now)
2422

25-
db.init_app(app)
23+
db.init_app(self.app)
2624

2725
Todo.drop_collection()
2826
self.Todo = Todo
2927

30-
@app.route('/')
28+
@self.app.route('/')
3129
def index():
3230
return '\n'.join(x.title for x in self.Todo.objects)
3331

34-
@app.route('/add', methods=['POST'])
32+
@self.app.route('/add', methods=['POST'])
3533
def add():
3634
form = flask.request.form
3735
todo = self.Todo(title=form['title'],
3836
text=form['text'])
3937
todo.save()
4038
return 'added'
4139

42-
@app.route('/show/<id>/')
40+
@self.app.route('/show/<id>/')
4341
def show(id):
4442
todo = self.Todo.objects.get_or_404(id=id)
4543
return '\n'.join([todo.title, todo.text])
4644

47-
self.app = app
4845
self.db = db
4946

5047
def test_connection_kwargs(self):
51-
app = flask.Flask(__name__)
52-
app.config['MONGODB_SETTINGS'] = {
48+
self.app.config['MONGODB_SETTINGS'] = {
5349
'DB': 'testing_tz_aware',
54-
'alias': 'tz_aware_true',
50+
'ALIAS': 'tz_aware_true',
5551
'TZ_AWARE': True
5652
}
57-
app.config['TESTING'] = True
53+
self.app.config['TESTING'] = True
5854
db = MongoEngine()
59-
db.init_app(app)
55+
db.init_app(self.app)
6056
self.assertTrue(db.connection.tz_aware)
6157

62-
app.config['MONGODB_SETTINGS'] = {
58+
# PyMongo defaults to tz_aware = True so we have to explicitly turn
59+
# it off.
60+
self.app.config['MONGODB_SETTINGS'] = {
6361
'DB': 'testing',
64-
'alias': 'tz_aware_false',
62+
'ALIAS': 'tz_aware_false',
63+
'TZ_AWARE': False
6564
}
66-
db.init_app(app)
65+
db = MongoEngine()
66+
db.init_app(self.app)
6767
self.assertFalse(db.connection.tz_aware)
6868

6969
def test_connection_kwargs_as_list(self):
70-
app = flask.Flask(__name__)
71-
app.config['MONGODB_SETTINGS'] = [{
70+
self.app.config['MONGODB_SETTINGS'] = [{
7271
'DB': 'testing_tz_aware',
7372
'alias': 'tz_aware_true',
7473
'TZ_AWARE': True
@@ -77,23 +76,22 @@ def test_connection_kwargs_as_list(self):
7776
'alias': 'tz_aware_false',
7877
'TZ_AWARE': False
7978
}]
80-
app.config['TESTING'] = True
79+
self.app.config['TESTING'] = True
8180
db = MongoEngine()
82-
db.init_app(app)
81+
db.init_app(self.app)
8382
self.assertTrue(db.connection['tz_aware_true'].tz_aware)
8483
self.assertFalse(db.connection['tz_aware_false'].tz_aware)
8584

8685
def test_connection_default(self):
87-
app = flask.Flask(__name__)
88-
app.config['MONGODB_SETTINGS'] = {}
89-
app.config['TESTING'] = True
86+
self.app.config['MONGODB_SETTINGS'] = {}
87+
self.app.config['TESTING'] = True
9088

9189
db = MongoEngine()
92-
db.init_app(app)
90+
db.init_app(self.app)
9391

94-
app.config['TESTING'] = True
92+
self.app.config['TESTING'] = True
9593
db = MongoEngine()
96-
db.init_app(app)
94+
db.init_app(self.app)
9795

9896
def test_with_id(self):
9997
c = self.app.test_client()

tests/test_forms.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,22 @@
1414
from flask.ext.mongoengine.wtf import model_form
1515

1616
from mongoengine import queryset_manager
17+
from . import FlaskMongoEngineTestCase
1718

1819

19-
class WTFormsAppTestCase(unittest.TestCase):
20+
class WTFormsAppTestCase(FlaskMongoEngineTestCase):
2021

2122
def setUp(self):
23+
super(WTFormsAppTestCase, self).setUp()
2224
self.db_name = 'testing'
23-
24-
app = flask.Flask(__name__)
25-
app.config['MONGODB_DB'] = self.db_name
26-
app.config['TESTING'] = True
25+
self.app.config['MONGODB_DB'] = self.db_name
26+
self.app.config['TESTING'] = True
2727
# For Flask-WTF < 0.9
28-
app.config['CSRF_ENABLED'] = False
28+
self.app.config['CSRF_ENABLED'] = False
2929
# For Flask-WTF >= 0.9
30-
app.config['WTF_CSRF_ENABLED'] = False
31-
self.app = app
30+
self.app.config['WTF_CSRF_ENABLED'] = False
3231
self.db = MongoEngine()
33-
self.db.init_app(app)
32+
self.db.init_app(self.app)
3433

3534
def tearDown(self):
3635
self.db.connection.drop_database(self.db_name)

tests/test_json.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from flask.ext.mongoengine import MongoEngine
99
from flask.ext.mongoengine.json import MongoEngineJSONEncoder
10+
from . import FlaskMongoEngineTestCase
1011

1112

1213
class DummyEncoder(flask.json.JSONEncoder):
@@ -17,7 +18,7 @@ class DummyEncoder(flask.json.JSONEncoder):
1718
'''
1819

1920

20-
class JSONAppTestCase(unittest.TestCase):
21+
class JSONAppTestCase(FlaskMongoEngineTestCase):
2122

2223
def dictContains(self,superset,subset):
2324
for k,v in subset.items():
@@ -29,14 +30,12 @@ def assertDictContains(self,superset,subset):
2930
return self.assertTrue(self.dictContains(superset,subset))
3031

3132
def setUp(self):
32-
app = flask.Flask(__name__)
33-
app.config['MONGODB_DB'] = 'testing'
34-
app.config['TESTING'] = True
35-
app.json_encoder = DummyEncoder
33+
super(JSONAppTestCase, self).setUp()
34+
self.app.config['MONGODB_DB'] = 'testing'
35+
self.app.config['TESTING'] = True
36+
self.app.json_encoder = DummyEncoder
3637
db = MongoEngine()
37-
db.init_app(app)
38-
39-
self.app = app
38+
db.init_app(self.app)
4039
self.db = db
4140

4241
def test_inheritance(self):

0 commit comments

Comments
 (0)