diff --git a/casbin/async_internal_enforcer.py b/casbin/async_internal_enforcer.py index 33f9452..0c65728 100644 --- a/casbin/async_internal_enforcer.py +++ b/casbin/async_internal_enforcer.py @@ -113,7 +113,10 @@ async def save_policy(self): else: update_for_save_policy(self.model) else: - self.watcher.update() + if inspect.iscoroutinefunction(self.watcher.update): + await self.watcher.update() + else: + self.watcher.update() async def _add_policy(self, sec, ptype, rule): """async adds a rule to the current policy.""" @@ -133,7 +136,10 @@ async def _add_policy(self, sec, ptype, rule): else: update_for_add_policy(sec, ptype, rule) else: - self.watcher.update() + if inspect.iscoroutinefunction(self.watcher.update): + await self.watcher.update() + else: + self.watcher.update() rule_added = self.model.add_policy(sec, ptype, rule) @@ -161,7 +167,10 @@ async def _add_policies(self, sec, ptype, rules): else: update_for_add_policies(sec, ptype, rules) else: - self.watcher.update() + if inspect.iscoroutinefunction(self.watcher.update): + await self.watcher.update() + else: + self.watcher.update() rules_added = self.model.add_policies(sec, ptype, rules) @@ -180,7 +189,10 @@ async def _update_policy(self, sec, ptype, old_rule, new_rule): return False if self.watcher and self.auto_notify_watcher: - self.watcher.update() + if inspect.iscoroutinefunction(self.watcher.update): + await self.watcher.update() + else: + self.watcher.update() return rule_updated @@ -197,7 +209,10 @@ async def _update_policies(self, sec, ptype, old_rules, new_rules): return False if self.watcher and self.auto_notify_watcher: - self.watcher.update() + if inspect.iscoroutinefunction(self.watcher.update): + await self.watcher.update() + else: + self.watcher.update() return rules_updated @@ -225,7 +240,10 @@ async def _update_filtered_policies(self, sec, ptype, new_rules, field_index, *f if sec == "g": self.build_role_links() if self.watcher and self.auto_notify_watcher: - self.watcher.update() + if inspect.iscoroutinefunction(self.watcher.update): + await self.watcher.update() + else: + self.watcher.update() return is_rule_changed async def _remove_policy(self, sec, ptype, rule): @@ -247,7 +265,10 @@ async def _remove_policy(self, sec, ptype, rule): else: update_for_remove_policy(sec, ptype, rule) else: - self.watcher.update() + if inspect.iscoroutinefunction(self.watcher.update): + await self.watcher.update() + else: + self.watcher.update() return rule_removed @@ -273,7 +294,10 @@ async def _remove_policies(self, sec, ptype, rules): else: update_for_remove_policies(sec, ptype, rules) else: - self.watcher.update() + if inspect.iscoroutinefunction(self.watcher.update): + await self.watcher.update() + else: + self.watcher.update() return rules_removed @@ -296,7 +320,10 @@ async def _remove_filtered_policy(self, sec, ptype, field_index, *field_values): else: update_for_remove_filtered_policy(sec, ptype, field_index, *field_values) else: - self.watcher.update() + if inspect.iscoroutinefunction(self.watcher.update): + await self.watcher.update() + else: + self.watcher.update() return rule_removed @@ -312,7 +339,10 @@ async def _remove_filtered_policy_returns_effects(self, sec, ptype, field_index, return False if self.watcher and self.auto_notify_watcher: - self.watcher.update() + if inspect.iscoroutinefunction(self.watcher.update): + await self.watcher.update() + else: + self.watcher.update() return rule_removed diff --git a/tests/test_watcher_ex.py b/tests/test_watcher_ex.py index 3684087..bfcd557 100644 --- a/tests/test_watcher_ex.py +++ b/tests/test_watcher_ex.py @@ -287,6 +287,24 @@ def test_auto_notify_disabled(self): self.assertEqual(w.notify_message, None) +class AsyncMinimalWatcher: + """A minimal async watcher that only implements async update() method.""" + + def __init__(self): + self.update_count = 0 + + async def update(self): + """update the policy""" + self.update_count += 1 + return True + + async def close(self): + pass + + async def set_update_callback(self, callback): + pass + + class TestAsyncWatcherEx(IsolatedAsyncioTestCase): def get_enforcer(self, model=None, adapter=None): return casbin.AsyncEnforcer( @@ -365,3 +383,43 @@ async def test_auto_notify_disabled(self): await e.remove_policies(rules) self.assertEqual(w.notify_message, None) + + async def test_async_minimal_watcher(self): + """Test that a watcher with only async update() method works properly.""" + e = self.get_enforcer( + get_examples("basic_model.conf"), + get_examples("basic_policy.csv"), + ) + await e.load_policy() + + w = AsyncMinimalWatcher() + e.set_watcher(w) + e.enable_auto_notify_watcher(True) + + # Test save_policy + await e.save_policy() + self.assertEqual(w.update_count, 1) + + # Test add_policy - fallback to update() + await e.add_policy("admin", "data1", "read") + self.assertEqual(w.update_count, 2) + + # Test remove_policy - fallback to update() + await e.remove_policy("admin", "data1", "read") + self.assertEqual(w.update_count, 3) + + # Test remove_filtered_policy - fallback to update() + await e.remove_filtered_policy(1, "data1") + self.assertEqual(w.update_count, 4) + + # Test add_policies - fallback to update() + rules = [ + ["jack", "data4", "read"], + ["katy", "data4", "write"], + ] + await e.add_policies(rules) + self.assertEqual(w.update_count, 5) + + # Test remove_policies - fallback to update() + await e.remove_policies(rules) + self.assertEqual(w.update_count, 6)