Skip to content

Commit f3038eb

Browse files
authored
GG-33808 IGNITE-15479 Fix partial read from socket (#57)
(cherry picked from commit 3bf1cc1)
1 parent 18b2aee commit f3038eb

File tree

3 files changed

+72
-16
lines changed

3 files changed

+72
-16
lines changed

pygridgain/connection/connection.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def _connection_listener(self):
160160
return self.client._event_listeners
161161

162162

163+
DEFAULT_INITIAL_BUF_SIZE = 1024
164+
165+
163166
class Connection(BaseConnection):
164167
"""
165168
This is a `pygridgain` class, that represents a connection to GridGain
@@ -354,39 +357,35 @@ def recv(self, flags=None, reconnect=True) -> bytearray:
354357
if flags is not None:
355358
kwargs['flags'] = flags
356359

357-
data = bytearray(1024)
360+
data = bytearray(DEFAULT_INITIAL_BUF_SIZE)
358361
buffer = memoryview(data)
359-
bytes_total_received, bytes_to_receive = 0, 0
362+
total_rcvd, packet_len = 0, 0
360363
while True:
361364
try:
362-
bytes_received = self._socket.recv_into(buffer, len(buffer), **kwargs)
363-
if bytes_received == 0:
365+
bytes_rcvd = self._socket.recv_into(buffer, len(buffer), **kwargs)
366+
if bytes_rcvd == 0:
364367
raise SocketError('Connection broken.')
365-
bytes_total_received += bytes_received
368+
total_rcvd += bytes_rcvd
366369
except connection_errors as e:
367370
self.failed = True
368371
if reconnect:
369372
self._on_connection_lost(e)
370373
self.reconnect()
371374
raise e
372375

373-
if bytes_total_received < 4:
374-
continue
375-
elif bytes_to_receive == 0:
376-
response_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER)
377-
bytes_to_receive = response_len
378-
379-
if response_len + 4 > len(data):
376+
if packet_len == 0 and total_rcvd > 4:
377+
packet_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER, signed=True) + 4
378+
if packet_len > len(data):
380379
buffer.release()
381-
data.extend(bytearray(response_len + 4 - len(data)))
382-
buffer = memoryview(data)[bytes_total_received:]
380+
data.extend(bytearray(packet_len - len(data)))
381+
buffer = memoryview(data)[total_rcvd:]
383382
continue
384383

385-
if bytes_total_received >= bytes_to_receive:
384+
if 0 < packet_len <= total_rcvd:
386385
buffer.release()
387386
break
388387

389-
buffer = buffer[bytes_received:]
388+
buffer = buffer[bytes_rcvd:]
390389

391390
return data
392391

tests/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#
2+
# Copyright 2021 GridGain Systems, Inc. and Contributors.
3+
#
4+
# Licensed under the GridGain Community Edition License (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.gridgain.com/products/software/community-edition/gridgain-community-edition-license
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#

tests/common/test_sync_socket.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#
2+
# Copyright 2021 GridGain Systems, Inc. and Contributors.
3+
#
4+
# Licensed under the GridGain Community Edition License (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.gridgain.com/products/software/community-edition/gridgain-community-edition-license
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
import secrets
17+
import socket
18+
import unittest.mock as mock
19+
20+
import pytest
21+
22+
from pygridgain import Client
23+
from tests.util import get_or_create_cache
24+
25+
old_recv_into = socket.socket.recv_into
26+
27+
28+
def patched_recv_into_factory(buf_len):
29+
def patched_recv_into(self, buffer, nbytes, **kwargs):
30+
return old_recv_into(self, buffer, min(nbytes, buf_len) if buf_len else nbytes, **kwargs)
31+
return patched_recv_into
32+
33+
34+
@pytest.mark.parametrize('buf_len', [0, 1, 4, 16, 32, 64, 128, 256, 512, 1024])
35+
def test_get_large_value(buf_len):
36+
with mock.patch.object(socket.socket, 'recv_into', new=patched_recv_into_factory(buf_len)):
37+
c = Client()
38+
with c.connect("127.0.0.1", 10801):
39+
with get_or_create_cache(c, 'test') as cache:
40+
value = secrets.token_hex((1 << 16) + 1)
41+
cache.put(1, value)
42+
assert value == cache.get(1)

0 commit comments

Comments
 (0)