Skip to content

Commit 0bdec06

Browse files
committed
ping: Add timeout to destinations
1 parent 8f6cc9f commit 0bdec06

File tree

4 files changed

+50
-4
lines changed

4 files changed

+50
-4
lines changed

ping/destination.go

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"net/netip"
77
"os"
88
"runtime"
9+
"time"
910

1011
"github.com/sagernet/sing-tun"
1112
"github.com/sagernet/sing/common/buf"
@@ -17,13 +18,21 @@ import (
1718
var _ tun.DirectRouteDestination = (*Destination)(nil)
1819

1920
type Destination struct {
21+
conn *Conn
2022
ctx context.Context
2123
logger logger.ContextLogger
2224
routeContext tun.DirectRouteContext
23-
conn *Conn
25+
timeout time.Duration
2426
}
2527

26-
func ConnectDestination(ctx context.Context, logger logger.ContextLogger, controlFunc control.Func, address netip.Addr, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) {
28+
func ConnectDestination(
29+
ctx context.Context,
30+
logger logger.ContextLogger,
31+
controlFunc control.Func,
32+
address netip.Addr,
33+
routeContext tun.DirectRouteContext,
34+
timeout time.Duration,
35+
) (tun.DirectRouteDestination, error) {
2736
var (
2837
conn *Conn
2938
err error
@@ -41,19 +50,25 @@ func ConnectDestination(ctx context.Context, logger logger.ContextLogger, contro
4150
return nil, err
4251
}
4352
d := &Destination{
53+
conn: conn,
4454
ctx: ctx,
4555
logger: logger,
4656
routeContext: routeContext,
47-
conn: conn,
57+
timeout: timeout,
4858
}
4959
go d.loopRead()
5060
return d, nil
5161
}
5262

5363
func (d *Destination) loopRead() {
64+
defer d.Close()
5465
for {
5566
buffer := buf.NewPacket()
56-
err := d.conn.ReadIP(buffer)
67+
err := d.conn.SetReadDeadline(time.Now().Add(d.timeout))
68+
if err != nil {
69+
d.logger.ErrorContext(d.ctx, E.Cause(err, "set read deadline for ICMP conn"))
70+
}
71+
err = d.conn.ReadIP(buffer)
5772
if err != nil {
5873
buffer.Release()
5974
if !E.IsClosed(err) {
@@ -76,3 +91,11 @@ func (d *Destination) WritePacket(packet *buf.Buffer) error {
7691
func (d *Destination) Close() error {
7792
return d.conn.Close()
7893
}
94+
95+
func (d *Destination) IsClosed() bool {
96+
_, err := d.conn.conn.Write([]byte{})
97+
if err != nil {
98+
return false
99+
}
100+
return true
101+
}

ping/destination_gvisor.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ package ping
55
import (
66
"context"
77
"net/netip"
8+
"time"
89

910
"github.com/sagernet/gvisor/pkg/tcpip"
1011
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
1112
"github.com/sagernet/gvisor/pkg/tcpip/header"
1213
"github.com/sagernet/gvisor/pkg/tcpip/stack"
14+
"github.com/sagernet/gvisor/pkg/tcpip/transport"
1315
"github.com/sagernet/gvisor/pkg/waiter"
1416
"github.com/sagernet/sing-tun"
1517
"github.com/sagernet/sing/common"
@@ -23,8 +25,10 @@ var _ tun.DirectRouteDestination = (*GVisorDestination)(nil)
2325
type GVisorDestination struct {
2426
ctx context.Context
2527
logger logger.ContextLogger
28+
endpoint tcpip.Endpoint
2629
conn *gonet.TCPConn
2730
rewriter *Rewriter
31+
timeout time.Duration
2832
}
2933

3034
func ConnectGVisor(
@@ -33,6 +37,7 @@ func ConnectGVisor(
3337
routeContext tun.DirectRouteContext,
3438
stack *stack.Stack,
3539
bindAddress4, bindAddress6 netip.Addr,
40+
timeout time.Duration,
3641
) (*GVisorDestination, error) {
3742
var (
3843
bindAddress tcpip.Address
@@ -76,16 +81,23 @@ func ConnectGVisor(
7681
destination := &GVisorDestination{
7782
ctx: ctx,
7883
logger: logger,
84+
endpoint: endpoint,
7985
conn: gonet.NewTCPConn(&wq, endpoint),
8086
rewriter: rewriter,
87+
timeout: timeout,
8188
}
8289
go destination.loopRead()
8390
return destination, nil
8491
}
8592

8693
func (d *GVisorDestination) loopRead() {
94+
defer d.endpoint.Close()
8795
for {
8896
buffer := buf.NewPacket()
97+
err := d.conn.SetReadDeadline(time.Now().Add(d.timeout))
98+
if err != nil {
99+
d.logger.ErrorContext(d.ctx, E.Cause(err, "set read deadline for ICMP conn"))
100+
}
89101
n, err := d.conn.Read(buffer.FreeBytes())
90102
if err != nil {
91103
buffer.Release()
@@ -111,3 +123,7 @@ func (d *GVisorDestination) WritePacket(packet *buf.Buffer) error {
111123
func (d *GVisorDestination) Close() error {
112124
return d.conn.Close()
113125
}
126+
127+
func (d *GVisorDestination) IsClosed() bool {
128+
return transport.DatagramEndpointState(d.endpoint.State()) == transport.DatagramEndpointStateClosed
129+
}

ping/socket_linux_unprivileged.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr neti
7474
}
7575

7676
func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
77+
if len(b) == 0 {
78+
return
79+
}
7780
conn, err := connect(false, c.controlFunc, c.destination)
7881
if err != nil {
7982
return

route_direct.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
type DirectRouteDestination interface {
1414
WritePacket(packet *buf.Buffer) error
1515
Close() error
16+
IsClosed() bool
1617
}
1718

1819
type DirectRouteSession struct {
@@ -28,6 +29,9 @@ type DirectRouteMapping struct {
2829

2930
func NewDirectRouteMapping(timeout time.Duration) *DirectRouteMapping {
3031
mapping := common.Must1(freelru.NewSharded[DirectRouteSession, DirectRouteDestination](1024, maphash.NewHasher[DirectRouteSession]().Hash32))
32+
mapping.SetHealthCheck(func(session DirectRouteSession, destination DirectRouteDestination) bool {
33+
return !destination.IsClosed()
34+
})
3135
mapping.SetOnEvict(func(session DirectRouteSession, action DirectRouteDestination) {
3236
action.Close()
3337
})

0 commit comments

Comments
 (0)