@@ -99,43 +99,53 @@ std::string decodeURIComponent(std::string url) {
9999 return result_url_.str ();
100100}
101101
102- WebSocketClient::WebSocketClient () {}
102+ WebSocketClient::WebSocketClient () : socket_guard_( nullptr ) {}
103103
104- WebSocketClient::~WebSocketClient () { Disconnect (); }
104+ WebSocketClient::~WebSocketClient () { DisconnectInternal (); }
105105
106106void WebSocketClient::Init () {}
107107
108108bool WebSocketClient::Connect (const std::string &url) {
109- Disconnect ();
109+ auto self = std::static_pointer_cast<WebSocketClient>(shared_from_this ());
110+ work_thread_.submit ([client_ptr = self, url]() {
111+ client_ptr->DisconnectInternal ();
112+ client_ptr->ConnectInternal (url);
113+ });
114+ return true ;
115+ }
110116
111- mutex_. lock ();
117+ void WebSocketClient::ConnectInternal ( const std::string &url) {
112118 url_ = url;
113119 thread_ = std::make_unique<std::thread>([this ]() { run (); });
114- mutex_.unlock ();
115- return true ;
116120}
117121
118122void WebSocketClient::Disconnect () {
119- mutex_.lock ();
120- if (socket_) {
121- CLOSESOCKET (socket_);
122- socket_ = 0 ;
123- }
123+ auto self = std::static_pointer_cast<WebSocketClient>(shared_from_this ());
124+ work_thread_.submit (
125+ [client_ptr = self]() { client_ptr->DisconnectInternal (); });
126+ }
127+
128+ void WebSocketClient::DisconnectInternal () {
124129 if (thread_) {
125130 if (thread_->joinable ()) {
126131 thread_->join ();
127132 LOGI (" WebSocketClient thread exit successfully." );
128133 }
129134 thread_.reset ();
130135 }
131- mutex_.unlock ();
132136}
133137
134138core::ConnectionType WebSocketClient::GetType () {
135139 return core::ConnectionType::kWebSocket ;
136140}
137141
138142void WebSocketClient::Send (const std::string &data) {
143+ auto self = std::static_pointer_cast<WebSocketClient>(shared_from_this ());
144+ work_thread_.submit (
145+ [client_ptr = self, data]() { client_ptr->SendInternal (data); });
146+ }
147+
148+ void WebSocketClient::SendInternal (const std::string &data) {
139149 const char *buf = data.data ();
140150 size_t payloadLen = data.size ();
141151 uint8_t prefix[14 ];
@@ -165,10 +175,8 @@ void WebSocketClient::Send(const std::string &data) {
165175 *reinterpret_cast <uint32_t *>(prefix + prefix_len) = 0 ;
166176 prefix_len += 4 ;
167177
168- mutex_.lock ();
169- send (socket_, (char *)prefix, prefix_len, 0 );
170- send (socket_, buf, payloadLen, 0 );
171- mutex_.unlock ();
178+ send (socket_guard_->Get (), (char *)prefix, prefix_len, 0 );
179+ send (socket_guard_->Get (), buf, payloadLen, 0 );
172180}
173181
174182void WebSocketClient::run () {
@@ -234,7 +242,7 @@ bool WebSocketClient::do_connect() {
234242 continue ;
235243 }
236244 if (connect (sockfd, p->ai_addr , p->ai_addrlen ) != -1 ) {
237- socket_ = sockfd;
245+ socket_guard_ = std::make_unique<base::SocketGuard>( sockfd) ;
238246 break ;
239247 }
240248 CLOSESOCKET (sockfd);
@@ -250,17 +258,18 @@ bool WebSocketClient::do_connect() {
250258 " Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n "
251259 " Sec-WebSocket-Version: 13\r\n\r\n " ,
252260 path, host, port);
253- send (socket_ , buf, strlen (buf), 0 );
261+ send (socket_guard_-> Get () , buf, strlen (buf), 0 );
254262
255263 int status;
256- if (readline (socket_ , buf, sizeof (buf)) < 10 ||
264+ if (readline (socket_guard_-> Get () , buf, sizeof (buf)) < 10 ||
257265 sscanf (buf, " HTTP/1.1 %d Switching Protocols\r\n " , &status) != 1 ||
258266 status != 101 ) {
259267 LOGE (" Connect Error: " << url_.c_str ());
260268 return false ;
261269 }
262270
263- while (readline (socket_, buf, sizeof (buf)) > 0 && buf[0 ] != ' \r ' ) {
271+ while (readline (socket_guard_->Get (), buf, sizeof (buf)) > 0 &&
272+ buf[0 ] != ' \r ' ) {
264273 size_t len = strlen (buf);
265274 buf[len - 2 ] = ' \0 ' ;
266275 LOGI (buf);
@@ -275,7 +284,8 @@ bool WebSocketClient::do_read(std::string &msg) {
275284 } head;
276285 auto self = std::static_pointer_cast<WebSocketClient>(shared_from_this ());
277286
278- if (recv (socket_, (char *)&head, sizeof (head), 0 ) != sizeof (head)) {
287+ if (recv (socket_guard_->Get (), (char *)&head, sizeof (head), 0 ) !=
288+ sizeof (head)) {
279289 LOGE (" failed to read websocket message" );
280290 delegate ()->OnFailure (self);
281291 return false ;
@@ -301,18 +311,18 @@ bool WebSocketClient::do_read(std::string &msg) {
301311
302312 if (payloadLen == 126 ) {
303313 uint8_t len[2 ];
304- recv (socket_ , (char *)&len, sizeof (len), 0 );
314+ recv (socket_guard_-> Get () , (char *)&len, sizeof (len), 0 );
305315 payloadLen = (len[0 ] << 8 ) | len[1 ];
306316 } else if (payloadLen == 127 ) {
307317 uint8_t len[8 ];
308- recv (socket_ , (char *)&len, sizeof (len), 0 );
318+ recv (socket_guard_-> Get () , (char *)&len, sizeof (len), 0 );
309319 payloadLen = (len[4 ] << 24 ) | (len[5 ] << 16 ) | (len[6 ] << 8 ) | len[7 ];
310320 }
311321
312322 msg.resize (payloadLen);
313323
314- if (recv (socket_ , const_cast <char *>(msg.data ()), payloadLen, 0 ) !=
315- payloadLen) {
324+ if (recv (socket_guard_-> Get () , const_cast <char *>(msg.data ()), payloadLen,
325+ 0 ) != payloadLen) {
316326 LOGE (" failed to read websocket message" );
317327 delegate ()->OnFailure (self);
318328 return false ;
0 commit comments