Skip to content

Commit d0b953b

Browse files
committed
Add best-effort session close notifications on client shutdown
1 parent 48d0aa3 commit d0b953b

File tree

10 files changed

+397
-6
lines changed

10 files changed

+397
-6
lines changed

cmd/client/main.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"strings"
1616
"sync"
1717
"syscall"
18+
"time"
1819

1920
"masterdnsvpn-go/internal/client"
2021
"masterdnsvpn-go/internal/logger"
@@ -138,12 +139,24 @@ func main() {
138139

139140
log.Infof("\U0001F3AF <green>Client Bootstrap Ready</green>")
140141

142+
var sessionCloseOnce sync.Once
143+
notifySessionClose := func() {
144+
sessionCloseOnce.Do(func() {
145+
app.BestEffortSessionClose(time.Second)
146+
})
147+
}
148+
141149
if !cfg.LocalDNSEnabled && !cfg.LocalSOCKS5Enabled && cfg.ProtocolType != "TCP" {
150+
notifySessionClose()
142151
return
143152
}
144153

145154
runCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
146155
defer stop()
156+
go func() {
157+
<-runCtx.Done()
158+
notifySessionClose()
159+
}()
147160

148161
enabledListeners := enabledClientListenerCount(cfg.LocalDNSEnabled, cfg.LocalSOCKS5Enabled, cfg.ProtocolType)
149162
errCh := make(chan error, enabledListeners)
@@ -162,6 +175,7 @@ func main() {
162175
}
163176

164177
listenersWG.Wait()
178+
notifySessionClose()
165179
select {
166180
case err := <-errCh:
167181
exitWithStderrf("%v\n", err)

internal/client/client.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ type Client struct {
7070
mtuOutputMu sync.Mutex
7171

7272
exchangeQueryFn func(Connection, []byte, time.Duration) ([]byte, error)
73+
sendOneWayPacketFn func(Connection, []byte, time.Time) error
7374
fragmentLimits sync.Map
7475
stream0Runtime *stream0Runtime
7576
streamsMu sync.RWMutex

internal/client/session_close.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// ==============================================================================
2+
// MasterDnsVPN
3+
// Author: MasterkinG32
4+
// Github: https://github.com/masterking32
5+
// Year: 2026
6+
// ==============================================================================
7+
8+
package client
9+
10+
import (
11+
"sync"
12+
"time"
13+
14+
Enums "masterdnsvpn-go/internal/enums"
15+
VpnProto "masterdnsvpn-go/internal/vpnproto"
16+
)
17+
18+
const (
19+
sessionCloseFanoutLimit = 10
20+
sessionCloseDefaultWindow = time.Second
21+
)
22+
23+
func (c *Client) BestEffortSessionClose(timeout time.Duration) {
24+
if c == nil || !c.sessionReady || c.sessionID == 0 {
25+
return
26+
}
27+
28+
targets := c.activeSessionCloseTargets(sessionCloseFanoutLimit)
29+
if len(targets) == 0 {
30+
return
31+
}
32+
33+
timeout = normalizeTimeout(timeout, sessionCloseDefaultWindow)
34+
deadline := time.Now().Add(timeout)
35+
queries := make(map[string][]byte, len(targets))
36+
37+
var wg sync.WaitGroup
38+
done := make(chan struct{})
39+
for _, conn := range targets {
40+
query, ok := queries[conn.Domain]
41+
if !ok {
42+
built, err := c.buildSessionCloseQuery(conn.Domain)
43+
if err != nil {
44+
continue
45+
}
46+
query = built
47+
queries[conn.Domain] = query
48+
}
49+
50+
connCopy := conn
51+
packetCopy := query
52+
wg.Go(func() {
53+
_ = c.sendOneWaySessionPacket(connCopy, packetCopy, deadline)
54+
})
55+
}
56+
57+
go func() {
58+
wg.Wait()
59+
close(done)
60+
}()
61+
62+
timer := time.NewTimer(timeout)
63+
defer timer.Stop()
64+
65+
select {
66+
case <-done:
67+
case <-timer.C:
68+
}
69+
}
70+
71+
func (c *Client) activeSessionCloseTargets(limit int) []Connection {
72+
if c == nil || limit <= 0 {
73+
return nil
74+
}
75+
76+
limit = min(limit, sessionCloseFanoutLimit)
77+
seen := make(map[string]struct{}, limit)
78+
targets := make([]Connection, 0, limit)
79+
candidateCount := min(len(c.connections), max(limit, limit*max(1, len(c.cfg.Domains))))
80+
81+
for _, conn := range c.balancer.GetUniqueConnections(candidateCount) {
82+
if !conn.IsValid || conn.ResolverLabel == "" {
83+
continue
84+
}
85+
if _, ok := seen[conn.ResolverLabel]; ok {
86+
continue
87+
}
88+
seen[conn.ResolverLabel] = struct{}{}
89+
targets = append(targets, conn)
90+
if len(targets) >= limit {
91+
return targets
92+
}
93+
}
94+
95+
for _, conn := range c.connections {
96+
if !conn.IsValid || conn.ResolverLabel == "" {
97+
continue
98+
}
99+
if _, ok := seen[conn.ResolverLabel]; ok {
100+
continue
101+
}
102+
seen[conn.ResolverLabel] = struct{}{}
103+
targets = append(targets, conn)
104+
if len(targets) >= limit {
105+
break
106+
}
107+
}
108+
109+
return targets
110+
}
111+
112+
func (c *Client) buildSessionCloseQuery(domain string) ([]byte, error) {
113+
return c.buildTunnelTXTQueryRaw(domain, VpnProto.BuildOptions{
114+
SessionID: c.sessionID,
115+
PacketType: Enums.PACKET_SESSION_CLOSE,
116+
SessionCookie: c.sessionCookie,
117+
})
118+
}
119+
120+
func (c *Client) sendOneWaySessionPacket(connection Connection, packet []byte, deadline time.Time) error {
121+
if c != nil && c.sendOneWayPacketFn != nil {
122+
return c.sendOneWayPacketFn(connection, packet, deadline)
123+
}
124+
return sendOneWayUDPQuery(connection.ResolverLabel, packet, deadline)
125+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// ==============================================================================
2+
// MasterDnsVPN
3+
// Author: MasterkinG32
4+
// Github: https://github.com/masterking32
5+
// Year: 2026
6+
// ==============================================================================
7+
8+
package client
9+
10+
import (
11+
"strconv"
12+
"sync"
13+
"testing"
14+
"time"
15+
16+
"masterdnsvpn-go/internal/config"
17+
"masterdnsvpn-go/internal/security"
18+
)
19+
20+
func TestBestEffortSessionCloseUsesUpToTenUniqueResolvers(t *testing.T) {
21+
resolvers := make([]config.ResolverAddress, 0, 12)
22+
for idx := 1; idx <= 12; idx++ {
23+
resolvers = append(resolvers, config.ResolverAddress{
24+
IP: "10.0.0." + strconv.Itoa(idx),
25+
Port: 53,
26+
})
27+
}
28+
29+
codec, err := security.NewCodec(0, "")
30+
if err != nil {
31+
t.Fatalf("NewCodec returned error: %v", err)
32+
}
33+
34+
c := New(config.ClientConfig{
35+
Domains: []string{
36+
"a.example.com",
37+
"b.example.com",
38+
},
39+
Resolvers: resolvers,
40+
}, nil, codec)
41+
c.BuildConnectionMap()
42+
c.sessionReady = true
43+
c.sessionID = 7
44+
c.sessionCookie = 9
45+
46+
var (
47+
mu sync.Mutex
48+
targets = make(map[string]int)
49+
)
50+
c.sendOneWayPacketFn = func(conn Connection, packet []byte, deadline time.Time) error {
51+
if conn.ResolverLabel == "" {
52+
t.Fatal("resolver label must not be empty")
53+
}
54+
if len(packet) == 0 {
55+
t.Fatal("session close packet must not be empty")
56+
}
57+
if deadline.IsZero() {
58+
t.Fatal("session close deadline must be set")
59+
}
60+
61+
mu.Lock()
62+
targets[conn.ResolverLabel]++
63+
mu.Unlock()
64+
return nil
65+
}
66+
67+
c.BestEffortSessionClose(50 * time.Millisecond)
68+
69+
if got := len(targets); got != 10 {
70+
t.Fatalf("unexpected unique resolver fanout: got=%d want=10", got)
71+
}
72+
for resolverLabel, count := range targets {
73+
if count != 1 {
74+
t.Fatalf("resolver %s received duplicate shutdown notifications: %d", resolverLabel, count)
75+
}
76+
}
77+
}
78+
79+
func TestBestEffortSessionCloseSkipsWithoutEstablishedSession(t *testing.T) {
80+
codec, err := security.NewCodec(0, "")
81+
if err != nil {
82+
t.Fatalf("NewCodec returned error: %v", err)
83+
}
84+
85+
c := New(config.ClientConfig{
86+
Domains: []string{"a.example.com"},
87+
Resolvers: []config.ResolverAddress{
88+
{IP: "8.8.8.8", Port: 53},
89+
},
90+
}, nil, codec)
91+
c.BuildConnectionMap()
92+
93+
var calls int
94+
c.sendOneWayPacketFn = func(Connection, []byte, time.Time) error {
95+
calls++
96+
return nil
97+
}
98+
99+
c.BestEffortSessionClose(20 * time.Millisecond)
100+
101+
if calls != 0 {
102+
t.Fatalf("expected no shutdown packets without an established session, got=%d", calls)
103+
}
104+
}

internal/client/tunnel_runtime.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,13 +313,16 @@ type udpQueryTransport struct {
313313
buffer []byte
314314
}
315315

316-
func newUDPQueryTransport(resolverLabel string) (*udpQueryTransport, error) {
316+
func dialUDPResolver(resolverLabel string) (*net.UDPConn, error) {
317317
addr, err := net.ResolveUDPAddr("udp", resolverLabel)
318318
if err != nil {
319319
return nil, err
320320
}
321+
return net.DialUDP("udp", nil, addr)
322+
}
321323

322-
conn, err := net.DialUDP("udp", nil, addr)
324+
func newUDPQueryTransport(resolverLabel string) (*udpQueryTransport, error) {
325+
conn, err := dialUDPResolver(resolverLabel)
323326
if err != nil {
324327
return nil, err
325328
}
@@ -329,6 +332,24 @@ func newUDPQueryTransport(resolverLabel string) (*udpQueryTransport, error) {
329332
}, nil
330333
}
331334

335+
func sendOneWayUDPQuery(resolverLabel string, packet []byte, deadline time.Time) error {
336+
if len(packet) == 0 {
337+
return nil
338+
}
339+
340+
conn, err := dialUDPResolver(resolverLabel)
341+
if err != nil {
342+
return err
343+
}
344+
defer conn.Close()
345+
346+
if err := conn.SetWriteDeadline(deadline); err != nil {
347+
return err
348+
}
349+
_, err = conn.Write(packet)
350+
return err
351+
}
352+
332353
func exchangeUDPQuery(transport *udpQueryTransport, packet []byte, timeout time.Duration) ([]byte, error) {
333354
if transport == nil || transport.conn == nil {
334355
return nil, net.ErrClosed

internal/enums/dns.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ const (
5252
PACKET_DNS_QUERY_RES = 0x2A
5353
PACKET_DNS_QUERY_REQ_ACK = 0x2B
5454
PACKET_DNS_QUERY_RES_ACK = 0x2C
55+
PACKET_SESSION_CLOSE = 0x2D
5556
PACKET_ERROR_DROP = 0xFF
5657
)
5758

internal/enums/dns_names.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ func PacketTypeName(packetType uint8) string {
112112
return "DNS_QUERY_REQ_ACK"
113113
case PACKET_DNS_QUERY_RES_ACK:
114114
return "DNS_QUERY_RES_ACK"
115+
case PACKET_SESSION_CLOSE:
116+
return "SESSION_CLOSE"
115117
case PACKET_ERROR_DROP:
116118
return "ERROR_DROP"
117119
default:

0 commit comments

Comments
 (0)