Skip to content

Commit e440f3b

Browse files
committed
server/user: fix OAuth account disconnection when users have multiple ones for the same platform
1 parent 45fc31c commit e440f3b

File tree

2 files changed

+48
-68
lines changed

2 files changed

+48
-68
lines changed

server/polar/user/endpoints.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from polar.customer_portal.endpoints.license_keys import router as license_keys_router
77
from polar.customer_portal.endpoints.order import router as order_router
88
from polar.customer_portal.endpoints.subscription import router as subscription_router
9-
from polar.exceptions import PolarError
109
from polar.models import User
1110
from polar.models.user import OAuthPlatform
1211
from polar.openapi import APITag
@@ -48,22 +47,6 @@ async def create_identity_verification(
4847
)
4948

5049

51-
class OAuthAccountNotFound(PolarError):
52-
def __init__(self, platform: OAuthPlatform) -> None:
53-
self.platform = platform
54-
message = f"No {platform} OAuth account found for this user."
55-
super().__init__(message, 404)
56-
57-
58-
class CannotDisconnectLastAuthMethod(PolarError):
59-
def __init__(self) -> None:
60-
message = (
61-
"Cannot disconnect this OAuth account as it's your only authentication method. "
62-
"Please verify your email or connect another OAuth provider before disconnecting."
63-
)
64-
super().__init__(message, 400)
65-
66-
6750
@router.delete(
6851
"/me/oauth-accounts/{platform}",
6952
status_code=204,
@@ -86,21 +69,4 @@ async def disconnect_oauth_account(
8669
Note: You cannot disconnect your last authentication method if your email is not verified.
8770
"""
8871
user = auth_subject.subject
89-
90-
oauth_account = await oauth_account_service.get_by_platform_and_user_id(
91-
session, platform, user.id
92-
)
93-
94-
if oauth_account is None:
95-
raise OAuthAccountNotFound(platform)
96-
97-
can_disconnect = await oauth_account_service.can_disconnect_oauth_account(
98-
session, user, oauth_account.id
99-
)
100-
101-
if not can_disconnect:
102-
raise CannotDisconnectLastAuthMethod()
103-
104-
await oauth_account_service.disconnect_oauth_account(
105-
session, user, oauth_account.id, platform
106-
)
72+
await oauth_account_service.disconnect_platform(session, user, platform)

server/polar/user/oauth_service.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from uuid import UUID
2-
31
import structlog
4-
from sqlalchemy import delete, func, select
2+
from sqlalchemy import func, select
53

4+
from polar.exceptions import PolarError
65
from polar.kit.services import ResourceServiceReader
76
from polar.logging import Logger
87
from polar.models import OAuthAccount, User
@@ -12,6 +11,25 @@
1211
log: Logger = structlog.get_logger()
1312

1413

14+
class OAuthError(PolarError): ...
15+
16+
17+
class OAuthAccountNotFound(OAuthError):
18+
def __init__(self, platform: OAuthPlatform) -> None:
19+
self.platform = platform
20+
message = f"No {platform} OAuth account found for this user."
21+
super().__init__(message, 404)
22+
23+
24+
class CannotDisconnectLastAuthMethod(OAuthError):
25+
def __init__(self) -> None:
26+
message = (
27+
"Cannot disconnect this OAuth account as it's your only authentication method. "
28+
"Please verify your email or connect another OAuth provider before disconnecting."
29+
)
30+
super().__init__(message, 400)
31+
32+
1533
class OAuthAccountService(ResourceServiceReader[OAuthAccount]):
1634
async def get_by_platform_and_account_id(
1735
self, session: AsyncSession, platform: OAuthPlatform, account_id: str
@@ -23,44 +41,40 @@ async def get_by_platform_and_account_id(
2341
result = await session.execute(stmt)
2442
return result.scalars().one_or_none()
2543

26-
async def get_by_platform_and_user_id(
27-
self, session: AsyncSession, platform: OAuthPlatform, user_id: UUID
28-
) -> OAuthAccount | None:
29-
stmt = select(OAuthAccount).where(
44+
async def disconnect_platform(
45+
self, session: AsyncSession, user: User, platform: OAuthPlatform
46+
) -> None:
47+
oauth_accounts_statement = select(OAuthAccount).where(
3048
OAuthAccount.platform == platform,
31-
OAuthAccount.user_id == user_id,
49+
OAuthAccount.user_id == user.id,
3250
)
33-
result = await session.execute(stmt)
34-
return result.scalars().one_or_none()
51+
oauth_account_result = await session.execute(oauth_accounts_statement)
52+
# Some users have a buggy state with multiple OAuth accounts for the same platform
53+
oauth_accounts = oauth_account_result.scalars().all()
3554

36-
async def can_disconnect_oauth_account(
37-
self, session: AsyncSession, user: User, oauth_account_id: UUID
38-
) -> bool:
39-
stmt = select(func.count(OAuthAccount.id)).where(
55+
if len(oauth_accounts) == 0:
56+
raise OAuthAccountNotFound(platform)
57+
58+
other_accounts_count_statement = select(func.count(OAuthAccount.id)).where(
4059
OAuthAccount.user_id == user.id,
41-
OAuthAccount.id != oauth_account_id,
60+
OAuthAccount.id.not_in([oa.id for oa in oauth_accounts]),
61+
)
62+
other_accounts_count_result = await session.execute(
63+
other_accounts_count_statement
4264
)
43-
active_oauth_count = await session.scalar(stmt)
65+
other_accounts_count = other_accounts_count_result.scalar_one()
4466

45-
if active_oauth_count == 0 and not user.email_verified:
46-
return False
67+
if other_accounts_count == 0 and not user.email_verified:
68+
raise CannotDisconnectLastAuthMethod()
4769

48-
return True
70+
for oauth_account in oauth_accounts:
71+
await session.delete(oauth_account)
72+
log.info(
73+
"oauth_account.disconnect",
74+
oauth_account_id=oauth_account.id,
75+
platform=platform,
76+
)
4977

50-
async def disconnect_oauth_account(
51-
self,
52-
session: AsyncSession,
53-
user: User,
54-
oauth_account_id: UUID,
55-
platform: OAuthPlatform,
56-
) -> None:
57-
log.info(
58-
"oauth_account.disconnect",
59-
oauth_account_id=oauth_account_id,
60-
platform=platform,
61-
)
62-
stmt = delete(OAuthAccount).where(OAuthAccount.id == oauth_account_id)
63-
await session.execute(stmt)
6478
await session.flush()
6579

6680

0 commit comments

Comments
 (0)