Skip to content

Commit e5f2ab6

Browse files
committed
Fixed client restarting session and closing streams
1 parent bd70235 commit e5f2ab6

File tree

3 files changed

+95
-62
lines changed

3 files changed

+95
-62
lines changed

client.py

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,14 @@ async def send_upload_mtu_test(
367367
packet_type = parsed_header["packet_type"] if parsed_header else None
368368

369369
if packet_type == Packet_Type.MTU_UP_RES:
370+
self.logger.info(
371+
f"<green>Upload Test Success: {mtu_size}B ({mtu_char_len} chars) via {dns_server} for {domain}</green>"
372+
)
370373
return True
371374
elif packet_type == Packet_Type.ERROR_DROP:
375+
self.logger.info(
376+
f"<yellow>Upload Test Dropped (Server MTU Limit): {mtu_size}B via {dns_server} for {domain}</yellow>"
377+
)
372378
return False
373379
return False
374380

@@ -405,8 +411,14 @@ async def send_download_mtu_test(
405411

406412
if packet_type == Packet_Type.MTU_DOWN_RES:
407413
if returned_data and len(returned_data) == mtu_size:
414+
self.logger.info(
415+
f"<green>Download Test Success: {mtu_size}B via {dns_server} for {domain}</green>"
416+
)
408417
return True
409418
else:
419+
self.logger.info(
420+
f"<yellow>Download Test Failed (Data Mismatch): {mtu_size}B via {dns_server} for {domain}</yellow>"
421+
)
410422
return False
411423
return False
412424

@@ -484,7 +496,7 @@ async def test_mtu_sizes(self) -> bool:
484496
self.min_upload_mtu > 0 and up_mtu_bytes < self.min_upload_mtu
485497
):
486498
self.logger.warning(
487-
f"Connection invalid for {domain} via {resolver}: Upload MTU failed."
499+
f"<red>❌ Connection invalid for {domain} via {resolver}: Upload MTU failed.</red>"
488500
)
489501
continue
490502

@@ -497,7 +509,7 @@ async def test_mtu_sizes(self) -> bool:
497509
self.min_download_mtu > 0 and down_mtu_bytes < self.min_download_mtu
498510
):
499511
self.logger.warning(
500-
f"Connection invalid for {domain} via {resolver}: Download MTU failed."
512+
f"<red>❌ Connection invalid for {domain} via {resolver}: Download MTU failed.</red>"
501513
)
502514
continue
503515

@@ -509,8 +521,8 @@ async def test_mtu_sizes(self) -> bool:
509521
connection["packet_loss"] = 0
510522

511523
self.logger.info(
512-
f"<green>Valid: <cyan>{domain}</cyan> via <cyan>{resolver}</cyan> | "
513-
f"UP: {up_mtu_bytes}B ({up_mtu_char}c) | DOWN: {down_mtu_bytes}B</green>"
524+
f"<cyan>✅ Valid: {domain} via <green>{resolver}</green> | "
525+
f"Upload MTU: <red>{up_mtu_bytes}</red> | Download MTU: <red>{down_mtu_bytes}</red></cyan>"
514526
)
515527

516528
valid_conns = [c for c in self.connections_map if c.get("is_valid")]
@@ -833,23 +845,30 @@ async def _main_tunnel_loop(self):
833845
finally:
834846
self.logger.info("Cleaning up tunnel resources...")
835847

848+
for w in getattr(self, "workers", []):
849+
if not w.done():
850+
w.cancel()
851+
836852
if server:
837853
try:
838854
server.close()
839-
await server.wait_closed()
855+
await asyncio.wait_for(server.wait_closed(), timeout=1.0)
840856
except Exception:
841857
pass
842858

843-
for w in getattr(self, "workers", []):
844-
w.cancel()
845-
859+
close_tasks = []
846860
for sid in list(self.active_streams.keys()):
861+
close_tasks.append(self.close_stream(sid, reason="Client App Closing"))
862+
863+
if close_tasks:
847864
try:
848865
await asyncio.wait_for(
849-
self.close_stream(sid, reason="Client App Closing"), timeout=1.5
866+
asyncio.gather(*close_tasks, return_exceptions=True),
867+
timeout=1.5,
850868
)
851869
except Exception:
852870
pass
871+
853872
self.active_streams.clear()
854873

855874
if hasattr(self, "tunnel_sock") and self.tunnel_sock:
@@ -895,9 +914,9 @@ async def _close_writer_safely(self, writer):
895914
try:
896915
if writer and not writer.is_closing():
897916
writer.close()
898-
await asyncio.wait_for(writer.wait_closed(), timeout=3.0)
899-
except Exception as e:
900-
self.logger.debug(f"Error closing writer: {e}")
917+
await asyncio.wait_for(writer.wait_closed(), timeout=0.5)
918+
except Exception:
919+
pass
901920

902921
def _new_get_stream_id(self):
903922
start = (self.last_stream_id + 1) or 1
@@ -920,6 +939,12 @@ def _new_get_stream_id(self):
920939
return False, 0
921940

922941
async def _handle_local_tcp_connection(self, reader, writer):
942+
if self.should_stop.is_set() or (
943+
self.session_restart_event and self.session_restart_event.is_set()
944+
):
945+
await self._close_writer_safely(writer)
946+
return
947+
923948
stream_id_status, stream_id = self._new_get_stream_id()
924949
if not stream_id_status:
925950
self.logger.error("No available Stream IDs! Too many connections.")
@@ -929,9 +954,14 @@ async def _handle_local_tcp_connection(self, reader, writer):
929954
self.logger.info(f"New local connection, assigning Stream ID: {stream_id}")
930955

931956
now = self.loop.time()
932-
await self.outbound_queue.put(
933-
(2, now, Packet_Type.STREAM_SYN, stream_id, 0, b"")
934-
)
957+
try:
958+
self.outbound_queue.put_nowait(
959+
(2, now, Packet_Type.STREAM_SYN, stream_id, 0, b"")
960+
)
961+
except asyncio.QueueFull:
962+
self.logger.debug("Queue is full, dropping new connection.")
963+
await self._close_writer_safely(writer)
964+
return
935965

936966
self.active_streams[stream_id] = {
937967
"reader": reader,
@@ -944,7 +974,7 @@ async def _handle_local_tcp_connection(self, reader, writer):
944974

945975
async def _clear_stream_from_queue(self, stream_id: int):
946976
"""Removes all packets of a specific stream from the outbound queue except FIN."""
947-
if self.outbound_queue.empty():
977+
if not hasattr(self, "outbound_queue") or self.outbound_queue.empty():
948978
return
949979

950980
items = []
@@ -957,13 +987,21 @@ async def _clear_stream_from_queue(self, stream_id: int):
957987
break
958988

959989
for item in items:
960-
await self.outbound_queue.put(item)
990+
try:
991+
self.outbound_queue.put_nowait(item)
992+
except asyncio.QueueFull:
993+
pass
961994

962995
self.logger.debug(f"Queue cleared for Stream {stream_id}")
963996

964997
async def _client_enqueue_tx(
965998
self, priority, stream_id, sn, data, is_ack=False, is_fin=False, is_resend=False
966999
):
1000+
if self.should_stop.is_set() or (
1001+
self.session_restart_event and self.session_restart_event.is_set()
1002+
):
1003+
return
1004+
9671005
ptype = Packet_Type.STREAM_DATA
9681006
effective_priority = 3
9691007

@@ -977,9 +1015,12 @@ async def _client_enqueue_tx(
9771015
ptype = Packet_Type.STREAM_RESEND if is_resend else ptype
9781016
effective_priority = 2
9791017

980-
await self.outbound_queue.put(
981-
(effective_priority, self.loop.time(), ptype, stream_id, sn, data)
982-
)
1018+
try:
1019+
self.outbound_queue.put_nowait(
1020+
(effective_priority, self.loop.time(), ptype, stream_id, sn, data)
1021+
)
1022+
except asyncio.QueueFull:
1023+
pass
9831024

9841025
async def _tx_worker(self):
9851026
self.logger.debug("<magenta>[TX]</magenta> TX Worker started.")
@@ -1151,11 +1192,10 @@ async def _handle_server_response(self, header, data):
11511192
await self.close_stream(stream_id, reason="Server sent FIN")
11521193

11531194
elif ptype == Packet_Type.ERROR_DROP:
1154-
self.logger.error(
1155-
"<red>Session dropped by server (Server Restarted or Invalid). Reconnecting...</red>"
1156-
)
1157-
1158-
if self.session_restart_event:
1195+
if not self.session_restart_event.is_set():
1196+
self.logger.error(
1197+
"<red>Session dropped by server (Server Restarted or Invalid). Reconnecting...</red>"
1198+
)
11591199
self.session_restart_event.set()
11601200

11611201
async def close_stream(self, stream_id: int, reason: str = "Unknown") -> None:

dns_utils/ARQ.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
self.rcv_buf = {}
3535

3636
self.last_activity = time.time()
37-
self.rto = 2.0
37+
self.rto = 1.0
3838
self.closed = False
3939
self.logger = logger
4040
self._fin_sent = False
@@ -202,7 +202,7 @@ async def close(self, reason="Unknown"):
202202
):
203203
self.writer.close()
204204
try:
205-
await asyncio.wait_for(self.writer.wait_closed(), timeout=3.0)
205+
await asyncio.wait_for(self.writer.wait_closed(), timeout=0.5)
206206
except Exception:
207207
pass
208208
except Exception:

server.py

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,18 @@ async def _close_session(self, session_id: int) -> None:
9494
self.logger.debug(f"Closing Session {session_id} and all its streams...")
9595

9696
stream_ids = list(session.get("streams", {}).keys())
97-
for sid in stream_ids:
98-
await self.close_stream(session_id, sid, reason="Session Closing")
97+
close_tasks = [
98+
self.close_stream(session_id, sid, reason="Session Closing")
99+
for sid in stream_ids
100+
]
101+
102+
if close_tasks:
103+
try:
104+
await asyncio.wait_for(
105+
asyncio.gather(*close_tasks, return_exceptions=True), timeout=2.0
106+
)
107+
except Exception:
108+
pass
99109

100110
out_queue = session.get("outbound_queue")
101111
if out_queue:
@@ -861,42 +871,25 @@ async def start(self) -> None:
861871
async def stop(self) -> None:
862872
"""Signal the server to stop."""
863873
self.should_stop.set()
864-
for session_id in list(self.sessions.keys()):
865-
await self._close_session(session_id)
866874

867-
try:
868-
if getattr(self, "_retransmit_task", None):
869-
self._retransmit_task.cancel()
870-
except Exception:
871-
pass
872-
873-
try:
874-
if getattr(self, "_dns_task", None):
875-
self._dns_task.cancel()
876-
except Exception:
877-
pass
875+
for task in list(self._background_tasks):
876+
if not task.done():
877+
task.cancel()
878878

879-
try:
880-
if getattr(self, "_session_cleanup_task", None):
881-
self._session_cleanup_task.cancel()
882-
except Exception:
883-
pass
879+
for task_name in ["_retransmit_task", "_dns_task", "_session_cleanup_task"]:
880+
task = getattr(self, task_name, None)
881+
if task and not task.done():
882+
task.cancel()
884883

885-
try:
886-
await asyncio.gather(
887-
*(
888-
t
889-
for t in (
890-
getattr(self, "_dns_task", None),
891-
getattr(self, "_session_cleanup_task", None),
892-
getattr(self, "_retransmit_task", None),
893-
)
894-
if t
895-
),
896-
return_exceptions=True,
897-
)
898-
except Exception:
899-
pass
884+
session_ids = list(self.sessions.keys())
885+
close_tasks = [self._close_session(sid) for sid in session_ids]
886+
if close_tasks:
887+
try:
888+
await asyncio.wait_for(
889+
asyncio.gather(*close_tasks, return_exceptions=True), timeout=3.0
890+
)
891+
except Exception:
892+
pass
900893

901894
if self.udp_sock:
902895
try:

0 commit comments

Comments
 (0)