Skip to content
79 changes: 76 additions & 3 deletions src/sentry/models/apitoken.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import contextlib
import hashlib
import logging
import secrets
from collections.abc import Collection, Mapping
from collections.abc import Collection, Generator, Mapping
from datetime import timedelta
from typing import Any, ClassVar

Expand All @@ -17,8 +19,10 @@
from sentry.constants import SentryAppStatus
from sentry.db.models import FlexibleForeignKey, control_silo_model, sane_repr
from sentry.db.models.fields.hybrid_cloud_foreign_key import HybridCloudForeignKey
from sentry.hybridcloud.models.outbox import ControlOutbox, outbox_context
from sentry.hybridcloud.outbox.base import ControlOutboxProducingManager, ReplicatedControlModel
from sentry.hybridcloud.outbox.category import OutboxCategory
from sentry.hybridcloud.outbox.category import OutboxCategory, OutboxScope
from sentry.hybridcloud.tasks.deliver_from_outbox import drain_outbox_shards_control
from sentry.locks import locks
from sentry.models.apiapplication import ApiApplicationStatus
from sentry.models.apigrant import ApiGrant, ExpiredGrantError, InvalidGrantError
Expand All @@ -29,6 +33,8 @@
DEFAULT_EXPIRATION = timedelta(days=30)
TOKEN_REDACTED = "***REDACTED***"

logger = logging.getLogger("sentry.apitoken")


def default_expiration():
return timezone.now() + DEFAULT_EXPIRATION
Expand Down Expand Up @@ -231,7 +237,16 @@ def save(self, *args: Any, **kwargs: Any) -> None:
token_last_characters = self.token[-4:]
self.token_last_characters = token_last_characters

return super().save(*args, **kwargs)
result = super().save(*args, **kwargs)

# Schedule async replication if using async mode
if not self._should_flush_outbox():
transaction.on_commit(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I'm not actually sure if this does anything since we've already completed all the prev save steps? But I saw this being used for drain_shard and we'd want to enqueue the replication task after the token has committed 🤷‍♀️

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you're scheduling a task, it is best to do that after the transaction has commit as you can ensure that all the records are saved. Without this it is possible for the task to be processed while the transaction has not complete if postgres is slow and kafka is fast.

lambda: self._schedule_async_replication(),
using=router.db_for_write(type(self)),
)

return result

def update(self, *args: Any, **kwargs: Any) -> int:
# if the token or refresh_token was updated, we need to
Expand All @@ -252,6 +267,64 @@ def update(self, *args: Any, **kwargs: Any) -> int:
def outbox_region_names(self) -> Collection[str]:
return list(find_all_region_names())

def _should_flush_outbox(self) -> bool:
from sentry import options

has_async_flush = self.user_id in options.get("users:api-token-async-flush")
logger.info(
"async_flush_check",
extra={
"has_async_flush": has_async_flush,
"user_id": self.user_id,
"token_id": self.id,
},
)
if has_async_flush:
return False

return True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed in person, but I don't think we need a user-level flag for this. I think a global option that sets the value of the class's default_flush property should be sufficient for us. Having an all-or-nothing toggle is probably okay for this change, and I don't expect the majority of users to even notice the impact of this change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting class properties based on options will be tricky as classes are imported during django's startup before the ORM is ready. The current approach of overriding _maybe_prepare_outboxes will work though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ended up doing some property decorator shenanigans since that should be loaded after the startup 🤔


@contextlib.contextmanager
def _maybe_prepare_outboxes(self, *, outbox_before_super: bool) -> Generator[None]:
# Overriding to get around how default_flush cannot be cleanly feature flagged
flush = self._should_flush_outbox()

with outbox_context(
transaction.atomic(router.db_for_write(type(self))),
flush=flush,
):
if not outbox_before_super:
yield
for outbox in self.outboxes_for_update():
outbox.save()
if outbox_before_super:
yield

def _schedule_async_replication(self) -> None:
# Query for the outboxes we just created for this specific token
outboxes = ControlOutbox.objects.filter(
shard_scope=OutboxScope.USER_SCOPE,
shard_identifier=self.user_id,
category=OutboxCategory.API_TOKEN_UPDATE,
object_identifier=self.id,
).order_by("id")

if not outboxes.exists():
return

# Get the ID range of our specific token's outboxes
first_row = outboxes.first()
last_row = outboxes.last()

if first_row is None or last_row is None:
return

drain_outbox_shards_control.delay(
outbox_identifier_low=first_row.id,
outbox_identifier_hi=last_row.id,
outbox_name="sentry.ControlOutbox",
)

def handle_async_replication(self, region_name: str, shard_identifier: int) -> None:
from sentry.auth.services.auth.serial import serialize_api_token
from sentry.hybridcloud.services.replica import region_replica_service
Expand Down
8 changes: 8 additions & 0 deletions src/sentry/options/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -3811,3 +3811,11 @@
default=[],
flags=FLAG_ALLOW_EMPTY | FLAG_AUTOMATOR_MODIFIABLE,
)

# Enabled organizations for API token async flush
register(
"users:api-token-async-flush",
default=[],
type=Sequence,
flags=FLAG_ALLOW_EMPTY | FLAG_AUTOMATOR_MODIFIABLE,
)
78 changes: 73 additions & 5 deletions tests/sentry/models/test_apitoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

from sentry.conf.server import SENTRY_SCOPE_HIERARCHY_MAPPING, SENTRY_SCOPES
from sentry.hybridcloud.models import ApiTokenReplica
from sentry.hybridcloud.models.outbox import ControlOutbox
from sentry.hybridcloud.outbox.category import OutboxCategory, OutboxScope
from sentry.models.apitoken import ApiToken, NotSupported, PlaintextSecretAlreadyRead
from sentry.sentry_apps.models.sentry_app_installation import SentryAppInstallation
from sentry.sentry_apps.models.sentry_app_installation_token import SentryAppInstallationToken
from sentry.silo.base import SiloMode
from sentry.testutils.cases import TestCase
from sentry.testutils.helpers.options import override_options
from sentry.testutils.outbox import outbox_runner
from sentry.testutils.silo import assume_test_silo_mode, control_silo_test
from sentry.types.token import AuthTokenType
Expand Down Expand Up @@ -51,14 +54,15 @@ def test_enforces_scope_hierarchy(self) -> None:
assert set(token.get_scopes()) == SENTRY_SCOPE_HIERARCHY_MAPPING[scope]

def test_organization_id_for_non_internal(self) -> None:
install = self.create_sentry_app_installation()
token = install.api_token
org_id = token.organization_id
with outbox_runner(), self.tasks():
install = self.create_sentry_app_installation()
token = install.api_token
org_id = token.organization_id

with assume_test_silo_mode(SiloMode.REGION):
assert ApiTokenReplica.objects.get(apitoken_id=token.id).organization_id == org_id

with outbox_runner():
with outbox_runner(), self.tasks():
install.delete()

with assume_test_silo_mode(SiloMode.REGION):
Expand Down Expand Up @@ -143,7 +147,8 @@ def test_default_string_serialization(self) -> None:

def test_replica_string_serialization(self) -> None:
user = self.create_user()
token = ApiToken.objects.create(user_id=user.id)
with outbox_runner(), self.tasks():
token = ApiToken.objects.create(user_id=user.id)
with assume_test_silo_mode(SiloMode.REGION):
replica = ApiTokenReplica.objects.get(apitoken_id=token.id)
assert (
Expand Down Expand Up @@ -186,6 +191,69 @@ def test_handle_async_deletion_called(self, mock_delete_replica: mock.MagicMock)
region_name=mock.ANY,
)

def test_outboxes_created_with_default_flush_false(self) -> None:
user = self.create_user()

with override_options({"users:api-token-async-flush": [user.id]}):
with self.tasks():
token = ApiToken.objects.create(user_id=user.id)

outboxes = ControlOutbox.objects.filter(
shard_scope=OutboxScope.USER_SCOPE,
shard_identifier=user.id,
category=OutboxCategory.API_TOKEN_UPDATE,
object_identifier=token.id,
)
assert outboxes.exists()
assert outboxes.count() > 0 # Should have one per region

# Verify replica does NOT exist yet (because outboxes haven't been processed)
with assume_test_silo_mode(SiloMode.REGION):
assert not ApiTokenReplica.objects.filter(apitoken_id=token.id).exists()

def test_async_replication_creates_replica_after_processing(self) -> None:
user = self.create_user()

with override_options({"users:api-token-async-flush": [user.id]}):
with outbox_runner(), self.tasks():
token = ApiToken.objects.create(user_id=user.id)

# Verify outboxes were processed (should be deleted after processing)
remaining_outboxes = ControlOutbox.objects.filter(
shard_scope=OutboxScope.USER_SCOPE,
shard_identifier=user.id,
category=OutboxCategory.API_TOKEN_UPDATE,
object_identifier=token.id,
)
assert not remaining_outboxes.exists()

with assume_test_silo_mode(SiloMode.REGION):
replica = ApiTokenReplica.objects.get(apitoken_id=token.id)
assert replica.hashed_token == token.hashed_token
assert replica.user_id == user.id

def test_async_replication_updates_existing_replica(self) -> None:
user = self.create_user()
initial_expires_at = timezone.now() + timedelta(days=1)
updated_expires_at = timezone.now() + timedelta(days=30)

with override_options({"users:api-token-async-flush": [user.id]}):
with outbox_runner(), self.tasks():
token = ApiToken.objects.create(user_id=user.id, expires_at=initial_expires_at)

with assume_test_silo_mode(SiloMode.REGION):
replica = ApiTokenReplica.objects.get(apitoken_id=token.id)
assert replica.expires_at is not None
assert abs((replica.expires_at - initial_expires_at).total_seconds()) < 1

with outbox_runner(), self.tasks():
token.update(expires_at=updated_expires_at)

with assume_test_silo_mode(SiloMode.REGION):
replica = ApiTokenReplica.objects.get(apitoken_id=token.id)
assert replica.expires_at is not None
assert abs((replica.expires_at - updated_expires_at).total_seconds()) < 1


@control_silo_test
class ApiTokenInternalIntegrationTest(TestCase):
Expand Down
Loading