1
+ use std:: io:: Cursor ;
1
2
use std:: io:: { self } ;
2
3
use std:: path:: Path ;
3
4
use std:: path:: PathBuf ;
@@ -15,6 +16,8 @@ use crate::pkce::generate_pkce;
15
16
use base64:: Engine ;
16
17
use chrono:: Utc ;
17
18
use rand:: RngCore ;
19
+ use tiny_http:: Header ;
20
+ use tiny_http:: Request ;
18
21
use tiny_http:: Response ;
19
22
use tiny_http:: Server ;
20
23
@@ -149,116 +152,23 @@ pub fn run_login_server(
149
152
}
150
153
} ;
151
154
152
- let url_raw = req. url ( ) . to_string ( ) ;
153
- let parsed_url = match url:: Url :: parse ( & format ! ( "http://localhost{url_raw}" ) ) {
154
- Ok ( u) => u,
155
- Err ( e) => {
156
- eprintln ! ( "URL parse error: {e}" ) ;
157
- let _ = req. respond ( Response :: from_string ( "Bad Request" ) . with_status_code ( 400 ) ) ;
158
- continue ;
155
+ let response = process_request ( & req, & opts, & redirect_uri, & pkce, actual_port, & state) ;
156
+ let is_login_complete = matches ! ( response, HandledRequest :: ResponseAndExit ( _) ) ;
157
+ match response {
158
+ HandledRequest :: Response ( r) | HandledRequest :: ResponseAndExit ( r) => {
159
+ let _ = req. respond ( r) ;
159
160
}
160
- } ;
161
- let path = parsed_url. path ( ) . to_string ( ) ;
162
-
163
- match path. as_str ( ) {
164
- "/auth/callback" => {
165
- let params: std:: collections:: HashMap < String , String > =
166
- parsed_url. query_pairs ( ) . into_owned ( ) . collect ( ) ;
167
- if params. get ( "state" ) . map ( String :: as_str) != Some ( state. as_str ( ) ) {
168
- let _ = req
169
- . respond ( Response :: from_string ( "State mismatch" ) . with_status_code ( 400 ) ) ;
170
- continue ;
171
- }
172
- let code = match params. get ( "code" ) {
173
- Some ( c) if !c. is_empty ( ) => c. clone ( ) ,
174
- _ => {
175
- let _ = req. respond (
176
- Response :: from_string ( "Missing authorization code" )
177
- . with_status_code ( 400 ) ,
178
- ) ;
179
- continue ;
180
- }
181
- } ;
182
-
183
- match exchange_code_for_tokens (
184
- & opts. issuer ,
185
- & opts. client_id ,
186
- & redirect_uri,
187
- & pkce,
188
- & code,
189
- ) {
190
- Ok ( tokens) => {
191
- // Obtain API key via token-exchange and persist
192
- let api_key =
193
- obtain_api_key ( & opts. issuer , & opts. client_id , & tokens. id_token )
194
- . ok ( ) ;
195
- if let Err ( err) = persist_tokens (
196
- & opts. codex_home ,
197
- api_key. clone ( ) ,
198
- tokens. id_token . clone ( ) ,
199
- Some ( tokens. access_token . clone ( ) ) ,
200
- Some ( tokens. refresh_token . clone ( ) ) ,
201
- ) {
202
- eprintln ! ( "Persist error: {err}" ) ;
203
- let _ = req. respond (
204
- Response :: from_string ( format ! (
205
- "Unable to persist auth file: {err}"
206
- ) )
207
- . with_status_code ( 500 ) ,
208
- ) ;
209
- continue ;
210
- }
211
-
212
- let success_url = compose_success_url (
213
- actual_port,
214
- & opts. issuer ,
215
- & tokens. id_token ,
216
- & tokens. access_token ,
217
- ) ;
218
- match tiny_http:: Header :: from_bytes (
219
- & b"Location" [ ..] ,
220
- success_url. as_bytes ( ) ,
221
- ) {
222
- Ok ( h) => {
223
- let response = tiny_http:: Response :: empty ( 302 ) . with_header ( h) ;
224
- let _ = req. respond ( response) ;
225
- }
226
- Err ( _) => {
227
- let _ = req. respond (
228
- Response :: from_string ( "Internal Server Error" )
229
- . with_status_code ( 500 ) ,
230
- ) ;
231
- }
232
- }
233
- }
234
- Err ( err) => {
235
- eprintln ! ( "Token exchange error: {err}" ) ;
236
- let _ = req. respond (
237
- Response :: from_string ( format ! ( "Token exchange failed: {err}" ) )
238
- . with_status_code ( 500 ) ,
239
- ) ;
240
- }
241
- }
161
+ HandledRequest :: RedirectWithHeader ( header) => {
162
+ let redirect = Response :: empty ( 302 ) . with_header ( header) ;
163
+ let _ = req. respond ( redirect) ;
242
164
}
243
- "/success" => {
244
- let body = include_str ! ( "assets/success.html" ) ;
245
- let mut resp = Response :: from_data ( body. as_bytes ( ) ) ;
246
- if let Ok ( h) = tiny_http:: Header :: from_bytes (
247
- & b"Content-Type" [ ..] ,
248
- & b"text/html; charset=utf-8" [ ..] ,
249
- ) {
250
- resp. add_header ( h) ;
251
- }
252
- let _ = req. respond ( resp) ;
253
- shutdown_flag. store ( true , Ordering :: SeqCst ) ;
165
+ }
254
166
255
- // Login has succeeded, so disarm the timeout watcher.
256
- let _ = done_tx. send ( ( ) ) ;
257
- return Ok ( ( ) ) ;
258
- }
259
- _ => {
260
- let _ = req. respond ( Response :: from_string ( "Not Found" ) . with_status_code ( 404 ) ) ;
261
- }
167
+ if is_login_complete {
168
+ shutdown_flag. store ( true , Ordering :: SeqCst ) ;
169
+ // Login has succeeded, so disarm the timeout watcher.
170
+ let _ = done_tx. send ( ( ) ) ;
171
+ return Ok ( ( ) ) ;
262
172
}
263
173
}
264
174
@@ -281,6 +191,107 @@ pub fn run_login_server(
281
191
} )
282
192
}
283
193
194
+ enum HandledRequest {
195
+ Response ( Response < Cursor < Vec < u8 > > > ) ,
196
+ RedirectWithHeader ( Header ) ,
197
+ ResponseAndExit ( Response < Cursor < Vec < u8 > > > ) ,
198
+ }
199
+
200
+ fn process_request (
201
+ req : & Request ,
202
+ opts : & ServerOptions ,
203
+ redirect_uri : & str ,
204
+ pkce : & PkceCodes ,
205
+ actual_port : u16 ,
206
+ state : & str ,
207
+ ) -> HandledRequest {
208
+ let url_raw = req. url ( ) . to_string ( ) ;
209
+ let parsed_url = match url:: Url :: parse ( & format ! ( "http://localhost{url_raw}" ) ) {
210
+ Ok ( u) => u,
211
+ Err ( e) => {
212
+ eprintln ! ( "URL parse error: {e}" ) ;
213
+ return HandledRequest :: Response (
214
+ Response :: from_string ( "Bad Request" ) . with_status_code ( 400 ) ,
215
+ ) ;
216
+ }
217
+ } ;
218
+ let path = parsed_url. path ( ) . to_string ( ) ;
219
+
220
+ match path. as_str ( ) {
221
+ "/auth/callback" => {
222
+ let params: std:: collections:: HashMap < String , String > =
223
+ parsed_url. query_pairs ( ) . into_owned ( ) . collect ( ) ;
224
+ if params. get ( "state" ) . map ( String :: as_str) != Some ( state) {
225
+ return HandledRequest :: Response (
226
+ Response :: from_string ( "State mismatch" ) . with_status_code ( 400 ) ,
227
+ ) ;
228
+ }
229
+ let code = match params. get ( "code" ) {
230
+ Some ( c) if !c. is_empty ( ) => c. clone ( ) ,
231
+ _ => {
232
+ return HandledRequest :: Response (
233
+ Response :: from_string ( "Missing authorization code" ) . with_status_code ( 400 ) ,
234
+ ) ;
235
+ }
236
+ } ;
237
+
238
+ match exchange_code_for_tokens ( & opts. issuer , & opts. client_id , redirect_uri, pkce, & code)
239
+ {
240
+ Ok ( tokens) => {
241
+ // Obtain API key via token-exchange and persist
242
+ let api_key =
243
+ obtain_api_key ( & opts. issuer , & opts. client_id , & tokens. id_token ) . ok ( ) ;
244
+ if let Err ( err) = persist_tokens (
245
+ & opts. codex_home ,
246
+ api_key. clone ( ) ,
247
+ tokens. id_token . clone ( ) ,
248
+ Some ( tokens. access_token . clone ( ) ) ,
249
+ Some ( tokens. refresh_token . clone ( ) ) ,
250
+ ) {
251
+ eprintln ! ( "Persist error: {err}" ) ;
252
+ return HandledRequest :: Response (
253
+ Response :: from_string ( format ! ( "Unable to persist auth file: {err}" ) )
254
+ . with_status_code ( 500 ) ,
255
+ ) ;
256
+ }
257
+
258
+ let success_url = compose_success_url (
259
+ actual_port,
260
+ & opts. issuer ,
261
+ & tokens. id_token ,
262
+ & tokens. access_token ,
263
+ ) ;
264
+ match tiny_http:: Header :: from_bytes ( & b"Location" [ ..] , success_url. as_bytes ( ) ) {
265
+ Ok ( header) => HandledRequest :: RedirectWithHeader ( header) ,
266
+ Err ( _) => HandledRequest :: Response (
267
+ Response :: from_string ( "Internal Server Error" ) . with_status_code ( 500 ) ,
268
+ ) ,
269
+ }
270
+ }
271
+ Err ( err) => {
272
+ eprintln ! ( "Token exchange error: {err}" ) ;
273
+ HandledRequest :: Response (
274
+ Response :: from_string ( format ! ( "Token exchange failed: {err}" ) )
275
+ . with_status_code ( 500 ) ,
276
+ )
277
+ }
278
+ }
279
+ }
280
+ "/success" => {
281
+ let body = include_str ! ( "assets/success.html" ) ;
282
+ let mut resp = Response :: from_data ( body. as_bytes ( ) ) ;
283
+ if let Ok ( h) = tiny_http:: Header :: from_bytes (
284
+ & b"Content-Type" [ ..] ,
285
+ & b"text/html; charset=utf-8" [ ..] ,
286
+ ) {
287
+ resp. add_header ( h) ;
288
+ }
289
+ HandledRequest :: ResponseAndExit ( resp)
290
+ }
291
+ _ => HandledRequest :: Response ( Response :: from_string ( "Not Found" ) . with_status_code ( 404 ) ) ,
292
+ }
293
+ }
294
+
284
295
/// Spawns a detached thread that waits for either a completion signal on `done_rx`
285
296
/// or the specified `timeout` to elapse. If the timeout elapses first it marks
286
297
/// the `shutdown_flag`, records `timeout_flag`, and unblocks the HTTP server so
0 commit comments