Skip to content

Commit 6b6d813

Browse files
dpgasparclaude
andcommitted
feat: add missing CRUD hooks and auth event hooks
Add post_add/post_update calls to UserApi and GroupApi post()/put() methods which were skipping these hooks unlike the base ModelRestApi. Add on_user_login, on_user_login_failed, and on_user_logout overridable hooks to BaseSecurityManager for audit logging and custom auth event handling. Called from update_user_auth_stat and logout views respectively. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 49de9e2 commit 6b6d813

File tree

5 files changed

+306
-0
lines changed

5 files changed

+306
-0
lines changed

flask_appbuilder/security/manager.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,27 @@ def reset_password(self, userid, password):
975975
)
976976
self.update_user(user)
977977

978+
def on_user_login(self, user) -> None:
979+
"""Called after a successful user login.
980+
Override to add custom logic (e.g., audit logging).
981+
982+
:param user: The authenticated user model
983+
"""
984+
985+
def on_user_login_failed(self, user) -> None:
986+
"""Called after a failed user login attempt.
987+
Override to add custom logic (e.g., audit logging).
988+
989+
:param user: The identified (but not authenticated) user model
990+
"""
991+
992+
def on_user_logout(self, user) -> None:
993+
"""Called when a user logs out.
994+
Override to add custom logic (e.g., audit logging).
995+
996+
:param user: The user model that is logging out
997+
"""
998+
978999
def update_user_auth_stat(self, user, success=True):
9791000
"""
9801001
Update user authentication stats upon successful/unsuccessful
@@ -997,8 +1018,10 @@ def update_user_auth_stat(self, user, success=True):
9971018
user.login_count += 1
9981019
user.last_login = datetime.datetime.now()
9991020
user.fail_login_count = 0
1021+
self.on_user_login(user)
10001022
else:
10011023
user.fail_login_count += 1
1024+
self.on_user_login_failed(user)
10021025
self.update_user(user)
10031026

10041027
def auth_user_db(self, username, password):

flask_appbuilder/security/sqla/apis/group/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def post(self):
120120

121121
self.pre_add(model)
122122
self.datamodel.add(model)
123+
self.post_add(model)
123124

124125
return self.response(201, id=model.id)
125126

@@ -216,6 +217,7 @@ def put(self, pk):
216217
model.users = users
217218
self.pre_update(model)
218219
self.datamodel.edit(model)
220+
self.post_update(model)
219221
return self.response(
220222
200,
221223
**{API_RESULT_RES_KEY: self.edit_model_schema.dump(item, many=False)},

flask_appbuilder/security/sqla/apis/user/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def post(self):
165165

166166
self.pre_add(model)
167167
self.datamodel.add(model)
168+
self.post_add(model)
168169
return self.response(201, id=model.id)
169170
except ValidationError as error:
170171
return self.response_400(message=error.messages)
@@ -281,6 +282,7 @@ def put(self, pk):
281282

282283
self.pre_update(model, item)
283284
self.datamodel.edit(model)
285+
self.post_update(model)
284286
return self.response(
285287
200,
286288
**{API_RESULT_RES_KEY: self.edit_model_schema.dump(item, many=False)},

flask_appbuilder/security/views.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,9 +577,22 @@ class AuthView(BaseView):
577577
def login(self):
578578
pass
579579

580+
def _get_authenticated_user(self):
581+
"""Resolve the current user before logout clears the session.
582+
Returns the User model or None if not authenticated.
583+
Note: g.user is a LocalProxy that becomes anonymous after
584+
logout_user(), so we must resolve it beforehand.
585+
"""
586+
if g.user is not None and g.user.is_authenticated:
587+
return self.appbuilder.sm.get_user_by_id(g.user.id)
588+
return None
589+
580590
@expose("/logout/")
581591
def logout(self):
592+
user = self._get_authenticated_user()
582593
logout_user()
594+
if user is not None:
595+
self.appbuilder.sm.on_user_logout(user)
583596
return redirect(
584597
current_app.config.get(
585598
"LOGOUT_REDIRECT_URL", self.appbuilder.get_url_for_index
@@ -865,7 +878,10 @@ def slo(self) -> WerkzeugResponse:
865878
try:
866879
idp = session.get("saml_idp")
867880
if not idp:
881+
user = self._get_authenticated_user()
868882
logout_user()
883+
if user is not None:
884+
self.appbuilder.sm.on_user_logout(user)
869885
return redirect(self.appbuilder.get_url_for_index)
870886

871887
url, should_logout = self.appbuilder.sm.get_saml_logout_redirect_url(
@@ -874,16 +890,25 @@ def slo(self) -> WerkzeugResponse:
874890
session_index=session.get("saml_session_index"),
875891
)
876892
if should_logout:
893+
user = self._get_authenticated_user()
877894
logout_user()
895+
if user is not None:
896+
self.appbuilder.sm.on_user_logout(user)
878897
return redirect(url or self.appbuilder.get_url_for_index)
879898

880899
except (OneLogin_Saml2_Error, OneLogin_Saml2_ValidationError) as e:
881900
log.error("SAML SLO validation error: %s", e)
901+
user = self._get_authenticated_user()
882902
logout_user()
903+
if user is not None:
904+
self.appbuilder.sm.on_user_logout(user)
883905
return redirect(self.appbuilder.get_url_for_index)
884906
except ValueError as e:
885907
log.error("SAML SLO configuration error: %s", e)
908+
user = self._get_authenticated_user()
886909
logout_user()
910+
if user is not None:
911+
self.appbuilder.sm.on_user_logout(user)
887912
return redirect(self.appbuilder.get_url_for_index)
888913

889914
@expose("/logout/")
@@ -895,7 +920,10 @@ def logout(self) -> WerkzeugResponse:
895920
return redirect(url_for(".slo"))
896921
except Exception:
897922
pass
923+
user = self._get_authenticated_user()
898924
logout_user()
925+
if user is not None:
926+
self.appbuilder.sm.on_user_logout(user)
899927
return redirect(
900928
current_app.config.get(
901929
"LOGOUT_REDIRECT_URL", self.appbuilder.get_url_for_index

tests/test_hooks.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
import logging
2+
import os
3+
import uuid
4+
from unittest.mock import MagicMock
5+
6+
from flask import Flask
7+
from flask_appbuilder import AppBuilder
8+
from flask_appbuilder.security.sqla.models import Group, Role, User
9+
from flask_appbuilder.utils.legacy import get_sqla_class
10+
from tests.base import FABTestCase
11+
from tests.const import PASSWORD_ADMIN, USERNAME_ADMIN
12+
13+
14+
log = logging.getLogger(__name__)
15+
16+
17+
def _uid():
18+
return uuid.uuid4().hex[:8]
19+
20+
21+
class UserApiHooksTestCase(FABTestCase):
22+
"""Test that post_add and post_update hooks are called on UserApi."""
23+
24+
def setUp(self):
25+
self.app = Flask(__name__)
26+
self.basedir = os.path.abspath(os.path.dirname(__file__))
27+
self.app.config.from_object("tests.config_security_api")
28+
29+
self.ctx = self.app.app_context()
30+
self.ctx.push()
31+
SQLA = get_sqla_class()
32+
self.db = SQLA(self.app)
33+
self.appbuilder = AppBuilder(self.app, self.db.session)
34+
self.create_default_users(self.appbuilder)
35+
36+
# Patch hooks on the registered UserApi view
37+
self.user_api = None
38+
for view in self.appbuilder.baseviews:
39+
if view.__class__.__name__ == "UserApi":
40+
self.user_api = view
41+
break
42+
self.assertIsNotNone(self.user_api, "UserApi view not found")
43+
self.user_api.post_add = MagicMock()
44+
self.user_api.post_update = MagicMock()
45+
self._created_users = []
46+
47+
def tearDown(self):
48+
for username in self._created_users:
49+
user = self.appbuilder.sm.find_user(username=username)
50+
if user:
51+
self.appbuilder.session.delete(user)
52+
self.appbuilder.session.commit()
53+
self.ctx.pop()
54+
55+
def test_post_add_called_on_create_user(self):
56+
client = self.app.test_client()
57+
token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
58+
59+
uid = _uid()
60+
role = self.appbuilder.sm.find_role("Admin")
61+
username = f"hook_add_{uid}"
62+
payload = {
63+
"active": True,
64+
"email": f"hook_add_{uid}@fab.com",
65+
"first_name": "hook",
66+
"last_name": "test",
67+
"password": "password",
68+
"roles": [role.id],
69+
"username": username,
70+
}
71+
rv = self.auth_client_post(
72+
client, token, "api/v1/security/users/", payload
73+
)
74+
self.assertEqual(rv.status_code, 201)
75+
self._created_users.append(username)
76+
self.user_api.post_add.assert_called_once()
77+
called_model = self.user_api.post_add.call_args[0][0]
78+
self.assertIsInstance(called_model, User)
79+
self.assertEqual(called_model.username, username)
80+
81+
def test_post_update_called_on_edit_user(self):
82+
client = self.app.test_client()
83+
token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
84+
85+
uid = _uid()
86+
edit_username = f"hook_edit_{uid}"
87+
role = self.appbuilder.sm.find_role("Admin")
88+
user = self.appbuilder.sm.add_user(
89+
username=edit_username,
90+
first_name="hook",
91+
last_name="edit",
92+
email=f"hook_edit_{uid}@fab.com",
93+
role=role,
94+
password="password",
95+
)
96+
self._created_users.append(edit_username)
97+
98+
rv = self.auth_client_put(
99+
client,
100+
token,
101+
f"api/v1/security/users/{user.id}",
102+
{"first_name": "updated_hook"},
103+
)
104+
self.assertEqual(rv.status_code, 200)
105+
self.user_api.post_update.assert_called_once()
106+
called_model = self.user_api.post_update.call_args[0][0]
107+
self.assertIsInstance(called_model, User)
108+
self.assertEqual(called_model.first_name, "updated_hook")
109+
110+
111+
class GroupApiHooksTestCase(FABTestCase):
112+
"""Test that post_add and post_update hooks are called on GroupApi."""
113+
114+
def setUp(self):
115+
self.app = Flask(__name__)
116+
self.basedir = os.path.abspath(os.path.dirname(__file__))
117+
self.app.config.from_object("tests.config_api")
118+
self.app.config["FAB_ADD_SECURITY_API"] = True
119+
120+
self.ctx = self.app.app_context()
121+
self.ctx.push()
122+
SQLA = get_sqla_class()
123+
self.db = SQLA(self.app)
124+
self.appbuilder = AppBuilder(self.app, self.db.session)
125+
self.create_default_users(self.appbuilder)
126+
127+
# Patch hooks on the registered GroupApi view
128+
self.group_api = None
129+
for view in self.appbuilder.baseviews:
130+
if view.__class__.__name__ == "GroupApi":
131+
self.group_api = view
132+
break
133+
self.assertIsNotNone(self.group_api, "GroupApi view not found")
134+
self.group_api.post_add = MagicMock()
135+
self.group_api.post_update = MagicMock()
136+
137+
def tearDown(self):
138+
groups = self.appbuilder.session.query(Group).all()
139+
for group in groups:
140+
group.users = []
141+
group.roles = []
142+
self.appbuilder.session.delete(group)
143+
self.appbuilder.session.commit()
144+
self.appbuilder.session.close()
145+
self.ctx.pop()
146+
147+
def test_post_add_called_on_create_group(self):
148+
client = self.app.test_client()
149+
token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
150+
151+
payload = {
152+
"name": "hook_test_group",
153+
"label": "Hook Test",
154+
"description": "Test group for hooks",
155+
}
156+
rv = self.auth_client_post(
157+
client, token, "api/v1/security/groups/", payload
158+
)
159+
self.assertEqual(rv.status_code, 201)
160+
self.group_api.post_add.assert_called_once()
161+
called_model = self.group_api.post_add.call_args[0][0]
162+
self.assertIsInstance(called_model, Group)
163+
self.assertEqual(called_model.name, "hook_test_group")
164+
165+
def test_post_update_called_on_edit_group(self):
166+
client = self.app.test_client()
167+
token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
168+
169+
group = self.appbuilder.sm.add_group(
170+
"hook_edit_group", "label", "description"
171+
)
172+
self.appbuilder.session.commit()
173+
174+
rv = self.auth_client_put(
175+
client,
176+
token,
177+
f"api/v1/security/groups/{group.id}",
178+
{"label": "updated_label"},
179+
)
180+
self.assertEqual(rv.status_code, 200)
181+
self.group_api.post_update.assert_called_once()
182+
called_model = self.group_api.post_update.call_args[0][0]
183+
self.assertIsInstance(called_model, Group)
184+
185+
186+
class AuthEventHooksTestCase(FABTestCase):
187+
"""Test on_user_login, on_user_login_failed, and on_user_logout hooks."""
188+
189+
def setUp(self):
190+
self.app = Flask(__name__)
191+
self.basedir = os.path.abspath(os.path.dirname(__file__))
192+
self.app.config.from_object("tests.config_security_api")
193+
194+
self.ctx = self.app.app_context()
195+
self.ctx.push()
196+
SQLA = get_sqla_class()
197+
self.db = SQLA(self.app)
198+
self.appbuilder = AppBuilder(self.app, self.db.session)
199+
self.create_default_users(self.appbuilder)
200+
201+
# Patch auth hooks on the security manager
202+
self.appbuilder.sm.on_user_login = MagicMock()
203+
self.appbuilder.sm.on_user_login_failed = MagicMock()
204+
self.appbuilder.sm.on_user_logout = MagicMock()
205+
206+
def tearDown(self):
207+
self.ctx.pop()
208+
209+
def test_on_user_login_called_on_successful_login(self):
210+
client = self.app.test_client()
211+
rv = self._login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
212+
self.assertEqual(rv.status_code, 200)
213+
214+
self.appbuilder.sm.on_user_login.assert_called_once()
215+
called_user = self.appbuilder.sm.on_user_login.call_args[0][0]
216+
self.assertEqual(called_user.username, USERNAME_ADMIN)
217+
self.appbuilder.sm.on_user_login_failed.assert_not_called()
218+
219+
def test_on_user_login_failed_called_on_bad_password(self):
220+
client = self.app.test_client()
221+
rv = self._login(client, USERNAME_ADMIN, "wrong_password")
222+
self.assertEqual(rv.status_code, 401)
223+
224+
self.appbuilder.sm.on_user_login_failed.assert_called_once()
225+
called_user = self.appbuilder.sm.on_user_login_failed.call_args[0][0]
226+
self.assertEqual(called_user.username, USERNAME_ADMIN)
227+
self.appbuilder.sm.on_user_login.assert_not_called()
228+
229+
def test_on_user_logout_called_on_logout(self):
230+
client = self.app.test_client()
231+
# Login via browser (session-based) so logout works
232+
self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
233+
234+
# Verify user is logged in by accessing a protected resource
235+
rv = client.get("/logout/", follow_redirects=False)
236+
self.assertIn(rv.status_code, (301, 302))
237+
238+
self.appbuilder.sm.on_user_logout.assert_called_once()
239+
called_user = self.appbuilder.sm.on_user_logout.call_args[0][0]
240+
self.assertIsNotNone(called_user)
241+
self.assertEqual(called_user.username, USERNAME_ADMIN)
242+
243+
def test_hooks_receive_correct_user_object(self):
244+
"""Verify hooks receive the actual User model instance."""
245+
client = self.app.test_client()
246+
rv = self._login(client, USERNAME_ADMIN, PASSWORD_ADMIN)
247+
self.assertEqual(rv.status_code, 200)
248+
249+
called_user = self.appbuilder.sm.on_user_login.call_args[0][0]
250+
self.assertIsInstance(called_user, User)
251+
self.assertTrue(called_user.is_authenticated)

0 commit comments

Comments
 (0)