Skip to content

Commit f68e2fd

Browse files
committed
Add closed stream tracking for late packet handling
1 parent d90f474 commit f68e2fd

File tree

7 files changed

+383
-2
lines changed

7 files changed

+383
-2
lines changed

TODO

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
Check Clean for server and client.
2-
Client Reconnect.
3-
loop for reconnect/test mtu, and notify end of session.
2+
Config Recommendations

internal/client/client.go

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ import (
2121
"masterdnsvpn-go/internal/compression"
2222
"masterdnsvpn-go/internal/config"
2323
dnsCache "masterdnsvpn-go/internal/dnscache"
24+
Enums "masterdnsvpn-go/internal/enums"
2425
fragmentStore "masterdnsvpn-go/internal/fragmentstore"
2526
"masterdnsvpn-go/internal/logger"
2627
"masterdnsvpn-go/internal/security"
28+
VpnProto "masterdnsvpn-go/internal/vpnproto"
2729
)
2830

2931
type Client struct {
@@ -76,6 +78,7 @@ type Client struct {
7678
stream0Runtime *stream0Runtime
7779
streamsMu sync.RWMutex
7880
streams map[uint16]*clientStream
81+
closedStreams map[uint16]int64
7982
mtuTestRetries int
8083
mtuTestTimeout time.Duration
8184
packetDuplicationCount int
@@ -104,6 +107,11 @@ type Client struct {
104107
sessionInitBusyUnix atomic.Int64
105108
}
106109

110+
const (
111+
clientClosedStreamRecordTTL = 45 * time.Second
112+
clientClosedStreamRecordCap = 2000
113+
)
114+
107115
type Connection struct {
108116
Domain string
109117
Resolver string
@@ -194,6 +202,7 @@ func New(cfg config.ClientConfig, log *logger.Logger, codec *security.Codec) *Cl
194202
),
195203
localDNSFragTTL: time.Duration(cfg.LocalDNSFragmentTimeoutSec * float64(time.Second)),
196204
streams: make(map[uint16]*clientStream, 16),
205+
closedStreams: make(map[uint16]int64, 16),
197206
mtuTestRetries: cfg.MTUTestRetries,
198207
mtuTestTimeout: time.Duration(cfg.MTUTestTimeout * float64(time.Second)),
199208
packetDuplicationCount: cfg.PacketDuplicationCount,
@@ -348,6 +357,7 @@ func (c *Client) ResetRuntimeState(resetSessionCookie bool) {
348357
c.closeAllStreams()
349358
c.streamsMu.Lock()
350359
c.streams = make(map[uint16]*clientStream, 16)
360+
c.closedStreams = make(map[uint16]int64, 16)
351361
c.streamsMu.Unlock()
352362
c.resolverHealthMu.Lock()
353363
c.resolverHealth = make(map[string]*resolverHealthState, len(c.connections))
@@ -555,6 +565,7 @@ func (c *Client) deleteStream(streamID uint16) {
555565
c.streamsMu.Lock()
556566
stream := c.streams[streamID]
557567
delete(c.streams, streamID)
568+
c.noteClosedStreamLocked(streamID, time.Now())
558569
c.streamsMu.Unlock()
559570
if stream != nil && stream.Conn != nil {
560571
stream.stopOnce.Do(func() {
@@ -571,6 +582,7 @@ func (c *Client) closeAllStreams() {
571582
c.streamsMu.Lock()
572583
streams := c.streams
573584
c.streams = make(map[uint16]*clientStream, 16)
585+
c.closedStreams = make(map[uint16]int64, 16)
574586
c.streamsMu.Unlock()
575587
for _, stream := range streams {
576588
if stream == nil {
@@ -585,6 +597,120 @@ func (c *Client) closeAllStreams() {
585597
}
586598
}
587599

600+
func (c *Client) handleClosedStreamPacket(packet VpnProto.Packet, timeout time.Duration) (VpnProto.Packet, bool, error) {
601+
if c == nil || packet.StreamID == 0 || !c.isRecentlyClosedStream(packet.StreamID, time.Now()) {
602+
return VpnProto.Packet{}, false, nil
603+
}
604+
605+
responsePacket := VpnProto.Packet{
606+
StreamID: packet.StreamID,
607+
HasStreamID: true,
608+
SequenceNum: packet.SequenceNum,
609+
HasSequenceNum: packet.SequenceNum != 0,
610+
}
611+
612+
outgoingType := uint8(0)
613+
switch packet.PacketType {
614+
case Enums.PACKET_STREAM_FIN:
615+
outgoingType = Enums.PACKET_STREAM_FIN_ACK
616+
responsePacket.PacketType = outgoingType
617+
case Enums.PACKET_STREAM_RST:
618+
outgoingType = Enums.PACKET_STREAM_RST_ACK
619+
responsePacket.PacketType = outgoingType
620+
case Enums.PACKET_STREAM_DATA, Enums.PACKET_STREAM_RESEND, Enums.PACKET_STREAM_DATA_ACK:
621+
outgoingType = Enums.PACKET_STREAM_RST
622+
responsePacket.PacketType = outgoingType
623+
responsePacket.SequenceNum = 0
624+
responsePacket.HasSequenceNum = false
625+
default:
626+
return VpnProto.Packet{}, false, nil
627+
}
628+
629+
_ = c.sendClosedStreamOneWayPacket(outgoingType, packet.StreamID, responsePacket.SequenceNum, timeout)
630+
return responsePacket, true, nil
631+
}
632+
633+
func (c *Client) sendClosedStreamOneWayPacket(packetType uint8, streamID uint16, sequenceNum uint16, timeout time.Duration) error {
634+
if c == nil || !c.SessionReady() {
635+
return nil
636+
}
637+
638+
connections, err := c.selectTargetConnectionsForPacket(packetType, streamID)
639+
if err != nil {
640+
return err
641+
}
642+
deadline := time.Now().Add(normalizeTimeout(timeout, defaultRuntimeTimeout))
643+
var firstErr error
644+
for _, connection := range connections {
645+
query, buildErr := c.buildStreamQuery(connection.Domain, packetType, streamID, sequenceNum, 0, 1, nil)
646+
if buildErr != nil {
647+
if firstErr == nil {
648+
firstErr = buildErr
649+
}
650+
continue
651+
}
652+
if sendErr := c.sendOneWaySessionPacket(connection, query, deadline); sendErr != nil && firstErr == nil {
653+
firstErr = sendErr
654+
}
655+
}
656+
return firstErr
657+
}
658+
659+
func (c *Client) isRecentlyClosedStream(streamID uint16, now time.Time) bool {
660+
if c == nil || streamID == 0 {
661+
return false
662+
}
663+
664+
c.streamsMu.RLock()
665+
closedAt, ok := c.closedStreams[streamID]
666+
c.streamsMu.RUnlock()
667+
if !ok {
668+
return false
669+
}
670+
if now.UnixNano()-closedAt <= clientClosedStreamRecordTTL.Nanoseconds() {
671+
return true
672+
}
673+
674+
c.streamsMu.Lock()
675+
if staleAt, stale := c.closedStreams[streamID]; stale && staleAt == closedAt && now.UnixNano()-staleAt > clientClosedStreamRecordTTL.Nanoseconds() {
676+
delete(c.closedStreams, streamID)
677+
}
678+
c.streamsMu.Unlock()
679+
return false
680+
}
681+
682+
func (c *Client) noteClosedStreamLocked(streamID uint16, now time.Time) {
683+
if c == nil || streamID == 0 {
684+
return
685+
}
686+
if c.closedStreams == nil {
687+
c.closedStreams = make(map[uint16]int64, 16)
688+
}
689+
nowUnix := now.UnixNano()
690+
expiredBefore := nowUnix - clientClosedStreamRecordTTL.Nanoseconds()
691+
for closedID, closedAt := range c.closedStreams {
692+
if closedAt < expiredBefore {
693+
delete(c.closedStreams, closedID)
694+
}
695+
}
696+
c.closedStreams[streamID] = nowUnix
697+
if len(c.closedStreams) <= clientClosedStreamRecordCap {
698+
return
699+
}
700+
701+
var oldestID uint16
702+
var oldestAt int64
703+
first := true
704+
for closedID, closedAt := range c.closedStreams {
705+
if first || closedAt < oldestAt {
706+
oldestID = closedID
707+
oldestAt = closedAt
708+
first = false
709+
}
710+
}
711+
delete(c.closedStreams, oldestID)
712+
}
713+
588714
func (c *Client) activeStreamCount() int {
589715
if c == nil {
590716
return 0

internal/client/client_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,6 +1991,79 @@ func TestSelectTargetConnectionsForPacketUsesSetupDuplicationCount(t *testing.T)
19911991
}
19921992
}
19931993

1994+
func TestDeleteStreamTracksClosedStreamRecord(t *testing.T) {
1995+
c := New(config.ClientConfig{}, nil, nil)
1996+
stream := c.createStream(31, nil)
1997+
1998+
c.deleteStream(stream.ID)
1999+
2000+
if !c.isRecentlyClosedStream(stream.ID, time.Now()) {
2001+
t.Fatal("expected deleted stream to be tracked as recently closed")
2002+
}
2003+
}
2004+
2005+
func TestHandleClosedStreamPacketSendsOneWayResetForLateData(t *testing.T) {
2006+
codec, err := security.NewCodec(0, "")
2007+
if err != nil {
2008+
t.Fatalf("NewCodec returned error: %v", err)
2009+
}
2010+
2011+
c := New(config.ClientConfig{
2012+
PacketDuplicationCount: 1,
2013+
Domains: []string{"v.example.com"},
2014+
}, nil, codec)
2015+
c.connections = []Connection{{
2016+
Domain: "v.example.com",
2017+
Resolver: "127.0.0.1",
2018+
ResolverPort: 5353,
2019+
ResolverLabel: "127.0.0.1:5353",
2020+
Key: "127.0.0.1|5353|v.example.com",
2021+
IsValid: true,
2022+
}}
2023+
c.connectionsByKey = map[string]int{c.connections[0].Key: 0}
2024+
c.rebuildBalancer()
2025+
c.sessionID = 7
2026+
c.sessionCookie = 9
2027+
c.sessionReady = true
2028+
2029+
stream := c.createStream(41, nil)
2030+
c.deleteStream(stream.ID)
2031+
2032+
var captured VpnProto.Packet
2033+
c.sendOneWayPacketFn = func(conn Connection, packet []byte, deadline time.Time) error {
2034+
parsed, err := DnsParser.ParsePacketLite(packet)
2035+
if err != nil {
2036+
t.Fatalf("ParsePacketLite returned error: %v", err)
2037+
}
2038+
if !parsed.HasQuestion {
2039+
t.Fatal("expected one-way stream query question")
2040+
}
2041+
captured, err = VpnProto.ParseFromLabels(extractTestTunnelLabels(parsed.FirstQuestion.Name, conn.Domain), c.codec)
2042+
if err != nil {
2043+
t.Fatalf("ParseFromLabels returned error: %v", err)
2044+
}
2045+
return nil
2046+
}
2047+
2048+
response, handled, err := c.handleClosedStreamPacket(VpnProto.Packet{
2049+
PacketType: Enums.PACKET_STREAM_DATA,
2050+
StreamID: 41,
2051+
SequenceNum: 77,
2052+
}, time.Second)
2053+
if err != nil {
2054+
t.Fatalf("handleClosedStreamPacket returned error: %v", err)
2055+
}
2056+
if !handled {
2057+
t.Fatal("expected closed stream packet to be handled")
2058+
}
2059+
if response.PacketType != Enums.PACKET_STREAM_RST || response.StreamID != 41 || response.SequenceNum != 0 {
2060+
t.Fatalf("unexpected synthetic response: %+v", response)
2061+
}
2062+
if captured.PacketType != Enums.PACKET_STREAM_RST || captured.StreamID != 41 || captured.SequenceNum != 0 {
2063+
t.Fatalf("unexpected one-way reset packet: %+v", captured)
2064+
}
2065+
}
2066+
19942067
func extractTestTunnelLabels(qName string, baseDomain string) string {
19952068
suffix := "." + baseDomain
19962069
if len(qName) <= len(suffix) || qName[len(qName)-len(suffix):] != suffix {

internal/client/stream_runtime.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ func (c *Client) handlePackedServerControlBlocks(payload []byte, timeout time.Du
201201
func (c *Client) handleInboundStreamPacket(packet VpnProto.Packet, timeout time.Duration) (VpnProto.Packet, error) {
202202
stream, ok := c.getStream(packet.StreamID)
203203
if !ok || stream == nil {
204+
if closedResponse, handled, err := c.handleClosedStreamPacket(packet, timeout); handled {
205+
return closedResponse, err
206+
}
204207
return c.exchangeStreamControlPacket(Enums.PACKET_STREAM_RST, packet.StreamID, packet.SequenceNum, nil, timeout)
205208
}
206209

internal/udpserver/server.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,10 @@ func (s *Server) handleTunnelCandidate(packet []byte, parsed DnsParser.LitePacke
423423
}
424424

425425
func (s *Server) handlePostSessionPacket(decision domainMatcher.Decision, vpnPacket VpnProto.Packet, sessionRecord *sessionRuntimeView) bool {
426+
if handled := s.handleClosedStreamPacket(vpnPacket); handled {
427+
return true
428+
}
429+
426430
switch vpnPacket.PacketType {
427431
case Enums.PACKET_PACKED_CONTROL_BLOCKS:
428432
return s.handlePackedControlBlocksRequest(vpnPacket, sessionRecord)
@@ -449,6 +453,35 @@ func (s *Server) handlePostSessionPacket(decision domainMatcher.Decision, vpnPac
449453
}
450454
}
451455

456+
func (s *Server) handleClosedStreamPacket(vpnPacket VpnProto.Packet) bool {
457+
if s == nil || vpnPacket.StreamID == 0 || !isClosedStreamAwarePacketType(vpnPacket.PacketType) {
458+
return false
459+
}
460+
response, handled := s.streams.HandleClosedPacket(vpnPacket.SessionID, vpnPacket.StreamID, vpnPacket.PacketType, vpnPacket.SequenceNum, time.Now())
461+
if !handled {
462+
return false
463+
}
464+
if response.PacketType != 0 {
465+
_ = s.queueSessionPacket(vpnPacket.SessionID, response)
466+
}
467+
return true
468+
}
469+
470+
func isClosedStreamAwarePacketType(packetType uint8) bool {
471+
switch packetType {
472+
case Enums.PACKET_STREAM_SYN,
473+
Enums.PACKET_SOCKS5_SYN,
474+
Enums.PACKET_STREAM_DATA,
475+
Enums.PACKET_STREAM_RESEND,
476+
Enums.PACKET_STREAM_DATA_ACK,
477+
Enums.PACKET_STREAM_FIN,
478+
Enums.PACKET_STREAM_RST:
479+
return true
480+
default:
481+
return false
482+
}
483+
}
484+
452485
func (s *Server) validatePostSessionPacket(questionPacket []byte, requestName string, vpnPacket VpnProto.Packet) postSessionValidation {
453486
now := time.Now()
454487
validation := s.sessions.ValidateAndTouch(vpnPacket.SessionID, vpnPacket.SessionCookie, now)

internal/udpserver/server_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,54 @@ func TestHandleStreamSynConnectsForwardTargetForTCPMode(t *testing.T) {
380380
}
381381
}
382382

383+
func TestHandlePacketReturnsResetForLateClosedStreamData(t *testing.T) {
384+
codec, err := security.NewCodec(0, "")
385+
if err != nil {
386+
t.Fatalf("NewCodec returned error: %v", err)
387+
}
388+
389+
srv := New(config.ServerConfig{
390+
MaxPacketSize: 65535,
391+
Domain: []string{"a.com"},
392+
MinVPNLabelLength: 3,
393+
}, nil, codec)
394+
395+
verifyCode := []byte{1, 2, 3, 4}
396+
initPayload := []byte{0, 0x00, 0x00, 0x96, 0x00, 0xC8, verifyCode[0], verifyCode[1], verifyCode[2], verifyCode[3]}
397+
initResponse := srv.handlePacket(buildTunnelQueryWithSessionID(t, codec, "a.com", 0, Enums.PACKET_SESSION_INIT, initPayload))
398+
packet, err := DnsParser.ExtractVPNResponse(initResponse, false)
399+
if err != nil {
400+
t.Fatalf("ExtractVPNResponse returned error: %v", err)
401+
}
402+
403+
sessionID := packet.Payload[0]
404+
sessionCookie := packet.Payload[1]
405+
now := time.Now()
406+
if _, created := srv.streams.EnsureOpen(sessionID, 9, now); !created {
407+
t.Fatal("expected fresh stream state")
408+
}
409+
if !srv.streams.MarkReset(sessionID, 9, 5, now) {
410+
t.Fatal("expected stream reset to succeed")
411+
}
412+
413+
query := buildTunnelStreamQuery(t, codec, "a.com", sessionID, sessionCookie, Enums.PACKET_STREAM_DATA, 9, 77, []byte("late"))
414+
response := srv.handlePacket(query)
415+
if len(response) == 0 {
416+
t.Fatal("expected late closed stream packet to get a response")
417+
}
418+
419+
vpnResponse, err := DnsParser.ExtractVPNResponse(response, false)
420+
if err != nil {
421+
t.Fatalf("ExtractVPNResponse returned error: %v", err)
422+
}
423+
if vpnResponse.PacketType != Enums.PACKET_STREAM_RST {
424+
t.Fatalf("unexpected packet type: got=%d want=%d", vpnResponse.PacketType, Enums.PACKET_STREAM_RST)
425+
}
426+
if vpnResponse.StreamID != 9 || vpnResponse.SequenceNum != 0 {
427+
t.Fatalf("unexpected stream reset routing: stream=%d seq=%d", vpnResponse.StreamID, vpnResponse.SequenceNum)
428+
}
429+
}
430+
383431
func TestHandlePacketRejectsMalformedSessionInit(t *testing.T) {
384432
codec, err := security.NewCodec(0, "")
385433
if err != nil {

0 commit comments

Comments
 (0)