1818
1919#include < atomic>
2020#include < fcntl.h>
21+ #include < functional>
2122#include < thread>
2223
2324#ifndef _WIN32
@@ -177,70 +178,89 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
177178#endif // _WIN32
178179}
179180
180- Expected<std::unique_ptr<raw_socket_stream>>
181- ListeningSocket::accept (std::chrono::milliseconds Timeout) {
182-
183- struct pollfd FDs[2 ];
184- FDs[0 ].events = POLLIN;
181+ // If a file descriptor being monitored by ::poll is closed by another thread,
182+ // the result is unspecified. In the case ::poll does not unblock and return,
183+ // when ActiveFD is closed, you can provide another file descriptor via CancelFD
184+ // that when written to will cause poll to return. Typically CancelFD is the
185+ // read end of a unidirectional pipe.
186+ //
187+ // Timeout should be -1 to block indefinitly
188+ //
189+ // getActiveFD is a callback to handle ActiveFD's of std::atomic<int> and int
190+ static std::error_code
191+ manageTimeout (const std::chrono::milliseconds &Timeout,
192+ const std::function<int ()> &getActiveFD,
193+ const std::optional<int > &CancelFD = std::nullopt ) {
194+ struct pollfd FD[2 ];
195+ FD[0 ].events = POLLIN;
185196#ifdef _WIN32
186- SOCKET WinServerSock = _get_osfhandle (FD );
187- FDs [0 ].fd = WinServerSock;
197+ SOCKET WinServerSock = _get_osfhandle (getActiveFD () );
198+ FD [0 ].fd = WinServerSock;
188199#else
189- FDs [0 ].fd = FD ;
200+ FD [0 ].fd = getActiveFD () ;
190201#endif
191- FDs[1 ].events = POLLIN;
192- FDs[1 ].fd = PipeFD[0 ];
193-
194- // Keep track of how much time has passed in case poll is interupted by a
195- // signal and needs to be recalled
196- int RemainingTime = Timeout.count ();
197- std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds (0 );
198- int PollStatus = -1 ;
199-
200- while (PollStatus == -1 && (Timeout.count () == -1 || ElapsedTime < Timeout)) {
201- if (Timeout.count () != -1 )
202- RemainingTime -= ElapsedTime.count ();
202+ uint8_t FDCount = 1 ;
203+ if (CancelFD.has_value ()) {
204+ FD[1 ].events = POLLIN;
205+ FD[1 ].fd = CancelFD.value ();
206+ FDCount++;
207+ }
203208
204- auto Start = std::chrono::steady_clock::now ();
209+ // Keep track of how much time has passed in case ::poll or WSAPoll are
210+ // interupted by a signal and need to be recalled
211+ auto Start = std::chrono::steady_clock::now ();
212+ auto RemainingTimeout = Timeout;
213+ int PollStatus = 0 ;
214+ do {
215+ // If Timeout is -1 then poll should block and RemainingTimeout does not
216+ // need to be recalculated
217+ if (PollStatus != 0 && Timeout != std::chrono::milliseconds (-1 )) {
218+ auto TotalElapsedTime =
219+ std::chrono::duration_cast<std::chrono::milliseconds>(
220+ std::chrono::steady_clock::now () - Start);
221+
222+ if (TotalElapsedTime >= Timeout)
223+ return std::make_error_code (std::errc::operation_would_block);
224+
225+ RemainingTimeout = Timeout - TotalElapsedTime;
226+ }
205227#ifdef _WIN32
206- PollStatus = WSAPoll (FDs, 2 , RemainingTime);
228+ PollStatus = WSAPoll (FD, FDCount, RemainingTimeout.count ());
229+ } while (PollStatus == SOCKET_ERROR &&
230+ getLastSocketErrorCode () == std::errc::interrupted);
207231#else
208- PollStatus = ::poll (FDs, 2 , RemainingTime);
232+ PollStatus = ::poll (FD, FDCount, RemainingTimeout.count ());
233+ } while (PollStatus == -1 &&
234+ getLastSocketErrorCode () == std::errc::interrupted);
209235#endif
210- // If FD equals -1 then ListeningSocket::shutdown has been called and it is
211- // appropriate to return operation_canceled
212- if (FD.load () == -1 )
213- return llvm::make_error<StringError>(
214- std::make_error_code (std::errc::operation_canceled),
215- " Accept canceled" );
216236
237+ // If ActiveFD equals -1 or CancelFD has data to be read then the operation
238+ // has been canceled by another thread
239+ if (getActiveFD () == -1 || (CancelFD.has_value () && FD[1 ].revents & POLLIN))
240+ return std::make_error_code (std::errc::operation_canceled);
217241#if _WIN32
218- if (PollStatus == SOCKET_ERROR) {
242+ if (PollStatus == SOCKET_ERROR)
219243#else
220- if (PollStatus == -1 ) {
244+ if (PollStatus == -1 )
221245#endif
222- std::error_code PollErrCode = getLastSocketErrorCode ();
223- // Ignore EINTR (signal occured before any request event) and retry
224- if (PollErrCode != std::errc::interrupted)
225- return llvm::make_error<StringError>(PollErrCode, " FD poll failed" );
226- }
227- if (PollStatus == 0 )
228- return llvm::make_error<StringError>(
229- std::make_error_code (std::errc::timed_out),
230- " No client requests within timeout window" );
231-
232- if (FDs[0 ].revents & POLLNVAL)
233- return llvm::make_error<StringError>(
234- std::make_error_code (std::errc::bad_file_descriptor));
246+ return getLastSocketErrorCode ();
247+ if (PollStatus == 0 )
248+ return std::make_error_code (std::errc::timed_out);
249+ if (FD[0 ].revents & POLLNVAL)
250+ return std::make_error_code (std::errc::bad_file_descriptor);
251+ return std::error_code ();
252+ }
235253
236- auto Stop = std::chrono::steady_clock::now ();
237- ElapsedTime +=
238- std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
239- }
254+ Expected<std::unique_ptr<raw_socket_stream>>
255+ ListeningSocket::accept (const std::chrono::milliseconds &Timeout) {
256+ auto getActiveFD = [this ]() -> int { return FD; };
257+ std::error_code TimeoutErr = manageTimeout (Timeout, getActiveFD, PipeFD[0 ]);
258+ if (TimeoutErr)
259+ return llvm::make_error<StringError>(TimeoutErr, " Timeout error" );
240260
241261 int AcceptFD;
242262#ifdef _WIN32
243- SOCKET WinAcceptSock = ::accept (WinServerSock , NULL , NULL );
263+ SOCKET WinAcceptSock = ::accept (_get_osfhandle (FD) , NULL , NULL );
244264 AcceptFD = _open_osfhandle (WinAcceptSock, 0 );
245265#else
246266 AcceptFD = ::accept (FD, NULL , NULL );
@@ -295,6 +315,8 @@ ListeningSocket::~ListeningSocket() {
295315raw_socket_stream::raw_socket_stream (int SocketFD)
296316 : raw_fd_stream (SocketFD, true ) {}
297317
318+ raw_socket_stream::~raw_socket_stream () {}
319+
298320Expected<std::unique_ptr<raw_socket_stream>>
299321raw_socket_stream::createConnectedUnix (StringRef SocketPath) {
300322#ifdef _WIN32
@@ -306,4 +328,14 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
306328 return std::make_unique<raw_socket_stream>(*FD);
307329}
308330
309- raw_socket_stream::~raw_socket_stream () {}
331+ ssize_t raw_socket_stream::read (char *Ptr, size_t Size,
332+ const std::chrono::milliseconds &Timeout) {
333+ auto getActiveFD = [this ]() -> int { return this ->get_fd (); };
334+ std::error_code Err = manageTimeout (Timeout, getActiveFD);
335+ // Mimic raw_fd_stream::read error handling behavior
336+ if (Err) {
337+ raw_fd_stream::error_detected (Err);
338+ return -1 ;
339+ }
340+ return raw_fd_stream::read (Ptr, Size);
341+ }
0 commit comments