@@ -4,7 +4,9 @@ use std::path::PathBuf;
4
4
use std:: sync:: Arc ;
5
5
use std:: sync:: atomic:: AtomicBool ;
6
6
use std:: sync:: atomic:: Ordering ;
7
+ use std:: sync:: mpsc;
7
8
use std:: thread;
9
+ use std:: time:: Duration ;
8
10
9
11
use crate :: AuthDotJson ;
10
12
use crate :: get_auth_file;
@@ -27,6 +29,7 @@ pub struct ServerOptions {
27
29
pub port : u16 ,
28
30
pub open_browser : bool ,
29
31
pub force_state : Option < String > ,
32
+ pub login_timeout : Option < Duration > ,
30
33
}
31
34
32
35
impl ServerOptions {
@@ -38,16 +41,17 @@ impl ServerOptions {
38
41
port : DEFAULT_PORT ,
39
42
open_browser : true ,
40
43
force_state : None ,
44
+ login_timeout : None ,
41
45
}
42
46
}
43
47
}
44
48
45
- #[ derive( Debug ) ]
46
49
pub struct LoginServer {
47
50
pub auth_url : String ,
48
51
pub actual_port : u16 ,
49
52
pub server_handle : thread:: JoinHandle < io:: Result < ( ) > > ,
50
53
pub shutdown_flag : Arc < AtomicBool > ,
54
+ pub server : Arc < Server > ,
51
55
}
52
56
53
57
impl LoginServer {
@@ -59,10 +63,34 @@ impl LoginServer {
59
63
}
60
64
61
65
pub fn cancel ( & self ) {
62
- self . shutdown_flag . store ( true , Ordering :: SeqCst ) ;
66
+ shutdown ( & self . shutdown_flag , & self . server ) ;
67
+ }
68
+
69
+ pub fn cancel_handle ( & self ) -> ShutdownHandle {
70
+ ShutdownHandle {
71
+ shutdown_flag : self . shutdown_flag . clone ( ) ,
72
+ server : self . server . clone ( ) ,
73
+ }
74
+ }
75
+ }
76
+
77
+ #[ derive( Clone ) ]
78
+ pub struct ShutdownHandle {
79
+ shutdown_flag : Arc < AtomicBool > ,
80
+ server : Arc < Server > ,
81
+ }
82
+
83
+ impl ShutdownHandle {
84
+ pub fn cancel ( & self ) {
85
+ shutdown ( & self . shutdown_flag , & self . server ) ;
63
86
}
64
87
}
65
88
89
+ pub fn shutdown ( shutdown_flag : & AtomicBool , server : & Server ) {
90
+ shutdown_flag. store ( true , Ordering :: SeqCst ) ;
91
+ server. unblock ( ) ;
92
+ }
93
+
66
94
pub fn run_login_server (
67
95
opts : ServerOptions ,
68
96
shutdown_flag : Option < Arc < AtomicBool > > ,
@@ -80,6 +108,7 @@ pub fn run_login_server(
80
108
) ) ;
81
109
}
82
110
} ;
111
+ let server = Arc :: new ( server) ;
83
112
84
113
let redirect_uri = format ! ( "http://localhost:{actual_port}/auth/callback" ) ;
85
114
let auth_url = build_authorize_url ( & opts. issuer , & opts. client_id , & redirect_uri, & pkce, & state) ;
@@ -89,11 +118,35 @@ pub fn run_login_server(
89
118
}
90
119
let shutdown_flag = shutdown_flag. unwrap_or_else ( || Arc :: new ( AtomicBool :: new ( false ) ) ) ;
91
120
let shutdown_flag_clone = shutdown_flag. clone ( ) ;
121
+ let timeout_flag = Arc :: new ( AtomicBool :: new ( false ) ) ;
122
+
123
+ // Channel used to signal completion to timeout watcher.
124
+ let ( done_tx, done_rx) = mpsc:: channel :: < ( ) > ( ) ;
125
+
126
+ if let Some ( timeout) = opts. login_timeout {
127
+ spawn_timeout_watcher (
128
+ done_rx,
129
+ timeout,
130
+ shutdown_flag. clone ( ) ,
131
+ timeout_flag. clone ( ) ,
132
+ server. clone ( ) ,
133
+ ) ;
134
+ }
135
+
136
+ let server_for_thread = server. clone ( ) ;
92
137
let server_handle = thread:: spawn ( move || {
93
138
while !shutdown_flag. load ( Ordering :: SeqCst ) {
94
- let req = match server . recv ( ) {
139
+ let req = match server_for_thread . recv ( ) {
95
140
Ok ( r) => r,
96
- Err ( e) => return Err ( io:: Error :: other ( e) ) ,
141
+ Err ( e) => {
142
+ // If we've been asked to shut down, break gracefully so that
143
+ // we can report timeout or cancellation status uniformly.
144
+ if shutdown_flag. load ( Ordering :: SeqCst ) {
145
+ break ;
146
+ } else {
147
+ return Err ( io:: Error :: other ( e) ) ;
148
+ }
149
+ }
97
150
} ;
98
151
99
152
let url_raw = req. url ( ) . to_string ( ) ;
@@ -198,24 +251,59 @@ pub fn run_login_server(
198
251
}
199
252
let _ = req. respond ( resp) ;
200
253
shutdown_flag. store ( true , Ordering :: SeqCst ) ;
254
+
255
+ // Login has succeeded, so disarm the timeout watcher.
256
+ let _ = done_tx. send ( ( ) ) ;
201
257
return Ok ( ( ) ) ;
202
258
}
203
259
_ => {
204
260
let _ = req. respond ( Response :: from_string ( "Not Found" ) . with_status_code ( 404 ) ) ;
205
261
}
206
262
}
207
263
}
208
- Err ( io:: Error :: other ( "Login flow was not completed" ) )
264
+
265
+ // Login has failed or timed out, so disarm the timeout watcher.
266
+ let _ = done_tx. send ( ( ) ) ;
267
+
268
+ if timeout_flag. load ( Ordering :: SeqCst ) {
269
+ Err ( io:: Error :: other ( "Login timed out" ) )
270
+ } else {
271
+ Err ( io:: Error :: other ( "Login was not completed" ) )
272
+ }
209
273
} ) ;
210
274
211
275
Ok ( LoginServer {
212
276
auth_url : auth_url. clone ( ) ,
213
277
actual_port,
214
278
server_handle,
215
279
shutdown_flag : shutdown_flag_clone,
280
+ server,
216
281
} )
217
282
}
218
283
284
+ /// Spawns a detached thread that waits for either a completion signal on `done_rx`
285
+ /// or the specified `timeout` to elapse. If the timeout elapses first it marks
286
+ /// the `shutdown_flag`, records `timeout_flag`, and unblocks the HTTP server so
287
+ /// that the main server loop can exit promptly.
288
+ fn spawn_timeout_watcher (
289
+ done_rx : mpsc:: Receiver < ( ) > ,
290
+ timeout : Duration ,
291
+ shutdown_flag : Arc < AtomicBool > ,
292
+ timeout_flag : Arc < AtomicBool > ,
293
+ server : Arc < Server > ,
294
+ ) {
295
+ thread:: spawn ( move || {
296
+ if done_rx. recv_timeout ( timeout) . is_err ( )
297
+ && shutdown_flag
298
+ . compare_exchange ( false , true , Ordering :: SeqCst , Ordering :: SeqCst )
299
+ . is_ok ( )
300
+ {
301
+ timeout_flag. store ( true , Ordering :: SeqCst ) ;
302
+ server. unblock ( ) ;
303
+ }
304
+ } ) ;
305
+ }
306
+
219
307
fn build_authorize_url (
220
308
issuer : & str ,
221
309
client_id : & str ,
0 commit comments