Skip to content

Commit 7875d02

Browse files
authored
feat: fix inconsistent async handling bug for watcher.update() API (#406)
1 parent 243d51f commit 7875d02

File tree

2 files changed

+98
-10
lines changed

2 files changed

+98
-10
lines changed

casbin/async_internal_enforcer.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ async def save_policy(self):
113113
else:
114114
update_for_save_policy(self.model)
115115
else:
116-
self.watcher.update()
116+
if inspect.iscoroutinefunction(self.watcher.update):
117+
await self.watcher.update()
118+
else:
119+
self.watcher.update()
117120

118121
async def _add_policy(self, sec, ptype, rule):
119122
"""async adds a rule to the current policy."""
@@ -133,7 +136,10 @@ async def _add_policy(self, sec, ptype, rule):
133136
else:
134137
update_for_add_policy(sec, ptype, rule)
135138
else:
136-
self.watcher.update()
139+
if inspect.iscoroutinefunction(self.watcher.update):
140+
await self.watcher.update()
141+
else:
142+
self.watcher.update()
137143

138144
rule_added = self.model.add_policy(sec, ptype, rule)
139145

@@ -161,7 +167,10 @@ async def _add_policies(self, sec, ptype, rules):
161167
else:
162168
update_for_add_policies(sec, ptype, rules)
163169
else:
164-
self.watcher.update()
170+
if inspect.iscoroutinefunction(self.watcher.update):
171+
await self.watcher.update()
172+
else:
173+
self.watcher.update()
165174

166175
rules_added = self.model.add_policies(sec, ptype, rules)
167176

@@ -180,7 +189,10 @@ async def _update_policy(self, sec, ptype, old_rule, new_rule):
180189
return False
181190

182191
if self.watcher and self.auto_notify_watcher:
183-
self.watcher.update()
192+
if inspect.iscoroutinefunction(self.watcher.update):
193+
await self.watcher.update()
194+
else:
195+
self.watcher.update()
184196

185197
return rule_updated
186198

@@ -197,7 +209,10 @@ async def _update_policies(self, sec, ptype, old_rules, new_rules):
197209
return False
198210

199211
if self.watcher and self.auto_notify_watcher:
200-
self.watcher.update()
212+
if inspect.iscoroutinefunction(self.watcher.update):
213+
await self.watcher.update()
214+
else:
215+
self.watcher.update()
201216

202217
return rules_updated
203218

@@ -225,7 +240,10 @@ async def _update_filtered_policies(self, sec, ptype, new_rules, field_index, *f
225240
if sec == "g":
226241
self.build_role_links()
227242
if self.watcher and self.auto_notify_watcher:
228-
self.watcher.update()
243+
if inspect.iscoroutinefunction(self.watcher.update):
244+
await self.watcher.update()
245+
else:
246+
self.watcher.update()
229247
return is_rule_changed
230248

231249
async def _remove_policy(self, sec, ptype, rule):
@@ -247,7 +265,10 @@ async def _remove_policy(self, sec, ptype, rule):
247265
else:
248266
update_for_remove_policy(sec, ptype, rule)
249267
else:
250-
self.watcher.update()
268+
if inspect.iscoroutinefunction(self.watcher.update):
269+
await self.watcher.update()
270+
else:
271+
self.watcher.update()
251272

252273
return rule_removed
253274

@@ -273,7 +294,10 @@ async def _remove_policies(self, sec, ptype, rules):
273294
else:
274295
update_for_remove_policies(sec, ptype, rules)
275296
else:
276-
self.watcher.update()
297+
if inspect.iscoroutinefunction(self.watcher.update):
298+
await self.watcher.update()
299+
else:
300+
self.watcher.update()
277301

278302
return rules_removed
279303

@@ -296,7 +320,10 @@ async def _remove_filtered_policy(self, sec, ptype, field_index, *field_values):
296320
else:
297321
update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
298322
else:
299-
self.watcher.update()
323+
if inspect.iscoroutinefunction(self.watcher.update):
324+
await self.watcher.update()
325+
else:
326+
self.watcher.update()
300327

301328
return rule_removed
302329

@@ -312,7 +339,10 @@ async def _remove_filtered_policy_returns_effects(self, sec, ptype, field_index,
312339
return False
313340

314341
if self.watcher and self.auto_notify_watcher:
315-
self.watcher.update()
342+
if inspect.iscoroutinefunction(self.watcher.update):
343+
await self.watcher.update()
344+
else:
345+
self.watcher.update()
316346

317347
return rule_removed
318348

tests/test_watcher_ex.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,24 @@ def test_auto_notify_disabled(self):
287287
self.assertEqual(w.notify_message, None)
288288

289289

290+
class AsyncMinimalWatcher:
291+
"""A minimal async watcher that only implements async update() method."""
292+
293+
def __init__(self):
294+
self.update_count = 0
295+
296+
async def update(self):
297+
"""update the policy"""
298+
self.update_count += 1
299+
return True
300+
301+
async def close(self):
302+
pass
303+
304+
async def set_update_callback(self, callback):
305+
pass
306+
307+
290308
class TestAsyncWatcherEx(IsolatedAsyncioTestCase):
291309
def get_enforcer(self, model=None, adapter=None):
292310
return casbin.AsyncEnforcer(
@@ -365,3 +383,43 @@ async def test_auto_notify_disabled(self):
365383

366384
await e.remove_policies(rules)
367385
self.assertEqual(w.notify_message, None)
386+
387+
async def test_async_minimal_watcher(self):
388+
"""Test that a watcher with only async update() method works properly."""
389+
e = self.get_enforcer(
390+
get_examples("basic_model.conf"),
391+
get_examples("basic_policy.csv"),
392+
)
393+
await e.load_policy()
394+
395+
w = AsyncMinimalWatcher()
396+
e.set_watcher(w)
397+
e.enable_auto_notify_watcher(True)
398+
399+
# Test save_policy
400+
await e.save_policy()
401+
self.assertEqual(w.update_count, 1)
402+
403+
# Test add_policy - fallback to update()
404+
await e.add_policy("admin", "data1", "read")
405+
self.assertEqual(w.update_count, 2)
406+
407+
# Test remove_policy - fallback to update()
408+
await e.remove_policy("admin", "data1", "read")
409+
self.assertEqual(w.update_count, 3)
410+
411+
# Test remove_filtered_policy - fallback to update()
412+
await e.remove_filtered_policy(1, "data1")
413+
self.assertEqual(w.update_count, 4)
414+
415+
# Test add_policies - fallback to update()
416+
rules = [
417+
["jack", "data4", "read"],
418+
["katy", "data4", "write"],
419+
]
420+
await e.add_policies(rules)
421+
self.assertEqual(w.update_count, 5)
422+
423+
# Test remove_policies - fallback to update()
424+
await e.remove_policies(rules)
425+
self.assertEqual(w.update_count, 6)

0 commit comments

Comments
 (0)