1212from abc import ABC , abstractmethod
1313from typing import List , Union , Optional
1414
15+ from .leak import Leakage
1516from .types import tcpConnectionTypes
1617from ...common .types import TcpOrTlsSocket
1718from ...common .constants import DEFAULT_BUFFER_SIZE , DEFAULT_MAX_SEND_SIZE
1819
1920
2021logger = logging .getLogger (__name__ )
2122
23+ EMPTY_MV = memoryview (b"" )
2224
2325class TcpConnectionUninitializedException (Exception ):
2426 pass
@@ -34,12 +36,23 @@ class TcpConnection(ABC):
3436 a socket connection object.
3537 """
3638
37- def __init__ (self , tag : int ) -> None :
39+ def __init__ (
40+ self ,
41+ tag : int ,
42+ flush_bw_in_bps : int = 512 ,
43+ recv_bw_in_bps : int = 512 ,
44+ ) -> None :
3845 self .tag : str = 'server' if tag == tcpConnectionTypes .SERVER else 'client'
3946 self .buffer : List [memoryview ] = []
4047 self .closed : bool = False
4148 self ._reusable : bool = False
4249 self ._num_buffer = 0
50+ self ._flush_leakage = (
51+ Leakage (rate = flush_bw_in_bps ) if flush_bw_in_bps > 0 else None
52+ )
53+ self ._recv_leakage = (
54+ Leakage (rate = recv_bw_in_bps ) if recv_bw_in_bps > 0 else None
55+ )
4356
4457 @property
4558 @abstractmethod
@@ -55,14 +68,19 @@ def recv(
5568 self , buffer_size : int = DEFAULT_BUFFER_SIZE ,
5669 ) -> Optional [memoryview ]:
5770 """Users must handle socket.error exceptions"""
71+ if self ._recv_leakage is not None :
72+ allowed_bytes = self ._recv_leakage .consume (buffer_size )
73+ if allowed_bytes == 0 :
74+ return EMPTY_MV
75+ buffer_size = min (buffer_size , allowed_bytes )
5876 data : bytes = self .connection .recv (buffer_size )
59- if len (data ) == 0 :
77+ size = len (data )
78+ if self ._recv_leakage is not None :
79+ self ._recv_leakage .putback (buffer_size - size )
80+ if size == 0 :
6081 return None
61- logger .debug (
62- 'received %d bytes from %s' %
63- (len (data ), self .tag ),
64- )
65- # logger.info(data)
82+ logger .debug ("received %d bytes from %s" % (size , self .tag ))
83+ logger .info (data )
6684 return memoryview (data )
6785
6886 def close (self ) -> bool :
@@ -75,6 +93,8 @@ def has_buffer(self) -> bool:
7593 return self ._num_buffer != 0
7694
7795 def queue (self , mv : memoryview ) -> None :
96+ if len (mv ) == 0 :
97+ return
7898 self .buffer .append (mv )
7999 self ._num_buffer += 1
80100
@@ -83,21 +103,38 @@ def flush(self, max_send_size: Optional[int] = None) -> int:
83103 if not self .has_buffer ():
84104 return 0
85105 mv = self .buffer [0 ]
106+ print (self .buffer )
107+ print (mv .tobytes ())
86108 # TODO: Assemble multiple packets if total
87109 # size remains below max send size.
88110 max_send_size = max_send_size or DEFAULT_MAX_SEND_SIZE
89- try :
90- sent : int = self .send (mv [:max_send_size ])
91- except BlockingIOError :
92- logger .warning ('BlockingIOError when trying send to {0}' .format (self .tag ))
93- return 0
111+ allowed_bytes = (
112+ self ._flush_leakage .consume (min (len (mv ), max_send_size ))
113+ if self ._flush_leakage is not None
114+ else max_send_size
115+ )
116+ sent : int = 0
117+ if allowed_bytes > 0 :
118+ try :
119+ sent = self .send (mv [:allowed_bytes ])
120+ if self ._flush_leakage is not None :
121+ self ._flush_leakage .putback (allowed_bytes - sent )
122+ except BlockingIOError :
123+ logger .warning (
124+ "BlockingIOError when trying send to {0}" .format (self .tag )
125+ )
126+ del mv
127+ return 0
128+ # if sent == 0:
129+ # return 0
94130 if sent == len (mv ):
95131 self .buffer .pop (0 )
96132 self ._num_buffer -= 1
97133 else :
98134 self .buffer [0 ] = mv [sent :]
99- logger .debug ('flushed %d bytes to %s' % (sent , self .tag ))
100- # logger.info(mv[:sent].tobytes())
135+ # if sent > 0:
136+ logger .debug ("flushed %d bytes to %s" % (sent , self .tag ))
137+ logger .info (mv [:sent ].tobytes ())
101138 del mv
102139 return sent
103140
0 commit comments