2
2
from logging import getLogger
3
3
from typing import Optional , Tuple
4
4
from urllib .parse import urlparse
5
- from fastapi import APIRouter , Depends , BackgroundTasks , Form , Request
5
+ from fastapi import APIRouter , Depends , BackgroundTasks , Form , Request , Query
6
6
from fastapi .responses import RedirectResponse
7
7
from fastapi .templating import Jinja2Templates
8
+ from fastapi .exceptions import HTTPException
8
9
from starlette .datastructures import URLPath
9
10
from pydantic import EmailStr
10
11
from sqlmodel import Session , select
11
- from utils .models import User , DataIntegrityError , Account
12
+ from utils .models import User , DataIntegrityError , Account , Invitation
12
13
from utils .db import get_session
13
14
from utils .auth import (
14
15
HTML_PASSWORD_PATTERN ,
32
33
from exceptions .http_exceptions import (
33
34
EmailAlreadyRegisteredError ,
34
35
CredentialsError ,
35
- PasswordValidationError
36
+ PasswordValidationError ,
37
+ InvalidInvitationTokenError ,
38
+ InvitationEmailMismatchError ,
39
+ InvitationProcessingError
36
40
)
37
41
from routers .dashboard import router as dashboard_router
38
42
from routers .user import router as user_router
43
+ from routers .organization import router as org_router
44
+ from utils .invitations import process_invitation
39
45
logger = getLogger ("uvicorn.error" )
40
46
41
47
router = APIRouter (prefix = "/account" , tags = ["account" ])
@@ -97,7 +103,8 @@ def logout():
97
103
async def read_login (
98
104
request : Request ,
99
105
user : Optional [User ] = Depends (get_optional_user ),
100
- email_updated : Optional [str ] = "false"
106
+ email_updated : Optional [str ] = Query ("false" ),
107
+ invitation_token : Optional [str ] = Query (None )
101
108
):
102
109
"""
103
110
Render login page or redirect to dashboard if already logged in.
@@ -106,14 +113,21 @@ async def read_login(
106
113
return RedirectResponse (url = dashboard_router .url_path_for ("read_dashboard" ), status_code = 302 )
107
114
return templates .TemplateResponse (
108
115
"account/login.html" ,
109
- {"request" : request , "user" : user , "email_updated" : email_updated }
116
+ {
117
+ "request" : request ,
118
+ "user" : user ,
119
+ "email_updated" : email_updated ,
120
+ "invitation_token" : invitation_token
121
+ }
110
122
)
111
123
112
124
113
125
@router .get ("/register" )
114
126
async def read_register (
115
127
request : Request ,
116
- user : Optional [User ] = Depends (get_optional_user )
128
+ user : Optional [User ] = Depends (get_optional_user ),
129
+ email : Optional [EmailStr ] = Query (None ),
130
+ invitation_token : Optional [str ] = Query (None )
117
131
):
118
132
"""
119
133
Render registration page or redirect to dashboard if already logged in.
@@ -123,7 +137,13 @@ async def read_register(
123
137
124
138
return templates .TemplateResponse (
125
139
"account/register.html" ,
126
- {"request" : request , "user" : user , "password_pattern" : HTML_PASSWORD_PATTERN }
140
+ {
141
+ "request" : request ,
142
+ "user" : user ,
143
+ "password_pattern" : HTML_PASSWORD_PATTERN ,
144
+ "email" : email ,
145
+ "invitation_token" : invitation_token
146
+ }
127
147
)
128
148
129
149
@@ -204,38 +224,98 @@ async def register(
204
224
email : EmailStr = Form (...),
205
225
session : Session = Depends (get_session ),
206
226
_ : None = Depends (validate_password_strength_and_match ),
207
- password : str = Form (...)
227
+ password : str = Form (...),
228
+ invitation_token : Optional [str ] = Form (None )
208
229
) -> RedirectResponse :
209
230
"""
210
- Register a new user account.
231
+ Register a new user account, optionally processing an invitation .
211
232
"""
212
233
# Check if the email is already registered
213
- account : Optional [Account ] = session .exec (select (Account ).where (
234
+ existing_account : Optional [Account ] = session .exec (select (Account ).where (
214
235
Account .email == email )).one_or_none ()
215
236
216
- if account :
237
+ if existing_account :
217
238
raise EmailAlreadyRegisteredError ()
218
239
219
240
# Hash the password
220
241
hashed_password = get_password_hash (password )
221
242
222
- # Create the account
243
+ # Create the account and user instances (don't commit yet)
223
244
account = Account (email = email , hashed_password = hashed_password )
224
245
session .add (account )
225
- session .flush () # Flush to get the account ID
226
-
227
- # Create the user
228
- account .user = User (name = name , account_id = account .id )
229
- session .add (account )
230
- session .commit ()
246
+ session .flush () # Flush here to get account.id before creating User
247
+
248
+ # Ensure account has an ID after flush
249
+ if not account .id :
250
+ logger .error (f"Account ID not generated after flush for email { email } . Aborting registration." )
251
+ session .rollback () # Rollback the account add
252
+ raise DataIntegrityError (resource = "Account ID generation" )
253
+
254
+ new_user = User (name = name , account_id = account .id ) # Use account.id
255
+ session .add (new_user )
256
+
257
+ # Default redirect target
258
+ redirect_url = dashboard_router .url_path_for ("read_dashboard" )
259
+
260
+ # Process invitation if token is provided (BEFORE final commit)
261
+ if invitation_token :
262
+ logger .info (f"Registration attempt with invitation token: { invitation_token } for email { email } " )
263
+ # Fetch the invitation
264
+ statement = select (Invitation ).where (Invitation .token == invitation_token )
265
+ invitation = session .exec (statement ).first ()
266
+
267
+ if not invitation or not invitation .is_active ():
268
+ logger .warning (f"Invalid or inactive invitation token provided during registration: { invitation_token } " )
269
+ # Consider raising a more generic error to avoid exposing token validity
270
+ raise InvalidInvitationTokenError ()
271
+
272
+ # Verify email matches
273
+ if email != invitation .invitee_email :
274
+ logger .warning (
275
+ f"Invitation email mismatch for token { invitation_token } during registration. "
276
+ f"Account: { email } , Invitation: { invitation .invitee_email } "
277
+ )
278
+ # Consider raising a more generic error to avoid confirming email existence
279
+ raise InvitationEmailMismatchError ()
280
+
281
+ # Process the invitation (adds changes to the session)
282
+ try :
283
+ logger .info (f"Processing invitation { invitation .id } for new user { new_user .name } ({ email } ) during registration." )
284
+ process_invitation (invitation , new_user , session )
285
+ # Set redirect to the organization page
286
+ redirect_url = org_router .url_path_for ("read_organization" , org_id = invitation .organization_id )
287
+ logger .info (f"Redirecting new user { new_user .name } to organization { invitation .organization_id } after accepting invitation { invitation .id } ." )
288
+ except Exception as e :
289
+ logger .error (
290
+ f"Error processing invitation { invitation .id } for new user { new_user .name } ({ email } ) during registration: { e } " ,
291
+ exc_info = True
292
+ )
293
+ session .rollback ()
294
+ raise InvitationProcessingError ()
295
+
296
+ else :
297
+ logger .info (f"Standard registration for email { email } . Redirecting to dashboard." )
298
+
299
+ # Commit all changes (Account, User, potentially Invitation)
300
+ try :
301
+ session .commit ()
302
+ except Exception as e :
303
+ logger .error (f"Error committing transaction during registration for { email } : { e } " , exc_info = True )
304
+ session .rollback ()
305
+ # Use DataIntegrityError for commit failure
306
+ raise DataIntegrityError (resource = "Account/User registration" )
307
+
308
+ # Refresh the account to ensure all relationships (like user) are loaded after commit
231
309
session .refresh (account )
310
+ # We might need the user object refreshed too if process_invitation modified it directly
311
+ # session.refresh(new_user) # Let's assume process_invitation only modifies the invitation object for now
232
312
233
- # Create access token
234
- access_token = create_access_token (data = {"sub" : email })
235
- refresh_token = create_refresh_token (data = {"sub" : email })
313
+ # Create access token using the committed account's email
314
+ access_token = create_access_token (data = {"sub" : account . email , "fresh" : True })
315
+ refresh_token = create_refresh_token (data = {"sub" : account . email })
236
316
237
317
# Set cookie
238
- response = RedirectResponse (url = dashboard_router . url_path_for ( "read_dashboard" ), status_code = 303 )
318
+ response = RedirectResponse (url = str ( redirect_url ), status_code = 303 ) # Use determined redirect_url
239
319
response .set_cookie (
240
320
key = "access_token" ,
241
321
value = access_token ,
@@ -256,21 +336,72 @@ async def register(
256
336
257
337
@router .post ("/login" , response_class = RedirectResponse )
258
338
async def login (
259
- account_and_session : Tuple [Account , Session ] = Depends (get_account_from_credentials )
339
+ account_and_session : Tuple [Account , Session ] = Depends (get_account_from_credentials ),
340
+ invitation_token : Optional [str ] = Form (None )
260
341
) -> RedirectResponse :
261
342
"""
262
- Log in a user with valid credentials.
343
+ Log in a user with valid credentials and process invitation if token is provided .
263
344
"""
264
345
account , session = account_and_session
265
346
347
+ # Default redirect target
348
+ redirect_url = dashboard_router .url_path_for ("read_dashboard" )
349
+
350
+ if invitation_token :
351
+ logger .info (f"Login attempt with invitation token: { invitation_token } for account { account .email } " )
352
+ # Fetch the invitation
353
+ statement = select (Invitation ).where (Invitation .token == invitation_token )
354
+ invitation = session .exec (statement ).first ()
355
+
356
+ if not invitation or not invitation .is_active ():
357
+ logger .warning (f"Invalid or inactive invitation token provided during login: { invitation_token } " )
358
+ raise InvalidInvitationTokenError ()
359
+
360
+ # Verify email matches
361
+ if account .email != invitation .invitee_email :
362
+ logger .warning (
363
+ f"Invitation email mismatch for token { invitation_token } . "
364
+ f"Account: { account .email } , Invitation: { invitation .invitee_email } "
365
+ )
366
+ raise InvitationEmailMismatchError ()
367
+
368
+ # Ensure user relationship is loaded for process_invitation
369
+ if not account .user :
370
+ logger .debug (f"Refreshing user relationship for account { account .id } " )
371
+ session .refresh (account , attribute_names = ["user" ])
372
+ if not account .user :
373
+ # This should not happen if the account has a valid user relationship
374
+ logger .error (f"Failed to load user for account { account .id } during invitation processing." )
375
+ raise DataIntegrityError (resource = "User relation" )
376
+
377
+ # Process the invitation
378
+ try :
379
+ logger .info (f"Processing invitation { invitation .id } for user { account .user .id } during login." )
380
+ process_invitation (invitation , account .user , session )
381
+ session .commit ()
382
+ # Set redirect to the organization page
383
+ redirect_url = org_router .url_path_for ("read_organization" , org_id = invitation .organization_id )
384
+ logger .info (f"Redirecting user { account .user .id } to organization { invitation .organization_id } after accepting invitation { invitation .id } ." )
385
+ except Exception as e :
386
+ logger .error (
387
+ f"Error processing invitation { invitation .id } for user { account .user .id } during login: { e } " ,
388
+ exc_info = True
389
+ )
390
+ session .rollback ()
391
+ # Raise the specific invitation processing error
392
+ raise InvitationProcessingError ()
393
+
394
+ else :
395
+ logger .info (f"Standard login for account { account .email } . Redirecting to dashboard." )
396
+
266
397
# Create access token
267
398
access_token = create_access_token (
268
399
data = {"sub" : account .email , "fresh" : True }
269
400
)
270
401
refresh_token = create_refresh_token (data = {"sub" : account .email })
271
402
272
403
# Set cookie
273
- response = RedirectResponse (url = dashboard_router . url_path_for ( "read_dashboard" ), status_code = 303 )
404
+ response = RedirectResponse (url = str ( redirect_url ), status_code = 303 )
274
405
response .set_cookie (
275
406
key = "access_token" ,
276
407
value = access_token ,
0 commit comments