Skip to content

Commit 4bbe660

Browse files
committed
feat: add account associations data model
Resolves #19052 Signed-off-by: Mike Fiedler <[email protected]>
1 parent 292ca9e commit 4bbe660

File tree

4 files changed

+204
-1
lines changed

4 files changed

+204
-1
lines changed

tests/common/db/accounts.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from argon2 import PasswordHasher
99

1010
from warehouse.accounts.models import (
11+
AccountAssociation,
1112
Email,
1213
ProhibitedEmailDomain,
1314
ProhibitedUserName,
@@ -130,3 +131,17 @@ class Meta:
130131
# TODO: Replace when factory_boy supports `unique`.
131132
# See https://github.com/FactoryBoy/factory_boy/pull/997
132133
name = factory.Sequence(lambda _: fake.unique.user_name())
134+
135+
136+
class AccountAssociationFactory(WarehouseFactory):
137+
class Meta:
138+
model = AccountAssociation
139+
140+
user = factory.SubFactory(UserFactory)
141+
service = "github"
142+
external_user_id = factory.Sequence(lambda n: f"{n}")
143+
external_username = factory.Faker("user_name")
144+
access_token = factory.Faker("sha256")
145+
refresh_token = None
146+
token_expires_at = None
147+
metadata_ = {}

tests/unit/accounts/test_models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from warehouse.utils.security_policy import principals_for
1313

1414
from ...common.db.accounts import (
15+
AccountAssociationFactory as DBAccountAssociationFactory,
1516
EmailFactory as DBEmailFactory,
1617
UserEventFactory as DBUserEventFactory,
1718
UserFactory as DBUserFactory,
@@ -309,3 +310,17 @@ def test_user_projects_is_ordered_by_name(self, db_session):
309310
DBRoleFactory.create(project=project3, user=user)
310311

311312
assert user.projects == [project2, project3, project1]
313+
314+
def test_account_associations_is_ordered_by_created_desc(self, db_session):
315+
user = DBUserFactory.create()
316+
assoc1 = DBAccountAssociationFactory.create(
317+
user=user, created=datetime.datetime(2020, 1, 1)
318+
)
319+
assoc2 = DBAccountAssociationFactory.create(
320+
user=user, created=datetime.datetime(2021, 1, 1)
321+
)
322+
assoc3 = DBAccountAssociationFactory.create(
323+
user=user, created=datetime.datetime(2022, 1, 1)
324+
)
325+
326+
assert user.account_associations == [assoc3, assoc2, assoc1]

warehouse/accounts/models.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
select,
2121
sql,
2222
)
23-
from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, UUID as PG_UUID
23+
from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID as PG_UUID
2424
from sqlalchemy.exc import NoResultFound
2525
from sqlalchemy.ext.hybrid import hybrid_property
2626
from sqlalchemy.orm import Mapped, mapped_column
@@ -180,6 +180,13 @@ class User(SitemapMixin, HasObservers, HasObservations, HasEvents, db.Model):
180180
)
181181
)
182182

183+
account_associations: Mapped[list[AccountAssociation]] = orm.relationship(
184+
back_populates="user",
185+
cascade="all, delete-orphan",
186+
lazy=True,
187+
order_by="AccountAssociation.created.desc()",
188+
)
189+
183190
@property
184191
def primary_email(self):
185192
primaries = [x for x in self.emails if x.primary]
@@ -475,3 +482,67 @@ class ProhibitedUserName(db.Model):
475482
)
476483
prohibited_by: Mapped[User] = orm.relationship(User)
477484
comment: Mapped[str] = mapped_column(server_default="")
485+
486+
487+
class AccountAssociation(db.Model):
488+
"""
489+
External account associations (e.g., Oauth Providers) linked to PyPI user accounts.
490+
491+
Allows users to connect multiple external accounts from
492+
the same third-party service to their PyPI account.
493+
"""
494+
495+
__tablename__ = "account_associations"
496+
__table_args__ = (
497+
# Prevent the same external account from being linked to multiple PyPI accounts
498+
UniqueConstraint(
499+
"service", "external_user_id", name="account_associations_service_external"
500+
),
501+
Index("account_associations_user_service", "user_id", "service"),
502+
)
503+
504+
__repr__ = make_repr("service", "external_username")
505+
506+
# Timestamps
507+
created: Mapped[datetime_now]
508+
updated: Mapped[datetime.datetime | None] = mapped_column(onupdate=sql.func.now())
509+
510+
# User relationship
511+
_user_id: Mapped[UUID] = mapped_column(
512+
"user_id",
513+
PG_UUID(as_uuid=True),
514+
ForeignKey("users.id", ondelete="CASCADE"),
515+
nullable=False,
516+
index=True,
517+
)
518+
user: Mapped[User] = orm.relationship(User, back_populates="account_associations")
519+
520+
# Service information
521+
service: Mapped[str] = mapped_column(
522+
String(50), nullable=False, comment="External service name (e.g., 'github')"
523+
)
524+
external_user_id: Mapped[str] = mapped_column(
525+
String(255), nullable=False, comment="User ID from external service"
526+
)
527+
external_username: Mapped[str] = mapped_column(
528+
String(255), nullable=False, comment="Username from external service"
529+
)
530+
531+
# OAuth tokens (encrypted at application layer before storage)
532+
access_token: Mapped[str | None] = mapped_column(
533+
comment="Encrypted OAuth access token"
534+
)
535+
refresh_token: Mapped[str | None] = mapped_column(
536+
comment="Encrypted OAuth refresh token"
537+
)
538+
token_expires_at: Mapped[datetime.datetime | None] = mapped_column(
539+
comment="When the access token expires"
540+
)
541+
542+
# Additional service-specific metadata
543+
metadata_: Mapped[dict | None] = mapped_column(
544+
"metadata",
545+
JSONB,
546+
server_default=sql.text("'{}'"),
547+
comment="Service-specific metadata (profile info, scopes, etc.)",
548+
)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Add account_associations table
4+
5+
Revision ID: 500a38d28fab
6+
Revises: 6c0f7fea7b1b
7+
Create Date: 2025-11-12 17:25:42.687250
8+
"""
9+
10+
import sqlalchemy as sa
11+
12+
from alembic import op
13+
from sqlalchemy.dialects import postgresql
14+
15+
from warehouse.utils.db.types import TZDateTime
16+
17+
revision = "500a38d28fab"
18+
down_revision = "6c0f7fea7b1b"
19+
20+
21+
def upgrade():
22+
op.create_table(
23+
"account_associations",
24+
sa.Column(
25+
"created", sa.DateTime(), server_default=sa.text("now()"), nullable=False
26+
),
27+
sa.Column("updated", TZDateTime(), nullable=True),
28+
sa.Column("user_id", sa.UUID(), nullable=False),
29+
sa.Column(
30+
"service",
31+
sa.String(length=50),
32+
nullable=False,
33+
comment="External service name (e.g., 'github')",
34+
),
35+
sa.Column(
36+
"external_user_id",
37+
sa.String(length=255),
38+
nullable=False,
39+
comment="User ID from external service",
40+
),
41+
sa.Column(
42+
"external_username",
43+
sa.String(length=255),
44+
nullable=False,
45+
comment="Username from external service",
46+
),
47+
sa.Column(
48+
"access_token",
49+
sa.String(),
50+
nullable=True,
51+
comment="Encrypted OAuth access token",
52+
),
53+
sa.Column(
54+
"refresh_token",
55+
sa.String(),
56+
nullable=True,
57+
comment="Encrypted OAuth refresh token",
58+
),
59+
sa.Column(
60+
"token_expires_at",
61+
TZDateTime(),
62+
nullable=True,
63+
comment="When the access token expires",
64+
),
65+
sa.Column(
66+
"metadata",
67+
postgresql.JSONB(astext_type=sa.Text()),
68+
server_default=sa.text("'{}'"),
69+
nullable=True,
70+
comment="Service-specific metadata (profile info, scopes, etc.)",
71+
),
72+
sa.Column(
73+
"id", sa.UUID(), server_default=sa.text("gen_random_uuid()"), nullable=False
74+
),
75+
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
76+
sa.PrimaryKeyConstraint("id"),
77+
sa.UniqueConstraint(
78+
"service", "external_user_id", name="account_associations_service_external"
79+
),
80+
)
81+
op.create_index(
82+
"account_associations_user_service",
83+
"account_associations",
84+
["user_id", "service"],
85+
unique=False,
86+
)
87+
op.create_index(
88+
op.f("ix_account_associations_user_id"),
89+
"account_associations",
90+
["user_id"],
91+
unique=False,
92+
)
93+
94+
95+
def downgrade():
96+
op.drop_index(
97+
op.f("ix_account_associations_user_id"), table_name="account_associations"
98+
)
99+
op.drop_index(
100+
"account_associations_user_service", table_name="account_associations"
101+
)
102+
op.drop_table("account_associations")

0 commit comments

Comments
 (0)