Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions authentik/lib/sync/outgoing/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.db.models import Model
from dramatiq.actor import Actor
from dramatiq.results.errors import ResultFailure
from drf_spectacular.utils import extend_schema
Expand All @@ -11,7 +12,7 @@
from authentik.events.logs import LogEventSerializer
from authentik.lib.sync.api import SyncStatusSerializer
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path
from authentik.lib.utils.reflection import class_to_path, path_to_class
from authentik.rbac.filters import ObjectFilter
from authentik.tasks.models import Task, TaskStatus

Expand Down Expand Up @@ -101,16 +102,20 @@ def sync_object(self, request: Request, pk: int) -> Response:
provider: OutgoingSyncProvider = self.get_object()
params = SyncObjectSerializer(data=request.data)
params.is_valid(raise_exception=True)
object_type = params.validated_data["sync_object_model"]
_object_type: type[Model] = path_to_class(object_type)
pk = params.validated_data["sync_object_id"]
msg = self.sync_objects_task.send_with_options(
kwargs={
"object_type": params.validated_data["sync_object_model"],
"object_type": object_type,
"page": 1,
"provider_pk": provider.pk,
"override_dry_run": params.validated_data["override_dry_run"],
"pk": params.validated_data["sync_object_id"],
"pk": pk,
},
retries=0,
rel_obj=provider,
uid=f"{provider.name}:{_object_type._meta.model_name}:{pk}:manual",
)
try:
msg.get_result(block=True)
Expand Down
12 changes: 6 additions & 6 deletions authentik/lib/sync/outgoing/tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from django.core.paginator import Paginator
from django.db.models import Model, QuerySet
from django.db.models.query import Q
from django.utils.text import slugify
from dramatiq.actor import Actor
from dramatiq.composition import group
from dramatiq.errors import Retry
Expand Down Expand Up @@ -51,6 +50,7 @@ def sync_paginator(
time_limit=PAGE_TIMEOUT_MS,
# Assign tasks to the same schedule as the current one
rel_obj=current_task.rel_obj,
uid=f"{provider.name}:{object_type._meta.model_name}:{page}",
**options,
)
tasks.append(page_sync)
Expand All @@ -73,7 +73,6 @@ def sync(
if not provider:
task.warning("No provider found. Is it assigned to an application?")
return
task.set_uid(slugify(provider.name))
task.info("Starting full provider sync")
self.logger.debug("Starting provider sync")
with provider.sync_lock as lock_acquired:
Expand Down Expand Up @@ -125,14 +124,13 @@ def sync_objects(
provider_pk=provider_pk,
object_type=object_type,
)
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
provider: OutgoingSyncProvider | None = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
if not provider:
task.warning("No provider found. Is it assigned to an application?")
return
task.set_uid(f"{slugify(provider.name)}:{_object_type._meta.model_name}:{page}")
# Override dry run mode if requested, however don't save the provider
# so that scheduled sync tasks still run in dry_run mode
if override_dry_run:
Expand Down Expand Up @@ -192,12 +190,14 @@ def sync_signal_direct_dispatch(
pk: str | int,
raw_op: str,
):
model_class: type[Model] = path_to_class(model)
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
task_sync_signal_direct.send_with_options(
args=(model, pk, provider.pk, raw_op),
rel_obj=provider,
uid=f"{provider.name}:{model_class._meta.model_name}:{pk}:direct",
)

def sync_signal_direct(
Expand All @@ -222,7 +222,6 @@ def sync_signal_direct(
if not provider:
task.warning("No provider found. Is it assigned to an application?")
return
task.set_uid(slugify(provider.name))
operation = Direction(raw_op)
client = provider.client_for_model(instance.__class__)
# Check if the object is allowed within the provider's restrictions
Expand Down Expand Up @@ -266,12 +265,14 @@ def sync_signal_m2m_dispatch(
task_sync_signal_m2m.send_with_options(
args=(instance_pk, provider.pk, action, list(pk_set)),
rel_obj=provider,
uid=f"{provider.name}:group:{instance_pk}:m2m",
)
else:
for pk in pk_set:
task_sync_signal_m2m.send_with_options(
args=(pk, provider.pk, action, [instance_pk]),
rel_obj=provider,
uid=f"{provider.name}:group:{pk}:m2m",
)

def sync_signal_m2m(
Expand All @@ -295,7 +296,6 @@ def sync_signal_m2m(
if not provider:
task.warning("No provider found. Is it assigned to an application?")
return
task.set_uid(slugify(provider.name))

# Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_object_qs(Group)
Expand Down
45 changes: 38 additions & 7 deletions authentik/outposts/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def outpost_pre_save(sender, instance: Outpost, **_):
args=(instance.pk.hex,),
kwargs={"action": "down", "from_cache": True},
rel_obj=instance,
uid=instance.name,
)


Expand All @@ -51,29 +52,51 @@ def outpost_m2m_changed(sender, instance: Outpost | Provider, action: str, **_):
outpost_controller.send_with_options(
args=(instance.pk,),
rel_obj=instance.service_connection,
uid=instance.name,
)
outpost_send_update.send_with_options(
args=(instance.pk,),
rel_obj=instance,
uid=instance.name,
)
outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance)
elif isinstance(instance, OutpostModel):
for outpost in instance.outpost_set.all():
outpost_controller.send_with_options(
args=(instance.pk,),
rel_obj=instance.service_connection,
uid=instance.name,
)
outpost_send_update.send_with_options(
args=(outpost.pk,),
rel_obj=outpost,
uid=outpost.name,
)
outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)


@receiver(post_save, sender=Outpost)
def outpost_post_save(sender, instance: Outpost, created: bool, **_):
if created:
LOGGER.info("New outpost saved, ensuring initial token and user are created")
_ = instance.token
outpost_controller.send_with_options(args=(instance.pk,), rel_obj=instance.service_connection)
outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance)
outpost_controller.send_with_options(
args=(instance.pk,),
rel_obj=instance.service_connection,
uid=instance.name,
)
outpost_send_update.send_with_options(
args=(instance.pk,),
rel_obj=instance,
uid=instance.name,
)


def outpost_related_post_save(sender, instance: OutpostServiceConnection | OutpostModel, **_):
for outpost in instance.outpost_set.all():
outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)
outpost_send_update.send_with_options(
args=(outpost.pk,),
rel_obj=outpost,
uid=outpost.name,
)


post_save.connect(outpost_related_post_save, sender=OutpostServiceConnection, weak=False)
Expand Down Expand Up @@ -102,7 +125,11 @@ def outpost_reverse_related_post_save(sender, instance: CertificateKeyPair | Bra
for reverse in getattr(instance, field_name).all():
if isinstance(reverse, OutpostModel):
for outpost in reverse.outpost_set.all():
outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)
outpost_send_update.send_with_options(
args=(outpost.pk,),
rel_obj=outpost,
uid=outpost.name,
)


post_save.connect(outpost_reverse_related_post_save, sender=Brand, weak=False)
Expand All @@ -114,7 +141,11 @@ def outpost_pre_delete_cleanup(sender, instance: Outpost, **_):
"""Ensure that Outpost's user is deleted (which will delete the token through cascade)"""
instance.user.delete()
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance)
outpost_controller.send(instance.pk.hex, action="down", from_cache=True)
outpost_controller.send_with_options(
args=(instance.pk.hex,),
kwargs={"action": "down", "from_cache": True},
uid=instance.name,
)


@receiver(pre_delete, sender=AuthenticatedSession)
Expand Down
3 changes: 0 additions & 3 deletions authentik/outposts/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.core.cache import cache
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from docker.constants import DEFAULT_UNIX_SOCKET
from dramatiq.actor import actor
Expand Down Expand Up @@ -108,7 +107,6 @@ def outpost_service_connection_monitor(connection_pk: Any):
def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False):
"""Create/update/monitor/delete the deployment of an Outpost"""
self = CurrentTask.get_task()
self.set_uid(outpost_pk)
logs = []
if from_cache:
outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
Expand All @@ -119,7 +117,6 @@ def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = F
if not outpost:
LOGGER.warning("No outpost")
return
self.set_uid(slugify(outpost.name))
try:
controller_type = controller_for_outpost(outpost)
if not controller_type:
Expand Down
2 changes: 1 addition & 1 deletion authentik/root/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@
("dramatiq.results.middleware.Results", {"store_results": True}),
("authentik.tasks.middleware.CurrentTask", {}),
("authentik.tasks.middleware.TenantMiddleware", {}),
("authentik.tasks.middleware.RelObjMiddleware", {}),
("authentik.tasks.middleware.ModelDataMiddleware", {}),
("authentik.tasks.middleware.MessagesMiddleware", {}),
("authentik.tasks.middleware.LoggingMiddleware", {}),
("authentik.tasks.middleware.DescriptionMiddleware", {}),
Expand Down
7 changes: 3 additions & 4 deletions authentik/sources/ldap/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def ldap_sync(source_pk: str):
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk, enabled=True).first()
if not source:
return
task.set_uid(f"{source.slug}")
with source.sync_lock as lock_acquired:
if not lock_acquired:
task.info("Synchronization is already running. Skipping")
Expand Down Expand Up @@ -111,11 +110,13 @@ def ldap_sync_paginator(
sync_inst: BaseLDAPSynchronizer = sync(source, task)
messages = []
for page in sync_inst.get_objects():
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
page_uid = str(uuid4())
page_cache_key = CACHE_KEY_PREFIX + page_uid
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"))
page_sync = ldap_sync_page.message_with_options(
args=(source.pk, class_to_path(sync), page_cache_key),
rel_obj=task.rel_obj,
uid=f"{source.slug}:{sync_inst.name()}:{page_uid}",
)
messages.append(page_sync)
return messages
Expand All @@ -134,8 +135,6 @@ def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str):
# to set the state with
return
sync: type[BaseLDAPSynchronizer] = path_to_class(sync_class)
uid = page_cache_key.replace(CACHE_KEY_PREFIX, "")
self.set_uid(f"{source.slug}:{sync.name()}:{uid}")
try:
sync_inst: BaseLDAPSynchronizer = sync(source, self)
page = cache.get(page_cache_key)
Expand Down
6 changes: 4 additions & 2 deletions authentik/tasks/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,16 @@ def after_process_message(self, *args, **kwargs):
after_skip_message = after_process_message


class RelObjMiddleware(Middleware):
class ModelDataMiddleware(Middleware):
@property
def actor_options(self):
return {"rel_obj"}
return {"rel_obj", "uid"}

def before_enqueue(self, broker: Broker, message: Message, delay: int):
if "rel_obj" in message.options:
message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj")
if "uid" in message.options:
message.options["model_defaults"]["_uid"] = message.options.pop("uid")


class MessagesMiddleware(Middleware):
Expand Down
5 changes: 4 additions & 1 deletion authentik/tasks/schedules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def get_kwargs(self) -> bytes:
return pickle.dumps(self.kwargs)

def get_options(self) -> bytes:
return pickle.dumps(self.options)
options = self.options
if self.uid is not None:
options["uid"] = self.uid
return pickle.dumps(options)

def update_or_create(self) -> "Schedule":
from authentik.tasks.schedules.models import Schedule
Expand Down
Loading