5
5
from fastapi .responses import RedirectResponse
6
6
from fastapi .staticfiles import StaticFiles
7
7
from fastapi .templating import Jinja2Templates
8
- from fastapi .exceptions import RequestValidationError , StarletteHTTPException
9
- from routers import auth , organization , role , user
10
- from utils .auth import get_authenticated_user , get_optional_user , NeedsNewTokens
11
- from utils .db import User
8
+ from fastapi .exceptions import RequestValidationError , StarletteHTTPException , HTTPException
9
+ from sqlmodel import Session
10
+ from routers import authentication , organization , role , user
11
+ from utils .auth import get_authenticated_user , get_optional_user , NeedsNewTokens , get_user_from_reset_token
12
+ from utils .db import User , get_session
12
13
13
14
14
15
logger = logging .getLogger ("uvicorn.error" )
@@ -51,14 +52,15 @@ async def needs_new_tokens_handler(request: Request, exc: NeedsNewTokens):
51
52
)
52
53
return response
53
54
54
-
55
- @app .exception_handler (StarletteHTTPException )
56
- async def http_exception_handler (request : Request , exc : StarletteHTTPException ):
57
- return templates .TemplateResponse (
58
- "errors/error.html" ,
59
- {"request" : request , "status_code" : exc .status_code , "detail" : exc .detail },
60
- status_code = exc .status_code ,
61
- )
55
+ # TODO: Make sure this only catches server errors and not 307 redirects
56
+ # Create a custom server error class that inherits from StarletteHTTPException?
57
+ # @app.exception_handler(StarletteHTTPException)
58
+ # async def http_exception_handler(request: Request, exc: StarletteHTTPException):
59
+ # return templates.TemplateResponse(
60
+ # "errors/error.html",
61
+ # {"request": request, "status_code": exc.status_code, "detail": exc.detail},
62
+ # status_code=exc.status_code,
63
+ # )
62
64
63
65
64
66
@app .exception_handler (RequestValidationError )
@@ -72,151 +74,124 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
72
74
73
75
# -- Unauthenticated Routes --
74
76
75
-
76
- @app .get ("/" )
77
- async def read_home (
77
+ # Define a dependency for common parameters
78
+ async def common_unauthenticated_parameters (
78
79
request : Request ,
79
80
user : Optional [User ] = Depends (get_optional_user ),
80
81
error_message : Optional [str ] = None ,
82
+ ) -> dict :
83
+ return {"request" : request , "user" : user , "error_message" : error_message }
84
+
85
+
86
+ @app .get ("/" )
87
+ async def read_home (
88
+ params : dict = Depends (common_unauthenticated_parameters )
81
89
):
82
- if user :
90
+ if params [ " user" ] :
83
91
return RedirectResponse (url = "/dashboard" , status_code = 302 )
84
- return templates .TemplateResponse (
85
- "index.html" , {"request" : request , "user" : user ,
86
- "error_message" : error_message }
87
- )
92
+ return templates .TemplateResponse ("index.html" , params )
88
93
89
94
90
95
@app .get ("/login" )
91
96
async def read_login (
92
- request : Request ,
93
- user : Optional [User ] = Depends (get_optional_user ),
94
- error_message : Optional [str ] = None ,
97
+ params : dict = Depends (common_unauthenticated_parameters )
95
98
):
96
- if user :
99
+ if params [ " user" ] :
97
100
return RedirectResponse (url = "/dashboard" , status_code = 302 )
98
- return templates .TemplateResponse (
99
- "authentication/login.html" ,
100
- {"request" : request , "user" : user , "error_message" : error_message },
101
- )
101
+ return templates .TemplateResponse ("authentication/login.html" , params )
102
102
103
103
104
104
@app .get ("/register" )
105
105
async def read_register (
106
- request : Request ,
107
- user : Optional [User ] = Depends (get_optional_user ),
108
- error_message : Optional [str ] = None ,
106
+ params : dict = Depends (common_unauthenticated_parameters )
109
107
):
110
- if user :
108
+ if params [ " user" ] :
111
109
return RedirectResponse (url = "/dashboard" , status_code = 302 )
112
- return templates .TemplateResponse (
113
- "authentication/register.html" ,
114
- {"request" : request , "user" : user , "error_message" : error_message },
115
- )
110
+ return templates .TemplateResponse ("authentication/register.html" , params )
116
111
117
112
118
113
@app .get ("/forgot_password" )
119
114
async def read_forgot_password (
120
- request : Request ,
121
- user : Optional [User ] = Depends (get_optional_user ),
122
- error_message : Optional [str ] = None ,
115
+ params : dict = Depends (common_unauthenticated_parameters ),
116
+ show_form : Optional [bool ] = True ,
123
117
):
124
- if user :
118
+ if params [ " user" ] :
125
119
return RedirectResponse (url = "/dashboard" , status_code = 302 )
126
- return templates .TemplateResponse (
127
- "authentication/forgot_password.html" ,
128
- {"request" : request , "user" : user , "error_message" : error_message },
129
- )
120
+ params ["show_form" ] = show_form
130
121
131
-
132
- @app .get ("/reset_password" )
133
- async def read_reset_password (
134
- request : Request ,
135
- token : str ,
136
- user : Optional [User ] = Depends (get_optional_user ),
137
- error_message : Optional [str ] = None ,
138
- ):
139
- if user :
140
- return RedirectResponse (url = "/dashboard" , status_code = 302 )
141
- # TODO: Validate the token here?
142
- return templates .TemplateResponse (
143
- "authentication/reset_password.html" ,
144
- {
145
- "request" : request ,
146
- "token" : token ,
147
- "user" : user ,
148
- "error_message" : error_message ,
149
- },
150
- )
122
+ return templates .TemplateResponse ("authentication/forgot_password.html" , params )
151
123
152
124
153
125
@app .get ("/about" )
154
- async def read_about (
155
- request : Request ,
156
- user : Optional [User ] = Depends (get_optional_user ),
157
- error_message : Optional [str ] = None ,
158
- ):
159
- return templates .TemplateResponse (
160
- "about.html" ,
161
- {"request" : request , "user" : user , "error_message" : error_message }
162
- )
126
+ async def read_about (params : dict = Depends (common_unauthenticated_parameters )):
127
+ return templates .TemplateResponse ("about.html" , params )
163
128
164
129
165
130
@app .get ("/privacy_policy" )
166
- async def read_privacy_policy (
167
- request : Request ,
168
- user : Optional [User ] = Depends (get_optional_user ),
169
- error_message : Optional [str ] = None ,
170
- ):
171
- return templates .TemplateResponse (
172
- "privacy_policy.html" ,
173
- {"request" : request , "user" : user , "error_message" : error_message },
174
- )
131
+ async def read_privacy_policy (params : dict = Depends (common_unauthenticated_parameters )):
132
+ return templates .TemplateResponse ("privacy_policy.html" , params )
175
133
176
134
177
135
@app .get ("/terms_of_service" )
178
- async def read_terms_of_service (
179
- request : Request ,
180
- user : Optional [User ] = Depends (get_optional_user ),
181
- error_message : Optional [str ] = None ,
136
+ async def read_terms_of_service (params : dict = Depends (common_unauthenticated_parameters )):
137
+ return templates .TemplateResponse ("terms_of_service.html" , params )
138
+
139
+
140
+ @app .get ("/reset_password" )
141
+ async def read_reset_password (
142
+ email : str ,
143
+ token : str ,
144
+ params : dict = Depends (common_unauthenticated_parameters ),
145
+ session : Session = Depends (get_session )
182
146
):
183
- return templates .TemplateResponse (
184
- "terms_of_service.html" ,
185
- {"request" : request , "user" : user , "error_message" : error_message },
186
- )
147
+ authorized_user , _ = get_user_from_reset_token (email , token , session )
148
+
149
+ # Raise informative error to let user know the token is invalid and may have expired
150
+ if not authorized_user :
151
+ raise HTTPException (status_code = 400 , detail = "Invalid or expired token" )
152
+
153
+ params ["email" ] = email
154
+ params ["token" ] = token
155
+
156
+ return templates .TemplateResponse ("authentication/reset_password.html" , params )
187
157
188
158
189
159
# -- Authenticated Routes --
190
160
191
161
192
- @ app . get ( "/dashboard" )
193
- async def read_dashboard (
162
+ # Define a dependency for common parameters
163
+ async def common_authenticated_parameters (
194
164
request : Request ,
195
165
user : User = Depends (get_authenticated_user ),
196
166
error_message : Optional [str ] = None ,
167
+ ) -> dict :
168
+ return {"request" : request , "user" : user , "error_message" : error_message }
169
+
170
+
171
+ # Redirect to home if user is not authenticated
172
+ @app .get ("/dashboard" )
173
+ async def read_dashboard (
174
+ params : dict = Depends (common_authenticated_parameters )
197
175
):
198
- return templates .TemplateResponse (
199
- "dashboard/index.html" ,
200
- {"request" : request , "user" : user , "error_message" : error_message },
201
- )
176
+ if not params ["user" ]:
177
+ return RedirectResponse (url = "/login" , status_code = status .HTTP_302_FOUND )
178
+ return templates .TemplateResponse ("dashboard/index.html" , params )
202
179
203
180
204
- @app .get ("/user_profile" )
205
- async def read_user_profile (
206
- request : Request ,
207
- user : User = Depends (get_authenticated_user ),
208
- error_message : Optional [str ] = None ,
181
+ @app .get ("/profile" )
182
+ async def read_profile (
183
+ params : dict = Depends (common_authenticated_parameters )
209
184
):
210
- return templates . TemplateResponse (
211
- "users/profile.html" ,
212
- { "request" : request , "user" : user , "error_message" : error_message },
213
- )
185
+ if not params [ "user" ]:
186
+ # Changed to 302
187
+ return RedirectResponse ( url = "/login" , status_code = status . HTTP_302_FOUND )
188
+ return templates . TemplateResponse ( "users/profile.html" , params )
214
189
215
190
216
191
# -- Include Routers --
217
192
218
193
219
- app .include_router (auth .router )
194
+ app .include_router (authentication .router )
220
195
app .include_router (organization .router )
221
196
app .include_router (role .router )
222
197
app .include_router (user .router )
0 commit comments