Skip to content

Commit e8e6d1e

Browse files
authored
Refactor rate limiting (fixes #326) (#327)
* Refactor rate limiting * Unit test for rate limiter
1 parent 7c039ee commit e8e6d1e

File tree

11 files changed

+97
-31
lines changed

11 files changed

+97
-31
lines changed

gramps_webapi/api/ratelimiter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Rate limiting decorator."""
2+
3+
4+
from flask_limiter import Limiter
5+
from flask_limiter.util import get_remote_address
6+
7+
limiter = Limiter(key_func=get_remote_address)

gramps_webapi/api/resources/config.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,15 @@
1919

2020
"""User administration resources."""
2121

22-
import datetime
23-
from gettext import gettext as _
2422

25-
from flask import abort, current_app, jsonify, render_template
26-
from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity
27-
from flask_limiter import Limiter
28-
from flask_limiter.util import get_remote_address
23+
from flask import abort, current_app, jsonify
2924
from webargs import fields
3025

3126
from ...auth.const import PERM_EDIT_SETTINGS, PERM_VIEW_SETTINGS
3227
from ...const import DB_CONFIG_ALLOWED_KEYS
3328
from ..auth import require_permissions
3429
from ..util import use_args
35-
from . import ProtectedResource, Resource
36-
37-
limiter = Limiter(key_func=get_remote_address)
30+
from . import ProtectedResource
3831

3932

4033
class ConfigsResource(ProtectedResource):

gramps_webapi/api/resources/token.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
"""Authentication endpoint blueprint."""
2121

22-
import datetime
2322
from typing import Iterable
2423

2524
from flask import abort, current_app
@@ -28,16 +27,13 @@
2827
create_refresh_token,
2928
get_jwt_identity,
3029
)
31-
from flask_limiter import Limiter
32-
from flask_limiter.util import get_remote_address
3330
from webargs import fields, validate
3431

3532
from ...auth.const import CLAIM_LIMITED_SCOPE, SCOPE_CREATE_OWNER
33+
from ..ratelimiter import limiter
3634
from ..util import use_args
3735
from . import RefreshProtectedResource, Resource
3836

39-
limiter = Limiter(key_func=get_remote_address)
40-
4137

4238
def get_tokens(user_id: str, permissions: Iterable[str], include_refresh: bool = False):
4339
"""Create access token (and refresh token if desired)."""

gramps_webapi/api/resources/user.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424

2525
from flask import abort, current_app, jsonify, render_template
2626
from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity
27-
from flask_limiter import Limiter
28-
from flask_limiter.util import get_remote_address
2927
from webargs import fields
3028

3129
from ...auth.const import (
@@ -44,6 +42,7 @@
4442
SCOPE_RESET_PW,
4543
)
4644
from ..auth import require_permissions
45+
from ..ratelimiter import limiter
4746
from ..tasks import (
4847
send_email_confirm_email,
4948
send_email_new_user,
@@ -52,8 +51,6 @@
5251
from ..util import use_args
5352
from . import LimitedScopeProtectedResource, ProtectedResource, Resource
5453

55-
limiter = Limiter(key_func=get_remote_address)
56-
5754

5855
class UserChangeBase(ProtectedResource):
5956
"""Base class for user change endpoints."""

gramps_webapi/app.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import logging
2323
import os
24+
from typing import Any, Dict, Optional
2425

2526
from flask import Flask, abort, g, send_from_directory
2627
from flask_compress import Compress
@@ -29,15 +30,17 @@
2930

3031
from .api import api_blueprint
3132
from .api.cache import thumbnail_cache
32-
from .api.resources.token import limiter
33+
from .api.ratelimiter import limiter
3334
from .api.search import SearchIndexer
3435
from .auth import SQLAuth
3536
from .config import DefaultConfig, DefaultConfigJWT
3637
from .const import API_PREFIX, ENV_CONFIG_FILE
3738
from .dbmanager import WebDbManager
3839

3940

40-
def create_app(db_manager=None):
41+
def create_app(
42+
db_manager: Optional[WebDbManager] = None, config: Optional[Dict[str, Any]] = None
43+
):
4144
"""Flask application factory."""
4245
app = Flask(__name__)
4346

@@ -97,6 +100,10 @@ def create_app(db_manager=None):
97100
"MEDIA_BASE_DIR"
98101
)
99102

103+
# update config from dictionary if present
104+
if config:
105+
app.config.update(**config)
106+
100107
# instantiate DB manager
101108
if db_manager is None:
102109
app.config["DB_MANAGER"] = WebDbManager(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"Flask-Compress",
3636
"Flask-Cors",
3737
"Flask-JWT-Extended>=4.2.1, !=4.4.0, !=4.4.1",
38-
"Flask-Limiter<=2.8.0",
38+
"Flask-Limiter>=2.9.0",
3939
"marshmallow>=3.13.0",
4040
"webargs",
4141
"SQLAlchemy",

tests/test_endpoints/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ def setUpModule():
6767

6868
test_db = ExampleDbSQLite()
6969
with patch.dict("os.environ", {ENV_CONFIG_FILE: TEST_EXAMPLE_GRAMPS_AUTH_CONFIG}):
70-
test_app = create_app(db_manager=test_db)
71-
test_app.config["TESTING"] = True
70+
test_app = create_app(
71+
db_manager=test_db, config={"TESTING": True, "RATELIMIT_ENABLED": False}
72+
)
7273
TEST_CLIENT = test_app.test_client()
7374
search_index = test_app.config["SEARCH_INDEXER"]
7475
db = test_db.get_db().db

tests/test_endpoints/test_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def setUp(self):
4444
self.dbman = CLIDbManager(DbState())
4545
_, _name = self.dbman.create_new_db_cli(self.name, dbid="sqlite")
4646
with patch.dict("os.environ", {ENV_CONFIG_FILE: TEST_AUTH_CONFIG}):
47-
self.app = create_app()
48-
self.app.config["TESTING"] = True
47+
self.app = create_app(config={"TESTING": True, "RATELIMIT_ENABLED": False})
4948
self.client = self.app.test_client()
5049
sqlauth = self.app.config["AUTH_PROVIDER"]
5150
sqlauth.create_table()

tests/test_endpoints/test_user.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def setUp(self):
4646
self.dbman = CLIDbManager(DbState())
4747
_, _name = self.dbman.create_new_db_cli(self.name, dbid="sqlite")
4848
with patch.dict("os.environ", {ENV_CONFIG_FILE: TEST_AUTH_CONFIG}):
49-
self.app = create_app()
50-
self.app.config["TESTING"] = True
49+
self.app = create_app(config={"TESTING": True, "RATELIMIT_ENABLED": False})
5150
self.client = self.app.test_client()
5251
sqlauth = self.app.config["AUTH_PROVIDER"]
5352
sqlauth.create_table()
@@ -703,8 +702,7 @@ def setUp(self):
703702
self.dbman = CLIDbManager(DbState())
704703
_, _name = self.dbman.create_new_db_cli(self.name, dbid="sqlite")
705704
with patch.dict("os.environ", {ENV_CONFIG_FILE: TEST_AUTH_CONFIG}):
706-
self.app = create_app()
707-
self.app.config["TESTING"] = True
705+
self.app = create_app(config={"TESTING": True, "RATELIMIT_ENABLED": False})
708706
self.client = self.app.test_client()
709707
sqlauth = self.app.config["AUTH_PROVIDER"]
710708
sqlauth.create_table()

tests/test_jwt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def setUpClass(cls):
5353
cls.dbman = CLIDbManager(DbState())
5454
_, _name = cls.dbman.create_new_db_cli(cls.name, dbid="sqlite")
5555
with patch.dict("os.environ", {ENV_CONFIG_FILE: TEST_AUTH_CONFIG}):
56-
cls.app = create_app()
57-
cls.app.config["TESTING"] = True
56+
cls.app = create_app(config={"TESTING": True, "RATELIMIT_ENABLED": False})
5857
cls.client = cls.app.test_client()
5958
sqlauth = cls.app.config["AUTH_PROVIDER"]
6059
sqlauth.create_table()
@@ -94,6 +93,7 @@ def test_person_endpoint(self):
9493
rv = self.client.post(
9594
"/api/token/", json={"username": "user", "password": "123"}
9695
)
96+
assert rv.status_code == 200
9797
token = rv.json["access_token"]
9898
rv = self.client.get(
9999
"/api/people/" + it["handle"] + "?profile=all",
@@ -114,6 +114,7 @@ def test_person_endpoint_privacy(self):
114114
rv = self.client.post(
115115
"/api/token/", json={"username": "user", "password": "123"}
116116
)
117+
assert rv.status_code == 200
117118
token_user = rv.json["access_token"]
118119
rv = self.client.post(
119120
"/api/token/", json={"username": "admin", "password": "123"}
@@ -159,6 +160,7 @@ def test_refresh_token_endpoint(self):
159160
rv = self.client.post(
160161
"/api/token/", json={"username": "user", "password": "123"}
161162
)
163+
assert rv.status_code == 200
162164
refresh_token = rv.json["refresh_token"]
163165
access_token = rv.json["access_token"]
164166
# incorrectly send access token instead of refresh token!

0 commit comments

Comments
 (0)