@@ -21,9 +21,11 @@ import (
2121 "masterdnsvpn-go/internal/compression"
2222 "masterdnsvpn-go/internal/config"
2323 dnsCache "masterdnsvpn-go/internal/dnscache"
24+ Enums "masterdnsvpn-go/internal/enums"
2425 fragmentStore "masterdnsvpn-go/internal/fragmentstore"
2526 "masterdnsvpn-go/internal/logger"
2627 "masterdnsvpn-go/internal/security"
28+ VpnProto "masterdnsvpn-go/internal/vpnproto"
2729)
2830
2931type Client struct {
@@ -76,6 +78,7 @@ type Client struct {
7678 stream0Runtime * stream0Runtime
7779 streamsMu sync.RWMutex
7880 streams map [uint16 ]* clientStream
81+ closedStreams map [uint16 ]int64
7982 mtuTestRetries int
8083 mtuTestTimeout time.Duration
8184 packetDuplicationCount int
@@ -104,6 +107,11 @@ type Client struct {
104107 sessionInitBusyUnix atomic.Int64
105108}
106109
110+ const (
111+ clientClosedStreamRecordTTL = 45 * time .Second
112+ clientClosedStreamRecordCap = 2000
113+ )
114+
107115type Connection struct {
108116 Domain string
109117 Resolver string
@@ -194,6 +202,7 @@ func New(cfg config.ClientConfig, log *logger.Logger, codec *security.Codec) *Cl
194202 ),
195203 localDNSFragTTL : time .Duration (cfg .LocalDNSFragmentTimeoutSec * float64 (time .Second )),
196204 streams : make (map [uint16 ]* clientStream , 16 ),
205+ closedStreams : make (map [uint16 ]int64 , 16 ),
197206 mtuTestRetries : cfg .MTUTestRetries ,
198207 mtuTestTimeout : time .Duration (cfg .MTUTestTimeout * float64 (time .Second )),
199208 packetDuplicationCount : cfg .PacketDuplicationCount ,
@@ -348,6 +357,7 @@ func (c *Client) ResetRuntimeState(resetSessionCookie bool) {
348357 c .closeAllStreams ()
349358 c .streamsMu .Lock ()
350359 c .streams = make (map [uint16 ]* clientStream , 16 )
360+ c .closedStreams = make (map [uint16 ]int64 , 16 )
351361 c .streamsMu .Unlock ()
352362 c .resolverHealthMu .Lock ()
353363 c .resolverHealth = make (map [string ]* resolverHealthState , len (c .connections ))
@@ -555,6 +565,7 @@ func (c *Client) deleteStream(streamID uint16) {
555565 c .streamsMu .Lock ()
556566 stream := c .streams [streamID ]
557567 delete (c .streams , streamID )
568+ c .noteClosedStreamLocked (streamID , time .Now ())
558569 c .streamsMu .Unlock ()
559570 if stream != nil && stream .Conn != nil {
560571 stream .stopOnce .Do (func () {
@@ -571,6 +582,7 @@ func (c *Client) closeAllStreams() {
571582 c .streamsMu .Lock ()
572583 streams := c .streams
573584 c .streams = make (map [uint16 ]* clientStream , 16 )
585+ c .closedStreams = make (map [uint16 ]int64 , 16 )
574586 c .streamsMu .Unlock ()
575587 for _ , stream := range streams {
576588 if stream == nil {
@@ -585,6 +597,120 @@ func (c *Client) closeAllStreams() {
585597 }
586598}
587599
600+ func (c * Client ) handleClosedStreamPacket (packet VpnProto.Packet , timeout time.Duration ) (VpnProto.Packet , bool , error ) {
601+ if c == nil || packet .StreamID == 0 || ! c .isRecentlyClosedStream (packet .StreamID , time .Now ()) {
602+ return VpnProto.Packet {}, false , nil
603+ }
604+
605+ responsePacket := VpnProto.Packet {
606+ StreamID : packet .StreamID ,
607+ HasStreamID : true ,
608+ SequenceNum : packet .SequenceNum ,
609+ HasSequenceNum : packet .SequenceNum != 0 ,
610+ }
611+
612+ outgoingType := uint8 (0 )
613+ switch packet .PacketType {
614+ case Enums .PACKET_STREAM_FIN :
615+ outgoingType = Enums .PACKET_STREAM_FIN_ACK
616+ responsePacket .PacketType = outgoingType
617+ case Enums .PACKET_STREAM_RST :
618+ outgoingType = Enums .PACKET_STREAM_RST_ACK
619+ responsePacket .PacketType = outgoingType
620+ case Enums .PACKET_STREAM_DATA , Enums .PACKET_STREAM_RESEND , Enums .PACKET_STREAM_DATA_ACK :
621+ outgoingType = Enums .PACKET_STREAM_RST
622+ responsePacket .PacketType = outgoingType
623+ responsePacket .SequenceNum = 0
624+ responsePacket .HasSequenceNum = false
625+ default :
626+ return VpnProto.Packet {}, false , nil
627+ }
628+
629+ _ = c .sendClosedStreamOneWayPacket (outgoingType , packet .StreamID , responsePacket .SequenceNum , timeout )
630+ return responsePacket , true , nil
631+ }
632+
633+ func (c * Client ) sendClosedStreamOneWayPacket (packetType uint8 , streamID uint16 , sequenceNum uint16 , timeout time.Duration ) error {
634+ if c == nil || ! c .SessionReady () {
635+ return nil
636+ }
637+
638+ connections , err := c .selectTargetConnectionsForPacket (packetType , streamID )
639+ if err != nil {
640+ return err
641+ }
642+ deadline := time .Now ().Add (normalizeTimeout (timeout , defaultRuntimeTimeout ))
643+ var firstErr error
644+ for _ , connection := range connections {
645+ query , buildErr := c .buildStreamQuery (connection .Domain , packetType , streamID , sequenceNum , 0 , 1 , nil )
646+ if buildErr != nil {
647+ if firstErr == nil {
648+ firstErr = buildErr
649+ }
650+ continue
651+ }
652+ if sendErr := c .sendOneWaySessionPacket (connection , query , deadline ); sendErr != nil && firstErr == nil {
653+ firstErr = sendErr
654+ }
655+ }
656+ return firstErr
657+ }
658+
659+ func (c * Client ) isRecentlyClosedStream (streamID uint16 , now time.Time ) bool {
660+ if c == nil || streamID == 0 {
661+ return false
662+ }
663+
664+ c .streamsMu .RLock ()
665+ closedAt , ok := c .closedStreams [streamID ]
666+ c .streamsMu .RUnlock ()
667+ if ! ok {
668+ return false
669+ }
670+ if now .UnixNano ()- closedAt <= clientClosedStreamRecordTTL .Nanoseconds () {
671+ return true
672+ }
673+
674+ c .streamsMu .Lock ()
675+ if staleAt , stale := c .closedStreams [streamID ]; stale && staleAt == closedAt && now .UnixNano ()- staleAt > clientClosedStreamRecordTTL .Nanoseconds () {
676+ delete (c .closedStreams , streamID )
677+ }
678+ c .streamsMu .Unlock ()
679+ return false
680+ }
681+
682+ func (c * Client ) noteClosedStreamLocked (streamID uint16 , now time.Time ) {
683+ if c == nil || streamID == 0 {
684+ return
685+ }
686+ if c .closedStreams == nil {
687+ c .closedStreams = make (map [uint16 ]int64 , 16 )
688+ }
689+ nowUnix := now .UnixNano ()
690+ expiredBefore := nowUnix - clientClosedStreamRecordTTL .Nanoseconds ()
691+ for closedID , closedAt := range c .closedStreams {
692+ if closedAt < expiredBefore {
693+ delete (c .closedStreams , closedID )
694+ }
695+ }
696+ c .closedStreams [streamID ] = nowUnix
697+ if len (c .closedStreams ) <= clientClosedStreamRecordCap {
698+ return
699+ }
700+
701+ var oldestID uint16
702+ var oldestAt int64
703+ first := true
704+ for closedID , closedAt := range c .closedStreams {
705+ if first || closedAt < oldestAt {
706+ oldestID = closedID
707+ oldestAt = closedAt
708+ first = false
709+ }
710+ }
711+ delete (c .closedStreams , oldestID )
712+ }
713+
588714func (c * Client ) activeStreamCount () int {
589715 if c == nil {
590716 return 0
0 commit comments