@@ -242,7 +242,6 @@ def __init__(
242
242
self .src_port = 0
243
243
self ._dns = b"\x00 \x00 \x00 \x00 "
244
244
# udp related
245
- self .udp_datasize = [0 ] * self .max_sockets
246
245
self .udp_from_ip = [b"\x00 \x00 \x00 \x00 " ] * self .max_sockets
247
246
self .udp_from_port = [0 ] * self .max_sockets
248
247
@@ -513,22 +512,27 @@ def socket_available(self, socket_num: int, sock_type: int = _SNMR_TCP) -> int:
513
512
self ._sock_num_in_range (socket_num )
514
513
515
514
number_of_bytes = self ._get_rx_rcv_size (socket_num )
516
-
517
- if sock_type == _SNMR_TCP :
518
- return number_of_bytes
519
- if number_of_bytes > 0 :
520
- if self .udp_datasize [socket_num ]:
521
- return self .udp_datasize [socket_num ]
522
- # parse the udp rx packet
523
- # read the first 8 header bytes
524
- udp_bytes , self ._pbuff [:8 ] = self .socket_read (socket_num , 8 )
525
- if udp_bytes > 0 :
526
- self .udp_from_ip [socket_num ] = self ._pbuff [:4 ]
527
- self .udp_from_port [socket_num ] = (self ._pbuff [4 ] << 8 ) + self ._pbuff [5 ]
528
- self .udp_datasize [socket_num ] = (self ._pbuff [6 ] << 8 ) + self ._pbuff [7 ]
529
- udp_bytes = self .udp_datasize [socket_num ]
530
- return udp_bytes
531
- return 0
515
+ if self .read_snsr (socket_num ) == SNMR_UDP :
516
+ number_of_bytes -= 8 # Subtract UDP header from packet size.
517
+ if number_of_bytes < 0 :
518
+ raise ValueError ("Negative number of bytes found on socket." )
519
+ return number_of_bytes
520
+
521
+ # if sock_type == _SNMR_TCP:
522
+ # return number_of_bytes
523
+ # if number_of_bytes > 0:
524
+ # if self.udp_datasize[socket_num]:
525
+ # return self.udp_datasize[socket_num]
526
+ # # parse the udp rx packet
527
+ # # read the first 8 header bytes
528
+ # udp_bytes, self._pbuff[:8] = self.socket_read(socket_num, 8)
529
+ # if udp_bytes > 0:
530
+ # self.udp_from_ip[socket_num] = self._pbuff[:4]
531
+ # self.udp_from_port[socket_num] = (self._pbuff[4] << 8) + self._pbuff[5]
532
+ # self.udp_datasize[socket_num] = (self._pbuff[6] << 8) + self._pbuff[7]
533
+ # udp_bytes = self.udp_datasize[socket_num]
534
+ # return udp_bytes
535
+ # return 0
532
536
533
537
def socket_status (self , socket_num : int ) -> int :
534
538
"""
@@ -590,8 +594,6 @@ def socket_connect(
590
594
)
591
595
if self .socket_status (socket_num ) == SNSR_SOCK_CLOSED :
592
596
raise ConnectionError ("Failed to establish connection." )
593
- elif conn_mode == SNMR_UDP :
594
- self .udp_datasize [socket_num ] = 0
595
597
return 1
596
598
597
599
def get_socket (self , * , reserve_socket = False ) -> int :
@@ -805,7 +807,7 @@ def socket_read(self, socket_num: int, length: int) -> Tuple[int, bytes]:
805
807
:param int socket_num: The socket to read data from.
806
808
:param int length: The number of bytes to read from the socket.
807
809
808
- :return Tuple[int, bytes]: If the read was successful then the first
810
+ :returns Tuple[int, bytes]: If the read was successful then the first
809
811
item of the tuple is the length of the data and the second is the data.
810
812
If the read was unsuccessful then 0, b"" is returned.
811
813
@@ -858,16 +860,24 @@ def read_udp(self, socket_num: int, length: int) -> Tuple[int, bytes]:
858
860
"""
859
861
self ._sock_num_in_range (socket_num )
860
862
bytes_on_socket , bytes_read = 0 , b""
861
- if self .udp_datasize [socket_num ] > 0 :
862
- if self .udp_datasize [socket_num ] <= length :
863
+ # Parse the UDP Rx packet.
864
+ _ , self ._pbuff [:8 ] = self .socket_read (socket_num , 8 )
865
+ try :
866
+ self .udp_from_ip [socket_num ] = self ._pbuff [:4 ]
867
+ self .udp_from_port [socket_num ] = int .from_bytes (self ._pbuff [4 :6 ], "big" )
868
+ udp_data_bytes = int .from_bytes (self ._pbuff [6 :8 ], "big" )
869
+ except IndexError as err :
870
+ raise IndexError ("Invalid UDP packet header." ) from err
871
+ # Read the UDP packet data.
872
+ if udp_data_bytes :
873
+ if udp_data_bytes <= length :
863
874
bytes_on_socket , bytes_read = self .socket_read (
864
- socket_num , self . udp_datasize [ socket_num ]
875
+ socket_num , udp_data_bytes
865
876
)
866
877
else :
867
878
bytes_on_socket , bytes_read = self .socket_read (socket_num , length )
868
879
# just consume the rest, it is lost to the higher layers
869
- self .socket_read (socket_num , self .udp_datasize [socket_num ] - length )
870
- self .udp_datasize [socket_num ] = 0
880
+ self .socket_read (socket_num , udp_data_bytes - length )
871
881
return bytes_on_socket , bytes_read
872
882
873
883
def socket_write (
0 commit comments