@@ -22,6 +22,7 @@ use alloc::vec::Vec;
2222use core:: net:: SocketAddr ;
2323use core:: time:: Duration ;
2424use feo_log:: warn;
25+ use feo_time:: Instant ;
2526use mio:: net:: { TcpListener , UnixListener } ;
2627use mio:: { Events , Token } ;
2728use std:: collections:: { HashMap , HashSet } ;
4647
4748 all_activities : Vec < ActivityId > ,
4849 all_recorders : Vec < AgentId > ,
50+ connection_timeout : Duration ,
4951}
5052
5153impl < L > SchedulerConnector < L >
5759 server : SocketServer < L > ,
5860 activity_ids : impl IntoIterator < Item = ActivityId > ,
5961 recorder_ids : impl IntoIterator < Item = AgentId > ,
62+ connection_timeout : Duration ,
6063 ) -> Self {
6164 let events = Events :: with_capacity ( 32 ) ;
6265
7376 recorder_id_token_map,
7477 all_activities,
7578 all_recorders,
79+ connection_timeout,
7680 }
7781 }
7882}
@@ -83,9 +87,10 @@ impl TcpSchedulerConnector {
8387 bind_address : SocketAddr ,
8488 activity_ids : impl IntoIterator < Item = ActivityId > ,
8589 recorder_ids : impl IntoIterator < Item = AgentId > ,
90+ connection_timeout : Duration ,
8691 ) -> Self {
8792 let tcp_server = TcpServer :: new ( bind_address) ;
88- Self :: new_with_server ( tcp_server, activity_ids, recorder_ids)
93+ Self :: new_with_server ( tcp_server, activity_ids, recorder_ids, connection_timeout )
8994 }
9095}
9196
@@ -95,9 +100,10 @@ impl UnixSchedulerConnector {
95100 path : & Path ,
96101 activity_ids : impl IntoIterator < Item = ActivityId > ,
97102 recorder_ids : impl IntoIterator < Item = AgentId > ,
103+ connection_timeout : Duration ,
98104 ) -> Self {
99105 let unix_server = UnixServer :: new ( path) ;
100- Self :: new_with_server ( unix_server, activity_ids, recorder_ids)
106+ Self :: new_with_server ( unix_server, activity_ids, recorder_ids, connection_timeout )
101107 }
102108}
103109
@@ -109,11 +115,19 @@ where
109115 let mut missing_activities: HashSet < ActivityId > =
110116 self . all_activities . iter ( ) . cloned ( ) . collect ( ) ;
111117 let mut missing_recorders: HashSet < AgentId > = self . all_recorders . iter ( ) . cloned ( ) . collect ( ) ;
118+ let start_time = Instant :: now ( ) ;
112119
113120 while !missing_activities. is_empty ( ) || !missing_recorders. is_empty ( ) {
114- if let Some ( ( token, signal) ) = self
115- . server
116- . receive ( & mut self . events , Duration :: from_secs ( 1 ) )
121+ let elapsed = start_time. elapsed ( ) ;
122+ if elapsed >= self . connection_timeout {
123+ return Err ( Error :: Io ( (
124+ std:: io:: ErrorKind :: TimedOut . into ( ) ,
125+ "CONNECTION_TIMEOUT" ,
126+ ) ) ) ;
127+ }
128+ let remaining_timeout = self . connection_timeout . saturating_sub ( elapsed) ;
129+ // Wait for a new connection, but no longer than the remaining overall timeout.
130+ if let Some ( ( token, signal) ) = self . server . receive ( & mut self . events , remaining_timeout)
117131 {
118132 match signal {
119133 ProtocolSignal :: ActivityHello ( activity_id) => {
0 commit comments