@@ -17,8 +17,9 @@ def __init__(self, conn, initiator):
1717 self .initiator = initiator
1818 self .buffers = {}
1919 self .streams = {}
20+ self .stream_queue = asyncio .Queue ()
2021
21- self .add_incoming_task ( )
22+ asyncio . ensure_future ( self .handle_incoming () )
2223
2324 def close (self ):
2425 """
@@ -33,7 +34,12 @@ def is_closed(self):
3334 """
3435 pass
3536
36- def read_buffer (self , stream_id ):
37+ async def read_buffer (self , stream_id ):
38+ # Empty buffer or nonexistent stream
39+ # TODO: propagate up timeout exception and catch
40+ if stream_id not in self .buffers or not self .buffers [stream_id ]:
41+ await self .handle_incoming ()
42+
3743 data = self .buffers [stream_id ]
3844 self .buffers [stream_id ] = bytearray ()
3945 return data
@@ -43,37 +49,22 @@ def open_stream(self, protocol_id, stream_id, peer_id, multi_addr):
4349 creates a new muxed_stream
4450 :return: a new stream
4551 """
46- stream = MuxedStream (peer_id , multi_addr , self )
52+ stream = MuxedStream (stream_id , multi_addr , self )
4753 self .streams [stream_id ] = stream
48- self .buffers [stream_id ] = bytearray ()
4954 return stream
5055
51- def accept_stream (self ):
56+ async def accept_stream (self ):
5257 """
5358 accepts a muxed stream opened by the other end
5459 :return: the accepted stream
5560 """
56- data = bytearray ()
57- while True :
58- chunk = self .raw_conn .reader .read (100 )
59- if not chunk :
60- break
61- data += chunk
62- header , end_index = decode_uvarint (data , 0 )
63- length , end_index = decode_uvarint (data , end_index )
64- message = data [end_index , end_index + length ]
65-
66- flag = header & 0x07
67- stream_id = header >> 3
68-
6961 # TODO update to pull out protocol_id from message
7062 protocol_id = "/echo/1.0.0"
71-
63+ stream_id = await self . stream_queue . get ()
7264 stream = MuxedStream (stream_id , False , self )
73-
7465 return stream , stream_id , protocol_id
7566
76- def send_message (self , flag , data , stream_id ):
67+ async def send_message (self , flag , data , stream_id ):
7768 """
7869 sends a message over the connection
7970 :param header: header to use
@@ -86,7 +77,8 @@ def send_message(self, flag, data, stream_id):
8677 header = encode_uvarint (header )
8778 data_length = encode_uvarint (len (data ))
8879 _bytes = header + data_length + data
89- return self .write_to_stream (_bytes )
80+
81+ return await self .write_to_stream (_bytes )
9082
9183 async def write_to_stream (self , _bytes ):
9284 self .raw_conn .writer .write (_bytes )
@@ -95,25 +87,23 @@ async def write_to_stream(self, _bytes):
9587
9688 async def handle_incoming (self ):
9789 data = bytearray ()
98- while True :
99- chunk = self .raw_conn .reader .read (100 )
100- if not chunk :
101- break
90+ try :
91+ chunk = await asyncio .wait_for (self .raw_conn .reader .read (1024 ), timeout = 5 )
10292 data += chunk
103- header , end_index = decode_uvarint ( data , 0 )
104- length , end_index = decode_uvarint (data , end_index )
105- message = data [ end_index , end_index + length ]
106-
107- # Deal with other types of messages
108- flag = header & 0x07
109- stream_id = header >> 3
110-
111- self . buffers [ stream_id ] = self . buffers [ stream_id ] + message
112- # Read header
113- # Read message length
114- # Read message into corresponding buffer
115-
116- def add_incoming_task ( self ) :
117- loop = asyncio . get_event_loop ()
118- handle_incoming_task = loop . create_task ( self . handle_incoming ())
119- handle_incoming_task . add_done_callback ( self . add_incoming_task )
93+
94+ header , end_index = decode_uvarint (data , 0 )
95+ length , end_index = decode_uvarint ( data , end_index )
96+
97+ message = data [ end_index : end_index + length + 1 ]
98+
99+ # Deal with other types of messages
100+ flag = header & 0x07
101+ stream_id = header >> 3
102+
103+ if stream_id not in self . buffers :
104+ self . buffers [ stream_id ] = message
105+ await self . stream_queue . put ( stream_id )
106+ else :
107+ self . buffers [ stream_id ] = self . buffers [ stream_id ] + message
108+ except asyncio . TimeoutError :
109+ print ( 'timeout!' )
0 commit comments