Skip to content

Commit 86d9606

Browse files
committed
ping: Add gVisor destination
1 parent 12c9fb6 commit 86d9606

File tree

9 files changed

+332
-221
lines changed

9 files changed

+332
-221
lines changed

ping/destination_gvisor.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
//go:build with_gvisor
2+
3+
package ping
4+
5+
import (
6+
"context"
7+
"net/netip"
8+
9+
"github.com/sagernet/gvisor/pkg/tcpip"
10+
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
11+
"github.com/sagernet/gvisor/pkg/tcpip/header"
12+
"github.com/sagernet/gvisor/pkg/tcpip/stack"
13+
"github.com/sagernet/gvisor/pkg/waiter"
14+
"github.com/sagernet/sing-tun"
15+
"github.com/sagernet/sing/common"
16+
"github.com/sagernet/sing/common/buf"
17+
E "github.com/sagernet/sing/common/exceptions"
18+
"github.com/sagernet/sing/common/logger"
19+
)
20+
21+
var _ tun.DirectRouteDestination = (*GVisorDestination)(nil)
22+
23+
type GVisorDestination struct {
24+
ctx context.Context
25+
logger logger.ContextLogger
26+
conn *gonet.TCPConn
27+
rewriter *Rewriter
28+
}
29+
30+
func ConnectGVisor(
31+
ctx context.Context, logger logger.ContextLogger,
32+
sourceAddress, destinationAddress netip.Addr,
33+
routeContext tun.DirectRouteContext,
34+
stack *stack.Stack,
35+
bindAddress4, bindAddress6 netip.Addr,
36+
) (*GVisorDestination, error) {
37+
var (
38+
bindAddress tcpip.Address
39+
wq waiter.Queue
40+
endpoint tcpip.Endpoint
41+
gErr tcpip.Error
42+
)
43+
if !destinationAddress.Is6() {
44+
if !bindAddress4.IsValid() {
45+
return nil, E.New("missing IPv4 interface address")
46+
}
47+
bindAddress = tun.AddressFromAddr(bindAddress4)
48+
endpoint, gErr = stack.NewRawEndpoint(header.ICMPv4ProtocolNumber, header.IPv4ProtocolNumber, &wq, true)
49+
} else {
50+
if !bindAddress6.IsValid() {
51+
return nil, E.New("missing IPv6 interface address")
52+
}
53+
bindAddress = tun.AddressFromAddr(bindAddress6)
54+
endpoint, gErr = stack.NewRawEndpoint(header.ICMPv6ProtocolNumber, header.IPv6ProtocolNumber, &wq, true)
55+
}
56+
if gErr != nil {
57+
return nil, gonet.TranslateNetstackError(gErr)
58+
}
59+
gErr = endpoint.Bind(tcpip.FullAddress{
60+
NIC: 1,
61+
Addr: bindAddress,
62+
})
63+
if gErr != nil {
64+
return nil, gonet.TranslateNetstackError(gErr)
65+
}
66+
gErr = endpoint.Connect(tcpip.FullAddress{
67+
NIC: 1,
68+
Addr: tun.AddressFromAddr(destinationAddress),
69+
})
70+
if gErr != nil {
71+
return nil, gonet.TranslateNetstackError(gErr)
72+
}
73+
endpoint.SocketOptions().SetHeaderIncluded(true)
74+
rewriter := NewRewriter(bindAddress4, bindAddress6)
75+
rewriter.CreateSession(tun.DirectRouteSession{Source: sourceAddress, Destination: destinationAddress}, routeContext)
76+
destination := &GVisorDestination{
77+
ctx: ctx,
78+
logger: logger,
79+
conn: gonet.NewTCPConn(&wq, endpoint),
80+
rewriter: rewriter,
81+
}
82+
go destination.loopRead()
83+
return destination, nil
84+
}
85+
86+
func (d *GVisorDestination) loopRead() {
87+
for {
88+
buffer := buf.NewPacket()
89+
n, err := d.conn.Read(buffer.FreeBytes())
90+
if err != nil {
91+
buffer.Release()
92+
if !E.IsClosed(err) {
93+
d.logger.ErrorContext(d.ctx, E.Cause(err, "receive ICMP echo reply"))
94+
}
95+
return
96+
}
97+
buffer.Truncate(n)
98+
_, err = d.rewriter.WriteBack(buffer.Bytes())
99+
if err != nil {
100+
d.logger.ErrorContext(d.ctx, E.Cause(err, "write ICMP echo reply"))
101+
}
102+
buffer.Release()
103+
}
104+
}
105+
106+
func (d *GVisorDestination) WritePacket(packet *buf.Buffer) error {
107+
d.rewriter.RewritePacket(packet.Bytes())
108+
return common.Error(d.conn.Write(packet.Bytes()))
109+
}
110+
111+
func (d *GVisorDestination) Close() error {
112+
return d.conn.Close()
113+
}

ping/rewriter.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
package ping
2+
3+
import (
4+
"net/netip"
5+
"sync"
6+
7+
"github.com/sagernet/sing-tun"
8+
"github.com/sagernet/sing-tun/internal/gtcpip/header"
9+
)
10+
11+
type Rewriter struct {
12+
access sync.RWMutex
13+
sessions map[tun.DirectRouteSession]tun.DirectRouteContext
14+
source4Address map[uint16]netip.Addr
15+
source6Address map[uint16]netip.Addr
16+
inet4Address netip.Addr
17+
inet6Address netip.Addr
18+
}
19+
20+
func NewRewriter(inet4Address netip.Addr, inet6Address netip.Addr) *Rewriter {
21+
return &Rewriter{
22+
sessions: make(map[tun.DirectRouteSession]tun.DirectRouteContext),
23+
inet4Address: inet4Address,
24+
inet6Address: inet6Address,
25+
}
26+
}
27+
28+
func (m *Rewriter) CreateSession(session tun.DirectRouteSession, context tun.DirectRouteContext) {
29+
m.access.Lock()
30+
m.sessions[session] = context
31+
m.access.Unlock()
32+
}
33+
34+
func (m *Rewriter) DeleteSession(session tun.DirectRouteSession) {
35+
m.access.Lock()
36+
delete(m.sessions, session)
37+
m.access.Unlock()
38+
}
39+
40+
func (m *Rewriter) RewritePacket(packet []byte) {
41+
var ipHdr header.Network
42+
var bindAddr netip.Addr
43+
switch header.IPVersion(packet) {
44+
case header.IPv4Version:
45+
ipHdr = header.IPv4(packet)
46+
bindAddr = m.inet4Address
47+
case header.IPv6Version:
48+
ipHdr = header.IPv6(packet)
49+
bindAddr = m.inet6Address
50+
default:
51+
return
52+
}
53+
sourceAddr := ipHdr.SourceAddr()
54+
ipHdr.SetSourceAddr(bindAddr)
55+
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
56+
ipHdr4.SetChecksum(0)
57+
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
58+
}
59+
switch ipHdr.TransportProtocol() {
60+
case header.ICMPv4ProtocolNumber:
61+
icmpHdr := header.ICMPv4(ipHdr.Payload())
62+
m.access.Lock()
63+
m.source4Address[icmpHdr.Ident()] = sourceAddr
64+
m.access.Lock()
65+
case header.ICMPv6ProtocolNumber:
66+
icmpHdr := header.ICMPv6(ipHdr.Payload())
67+
icmpHdr.SetChecksum(0)
68+
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
69+
Header: icmpHdr,
70+
Src: ipHdr.SourceAddressSlice(),
71+
Dst: ipHdr.DestinationAddressSlice(),
72+
}))
73+
m.access.Lock()
74+
m.source6Address[icmpHdr.Ident()] = sourceAddr
75+
m.access.Lock()
76+
}
77+
}
78+
79+
func (m *Rewriter) WriteBack(packet []byte) (bool, error) {
80+
var ipHdr header.Network
81+
var routeSession tun.DirectRouteSession
82+
switch header.IPVersion(packet) {
83+
case header.IPv4Version:
84+
ipHdr = header.IPv4(packet)
85+
routeSession.Destination = ipHdr.SourceAddr()
86+
case header.IPv6Version:
87+
ipHdr = header.IPv6(packet)
88+
routeSession.Destination = ipHdr.SourceAddr()
89+
default:
90+
return false, nil
91+
}
92+
switch ipHdr.TransportProtocol() {
93+
case header.ICMPv4ProtocolNumber:
94+
icmpHdr := header.ICMPv4(ipHdr.Payload())
95+
m.access.Lock()
96+
ident := icmpHdr.Ident()
97+
source, loaded := m.source4Address[ident]
98+
if !loaded {
99+
m.access.Unlock()
100+
return false, nil
101+
}
102+
delete(m.source4Address, icmpHdr.Ident())
103+
m.access.Lock()
104+
routeSession.Source = source
105+
case header.ICMPv6ProtocolNumber:
106+
icmpHdr := header.ICMPv6(ipHdr.Payload())
107+
m.access.Lock()
108+
ident := icmpHdr.Ident()
109+
source, loaded := m.source6Address[ident]
110+
if !loaded {
111+
m.access.Unlock()
112+
return false, nil
113+
}
114+
delete(m.source6Address, icmpHdr.Ident())
115+
m.access.Lock()
116+
routeSession.Source = source
117+
default:
118+
return false, nil
119+
}
120+
m.access.RLock()
121+
context, loaded := m.sessions[routeSession]
122+
m.access.RUnlock()
123+
if !loaded {
124+
return false, nil
125+
}
126+
ipHdr.SetDestinationAddr(routeSession.Source)
127+
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
128+
ipHdr4.SetChecksum(0)
129+
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
130+
}
131+
switch ipHdr.TransportProtocol() {
132+
case header.ICMPv6ProtocolNumber:
133+
icmpHdr := header.ICMPv6(ipHdr.Payload())
134+
icmpHdr.SetChecksum(0)
135+
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
136+
Header: icmpHdr,
137+
Src: ipHdr.SourceAddressSlice(),
138+
Dst: ipHdr.DestinationAddressSlice(),
139+
}))
140+
}
141+
return true, context.WritePacket(packet)
142+
}

route_direct.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package tun
2+
3+
import (
4+
"net/netip"
5+
"time"
6+
7+
"github.com/sagernet/sing/common"
8+
"github.com/sagernet/sing/common/buf"
9+
"github.com/sagernet/sing/contrab/freelru"
10+
"github.com/sagernet/sing/contrab/maphash"
11+
)
12+
13+
type DirectRouteDestination interface {
14+
WritePacket(packet *buf.Buffer) error
15+
Close() error
16+
}
17+
18+
type DirectRouteSession struct {
19+
// IPVersion uint8
20+
// Network uint8
21+
Source netip.Addr
22+
Destination netip.Addr
23+
}
24+
25+
type DirectRouteMapping struct {
26+
mapping freelru.Cache[DirectRouteSession, DirectRouteDestination]
27+
}
28+
29+
func NewDirectRouteMapping(timeout time.Duration) *DirectRouteMapping {
30+
mapping := common.Must1(freelru.NewSharded[DirectRouteSession, DirectRouteDestination](1024, maphash.NewHasher[DirectRouteSession]().Hash32))
31+
mapping.SetOnEvict(func(session DirectRouteSession, action DirectRouteDestination) {
32+
action.Close()
33+
})
34+
mapping.SetLifetime(timeout)
35+
return &DirectRouteMapping{mapping}
36+
}
37+
38+
func (m *DirectRouteMapping) Lookup(session DirectRouteSession, constructor func() (DirectRouteDestination, error)) (DirectRouteDestination, error) {
39+
var (
40+
created DirectRouteDestination
41+
err error
42+
)
43+
action, _, ok := m.mapping.GetAndRefreshOrAdd(session, func() (DirectRouteDestination, bool) {
44+
created, err = constructor()
45+
return created, err == nil
46+
})
47+
if !ok {
48+
return nil, err
49+
}
50+
return action, nil
51+
}

route_mapping.go

Lines changed: 0 additions & 45 deletions
This file was deleted.

0 commit comments

Comments
 (0)