Skip to content

Commit a5da4e7

Browse files
dpgasparclaude
andcommitted
feat: add missing CRUD hooks and auth event hooks
Add post_add/post_update calls to UserApi and GroupApi post()/put() methods which were skipping these hooks unlike the base ModelRestApi. Add on_user_login, on_user_login_failed, and on_user_logout overridable hooks to BaseSecurityManager for audit logging and custom auth event handling. Called from update_user_auth_stat and logout views respectively. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 49de9e2 commit a5da4e7

File tree

5 files changed

+300
-0
lines changed

5 files changed

+300
-0
lines changed

flask_appbuilder/security/manager.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,27 @@ def reset_password(self, userid, password):
975975
)
976976
self.update_user(user)
977977

978+
def on_user_login(self, user) -> None:
979+
"""Called after a successful user login.
980+
Override to add custom logic (e.g., audit logging).
981+
982+
:param user: The authenticated user model
983+
"""
984+
985+
def on_user_login_failed(self, user) -> None:
986+
"""Called after a failed user login attempt.
987+
Override to add custom logic (e.g., audit logging).
988+
989+
:param user: The identified (but not authenticated) user model
990+
"""
991+
992+
def on_user_logout(self, user) -> None:
993+
"""Called when a user logs out.
994+
Override to add custom logic (e.g., audit logging).
995+
996+
:param user: The user model that is logging out
997+
"""
998+
978999
def update_user_auth_stat(self, user, success=True):
9791000
"""
9801001
Update user authentication stats upon successful/unsuccessful
@@ -997,8 +1018,10 @@ def update_user_auth_stat(self, user, success=True):
9971018
user.login_count += 1
9981019
user.last_login = datetime.datetime.now()
9991020
user.fail_login_count = 0
1021+
self.on_user_login(user)
10001022
else:
10011023
user.fail_login_count += 1
1024+
self.on_user_login_failed(user)
10021025
self.update_user(user)
10031026

10041027
def auth_user_db(self, username, password):

flask_appbuilder/security/sqla/apis/group/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def post(self):
120120

121121
self.pre_add(model)
122122
self.datamodel.add(model)
123+
self.post_add(model)
123124

124125
return self.response(201, id=model.id)
125126

@@ -216,6 +217,7 @@ def put(self, pk):
216217
model.users = users
217218
self.pre_update(model)
218219
self.datamodel.edit(model)
220+
self.post_update(model)
219221
return self.response(
220222
200,
221223
**{API_RESULT_RES_KEY: self.edit_model_schema.dump(item, many=False)},

flask_appbuilder/security/sqla/apis/user/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def post(self):
165165

166166
self.pre_add(model)
167167
self.datamodel.add(model)
168+
self.post_add(model)
168169
return self.response(201, id=model.id)
169170
except ValidationError as error:
170171
return self.response_400(message=error.messages)
@@ -281,6 +282,7 @@ def put(self, pk):
281282

282283
self.pre_update(model, item)
283284
self.datamodel.edit(model)
285+
self.post_update(model)
284286
return self.response(
285287
200,
286288
**{API_RESULT_RES_KEY: self.edit_model_schema.dump(item, many=False)},

flask_appbuilder/security/views.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,9 +577,22 @@ class AuthView(BaseView):
577577
def login(self):
578578
pass
579579

580+
def _get_authenticated_user(self):
581+
"""Resolve the current user before logout clears the session.
582+
Returns the User model or None if not authenticated.
583+
Note: g.user is a LocalProxy that becomes anonymous after
584+
logout_user(), so we must resolve it beforehand.
585+
"""
586+
if g.user is not None and g.user.is_authenticated:
587+
return self.appbuilder.sm.get_user_by_id(g.user.id)
588+
return None
589+
580590
@expose("/logout/")
581591
def logout(self):
592+
user = self._get_authenticated_user()
582593
logout_user()
594+
if user is not None:
595+
self.appbuilder.sm.on_user_logout(user)
583596
return redirect(
584597
current_app.config.get(
585598
"LOGOUT_REDIRECT_URL", self.appbuilder.get_url_for_index
@@ -865,7 +878,10 @@ def slo(self) -> WerkzeugResponse:
865878
try:
866879
idp = session.get("saml_idp")
867880
if not idp:
881+
user = self._get_authenticated_user()
868882
logout_user()
883+
if user is not None:
884+
self.appbuilder.sm.on_user_logout(user)
869885
return redirect(self.appbuilder.get_url_for_index)
870886

871887
url, should_logout = self.appbuilder.sm.get_saml_logout_redirect_url(
@@ -874,16 +890,25 @@ def slo(self) -> WerkzeugResponse:
874890
session_index=session.get("saml_session_index"),
875891
)
876892
if should_logout:
893+
user = self._get_authenticated_user()
877894
logout_user()
895+
if user is not None:
896+
self.appbuilder.sm.on_user_logout(user)
878897
return redirect(url or self.appbuilder.get_url_for_index)
879898

880899
except (OneLogin_Saml2_Error, OneLogin_Saml2_ValidationError) as e:
881900
log.error("SAML SLO validation error: %s", e)
901+
user = self._get_authenticated_user()
882902
logout_user()
903+
if user is not None:
904+
self.appbuilder.sm.on_user_logout(user)
883905
return redirect(self.appbuilder.get_url_for_index)
884906
except ValueError as e:
885907
log.error("SAML SLO configuration error: %s", e)
908+
user = self._get_authenticated_user()
886909
logout_user()
910+
if user is not None:
911+
self.appbuilder.sm.on_user_logout(user)
887912
return redirect(self.appbuilder.get_url_for_index)
888913

889914
@expose("/logout/")
@@ -895,7 +920,10 @@ def logout(self) -> WerkzeugResponse:
895920
return redirect(url_for(".slo"))
896921
except Exception:
897922
pass
923+
user = self._get_authenticated_user()
898924
logout_user()
925+
if user is not None:
926+
self.appbuilder.sm.on_user_logout(user)
899927
return redirect(
900928
current_app.config.get(
901929
"LOGOUT_REDIRECT_URL", self.appbuilder.get_url_for_index

tests/test_hooks.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
import logging
2+
import os
3+
from unittest.mock import MagicMock
4+
import uuid
5+
6+
from flask import Flask
7+
from flask_appbuilder import AppBuilder
8+
from flask_appbuilder.security.sqla.models import Group, User
9+
from flask_appbuilder.utils.legacy import get_sqla_class
10+
from tests.base import FABTestCase
11+
from tests.const import PASSWORD_ADMIN, USERNAME_ADMIN
12+
13+
14+
log = logging.getLogger(__name__)
15+
16+
17+
def _uid():
18+
return uuid.uuid4().hex[:8]
19+
20+
21+
class UserApiHooksTestCase(FABTestCase):
22+
"""Test that post_add and post_update hooks are called on UserApi."""
23+
24+
def setUp(self):
25+
self.app = Flask(__name__)
26+
self.basedir = os.path.abspath(os.path.dirname(__file__))
27+
self.app.config.from_object("tests.config_security_api")
28+
29+
self.ctx = self.app.app_context()
30+
self.ctx.push()
31+
SQLA = get_sqla_class()
32+
self.db = SQLA(self.app)
33+
self.appbuilder = AppBuilder(self.app, self.db.session)
34+
self.create_default_users(self.appbuilder)
35+
36+
# Patch hooks on the registered UserApi view
37+
self.user_api = None
38+
for view in self.appbuilder.baseviews:
39+
if view.__class__.__name__ == "UserApi":
40+
self.user_api = view
41+
break
42+
self.assertIsNotNone(self.user_api, "UserApi view not found")
43+
self.user_api.post_add = MagicMock()
44+
self.user_api.post_update = MagicMock()
45+
self._created_users = []
46+
47+
def tearDown(self):
48+
for username in self._created_users:
49+
user = self.appbuilder.sm.find_user(username=username)
50+
if user:
51+
self.appbuilder.session.delete(user)
52+
self.appbuilder.session.commit()
53+
self.ctx.pop()
54+
55+
def test_post_add_called_on_create_user(self):
56+
client = self.app.test_client()
57+
token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
58+
59+
uid = _uid()
60+
role = self.appbuilder.sm.find_role("Admin")
61+
username = f"hook_add_{uid}"
62+
payload = {
63+
"active": True,
64+
"email": f"hook_add_{uid}@fab.com",
65+
"first_name": "hook",
66+
"last_name": "test",
67+
"password": "password",
68+
"roles": [role.id],
69+
"username": username,
70+
}
71+
rv = self.auth_client_post(client, token, "api/v1/security/users/", payload)
72+
self.assertEqual(rv.status_code, 201)
73+
self._created_users.append(username)
74+
self.user_api.post_add.assert_called_once()
75+
called_model = self.user_api.post_add.call_args[0][0]
76+
self.assertIsInstance(called_model, User)
77+
self.assertEqual(called_model.username, username)
78+
79+
def test_post_update_called_on_edit_user(self):
80+
client = self.app.test_client()
81+
token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
82+
83+
uid = _uid()
84+
edit_username = f"hook_edit_{uid}"
85+
role = self.appbuilder.sm.find_role("Admin")
86+
user = self.appbuilder.sm.add_user(
87+
username=edit_username,
88+
first_name="hook",
89+
last_name="edit",
90+
email=f"hook_edit_{uid}@fab.com",
91+
role=role,
92+
password="password",
93+
)
94+
self._created_users.append(edit_username)
95+
96+
rv = self.auth_client_put(
97+
client,
98+
token,
99+
f"api/v1/security/users/{user.id}",
100+
{"first_name": "updated_hook"},
101+
)
102+
self.assertEqual(rv.status_code, 200)
103+
self.user_api.post_update.assert_called_once()
104+
called_model = self.user_api.post_update.call_args[0][0]
105+
self.assertIsInstance(called_model, User)
106+
self.assertEqual(called_model.first_name, "updated_hook")
107+
108+
109+
class GroupApiHooksTestCase(FABTestCase):
110+
"""Test that post_add and post_update hooks are called on GroupApi."""
111+
112+
def setUp(self):
113+
self.app = Flask(__name__)
114+
self.basedir = os.path.abspath(os.path.dirname(__file__))
115+
self.app.config.from_object("tests.config_api")
116+
self.app.config["FAB_ADD_SECURITY_API"] = True
117+
118+
self.ctx = self.app.app_context()
119+
self.ctx.push()
120+
SQLA = get_sqla_class()
121+
self.db = SQLA(self.app)
122+
self.appbuilder = AppBuilder(self.app, self.db.session)
123+
self.create_default_users(self.appbuilder)
124+
125+
# Patch hooks on the registered GroupApi view
126+
self.group_api = None
127+
for view in self.appbuilder.baseviews:
128+
if view.__class__.__name__ == "GroupApi":
129+
self.group_api = view
130+
break
131+
self.assertIsNotNone(self.group_api, "GroupApi view not found")
132+
self.group_api.post_add = MagicMock()
133+
self.group_api.post_update = MagicMock()
134+
135+
def tearDown(self):
136+
groups = self.appbuilder.session.query(Group).all()
137+
for group in groups:
138+
group.users = []
139+
group.roles = []
140+
self.appbuilder.session.delete(group)
141+
self.appbuilder.session.commit()
142+
self.appbuilder.session.close()
143+
self.ctx.pop()
144+
145+
def test_post_add_called_on_create_group(self):
146+
client = self.app.test_client()
147+
token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
148+
149+
payload = {
150+
"name": "hook_test_group",
151+
"label": "Hook Test",
152+
"description": "Test group for hooks",
153+
}
154+
rv = self.auth_client_post(client, token, "api/v1/security/groups/", payload)
155+
self.assertEqual(rv.status_code, 201)
156+
self.group_api.post_add.assert_called_once()
157+
called_model = self.group_api.post_add.call_args[0][0]
158+
self.assertIsInstance(called_model, Group)
159+
self.assertEqual(called_model.name, "hook_test_group")
160+
161+
def test_post_update_called_on_edit_group(self):
162+
client = self.app.test_client()
163+
token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
164+
165+
group = self.appbuilder.sm.add_group("hook_edit_group", "label", "description")
166+
self.appbuilder.session.commit()
167+
168+
rv = self.auth_client_put(
169+
client,
170+
token,
171+
f"api/v1/security/groups/{group.id}",
172+
{"label": "updated_label"},
173+
)
174+
self.assertEqual(rv.status_code, 200)
175+
self.group_api.post_update.assert_called_once()
176+
called_model = self.group_api.post_update.call_args[0][0]
177+
self.assertIsInstance(called_model, Group)
178+
179+
180+
class AuthEventHooksTestCase(FABTestCase):
181+
"""Test on_user_login, on_user_login_failed, and on_user_logout hooks."""
182+
183+
def setUp(self):
184+
self.app = Flask(__name__)
185+
self.basedir = os.path.abspath(os.path.dirname(__file__))
186+
self.app.config.from_object("tests.config_security_api")
187+
188+
self.ctx = self.app.app_context()
189+
self.ctx.push()
190+
SQLA = get_sqla_class()
191+
self.db = SQLA(self.app)
192+
self.appbuilder = AppBuilder(self.app, self.db.session)
193+
self.create_default_users(self.appbuilder)
194+
195+
# Patch auth hooks on the security manager
196+
self.appbuilder.sm.on_user_login = MagicMock()
197+
self.appbuilder.sm.on_user_login_failed = MagicMock()
198+
self.appbuilder.sm.on_user_logout = MagicMock()
199+
200+
def tearDown(self):
201+
self.ctx.pop()
202+
203+
def test_on_user_login_called_on_successful_login(self):
204+
client = self.app.test_client()
205+
rv = self._login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
206+
self.assertEqual(rv.status_code, 200)
207+
208+
self.appbuilder.sm.on_user_login.assert_called_once()
209+
called_user = self.appbuilder.sm.on_user_login.call_args[0][0]
210+
self.assertEqual(called_user.username, USERNAME_ADMIN)
211+
self.appbuilder.sm.on_user_login_failed.assert_not_called()
212+
213+
def test_on_user_login_failed_called_on_bad_password(self):
214+
client = self.app.test_client()
215+
rv = self._login(client, USERNAME_ADMIN, "wrong_password")
216+
self.assertEqual(rv.status_code, 401)
217+
218+
self.appbuilder.sm.on_user_login_failed.assert_called_once()
219+
called_user = self.appbuilder.sm.on_user_login_failed.call_args[0][0]
220+
self.assertEqual(called_user.username, USERNAME_ADMIN)
221+
self.appbuilder.sm.on_user_login.assert_not_called()
222+
223+
def test_on_user_logout_called_on_logout(self):
224+
client = self.app.test_client()
225+
# Login via browser (session-based) so logout works
226+
self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
227+
228+
# Verify user is logged in by accessing a protected resource
229+
rv = client.get("/logout/", follow_redirects=False)
230+
self.assertIn(rv.status_code, (301, 302))
231+
232+
self.appbuilder.sm.on_user_logout.assert_called_once()
233+
called_user = self.appbuilder.sm.on_user_logout.call_args[0][0]
234+
self.assertIsNotNone(called_user)
235+
self.assertEqual(called_user.username, USERNAME_ADMIN)
236+
237+
def test_hooks_receive_correct_user_object(self):
238+
"""Verify hooks receive the actual User model instance."""
239+
client = self.app.test_client()
240+
rv = self._login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
241+
self.assertEqual(rv.status_code, 200)
242+
243+
called_user = self.appbuilder.sm.on_user_login.call_args[0][0]
244+
self.assertIsInstance(called_user, User)
245+
self.assertTrue(called_user.is_authenticated)

0 commit comments

Comments
 (0)