@@ -80,13 +80,15 @@ def __init__(self):
80
80
self ._buffer = memoryview (bytearray (MAX_MESSAGE_SIZE ))
81
81
self .expected_length = 0
82
82
self .expecting_header = False
83
- self .bytes_read = 0
83
+ self .ready_offset = 0
84
+ self .empty_offset = 0
84
85
self .op_code = None
85
86
self ._done = None
86
87
self ._connection_lost = False
87
88
self ._paused = False
88
89
self ._drain_waiter = None
89
90
self ._loop = asyncio .get_running_loop ()
91
+ self ._messages = collections .deque ()
90
92
91
93
def connection_made (self , transport ):
92
94
self .transport = transport
@@ -96,24 +98,71 @@ async def write(self, message: bytes):
96
98
await self ._drain_helper ()
97
99
98
100
async def read (self ):
99
- self ._done = self ._loop .create_future ()
100
- await self ._done
101
- return self .expected_length , self .op_code
101
+ data , opcode , to_remove = None , None , None
102
+ for message in self ._messages :
103
+ if message .done ():
104
+ data , opcode = self .unpack_message (message )
105
+ to_remove = message
106
+ if to_remove :
107
+ self ._messages .remove (to_remove )
108
+ else :
109
+ message = self ._loop .create_future ()
110
+ self ._messages .append (message )
111
+ try :
112
+ await message
113
+ finally :
114
+ self ._messages .remove (message )
115
+ data , opcode = self .unpack_message (message )
116
+ return data , opcode
117
+
118
+ def unpack_message (self , message ):
119
+ start , end , opcode = message .result ()
120
+ if isinstance (start , tuple ):
121
+ return memoryview (
122
+ self ._buffer [start [0 ]:end [0 ]].tobytes () + self ._buffer [start [1 ]:end [1 ]].tobytes ()), opcode
123
+ else :
124
+ return self ._buffer [start :end ], opcode
102
125
103
126
def get_buffer (self , sizehint : int ):
104
- return self ._buffer [self .bytes_read :]
127
+ if self .empty_offset + sizehint >= MAX_MESSAGE_SIZE - 1 :
128
+ self .empty_offset = 0
129
+ if self .empty_offset < self .ready_offset :
130
+ return self ._buffer [self .empty_offset :self .ready_offset ]
131
+ else :
132
+ return self ._buffer [self .empty_offset :]
105
133
106
134
def buffer_updated (self , nbytes : int ):
107
135
if nbytes == 0 :
108
136
self .connection_lost (OSError ("connection closed" ))
109
137
self ._done .set_result (None )
110
- self .bytes_read += nbytes
138
+ self .empty_offset += nbytes
111
139
if self .expecting_header :
112
- self .expected_length , _ , _ , self .op_code = _UNPACK_HEADER (self ._buffer [: 16 ])
140
+ self .expected_length , _ , _ , self .op_code = _UNPACK_HEADER (self ._buffer [self . ready_offset : self . ready_offset + 16 ])
113
141
self .expecting_header = False
114
142
115
- if self .bytes_read == self .expected_length :
116
- self ._done .set_result ((self .expected_length , self .op_code ))
143
+ if self .ready_offset < self .empty_offset :
144
+ if self .empty_offset - self .ready_offset >= self .expected_length :
145
+ self .store_message (self .ready_offset + 16 , self .ready_offset + self .expected_length , self .op_code )
146
+ self .ready_offset += self .expected_length
147
+ else :
148
+ if self .ready_offset + self .expected_length <= MAX_MESSAGE_SIZE - 1 :
149
+ self .store_message (self .ready_offset + 16 , self .ready_offset + self .expected_length , self .op_code )
150
+ self .ready_offset += self .expected_length
151
+ elif MAX_MESSAGE_SIZE - 1 - self .ready_offset + self .empty_offset >= self .expected_length :
152
+ self .store_message ((self .ready_offset , 0 ), (MAX_MESSAGE_SIZE - 1 , self .expected_length - (MAX_MESSAGE_SIZE - 1 - self .ready_offset )), self .op_code )
153
+ self .ready_offset = self .expected_length - (MAX_MESSAGE_SIZE - 1 - self .ready_offset )
154
+
155
+ def store_message (self , start , end , opcode ):
156
+ stored = False
157
+ for message in self ._messages :
158
+ if not message .done ():
159
+ message .set_result ((start , end , opcode ))
160
+ stored = True
161
+ if not stored :
162
+ message = self ._loop .create_future ()
163
+ message .set_result ((start , end , opcode ))
164
+ self ._messages .append (message )
165
+ self .expecting_header = True
117
166
118
167
def pause_writing (self ):
119
168
assert not self ._paused
0 commit comments