1- from uuid import UUID
2-
31import structlog
4- from sqlalchemy import delete , func , select
2+ from sqlalchemy import func , select
53
4+ from polar .exceptions import PolarError
65from polar .kit .services import ResourceServiceReader
76from polar .logging import Logger
87from polar .models import OAuthAccount , User
1211log : Logger = structlog .get_logger ()
1312
1413
14+ class OAuthError (PolarError ): ...
15+
16+
17+ class OAuthAccountNotFound (OAuthError ):
18+ def __init__ (self , platform : OAuthPlatform ) -> None :
19+ self .platform = platform
20+ message = f"No { platform } OAuth account found for this user."
21+ super ().__init__ (message , 404 )
22+
23+
24+ class CannotDisconnectLastAuthMethod (OAuthError ):
25+ def __init__ (self ) -> None :
26+ message = (
27+ "Cannot disconnect this OAuth account as it's your only authentication method. "
28+ "Please verify your email or connect another OAuth provider before disconnecting."
29+ )
30+ super ().__init__ (message , 400 )
31+
32+
1533class OAuthAccountService (ResourceServiceReader [OAuthAccount ]):
1634 async def get_by_platform_and_account_id (
1735 self , session : AsyncSession , platform : OAuthPlatform , account_id : str
@@ -23,44 +41,40 @@ async def get_by_platform_and_account_id(
2341 result = await session .execute (stmt )
2442 return result .scalars ().one_or_none ()
2543
26- async def get_by_platform_and_user_id (
27- self , session : AsyncSession , platform : OAuthPlatform , user_id : UUID
28- ) -> OAuthAccount | None :
29- stmt = select (OAuthAccount ).where (
44+ async def disconnect_platform (
45+ self , session : AsyncSession , user : User , platform : OAuthPlatform
46+ ) -> None :
47+ oauth_accounts_statement = select (OAuthAccount ).where (
3048 OAuthAccount .platform == platform ,
31- OAuthAccount .user_id == user_id ,
49+ OAuthAccount .user_id == user . id ,
3250 )
33- result = await session .execute (stmt )
34- return result .scalars ().one_or_none ()
51+ oauth_account_result = await session .execute (oauth_accounts_statement )
52+ # Some users have a buggy state with multiple OAuth accounts for the same platform
53+ oauth_accounts = oauth_account_result .scalars ().all ()
3554
36- async def can_disconnect_oauth_account (
37- self , session : AsyncSession , user : User , oauth_account_id : UUID
38- ) -> bool :
39- stmt = select (func .count (OAuthAccount .id )).where (
55+ if len ( oauth_accounts ) == 0 :
56+ raise OAuthAccountNotFound ( platform )
57+
58+ other_accounts_count_statement = select (func .count (OAuthAccount .id )).where (
4059 OAuthAccount .user_id == user .id ,
41- OAuthAccount .id != oauth_account_id ,
60+ OAuthAccount .id .not_in ([oa .id for oa in oauth_accounts ]),
61+ )
62+ other_accounts_count_result = await session .execute (
63+ other_accounts_count_statement
4264 )
43- active_oauth_count = await session . scalar ( stmt )
65+ other_accounts_count = other_accounts_count_result . scalar_one ( )
4466
45- if active_oauth_count == 0 and not user .email_verified :
46- return False
67+ if other_accounts_count == 0 and not user .email_verified :
68+ raise CannotDisconnectLastAuthMethod ()
4769
48- return True
70+ for oauth_account in oauth_accounts :
71+ await session .delete (oauth_account )
72+ log .info (
73+ "oauth_account.disconnect" ,
74+ oauth_account_id = oauth_account .id ,
75+ platform = platform ,
76+ )
4977
50- async def disconnect_oauth_account (
51- self ,
52- session : AsyncSession ,
53- user : User ,
54- oauth_account_id : UUID ,
55- platform : OAuthPlatform ,
56- ) -> None :
57- log .info (
58- "oauth_account.disconnect" ,
59- oauth_account_id = oauth_account_id ,
60- platform = platform ,
61- )
62- stmt = delete (OAuthAccount ).where (OAuthAccount .id == oauth_account_id )
63- await session .execute (stmt )
6478 await session .flush ()
6579
6680
0 commit comments