Skip to content

Commit 69ee7ba

Browse files
committed
Add lazy conn support for gVisor
1 parent 3185844 commit 69ee7ba

File tree

5 files changed

+242
-160
lines changed

5 files changed

+242
-160
lines changed

stack_gvisor.go

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -76,43 +76,17 @@ func (t *GVisor) Start() error {
7676
return err
7777
}
7878
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
79-
var wq waiter.Queue
80-
handshakeCtx, cancel := context.WithCancel(context.Background())
81-
go func() {
82-
select {
83-
case <-t.ctx.Done():
84-
wq.Notify(wq.Events())
85-
case <-handshakeCtx.Done():
86-
}
87-
}()
88-
endpoint, err := r.CreateEndpoint(&wq)
89-
cancel()
90-
if err != nil {
91-
r.Complete(true)
92-
return
93-
}
94-
r.Complete(false)
95-
endpoint.SocketOptions().SetKeepAlive(true)
96-
keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
97-
endpoint.SetSockOpt(&keepAliveIdle)
98-
keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
99-
endpoint.SetSockOpt(&keepAliveInterval)
100-
tcpConn := gonet.NewTCPConn(&wq, endpoint)
101-
lAddr := tcpConn.RemoteAddr()
102-
rAddr := tcpConn.LocalAddr()
103-
if lAddr == nil || rAddr == nil {
104-
tcpConn.Close()
105-
return
79+
var metadata M.Metadata
80+
metadata.Source = M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
81+
metadata.Destination = M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
82+
conn := &gLazyConn{
83+
parentCtx: t.ctx,
84+
stack: t.stack,
85+
request: r,
86+
localAddr: metadata.Source.TCPAddr(),
87+
remoteAddr: metadata.Destination.TCPAddr(),
10688
}
107-
go func() {
108-
var metadata M.Metadata
109-
metadata.Source = M.SocksaddrFromNet(lAddr)
110-
metadata.Destination = M.SocksaddrFromNet(rAddr)
111-
hErr := t.handler.NewConnection(t.ctx, &gTCPConn{tcpConn}, metadata)
112-
if hErr != nil {
113-
endpoint.Abort()
114-
}
115-
}()
89+
_ = t.handler.NewConnection(t.ctx, conn, metadata)
11690
})
11791
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
11892
if !t.endpointIndependentNat {
@@ -129,12 +103,11 @@ func (t *GVisor) Start() error {
129103
endpoint.Abort()
130104
return
131105
}
132-
gConn := &gUDPConn{UDPConn: udpConn}
133106
go func() {
134107
var metadata M.Metadata
135108
metadata.Source = M.SocksaddrFromNet(lAddr)
136109
metadata.Destination = M.SocksaddrFromNet(rAddr)
137-
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(t.udpTimeout)*time.Second)
110+
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, metadata.Destination), time.Duration(t.udpTimeout)*time.Second)
138111
hErr := t.handler.NewPacketConnection(ctx, conn, metadata)
139112
if hErr != nil {
140113
endpoint.Abort()
@@ -191,7 +164,7 @@ func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
191164
})
192165
tErr := ipStack.CreateNIC(defaultNIC, ep)
193166
if tErr != nil {
194-
return nil, E.New("create nic: ", wrapStackError(tErr))
167+
return nil, E.New("create nic: ", gonet.TranslateNetstackError(tErr))
195168
}
196169
ipStack.SetRouteTable([]tcpip.Route{
197170
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},

stack_gvisor_err.go

Lines changed: 0 additions & 50 deletions
This file was deleted.

stack_gvisor_lazy.go

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
//go:build with_gvisor
2+
3+
package tun
4+
5+
import (
6+
"context"
7+
"errors"
8+
"net"
9+
"os"
10+
"syscall"
11+
"time"
12+
13+
"github.com/sagernet/gvisor/pkg/tcpip"
14+
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
15+
"github.com/sagernet/gvisor/pkg/tcpip/header"
16+
"github.com/sagernet/gvisor/pkg/tcpip/stack"
17+
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
18+
"github.com/sagernet/gvisor/pkg/waiter"
19+
)
20+
21+
type gLazyConn struct {
22+
tcpConn *gonet.TCPConn
23+
parentCtx context.Context
24+
stack *stack.Stack
25+
request *tcp.ForwarderRequest
26+
localAddr net.Addr
27+
remoteAddr net.Addr
28+
handshakeDone bool
29+
handshakeErr error
30+
}
31+
32+
func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
33+
if c.handshakeDone {
34+
return nil
35+
}
36+
defer func() {
37+
c.handshakeDone = true
38+
}()
39+
var (
40+
wq waiter.Queue
41+
endpoint tcpip.Endpoint
42+
)
43+
handshakeCtx, cancel := context.WithCancel(ctx)
44+
go func() {
45+
select {
46+
case <-c.parentCtx.Done():
47+
wq.Notify(wq.Events())
48+
case <-handshakeCtx.Done():
49+
}
50+
}()
51+
endpoint, err := c.request.CreateEndpoint(&wq)
52+
cancel()
53+
if err != nil {
54+
gErr := gonet.TranslateNetstackError(err)
55+
c.handshakeErr = gErr
56+
c.request.Complete(true)
57+
return gErr
58+
}
59+
c.request.Complete(false)
60+
endpoint.SocketOptions().SetKeepAlive(true)
61+
keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
62+
endpoint.SetSockOpt(&keepAliveIdle)
63+
keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
64+
endpoint.SetSockOpt(&keepAliveInterval)
65+
tcpConn := gonet.NewTCPConn(&wq, endpoint)
66+
c.tcpConn = tcpConn
67+
return nil
68+
}
69+
70+
func (c *gLazyConn) HandshakeFailure(err error) error {
71+
if c.handshakeDone {
72+
return nil
73+
}
74+
wErr := gWriteUnreachable(c.stack, c.request.Packet(), err)
75+
c.request.Complete(wErr == os.ErrInvalid)
76+
c.handshakeDone = true
77+
c.handshakeErr = err
78+
return nil
79+
}
80+
81+
func (c *gLazyConn) HandshakeSuccess() error {
82+
return c.HandshakeContext(context.Background())
83+
}
84+
85+
func (c *gLazyConn) Read(b []byte) (n int, err error) {
86+
if !c.handshakeDone {
87+
err = c.HandshakeContext(context.Background())
88+
if err != nil {
89+
return
90+
}
91+
} else if c.handshakeErr != nil {
92+
return 0, c.handshakeErr
93+
}
94+
return c.tcpConn.Read(b)
95+
}
96+
97+
func (c *gLazyConn) Write(b []byte) (n int, err error) {
98+
if !c.handshakeDone {
99+
err = c.HandshakeContext(context.Background())
100+
if err != nil {
101+
return
102+
}
103+
} else if c.handshakeErr != nil {
104+
return 0, c.handshakeErr
105+
}
106+
return c.tcpConn.Write(b)
107+
}
108+
109+
func (c *gLazyConn) LocalAddr() net.Addr {
110+
return c.localAddr
111+
}
112+
113+
func (c *gLazyConn) RemoteAddr() net.Addr {
114+
return c.remoteAddr
115+
}
116+
117+
func (c *gLazyConn) SetDeadline(t time.Time) error {
118+
if !c.handshakeDone {
119+
err := c.HandshakeContext(context.Background())
120+
if err != nil {
121+
return err
122+
}
123+
} else if c.handshakeErr != nil {
124+
return c.handshakeErr
125+
}
126+
return c.tcpConn.SetDeadline(t)
127+
}
128+
129+
func (c *gLazyConn) SetReadDeadline(t time.Time) error {
130+
if !c.handshakeDone {
131+
err := c.HandshakeContext(context.Background())
132+
if err != nil {
133+
return err
134+
}
135+
} else if c.handshakeErr != nil {
136+
return c.handshakeErr
137+
}
138+
return c.tcpConn.SetReadDeadline(t)
139+
}
140+
141+
func (c *gLazyConn) SetWriteDeadline(t time.Time) error {
142+
if !c.handshakeDone {
143+
err := c.HandshakeContext(context.Background())
144+
if err != nil {
145+
return err
146+
}
147+
} else if c.handshakeErr != nil {
148+
return c.handshakeErr
149+
}
150+
return c.tcpConn.SetWriteDeadline(t)
151+
}
152+
153+
func (c *gLazyConn) Close() error {
154+
if !c.handshakeDone {
155+
c.request.Complete(true)
156+
c.handshakeErr = net.ErrClosed
157+
return nil
158+
} else if c.handshakeErr != nil {
159+
return nil
160+
}
161+
return c.tcpConn.Close()
162+
}
163+
164+
func (c *gLazyConn) CloseRead() error {
165+
if !c.handshakeDone {
166+
c.request.Complete(true)
167+
c.handshakeErr = net.ErrClosed
168+
return nil
169+
} else if c.handshakeErr != nil {
170+
return nil
171+
}
172+
return c.tcpConn.CloseRead()
173+
}
174+
175+
func (c *gLazyConn) CloseWrite() error {
176+
if !c.handshakeDone {
177+
c.request.Complete(true)
178+
c.handshakeErr = net.ErrClosed
179+
return nil
180+
} else if c.handshakeErr != nil {
181+
return nil
182+
}
183+
return c.tcpConn.CloseRead()
184+
}
185+
186+
func (c *gLazyConn) ReaderReplaceable() bool {
187+
return c.handshakeDone && c.handshakeErr == nil
188+
}
189+
190+
func (c *gLazyConn) WriterReplaceable() bool {
191+
return c.handshakeDone && c.handshakeErr == nil
192+
}
193+
194+
func (c *gLazyConn) Upstream() any {
195+
return c.tcpConn
196+
}
197+
198+
func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) error {
199+
if errors.Is(err, syscall.ENETUNREACH) {
200+
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
201+
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
202+
} else {
203+
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
204+
}
205+
} else if errors.Is(err, syscall.EHOSTUNREACH) {
206+
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
207+
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostProhibited)
208+
} else {
209+
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
210+
}
211+
} else if errors.Is(err, syscall.ECONNREFUSED) {
212+
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
213+
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
214+
} else {
215+
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable)
216+
}
217+
}
218+
return os.ErrInvalid
219+
}
220+
221+
func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error {
222+
return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true))
223+
}
224+
225+
func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error {
226+
return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true))
227+
}

0 commit comments

Comments
 (0)