diff --git a/authentik/lib/sync/outgoing/api.py b/authentik/lib/sync/outgoing/api.py index b5bbbae35512..2059ce122800 100644 --- a/authentik/lib/sync/outgoing/api.py +++ b/authentik/lib/sync/outgoing/api.py @@ -1,3 +1,4 @@ +from django.db.models import Model from dramatiq.actor import Actor from dramatiq.results.errors import ResultFailure from drf_spectacular.utils import extend_schema @@ -11,7 +12,7 @@ from authentik.events.logs import LogEventSerializer from authentik.lib.sync.api import SyncStatusSerializer from authentik.lib.sync.outgoing.models import OutgoingSyncProvider -from authentik.lib.utils.reflection import class_to_path +from authentik.lib.utils.reflection import class_to_path, path_to_class from authentik.rbac.filters import ObjectFilter from authentik.tasks.models import Task, TaskStatus @@ -101,16 +102,20 @@ def sync_object(self, request: Request, pk: int) -> Response: provider: OutgoingSyncProvider = self.get_object() params = SyncObjectSerializer(data=request.data) params.is_valid(raise_exception=True) + object_type = params.validated_data["sync_object_model"] + _object_type: type[Model] = path_to_class(object_type) + pk = params.validated_data["sync_object_id"] msg = self.sync_objects_task.send_with_options( kwargs={ - "object_type": params.validated_data["sync_object_model"], + "object_type": object_type, "page": 1, "provider_pk": provider.pk, "override_dry_run": params.validated_data["override_dry_run"], - "pk": params.validated_data["sync_object_id"], + "pk": pk, }, retries=0, rel_obj=provider, + uid=f"{provider.name}:{_object_type._meta.model_name}:{pk}:manual", ) try: msg.get_result(block=True) diff --git a/authentik/lib/sync/outgoing/tasks.py b/authentik/lib/sync/outgoing/tasks.py index dbb22bbf8c0d..7650d43d2c13 100644 --- a/authentik/lib/sync/outgoing/tasks.py +++ b/authentik/lib/sync/outgoing/tasks.py @@ -1,7 +1,6 @@ from django.core.paginator import Paginator from django.db.models import Model, QuerySet from django.db.models.query import Q -from django.utils.text import slugify from dramatiq.actor import Actor from dramatiq.composition import group from dramatiq.errors import Retry @@ -51,6 +50,7 @@ def sync_paginator( time_limit=PAGE_TIMEOUT_MS, # Assign tasks to the same schedule as the current one rel_obj=current_task.rel_obj, + uid=f"{provider.name}:{object_type._meta.model_name}:{page}", **options, ) tasks.append(page_sync) @@ -73,7 +73,6 @@ def sync( if not provider: task.warning("No provider found. Is it assigned to an application?") return - task.set_uid(slugify(provider.name)) task.info("Starting full provider sync") self.logger.debug("Starting provider sync") with provider.sync_lock as lock_acquired: @@ -125,14 +124,13 @@ def sync_objects( provider_pk=provider_pk, object_type=object_type, ) - provider: OutgoingSyncProvider = self._provider_model.objects.filter( + provider: OutgoingSyncProvider | None = self._provider_model.objects.filter( Q(backchannel_application__isnull=False) | Q(application__isnull=False), pk=provider_pk, ).first() if not provider: task.warning("No provider found. Is it assigned to an application?") return - task.set_uid(f"{slugify(provider.name)}:{_object_type._meta.model_name}:{page}") # Override dry run mode if requested, however don't save the provider # so that scheduled sync tasks still run in dry_run mode if override_dry_run: @@ -192,12 +190,14 @@ def sync_signal_direct_dispatch( pk: str | int, raw_op: str, ): + model_class: type[Model] = path_to_class(model) for provider in self._provider_model.objects.filter( Q(backchannel_application__isnull=False) | Q(application__isnull=False) ): task_sync_signal_direct.send_with_options( args=(model, pk, provider.pk, raw_op), rel_obj=provider, + uid=f"{provider.name}:{model_class._meta.model_name}:{pk}:direct", ) def sync_signal_direct( @@ -222,7 +222,6 @@ def sync_signal_direct( if not provider: task.warning("No provider found. Is it assigned to an application?") return - task.set_uid(slugify(provider.name)) operation = Direction(raw_op) client = provider.client_for_model(instance.__class__) # Check if the object is allowed within the provider's restrictions @@ -266,12 +265,14 @@ def sync_signal_m2m_dispatch( task_sync_signal_m2m.send_with_options( args=(instance_pk, provider.pk, action, list(pk_set)), rel_obj=provider, + uid=f"{provider.name}:group:{instance_pk}:m2m", ) else: for pk in pk_set: task_sync_signal_m2m.send_with_options( args=(pk, provider.pk, action, [instance_pk]), rel_obj=provider, + uid=f"{provider.name}:group:{pk}:m2m", ) def sync_signal_m2m( @@ -295,7 +296,6 @@ def sync_signal_m2m( if not provider: task.warning("No provider found. Is it assigned to an application?") return - task.set_uid(slugify(provider.name)) # Check if the object is allowed within the provider's restrictions queryset: QuerySet = provider.get_object_qs(Group) diff --git a/authentik/outposts/signals.py b/authentik/outposts/signals.py index b08c0ecef2b5..aa8e9e472d7c 100644 --- a/authentik/outposts/signals.py +++ b/authentik/outposts/signals.py @@ -39,6 +39,7 @@ def outpost_pre_save(sender, instance: Outpost, **_): args=(instance.pk.hex,), kwargs={"action": "down", "from_cache": True}, rel_obj=instance, + uid=instance.name, ) @@ -51,15 +52,25 @@ def outpost_m2m_changed(sender, instance: Outpost | Provider, action: str, **_): outpost_controller.send_with_options( args=(instance.pk,), rel_obj=instance.service_connection, + uid=instance.name, + ) + outpost_send_update.send_with_options( + args=(instance.pk,), + rel_obj=instance, + uid=instance.name, ) - outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance) elif isinstance(instance, OutpostModel): for outpost in instance.outpost_set.all(): outpost_controller.send_with_options( args=(instance.pk,), rel_obj=instance.service_connection, + uid=instance.name, + ) + outpost_send_update.send_with_options( + args=(outpost.pk,), + rel_obj=outpost, + uid=outpost.name, ) - outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) @receiver(post_save, sender=Outpost) @@ -67,13 +78,25 @@ def outpost_post_save(sender, instance: Outpost, created: bool, **_): if created: LOGGER.info("New outpost saved, ensuring initial token and user are created") _ = instance.token - outpost_controller.send_with_options(args=(instance.pk,), rel_obj=instance.service_connection) - outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance) + outpost_controller.send_with_options( + args=(instance.pk,), + rel_obj=instance.service_connection, + uid=instance.name, + ) + outpost_send_update.send_with_options( + args=(instance.pk,), + rel_obj=instance, + uid=instance.name, + ) def outpost_related_post_save(sender, instance: OutpostServiceConnection | OutpostModel, **_): for outpost in instance.outpost_set.all(): - outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) + outpost_send_update.send_with_options( + args=(outpost.pk,), + rel_obj=outpost, + uid=outpost.name, + ) post_save.connect(outpost_related_post_save, sender=OutpostServiceConnection, weak=False) @@ -102,7 +125,11 @@ def outpost_reverse_related_post_save(sender, instance: CertificateKeyPair | Bra for reverse in getattr(instance, field_name).all(): if isinstance(reverse, OutpostModel): for outpost in reverse.outpost_set.all(): - outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) + outpost_send_update.send_with_options( + args=(outpost.pk,), + rel_obj=outpost, + uid=outpost.name, + ) post_save.connect(outpost_reverse_related_post_save, sender=Brand, weak=False) @@ -114,7 +141,11 @@ def outpost_pre_delete_cleanup(sender, instance: Outpost, **_): """Ensure that Outpost's user is deleted (which will delete the token through cascade)""" instance.user.delete() cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) - outpost_controller.send(instance.pk.hex, action="down", from_cache=True) + outpost_controller.send_with_options( + args=(instance.pk.hex,), + kwargs={"action": "down", "from_cache": True}, + uid=instance.name, + ) @receiver(pre_delete, sender=AuthenticatedSession) diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index 3fd5618eb2e2..65dac91a774c 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -10,7 +10,6 @@ from asgiref.sync import async_to_sync from channels.layers import get_channel_layer from django.core.cache import cache -from django.utils.text import slugify from django.utils.translation import gettext_lazy as _ from docker.constants import DEFAULT_UNIX_SOCKET from dramatiq.actor import actor @@ -108,7 +107,6 @@ def outpost_service_connection_monitor(connection_pk: Any): def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False): """Create/update/monitor/delete the deployment of an Outpost""" self = CurrentTask.get_task() - self.set_uid(outpost_pk) logs = [] if from_cache: outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) @@ -119,7 +117,6 @@ def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = F if not outpost: LOGGER.warning("No outpost") return - self.set_uid(slugify(outpost.name)) try: controller_type = controller_for_outpost(outpost) if not controller_type: diff --git a/authentik/providers/scim/tests/test_user.py b/authentik/providers/scim/tests/test_user.py index 6a2be7ddf396..f61139e343fc 100644 --- a/authentik/providers/scim/tests/test_user.py +++ b/authentik/providers/scim/tests/test_user.py @@ -3,7 +3,6 @@ from json import loads from django.test import TestCase -from django.utils.text import slugify from jsonschema import validate from requests_mock import Mocker @@ -436,7 +435,7 @@ def test_sync_task_dry_run(self): task = list( Task.objects.filter( actor_name=scim_sync_objects.actor_name, - _uid__startswith=slugify(self.provider.name), + _uid__startswith=self.provider.name, ).order_by("-mtime") )[1] self.assertIsNotNone(task) diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 3192dddca556..8980c7a89890 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -416,7 +416,7 @@ ("dramatiq.results.middleware.Results", {"store_results": True}), ("authentik.tasks.middleware.CurrentTask", {}), ("authentik.tasks.middleware.TenantMiddleware", {}), - ("authentik.tasks.middleware.RelObjMiddleware", {}), + ("authentik.tasks.middleware.ModelDataMiddleware", {}), ("authentik.tasks.middleware.MessagesMiddleware", {}), ("authentik.tasks.middleware.LoggingMiddleware", {}), ("authentik.tasks.middleware.DescriptionMiddleware", {}), diff --git a/authentik/sources/ldap/tasks.py b/authentik/sources/ldap/tasks.py index b9e0ea585df7..797f090eb05d 100644 --- a/authentik/sources/ldap/tasks.py +++ b/authentik/sources/ldap/tasks.py @@ -57,7 +57,6 @@ def ldap_sync(source_pk: str): source: LDAPSource = LDAPSource.objects.filter(pk=source_pk, enabled=True).first() if not source: return - task.set_uid(f"{source.slug}") with source.sync_lock as lock_acquired: if not lock_acquired: task.info("Synchronization is already running. Skipping") @@ -111,11 +110,13 @@ def ldap_sync_paginator( sync_inst: BaseLDAPSynchronizer = sync(source, task) messages = [] for page in sync_inst.get_objects(): - page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) + page_uid = str(uuid4()) + page_cache_key = CACHE_KEY_PREFIX + page_uid cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) page_sync = ldap_sync_page.message_with_options( args=(source.pk, class_to_path(sync), page_cache_key), rel_obj=task.rel_obj, + uid=f"{source.slug}:{sync_inst.name()}:{page_uid}", ) messages.append(page_sync) return messages @@ -134,8 +135,6 @@ def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str): # to set the state with return sync: type[BaseLDAPSynchronizer] = path_to_class(sync_class) - uid = page_cache_key.replace(CACHE_KEY_PREFIX, "") - self.set_uid(f"{source.slug}:{sync.name()}:{uid}") try: sync_inst: BaseLDAPSynchronizer = sync(source, self) page = cache.get(page_cache_key) diff --git a/authentik/tasks/middleware.py b/authentik/tasks/middleware.py index 41ef60355b6f..4a71924ab153 100644 --- a/authentik/tasks/middleware.py +++ b/authentik/tasks/middleware.py @@ -52,14 +52,16 @@ def after_process_message(self, *args, **kwargs): after_skip_message = after_process_message -class RelObjMiddleware(Middleware): +class ModelDataMiddleware(Middleware): @property def actor_options(self): - return {"rel_obj"} + return {"rel_obj", "uid"} def before_enqueue(self, broker: Broker, message: Message, delay: int): if "rel_obj" in message.options: message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj") + if "uid" in message.options: + message.options["model_defaults"]["_uid"] = message.options.pop("uid") class MessagesMiddleware(Middleware): diff --git a/authentik/tasks/schedules/common.py b/authentik/tasks/schedules/common.py index 9d93c9787c3a..067d3f2df88a 100644 --- a/authentik/tasks/schedules/common.py +++ b/authentik/tasks/schedules/common.py @@ -34,7 +34,10 @@ def get_kwargs(self) -> bytes: return pickle.dumps(self.kwargs) def get_options(self) -> bytes: - return pickle.dumps(self.options) + options = self.options + if self.uid is not None: + options["uid"] = self.uid + return pickle.dumps(options) def update_or_create(self) -> "Schedule": from authentik.tasks.schedules.models import Schedule