Skip to content

Commit bd70235

Browse files
committed
Refactoring Server Part 1, Working on client closing streams, Decrease time of MTU testing to Ignore low speed DNS servers.
1 parent cc8f958 commit bd70235

File tree

2 files changed

+123
-274
lines changed

2 files changed

+123
-274
lines changed

client.py

Lines changed: 36 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ async def _binary_search_mtu(
294294
self.logger.debug(
295295
f"<cyan>[MTU]</cyan> Starting binary search for MTU. Range: {min_mtu}-{max_mtu}"
296296
)
297+
297298
for _ in range(2):
298299
if await test_callable(max_mtu):
299300
self.logger.debug(f"<cyan>[MTU]</cyan> Max MTU {max_mtu} is valid.")
@@ -309,7 +310,8 @@ async def _binary_search_mtu(
309310
break
310311

311312
ok = False
312-
for _ in range(1):
313+
314+
for _ in range(2):
313315
try:
314316
ok = await test_callable(mid)
315317
if ok:
@@ -358,7 +360,7 @@ async def send_upload_mtu_test(
358360
return False
359361

360362
response = await self._send_and_receive_dns(
361-
dns_queries[0], dns_server, dns_port, 2
363+
dns_queries[0], dns_server, dns_port, 1
362364
)
363365

364366
parsed_header, _ = await self._process_received_packet(response)
@@ -396,7 +398,7 @@ async def send_download_mtu_test(
396398
return False
397399

398400
response = await self._send_and_receive_dns(
399-
dns_queries[0], dns_server, dns_port, 2
401+
dns_queries[0], dns_server, dns_port, 1
400402
)
401403
parsed_header, returned_data = await self._process_received_packet(response)
402404
packet_type = parsed_header["packet_type"] if parsed_header else None
@@ -705,6 +707,7 @@ async def run_client(self, MTU_TEST=False) -> None:
705707
"""Run the MasterDnsVPN Client main logic."""
706708
self.logger.info("Setting up connections...")
707709
try:
710+
self.session_restart_event = asyncio.Event()
708711
if MTU_TEST or len(self.connections_map) <= 0:
709712
await self.create_connection_map()
710713

@@ -840,28 +843,14 @@ async def _main_tunnel_loop(self):
840843
for w in getattr(self, "workers", []):
841844
w.cancel()
842845

843-
for _id, active_stream in list(self.active_streams.items()):
846+
for sid in list(self.active_streams.keys()):
844847
try:
845-
stream_obj = active_stream.get("stream")
846-
if stream_obj and not stream_obj.closed:
847-
stream_obj._fin_sent = True
848-
stream_obj.closed = True
849-
if (
850-
hasattr(stream_obj, "io_task")
851-
and stream_obj.io_task
852-
and not stream_obj.io_task.done()
853-
):
854-
stream_obj.io_task.cancel()
855-
except Exception:
856-
pass
857-
858-
try:
859-
writer = active_stream.get("writer")
860848
await asyncio.wait_for(
861-
self._close_writer_safely(writer), timeout=3.0
849+
self.close_stream(sid, reason="Client App Closing"), timeout=1.5
862850
)
863851
except Exception:
864852
pass
853+
self.active_streams.clear()
865854

866855
if hasattr(self, "tunnel_sock") and self.tunnel_sock:
867856
try:
@@ -880,7 +869,7 @@ async def _main_tunnel_loop(self):
880869
async def _rx_worker(self):
881870
"""Continuously listen for incoming VPN packets on the tunnel socket."""
882871
self.logger.debug("<magenta>[RX]</magenta> RX Worker started.")
883-
while not self.should_stop.is_set():
872+
while not self.should_stop.is_set() and not self.session_restart_event.is_set():
884873
try:
885874
data, addr = await asyncio.wait_for(
886875
async_recvfrom(self.loop, self.tunnel_sock, 65536), timeout=1.0
@@ -915,7 +904,7 @@ def _new_get_stream_id(self):
915904
stream_id = start
916905
wrapped = False
917906

918-
while not self.should_stop.is_set():
907+
while not self.should_stop.is_set() and not self.session_restart_event.is_set():
919908
if stream_id > 65535:
920909
if wrapped:
921910
return False, 0
@@ -994,7 +983,7 @@ async def _client_enqueue_tx(
994983

995984
async def _tx_worker(self):
996985
self.logger.debug("<magenta>[TX]</magenta> TX Worker started.")
997-
while not self.should_stop.is_set():
986+
while not self.should_stop.is_set() and not self.session_restart_event.is_set():
998987
try:
999988
item = await asyncio.wait_for(self.outbound_queue.get(), timeout=0.2)
1000989
await self._send_single_packet(item)
@@ -1159,48 +1148,7 @@ async def _handle_server_response(self, header, data):
11591148
)
11601149

11611150
elif ptype == Packet_Type.STREAM_FIN and stream_id_exists:
1162-
self.logger.info(f"<y>Stream {stream_id} Closed by server.</y>")
1163-
stream = self.active_streams.get(stream_id)
1164-
if stream:
1165-
stream_obj = stream.get("stream")
1166-
if stream_obj:
1167-
stream_obj._fin_sent = True
1168-
stream_obj.closed = True
1169-
await self._client_enqueue_tx(1, stream_id, 0, b"", is_fin=True)
1170-
1171-
self.active_streams.pop(stream_id, None)
1172-
await self._clear_stream_from_queue(stream_id)
1173-
1174-
if stream_obj:
1175-
if (
1176-
hasattr(stream_obj, "io_task")
1177-
and stream_obj.io_task
1178-
and not stream_obj.io_task.done()
1179-
):
1180-
stream_obj.io_task.cancel()
1181-
try:
1182-
await asyncio.wait_for(stream_obj.io_task, timeout=0.1)
1183-
except Exception:
1184-
pass
1185-
try:
1186-
writer_tcp = stream_obj.writer
1187-
if (
1188-
writer_tcp
1189-
and hasattr(writer_tcp, "is_closing")
1190-
and not writer_tcp.is_closing()
1191-
):
1192-
writer_tcp.close()
1193-
await asyncio.wait_for(
1194-
writer_tcp.wait_closed(), timeout=3.0
1195-
)
1196-
except Exception:
1197-
pass
1198-
1199-
try:
1200-
local_writer = stream.get("writer")
1201-
await self._close_writer_safely(local_writer)
1202-
except Exception:
1203-
pass
1151+
await self.close_stream(stream_id, reason="Server sent FIN")
12041152

12051153
elif ptype == Packet_Type.ERROR_DROP:
12061154
self.logger.error(
@@ -1210,9 +1158,28 @@ async def _handle_server_response(self, header, data):
12101158
if self.session_restart_event:
12111159
self.session_restart_event.set()
12121160

1161+
async def close_stream(self, stream_id: int, reason: str = "Unknown") -> None:
1162+
"""Safely and fully close a specific local stream."""
1163+
if stream_id not in self.active_streams:
1164+
return
1165+
1166+
self.logger.info(f"<y>Closing Client Stream {stream_id}. Reason: {reason}</y>")
1167+
stream_data = self.active_streams.pop(stream_id)
1168+
1169+
await self._clear_stream_from_queue(stream_id)
1170+
1171+
stream_obj = stream_data.get("stream")
1172+
if stream_obj:
1173+
await stream_obj.close(reason=reason)
1174+
else:
1175+
await self._client_enqueue_tx(1, stream_id, 0, b"", is_fin=True)
1176+
1177+
writer = stream_data.get("writer")
1178+
await self._close_writer_safely(writer)
1179+
12131180
async def _retransmit_worker(self):
12141181
self.logger.debug("<magenta>[RETRANS]</magenta> Retransmit Worker started.")
1215-
while not self.should_stop.is_set():
1182+
while not self.should_stop.is_set() and not self.session_restart_event.is_set():
12161183
await asyncio.sleep(0.1)
12171184

12181185
dead_streams = [
@@ -1238,51 +1205,11 @@ async def _retransmit_worker(self):
12381205
s.get("status") == "PENDING"
12391206
and self.loop.time() - s.get("create_time", 0) > 30.0
12401207
):
1241-
self.logger.warning(f"Stream {sid} handshake timeout. Closing.")
1208+
reason = "Handshake timeout (No SYN_ACK from server)"
12421209
else:
1243-
self.logger.info(f"Stream {sid} closed locally. Notifying server.")
1210+
reason = "Closed locally or Inactivity Timeout"
12441211

1245-
try:
1246-
stream_obj = s.get("stream")
1247-
fin_already_sent = False
1248-
if stream_obj and getattr(stream_obj, "_fin_sent", False):
1249-
fin_already_sent = True
1250-
1251-
if not fin_already_sent:
1252-
self.ping_manager.update_activity()
1253-
target_conns = self.balancer.get_unique_servers(
1254-
self.packet_duplication
1255-
)
1256-
if target_conns:
1257-
for conn in target_conns:
1258-
query_packets = await self.dns_packet_parser.build_request_dns_query(
1259-
domain=conn["domain"],
1260-
session_id=self.session_id,
1261-
packet_type=Packet_Type.STREAM_FIN,
1262-
data=b"",
1263-
mtu_chars=self.synced_upload_mtu_chars,
1264-
encode_data=True,
1265-
qType=DNS_Record_Type.TXT,
1266-
stream_id=sid,
1267-
sequence_num=0,
1268-
)
1269-
if query_packets:
1270-
for qp in query_packets:
1271-
await async_sendto(
1272-
self.loop,
1273-
self.tunnel_sock,
1274-
qp,
1275-
(conn["resolver"], 53),
1276-
)
1277-
1278-
stream_data = self.active_streams.pop(sid, None)
1279-
await self._clear_stream_from_queue(sid)
1280-
if stream_data:
1281-
writer = stream_data.get("writer")
1282-
await self._close_writer_safely(writer)
1283-
1284-
except Exception as e:
1285-
self.logger.debug(f"Error handling dead stream {sid}: {e}")
1212+
await self.close_stream(sid, reason=reason)
12861213

12871214
for s in list(self.active_streams.values()):
12881215
arq = s.get("stream")

0 commit comments

Comments
 (0)