diff --git a/django_celery_beat/schedulers.py b/django_celery_beat/schedulers.py index 80debf96..1446ff4c 100644 --- a/django_celery_beat/schedulers.py +++ b/django_celery_beat/schedulers.py @@ -10,7 +10,8 @@ from celery.utils.time import maybe_make_aware from django.conf import settings from django.core.exceptions import ObjectDoesNotExist -from django.db import close_old_connections, transaction +from django.db import (DEFAULT_DB_ALIAS, close_old_connections, router, + transaction) from django.db.models import Q from django.db.utils import DatabaseError, InterfaceError from django.utils import timezone @@ -282,7 +283,7 @@ def schedule_changed(self): # other transactions until the current transaction is # committed (Issue #41). try: - transaction.commit() + transaction.commit(using=self.target_db) except transaction.TransactionManagementError: pass # not in transaction management. @@ -311,7 +312,17 @@ def reserve(self, entry): self._dirty.add(new_entry.name) return new_entry - def sync(self): + @property + def target_db(self): + """Determine if there is a django route""" + if not settings.DATABASE_ROUTERS: + return DEFAULT_DB_ALIAS + # If the project does not actually implement this method, + # DEFAULT_DB_ALIAS will be automatically returned. + # The exception will be located to the django routing section + db = router.db_for_write(self.Model) + return db or DEFAULT_DB_ALIAS + def _sync(self): if logger.isEnabledFor(logging.DEBUG): debug('Writing entries...') _tried = set() @@ -337,6 +348,10 @@ def sync(self): # retry later, only for the failed ones self._dirty |= _failed + def sync(self): + with transaction.atomic(using=self.target_db): + self._sync() + def update_from_dict(self, mapping): s = {} for name, entry_fields in mapping.items(): diff --git a/t/unit/test_schedulers.py b/t/unit/test_schedulers.py index 05fb3ad4..0217edef 100644 --- a/t/unit/test_schedulers.py +++ b/t/unit/test_schedulers.py @@ -28,6 +28,15 @@ _ids = count(0) +class Router: + target_db = None + + def db_for_read(self, model, **hints): + return self.target_db + + db_for_write = db_for_read + + @pytest.fixture(autouse=True) def no_multiprocessing_finalizers(patching): patching('multiprocessing.util.Finalize') @@ -122,6 +131,10 @@ def create_crontab_schedule(self): class test_ModelEntry(SchedulerCase): Entry = EntryTrackSave + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_entry(self): m = self.create_model_interval(schedule(timedelta(seconds=10))) e = self.Entry(m, app=self.app) @@ -154,7 +167,9 @@ def test_entry(self): @override_settings( USE_TZ=False, - DJANGO_CELERY_BEAT_TZ_AWARE=False + DJANGO_CELERY_BEAT_TZ_AWARE=False, + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] ) @pytest.mark.usefixtures('depends_on_current_app') @timezone.override('Europe/Berlin') @@ -186,7 +201,9 @@ def test_entry_is_due__no_use_tz(self): @override_settings( USE_TZ=False, - DJANGO_CELERY_BEAT_TZ_AWARE=False + DJANGO_CELERY_BEAT_TZ_AWARE=False, + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] ) @pytest.mark.usefixtures('depends_on_current_app') @timezone.override('Europe/Berlin') @@ -271,7 +288,9 @@ def test_entry_and_model_last_run_at_when_model_changed(self, monkeypatch): USE_TZ=False, DJANGO_CELERY_BEAT_TZ_AWARE=False, TIME_ZONE="Europe/Berlin", - CELERY_TIMEZONE="America/New_York" + CELERY_TIMEZONE="America/New_York", + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] ) @pytest.mark.usefixtures('depends_on_current_app') @timezone.override('Europe/Berlin') @@ -302,6 +321,10 @@ def test_entry_is_due__celery_timezone_doesnt_match_time_zone(self): if hasattr(time, "tzset"): time.tzset() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_task_with_start_time(self): interval = 10 right_now = self.app.now() @@ -323,6 +346,10 @@ def test_task_with_start_time(self): assert not isdue assert delay == math.ceil((tomorrow - right_now).total_seconds()) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_one_off_task(self): interval = 10 right_now = self.app.now() @@ -379,7 +406,11 @@ def test_task_with_expires(self): class test_DatabaseSchedulerFromAppConf(SchedulerCase): Scheduler = TrackingScheduler - @pytest.mark.django_db + @pytest.mark.django_db() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) @pytest.fixture(autouse=True) def setup_scheduler(self, app): self.app = app @@ -389,6 +420,10 @@ def setup_scheduler(self, app): self.m1 = PeriodicTask(name=self.entry_name, interval=self.create_interval_schedule()) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_constructor(self): s = self.Scheduler(app=self.app) @@ -396,6 +431,10 @@ def test_constructor(self): assert s._last_sync is None assert s.sync_every + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_periodic_task_model_enabled_schedule(self): s = self.Scheduler(app=self.app) sched = s.schedule @@ -409,6 +448,10 @@ def test_periodic_task_model_enabled_schedule(self): assert e.model.expires is None assert e.model.expire_seconds == 12 * 3600 + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_periodic_task_model_disabled_schedule(self): self.m1.enabled = False self.m1.save() @@ -435,7 +478,11 @@ def test_periodic_task_model_schedule_type_change(self): class test_DatabaseScheduler(SchedulerCase): Scheduler = TrackingScheduler - @pytest.mark.django_db + @pytest.mark.django_db() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) @pytest.fixture(autouse=True) def setup_scheduler(self, app): self.app = app @@ -525,11 +572,19 @@ def setup_scheduler(self, app): self.s = self.Scheduler(app=self.app) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_constructor(self): assert isinstance(self.s._dirty, set) assert self.s._last_sync is None assert self.s.sync_every + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_all_as_schedule(self): sched = self.s.schedule assert sched @@ -538,6 +593,10 @@ def test_all_as_schedule(self): for n, e in sched.items(): assert isinstance(e, self.s.Entry) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_get_excluded_hours_for_crontab_tasks(self): now_hour = timezone.localtime(timezone.now()).hour excluded_hours = self.s.get_excluded_hours_for_crontab_tasks() @@ -564,6 +623,10 @@ def test_schedule_changed(self): with pytest.raises(KeyError): self.s.schedule.__getitem__(self.m3.name) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_should_sync(self): assert self.s.should_sync() self.s._last_sync = monotonic() @@ -571,6 +634,10 @@ def test_should_sync(self): self.s._last_sync -= self.s.sync_every assert self.s.should_sync() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_reserve(self): e1 = self.s.schedule[self.m1.name] self.s.schedule[self.m1.name] = self.s.reserve(e1) @@ -581,6 +648,10 @@ def test_reserve(self): assert self.s.flushed == 1 assert self.m2.name in self.s._dirty + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_sync_not_saves_last_run_at_while_schedule_changed(self): # Update e1 last_run_at and add to dirty e1 = self.s.schedule[self.m2.name] @@ -614,6 +685,10 @@ def test_sync_saves_last_run_at(self): e2 = self.s.schedule[self.m2.name] assert e2.last_run_at == last_run2 + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_sync_syncs_before_save(self): # Get the entry for m2 e1 = self.s.schedule[self.m2.name] @@ -636,6 +711,10 @@ def test_sync_syncs_before_save(self): assert e3.last_run_at == e2.last_run_at assert e3.args == [16, 16] + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_periodic_task_disabled_and_enabled(self): # Get the entry for m2 e1 = self.s.schedule[self.m2.name] @@ -664,6 +743,10 @@ def test_periodic_task_disabled_and_enabled(self): assert self.m2.name in self.s.schedule assert self.s.flushed == 3 + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_periodic_task_disabled_while_reserved(self): # Get the entry for m2 e1 = self.s.schedule[self.m2.name] @@ -688,26 +771,46 @@ def test_periodic_task_disabled_while_reserved(self): assert self.m2.name not in self.s.schedule assert self.s.flushed == 2 + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_sync_not_dirty(self): self.s._dirty.clear() self.s.sync() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_sync_object_gone(self): self.s._dirty.add('does-not-exist') self.s.sync() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_sync_rollback_on_save_error(self): self.s.schedule[self.m1.name] = EntrySaveRaises(self.m1, app=self.app) self.s._dirty.add(self.m1.name) with pytest.raises(RuntimeError): self.s.sync() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_update_scheduler_heap_invalidation(self, monkeypatch): # mock "schedule_changed" to always trigger update for # all calls to schedule, as a change may occur at any moment monkeypatch.setattr(self.s, 'schedule_changed', lambda: True) self.s.tick() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_heap_size_is_constant(self, monkeypatch): # heap size is constant unless the schedule changes monkeypatch.setattr(self.s, 'schedule_changed', lambda: True) @@ -717,6 +820,10 @@ def test_heap_size_is_constant(self, monkeypatch): self.s.tick() assert len(self.s._heap) == expected_heap_size + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_scheduler_schedules_equality_on_change(self, monkeypatch): monkeypatch.setattr(self.s, 'schedule_changed', lambda: False) assert self.s.schedules_equal(self.s.schedule, self.s.schedule) @@ -724,6 +831,10 @@ def test_scheduler_schedules_equality_on_change(self, monkeypatch): monkeypatch.setattr(self.s, 'schedule_changed', lambda: True) assert not self.s.schedules_equal(self.s.schedule, self.s.schedule) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_heap_always_return_the_first_item(self): interval = 10 @@ -902,12 +1013,20 @@ def test_crontab_with_start_time_tick(self, app): @pytest.mark.django_db class test_models(SchedulerCase): + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_IntervalSchedule_unicode(self): assert (str(IntervalSchedule(every=1, period='seconds')) == 'every second') assert (str(IntervalSchedule(every=10, period='seconds')) == 'every 10 seconds') + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_CrontabSchedule_unicode(self): assert str(CrontabSchedule( minute=3, @@ -922,10 +1041,18 @@ def test_CrontabSchedule_unicode(self): month_of_year='4,6', )) == '3 3 */2 4,6 tue (m/h/dM/MY/d) UTC' + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_unicode_interval(self): p = self.create_model_interval(schedule(timedelta(seconds=10))) assert str(p) == f'{p.name}: every 10.0 seconds' + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_unicode_crontab(self): p = self.create_model_crontab(crontab( hour='4, 5', @@ -935,6 +1062,10 @@ def test_PeriodicTask_unicode_crontab(self): p.name ) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_unicode_solar(self): p = self.create_model_solar( solar('solar_noon', 48.06, 12.86), name='solar_event' @@ -943,6 +1074,10 @@ def test_PeriodicTask_unicode_solar(self): 'Solar noon', '48.06', '12.86' ) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_unicode_clocked(self): time = make_aware(datetime.now()) p = self.create_model_clocked( @@ -952,6 +1087,10 @@ def test_PeriodicTask_unicode_clocked(self): 'clocked_event', str(time) ) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_schedule_property(self): p1 = self.create_model_interval(schedule(timedelta(seconds=10))) s1 = p1.schedule @@ -970,10 +1109,18 @@ def test_PeriodicTask_schedule_property(self): assert s2.day_of_month == {1, 2, 3, 4, 5, 6, 7} assert s2.month_of_year == {1, 4, 7, 10} + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_unicode_no_schedule(self): p = self.create_model() assert str(p) == f'{p.name}: {{no schedule}}' + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_CrontabSchedule_schedule(self): s = CrontabSchedule( minute='3, 7', @@ -988,6 +1135,10 @@ def test_CrontabSchedule_schedule(self): assert s.schedule.day_of_month == {1, 16} assert s.schedule.month_of_year == {1, 7} + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_CrontabSchedule_long_schedule(self): s = CrontabSchedule( minute=str(list(range(60)))[1:-1], @@ -1009,6 +1160,10 @@ def test_CrontabSchedule_long_schedule(self): field_length = s._meta.get_field(field).max_length assert str_length <= field_length + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_SolarSchedule_schedule(self): s = SolarSchedule(event='solar_noon', latitude=48.06, longitude=12.86) dt = datetime(day=26, month=7, year=2050, hour=1, minute=0) @@ -1030,6 +1185,10 @@ def test_SolarSchedule_schedule(self): assert (nextcheck2 > 0) and (isdue2 is True) or \ (nextcheck2 == s2.max_interval) and (isdue2 is False) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_ClockedSchedule_schedule(self): due_datetime = make_aware(datetime(day=26, month=7, @@ -1059,6 +1218,10 @@ def test_ClockedSchedule_schedule(self): @pytest.mark.django_db class test_model_PeriodicTasks(SchedulerCase): + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_track_changes(self): assert PeriodicTasks.last_change() is None m1 = self.create_model_interval(schedule(timedelta(seconds=10))) @@ -1074,7 +1237,11 @@ def test_track_changes(self): @pytest.mark.django_db class test_modeladmin_PeriodicTaskAdmin(SchedulerCase): - @pytest.mark.django_db + @pytest.mark.django_db() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) @pytest.fixture(autouse=True) def setup_scheduler(self, app): self.app = app @@ -1095,6 +1262,10 @@ def setup_scheduler(self, app): self.m2.task = 'celery.backend_cleanup' self.m2.save() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def patch_request(self, request): """patch request to allow for django messages storage""" setattr(request, 'session', 'session') @@ -1105,6 +1276,10 @@ def patch_request(self, request): # don't hang if broker is down # https://github.com/celery/celery/issues/4627 @pytest.mark.timeout(5) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_run_task(self): ma = PeriodicTaskAdmin(PeriodicTask, self.site) self.request = self.patch_request(self.request_factory.get('/')) @@ -1116,6 +1291,10 @@ def test_run_task(self): # don't hang if broker is down # https://github.com/celery/celery/issues/4627 @pytest.mark.timeout(5) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_run_tasks(self): ma = PeriodicTaskAdmin(PeriodicTask, self.site) self.request = self.patch_request(self.request_factory.get('/'))