1+ use std:: io:: Cursor ;
12use std:: io:: { self } ;
23use std:: path:: Path ;
34use std:: path:: PathBuf ;
@@ -15,6 +16,8 @@ use crate::pkce::generate_pkce;
1516use base64:: Engine ;
1617use chrono:: Utc ;
1718use rand:: RngCore ;
19+ use tiny_http:: Header ;
20+ use tiny_http:: Request ;
1821use tiny_http:: Response ;
1922use tiny_http:: Server ;
2023
@@ -149,116 +152,23 @@ pub fn run_login_server(
149152 }
150153 } ;
151154
152- let url_raw = req. url ( ) . to_string ( ) ;
153- let parsed_url = match url:: Url :: parse ( & format ! ( "http://localhost{url_raw}" ) ) {
154- Ok ( u) => u,
155- Err ( e) => {
156- eprintln ! ( "URL parse error: {e}" ) ;
157- let _ = req. respond ( Response :: from_string ( "Bad Request" ) . with_status_code ( 400 ) ) ;
158- continue ;
155+ let response = process_request ( & req, & opts, & redirect_uri, & pkce, actual_port, & state) ;
156+ let is_login_complete = matches ! ( response, HandledRequest :: ResponseAndExit ( _) ) ;
157+ match response {
158+ HandledRequest :: Response ( r) | HandledRequest :: ResponseAndExit ( r) => {
159+ let _ = req. respond ( r) ;
159160 }
160- } ;
161- let path = parsed_url. path ( ) . to_string ( ) ;
162-
163- match path. as_str ( ) {
164- "/auth/callback" => {
165- let params: std:: collections:: HashMap < String , String > =
166- parsed_url. query_pairs ( ) . into_owned ( ) . collect ( ) ;
167- if params. get ( "state" ) . map ( String :: as_str) != Some ( state. as_str ( ) ) {
168- let _ = req
169- . respond ( Response :: from_string ( "State mismatch" ) . with_status_code ( 400 ) ) ;
170- continue ;
171- }
172- let code = match params. get ( "code" ) {
173- Some ( c) if !c. is_empty ( ) => c. clone ( ) ,
174- _ => {
175- let _ = req. respond (
176- Response :: from_string ( "Missing authorization code" )
177- . with_status_code ( 400 ) ,
178- ) ;
179- continue ;
180- }
181- } ;
182-
183- match exchange_code_for_tokens (
184- & opts. issuer ,
185- & opts. client_id ,
186- & redirect_uri,
187- & pkce,
188- & code,
189- ) {
190- Ok ( tokens) => {
191- // Obtain API key via token-exchange and persist
192- let api_key =
193- obtain_api_key ( & opts. issuer , & opts. client_id , & tokens. id_token )
194- . ok ( ) ;
195- if let Err ( err) = persist_tokens (
196- & opts. codex_home ,
197- api_key. clone ( ) ,
198- tokens. id_token . clone ( ) ,
199- Some ( tokens. access_token . clone ( ) ) ,
200- Some ( tokens. refresh_token . clone ( ) ) ,
201- ) {
202- eprintln ! ( "Persist error: {err}" ) ;
203- let _ = req. respond (
204- Response :: from_string ( format ! (
205- "Unable to persist auth file: {err}"
206- ) )
207- . with_status_code ( 500 ) ,
208- ) ;
209- continue ;
210- }
211-
212- let success_url = compose_success_url (
213- actual_port,
214- & opts. issuer ,
215- & tokens. id_token ,
216- & tokens. access_token ,
217- ) ;
218- match tiny_http:: Header :: from_bytes (
219- & b"Location" [ ..] ,
220- success_url. as_bytes ( ) ,
221- ) {
222- Ok ( h) => {
223- let response = tiny_http:: Response :: empty ( 302 ) . with_header ( h) ;
224- let _ = req. respond ( response) ;
225- }
226- Err ( _) => {
227- let _ = req. respond (
228- Response :: from_string ( "Internal Server Error" )
229- . with_status_code ( 500 ) ,
230- ) ;
231- }
232- }
233- }
234- Err ( err) => {
235- eprintln ! ( "Token exchange error: {err}" ) ;
236- let _ = req. respond (
237- Response :: from_string ( format ! ( "Token exchange failed: {err}" ) )
238- . with_status_code ( 500 ) ,
239- ) ;
240- }
241- }
161+ HandledRequest :: RedirectWithHeader ( header) => {
162+ let redirect = Response :: empty ( 302 ) . with_header ( header) ;
163+ let _ = req. respond ( redirect) ;
242164 }
243- "/success" => {
244- let body = include_str ! ( "assets/success.html" ) ;
245- let mut resp = Response :: from_data ( body. as_bytes ( ) ) ;
246- if let Ok ( h) = tiny_http:: Header :: from_bytes (
247- & b"Content-Type" [ ..] ,
248- & b"text/html; charset=utf-8" [ ..] ,
249- ) {
250- resp. add_header ( h) ;
251- }
252- let _ = req. respond ( resp) ;
253- shutdown_flag. store ( true , Ordering :: SeqCst ) ;
165+ }
254166
255- // Login has succeeded, so disarm the timeout watcher.
256- let _ = done_tx. send ( ( ) ) ;
257- return Ok ( ( ) ) ;
258- }
259- _ => {
260- let _ = req. respond ( Response :: from_string ( "Not Found" ) . with_status_code ( 404 ) ) ;
261- }
167+ if is_login_complete {
168+ shutdown_flag. store ( true , Ordering :: SeqCst ) ;
169+ // Login has succeeded, so disarm the timeout watcher.
170+ let _ = done_tx. send ( ( ) ) ;
171+ return Ok ( ( ) ) ;
262172 }
263173 }
264174
@@ -281,6 +191,107 @@ pub fn run_login_server(
281191 } )
282192}
283193
194+ enum HandledRequest {
195+ Response ( Response < Cursor < Vec < u8 > > > ) ,
196+ RedirectWithHeader ( Header ) ,
197+ ResponseAndExit ( Response < Cursor < Vec < u8 > > > ) ,
198+ }
199+
200+ fn process_request (
201+ req : & Request ,
202+ opts : & ServerOptions ,
203+ redirect_uri : & str ,
204+ pkce : & PkceCodes ,
205+ actual_port : u16 ,
206+ state : & str ,
207+ ) -> HandledRequest {
208+ let url_raw = req. url ( ) . to_string ( ) ;
209+ let parsed_url = match url:: Url :: parse ( & format ! ( "http://localhost{url_raw}" ) ) {
210+ Ok ( u) => u,
211+ Err ( e) => {
212+ eprintln ! ( "URL parse error: {e}" ) ;
213+ return HandledRequest :: Response (
214+ Response :: from_string ( "Bad Request" ) . with_status_code ( 400 ) ,
215+ ) ;
216+ }
217+ } ;
218+ let path = parsed_url. path ( ) . to_string ( ) ;
219+
220+ match path. as_str ( ) {
221+ "/auth/callback" => {
222+ let params: std:: collections:: HashMap < String , String > =
223+ parsed_url. query_pairs ( ) . into_owned ( ) . collect ( ) ;
224+ if params. get ( "state" ) . map ( String :: as_str) != Some ( state) {
225+ return HandledRequest :: Response (
226+ Response :: from_string ( "State mismatch" ) . with_status_code ( 400 ) ,
227+ ) ;
228+ }
229+ let code = match params. get ( "code" ) {
230+ Some ( c) if !c. is_empty ( ) => c. clone ( ) ,
231+ _ => {
232+ return HandledRequest :: Response (
233+ Response :: from_string ( "Missing authorization code" ) . with_status_code ( 400 ) ,
234+ ) ;
235+ }
236+ } ;
237+
238+ match exchange_code_for_tokens ( & opts. issuer , & opts. client_id , redirect_uri, pkce, & code)
239+ {
240+ Ok ( tokens) => {
241+ // Obtain API key via token-exchange and persist
242+ let api_key =
243+ obtain_api_key ( & opts. issuer , & opts. client_id , & tokens. id_token ) . ok ( ) ;
244+ if let Err ( err) = persist_tokens (
245+ & opts. codex_home ,
246+ api_key. clone ( ) ,
247+ tokens. id_token . clone ( ) ,
248+ Some ( tokens. access_token . clone ( ) ) ,
249+ Some ( tokens. refresh_token . clone ( ) ) ,
250+ ) {
251+ eprintln ! ( "Persist error: {err}" ) ;
252+ return HandledRequest :: Response (
253+ Response :: from_string ( format ! ( "Unable to persist auth file: {err}" ) )
254+ . with_status_code ( 500 ) ,
255+ ) ;
256+ }
257+
258+ let success_url = compose_success_url (
259+ actual_port,
260+ & opts. issuer ,
261+ & tokens. id_token ,
262+ & tokens. access_token ,
263+ ) ;
264+ match tiny_http:: Header :: from_bytes ( & b"Location" [ ..] , success_url. as_bytes ( ) ) {
265+ Ok ( header) => HandledRequest :: RedirectWithHeader ( header) ,
266+ Err ( _) => HandledRequest :: Response (
267+ Response :: from_string ( "Internal Server Error" ) . with_status_code ( 500 ) ,
268+ ) ,
269+ }
270+ }
271+ Err ( err) => {
272+ eprintln ! ( "Token exchange error: {err}" ) ;
273+ HandledRequest :: Response (
274+ Response :: from_string ( format ! ( "Token exchange failed: {err}" ) )
275+ . with_status_code ( 500 ) ,
276+ )
277+ }
278+ }
279+ }
280+ "/success" => {
281+ let body = include_str ! ( "assets/success.html" ) ;
282+ let mut resp = Response :: from_data ( body. as_bytes ( ) ) ;
283+ if let Ok ( h) = tiny_http:: Header :: from_bytes (
284+ & b"Content-Type" [ ..] ,
285+ & b"text/html; charset=utf-8" [ ..] ,
286+ ) {
287+ resp. add_header ( h) ;
288+ }
289+ HandledRequest :: ResponseAndExit ( resp)
290+ }
291+ _ => HandledRequest :: Response ( Response :: from_string ( "Not Found" ) . with_status_code ( 404 ) ) ,
292+ }
293+ }
294+
284295/// Spawns a detached thread that waits for either a completion signal on `done_rx`
285296/// or the specified `timeout` to elapse. If the timeout elapses first it marks
286297/// the `shutdown_flag`, records `timeout_flag`, and unblocks the HTTP server so
0 commit comments