Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/common/db/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from argon2 import PasswordHasher

from warehouse.accounts.models import (
AccountAssociation,
Email,
ProhibitedEmailDomain,
ProhibitedUserName,
Expand Down Expand Up @@ -140,3 +141,17 @@ class Meta:

user = factory.SubFactory(UserFactory)
ip_address = REMOTE_ADDR


class AccountAssociationFactory(WarehouseFactory):
class Meta:
model = AccountAssociation

user = factory.SubFactory(UserFactory)
service = "github"
external_user_id = factory.Sequence(lambda n: f"{n}")
external_username = factory.Faker("user_name")
access_token = factory.Faker("sha256")
refresh_token = None
token_expires_at = None
metadata_ = {}
15 changes: 15 additions & 0 deletions tests/unit/accounts/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from warehouse.utils.security_policy import principals_for

from ...common.db.accounts import (
AccountAssociationFactory as DBAccountAssociationFactory,
EmailFactory as DBEmailFactory,
UserEventFactory as DBUserEventFactory,
UserFactory as DBUserFactory,
Expand Down Expand Up @@ -317,6 +318,20 @@ def test_user_projects_is_ordered_by_name(self, db_session):

assert user.projects == [project2, project3, project1]

def test_account_associations_is_ordered_by_created_desc(self, db_session):
user = DBUserFactory.create()
assoc1 = DBAccountAssociationFactory.create(
user=user, created=datetime.datetime(2020, 1, 1)
)
assoc2 = DBAccountAssociationFactory.create(
user=user, created=datetime.datetime(2021, 1, 1)
)
assoc3 = DBAccountAssociationFactory.create(
user=user, created=datetime.datetime(2022, 1, 1)
)

assert user.account_associations == [assoc3, assoc2, assoc1]


class TestUserUniqueLogin:
def test_repr(self, db_session):
Expand Down
71 changes: 71 additions & 0 deletions warehouse/accounts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,13 @@ class User(SitemapMixin, HasObservers, HasObservations, HasEvents, db.Model):
)
)

account_associations: Mapped[list[AccountAssociation]] = orm.relationship(
back_populates="user",
cascade="all, delete-orphan",
lazy=True,
order_by="AccountAssociation.created.desc()",
)

@property
def primary_email(self):
primaries = [x for x in self.emails if x.primary]
Expand Down Expand Up @@ -526,3 +533,67 @@ def __repr__(self):
f"ip_address={self.ip_address!r}, "
f"status={self.status!r})>"
)


class AccountAssociation(db.Model):
"""
External account associations (e.g., Oauth Providers) linked to PyPI user accounts.

Allows users to connect multiple external accounts from
the same third-party service to their PyPI account.
"""

__tablename__ = "account_associations"
__table_args__ = (
# Prevent the same external account from being linked to multiple PyPI accounts
UniqueConstraint(
"service", "external_user_id", name="account_associations_service_external"
),
Index("account_associations_user_service", "user_id", "service"),
)

__repr__ = make_repr("service", "external_username")

# Timestamps
created: Mapped[datetime_now]
updated: Mapped[datetime.datetime | None] = mapped_column(onupdate=sql.func.now())

# User relationship
_user_id: Mapped[UUID] = mapped_column(
"user_id",
PG_UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
user: Mapped[User] = orm.relationship(User, back_populates="account_associations")

# Service information
service: Mapped[str] = mapped_column(
String(50), nullable=False, comment="External service name (e.g., 'github')"
)
external_user_id: Mapped[str] = mapped_column(
String(255), nullable=False, comment="User ID from external service"
)
external_username: Mapped[str] = mapped_column(
String(255), nullable=False, comment="Username from external service"
)

# OAuth tokens (encrypted at application layer before storage)
access_token: Mapped[str | None] = mapped_column(
comment="Encrypted OAuth access token"
)
refresh_token: Mapped[str | None] = mapped_column(
comment="Encrypted OAuth refresh token"
)
token_expires_at: Mapped[datetime.datetime | None] = mapped_column(
comment="When the access token expires"
)

# Additional service-specific metadata
metadata_: Mapped[dict | None] = mapped_column(
"metadata",
JSONB,
server_default=sql.text("'{}'"),
comment="Service-specific metadata (profile info, scopes, etc.)",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0
"""
Add account_associations table

Revision ID: 500a38d28fab
Revises: 4c20f2342bba
Create Date: 2025-11-12 17:25:42.687250
"""

import sqlalchemy as sa

from alembic import op
from sqlalchemy.dialects import postgresql

from warehouse.utils.db.types import TZDateTime

revision = "500a38d28fab"
down_revision = "4c20f2342bba"


def upgrade():
op.create_table(
"account_associations",
sa.Column(
"created", sa.DateTime(), server_default=sa.text("now()"), nullable=False
),
sa.Column("updated", TZDateTime(), nullable=True),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column(
"service",
sa.String(length=50),
nullable=False,
comment="External service name (e.g., 'github')",
),
sa.Column(
"external_user_id",
sa.String(length=255),
nullable=False,
comment="User ID from external service",
),
sa.Column(
"external_username",
sa.String(length=255),
nullable=False,
comment="Username from external service",
),
sa.Column(
"access_token",
sa.String(),
nullable=True,
comment="Encrypted OAuth access token",
),
sa.Column(
"refresh_token",
sa.String(),
nullable=True,
comment="Encrypted OAuth refresh token",
),
sa.Column(
"token_expires_at",
TZDateTime(),
nullable=True,
comment="When the access token expires",
),
sa.Column(
"metadata",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'"),
nullable=True,
comment="Service-specific metadata (profile info, scopes, etc.)",
),
sa.Column(
"id", sa.UUID(), server_default=sa.text("gen_random_uuid()"), nullable=False
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"service", "external_user_id", name="account_associations_service_external"
),
)
op.create_index(
"account_associations_user_service",
"account_associations",
["user_id", "service"],
unique=False,
)
op.create_index(
op.f("ix_account_associations_user_id"),
"account_associations",
["user_id"],
unique=False,
)


def downgrade():
op.drop_index(
op.f("ix_account_associations_user_id"), table_name="account_associations"
)
op.drop_index(
"account_associations_user_service", table_name="account_associations"
)
op.drop_table("account_associations")