|
4 | 4 |
|
5 | 5 | #include <compat.h>
|
6 | 6 | #include <logging.h>
|
| 7 | +#include <threadinterrupt.h> |
7 | 8 | #include <tinyformat.h>
|
8 | 9 | #include <util/sock.h>
|
9 | 10 | #include <util/system.h>
|
|
12 | 13 | #include <codecvt>
|
13 | 14 | #include <cwchar>
|
14 | 15 | #include <locale>
|
| 16 | +#include <stdexcept> |
15 | 17 | #include <string>
|
16 | 18 |
|
17 | 19 | #ifdef USE_POLL
|
18 | 20 | #include <poll.h>
|
19 | 21 | #endif
|
20 | 22 |
|
| 23 | +static inline bool IOErrorIsPermanent(int err) |
| 24 | +{ |
| 25 | + return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS; |
| 26 | +} |
| 27 | + |
21 | 28 | Sock::Sock() : m_socket(INVALID_SOCKET) {}
|
22 | 29 |
|
23 | 30 | Sock::Sock(SOCKET s) : m_socket(s) {}
|
@@ -125,6 +132,124 @@ bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occur
|
125 | 132 | #endif /* USE_POLL */
|
126 | 133 | }
|
127 | 134 |
|
| 135 | +void Sock::SendComplete(const std::string& data, |
| 136 | + std::chrono::milliseconds timeout, |
| 137 | + CThreadInterrupt& interrupt) const |
| 138 | +{ |
| 139 | + const auto deadline = GetTime<std::chrono::milliseconds>() + timeout; |
| 140 | + size_t sent{0}; |
| 141 | + |
| 142 | + for (;;) { |
| 143 | + const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)}; |
| 144 | + |
| 145 | + if (ret > 0) { |
| 146 | + sent += static_cast<size_t>(ret); |
| 147 | + if (sent == data.size()) { |
| 148 | + break; |
| 149 | + } |
| 150 | + } else { |
| 151 | + const int err{WSAGetLastError()}; |
| 152 | + if (IOErrorIsPermanent(err)) { |
| 153 | + throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err))); |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + const auto now = GetTime<std::chrono::milliseconds>(); |
| 158 | + |
| 159 | + if (now >= deadline) { |
| 160 | + throw std::runtime_error(strprintf( |
| 161 | + "Send timeout (sent only %u of %u bytes before that)", sent, data.size())); |
| 162 | + } |
| 163 | + |
| 164 | + if (interrupt) { |
| 165 | + throw std::runtime_error(strprintf( |
| 166 | + "Send interrupted (sent only %u of %u bytes before that)", sent, data.size())); |
| 167 | + } |
| 168 | + |
| 169 | + // Wait for a short while (or the socket to become ready for sending) before retrying |
| 170 | + // if nothing was sent. |
| 171 | + const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO}); |
| 172 | + Wait(wait_time, SEND); |
| 173 | + } |
| 174 | +} |
| 175 | + |
| 176 | +std::string Sock::RecvUntilTerminator(uint8_t terminator, |
| 177 | + std::chrono::milliseconds timeout, |
| 178 | + CThreadInterrupt& interrupt) const |
| 179 | +{ |
| 180 | + const auto deadline = GetTime<std::chrono::milliseconds>() + timeout; |
| 181 | + std::string data; |
| 182 | + bool terminator_found{false}; |
| 183 | + |
| 184 | + // We must not consume any bytes past the terminator from the socket. |
| 185 | + // One option is to read one byte at a time and check if we have read a terminator. |
| 186 | + // However that is very slow. Instead, we peek at what is in the socket and only read |
| 187 | + // as many bytes as possible without crossing the terminator. |
| 188 | + // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read |
| 189 | + // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte |
| 190 | + // at a time is about 50 times slower. |
| 191 | + |
| 192 | + for (;;) { |
| 193 | + char buf[512]; |
| 194 | + |
| 195 | + const ssize_t peek_ret{Recv(buf, sizeof(buf), MSG_PEEK)}; |
| 196 | + |
| 197 | + switch (peek_ret) { |
| 198 | + case -1: { |
| 199 | + const int err{WSAGetLastError()}; |
| 200 | + if (IOErrorIsPermanent(err)) { |
| 201 | + throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err))); |
| 202 | + } |
| 203 | + break; |
| 204 | + } |
| 205 | + case 0: |
| 206 | + throw std::runtime_error("Connection unexpectedly closed by peer"); |
| 207 | + default: |
| 208 | + auto end = buf + peek_ret; |
| 209 | + auto terminator_pos = std::find(buf, end, terminator); |
| 210 | + terminator_found = terminator_pos != end; |
| 211 | + |
| 212 | + const size_t try_len{terminator_found ? terminator_pos - buf + 1 : |
| 213 | + static_cast<size_t>(peek_ret)}; |
| 214 | + |
| 215 | + const ssize_t read_ret{Recv(buf, try_len, 0)}; |
| 216 | + |
| 217 | + if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) { |
| 218 | + throw std::runtime_error( |
| 219 | + strprintf("recv() returned %u bytes on attempt to read %u bytes but previous " |
| 220 | + "peek claimed %u bytes are available", |
| 221 | + read_ret, try_len, peek_ret)); |
| 222 | + } |
| 223 | + |
| 224 | + // Don't include the terminator in the output. |
| 225 | + const size_t append_len{terminator_found ? try_len - 1 : try_len}; |
| 226 | + |
| 227 | + data.append(buf, buf + append_len); |
| 228 | + |
| 229 | + if (terminator_found) { |
| 230 | + return data; |
| 231 | + } |
| 232 | + } |
| 233 | + |
| 234 | + const auto now = GetTime<std::chrono::milliseconds>(); |
| 235 | + |
| 236 | + if (now >= deadline) { |
| 237 | + throw std::runtime_error(strprintf( |
| 238 | + "Receive timeout (received %u bytes without terminator before that)", data.size())); |
| 239 | + } |
| 240 | + |
| 241 | + if (interrupt) { |
| 242 | + throw std::runtime_error(strprintf( |
| 243 | + "Receive interrupted (received %u bytes without terminator before that)", |
| 244 | + data.size())); |
| 245 | + } |
| 246 | + |
| 247 | + // Wait for a short while (or the socket to become ready for reading) before retrying. |
| 248 | + const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO}); |
| 249 | + Wait(wait_time, RECV); |
| 250 | + } |
| 251 | +} |
| 252 | + |
128 | 253 | #ifdef WIN32
|
129 | 254 | std::string NetworkErrorString(int err)
|
130 | 255 | {
|
|
0 commit comments