diff --git a/tracker/clientcontext/injector.go b/tracker/clientcontext/injector.go new file mode 100644 index 0000000..3b4a569 --- /dev/null +++ b/tracker/clientcontext/injector.go @@ -0,0 +1,166 @@ +package clientcontext + +import ( + "context" + "encoding/json" + "fmt" + "net" + "sync" + + "github.com/sagernet/sing-box/adapter" + N "github.com/sagernet/sing/common/network" +) + +var ( + _ (adapter.ConnectionTracker) = (*ClientContextInjector)(nil) + _ (N.ConnHandshakeSuccess) = (*writeConn)(nil) + _ (N.PacketConnHandshakeSuccess) = (*writePacketConn)(nil) +) + +// ClientContextInjector is a connection tracker that sends client info to a ClientContext Manager. +type ClientContextInjector struct { + getInfo GetClientInfoFn + inboundRule *boundsRule + outboundRule *boundsRule + ruleMu sync.RWMutex +} + +// NewClientContextInjector creates a tracker for injecting client info. +func NewClientContextInjector(fn GetClientInfoFn, bounds MatchBounds) *ClientContextInjector { + return &ClientContextInjector{ + inboundRule: newBoundsRule(bounds.Inbound), + outboundRule: newBoundsRule(bounds.Outbound), + getInfo: fn, + } +} + +// RoutedConnection wraps the connection for writing client info. +func (t *ClientContextInjector) RoutedConnection( + ctx context.Context, + conn net.Conn, + metadata adapter.InboundContext, + matchedRule adapter.Rule, + matchOutbound adapter.Outbound, +) net.Conn { + if !t.match(metadata.Inbound, matchOutbound.Tag()) { + return conn + } + info := t.getInfo() + return newWriteConn(conn, &info) +} + +// RoutedPacketConnection wraps the packet connection for writing client info. +func (t *ClientContextInjector) RoutedPacketConnection( + ctx context.Context, + conn N.PacketConn, + metadata adapter.InboundContext, + matchedRule adapter.Rule, + matchOutbound adapter.Outbound, +) N.PacketConn { + if !t.match(metadata.Inbound, matchOutbound.Tag()) { + return conn + } + info := t.getInfo() + return newWritePacketConn(conn, metadata, &info) +} + +func (t *ClientContextInjector) match(inbound, outbound string) bool { + t.ruleMu.RLock() + defer t.ruleMu.RUnlock() + return t.inboundRule.match(inbound) && t.outboundRule.match(outbound) +} + +func (t *ClientContextInjector) UpdateBounds(bounds MatchBounds) { + t.ruleMu.Lock() + t.inboundRule = newBoundsRule(bounds.Inbound) + t.outboundRule = newBoundsRule(bounds.Outbound) + t.ruleMu.Unlock() +} + +// writeConn sends client info after handshake. +type writeConn struct { + net.Conn + info *ClientInfo +} + +func newWriteConn(conn net.Conn, info *ClientInfo) net.Conn { + return &writeConn{Conn: conn, info: info} +} + +// ConnHandshakeSuccess sends client info upon successful handshake with the server. +func (c *writeConn) ConnHandshakeSuccess(conn net.Conn) error { + if err := c.sendInfo(conn); err != nil { + return fmt.Errorf("sending client info: %w", err) + } + return nil +} + +// sendInfo marshals and sends client info as an HTTP POST, then waits for HTTP 200 OK. +func (c *writeConn) sendInfo(conn net.Conn) error { + buf, err := json.Marshal(c.info) + if err != nil { + return fmt.Errorf("marshaling client info: %w", err) + } + packet := append([]byte(packetPrefix), buf...) + if _, err = conn.Write(packet); err != nil { + return fmt.Errorf("writing client info: %w", err) + } + + // wait for `OK` response + var resp [2]byte + if _, err := conn.Read(resp[:]); err != nil { + return fmt.Errorf("reading response: %w", err) + } + if string(resp[:]) != "OK" { + return fmt.Errorf("invalid response: %s", resp) + } + return nil +} + +type writePacketConn struct { + N.PacketConn + metadata adapter.InboundContext + info *ClientInfo +} + +func newWritePacketConn( + conn N.PacketConn, + metadata adapter.InboundContext, + info *ClientInfo, +) N.PacketConn { + return &writePacketConn{ + PacketConn: conn, + metadata: metadata, + info: info, + } +} + +// PacketConnHandshakeSuccess sends client info upon successful handshake. +func (c *writePacketConn) PacketConnHandshakeSuccess(conn net.PacketConn) error { + if err := c.sendInfo(conn); err != nil { + return fmt.Errorf("sending client info: %w", err) + } + return nil +} + +// sendInfo marshals and sends client info as a CLIENTINFO packet, then waits for OK. +func (c *writePacketConn) sendInfo(conn net.PacketConn) error { + buf, err := json.Marshal(c.info) + if err != nil { + return fmt.Errorf("marshaling client info: %w", err) + } + packet := append([]byte(packetPrefix), buf...) + if _, err = conn.WriteTo(packet, c.metadata.Destination); err != nil { + return fmt.Errorf("writing packet: %w", err) + } + + // wait for `OK` response + var resp [2]byte + if _, _, err := conn.ReadFrom(resp[:]); err != nil { + return fmt.Errorf("reading response: %w", err) + } + if string(resp[:]) != "OK" { + return fmt.Errorf("invalid response: %s", resp) + } + return nil +} diff --git a/tracker/clientcontext/manager.go b/tracker/clientcontext/manager.go new file mode 100644 index 0000000..d9b3645 --- /dev/null +++ b/tracker/clientcontext/manager.go @@ -0,0 +1,211 @@ +package clientcontext + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var _ (adapter.ConnectionTracker) = (*Manager)(nil) + +type clientInfoKey struct{} + +// ContextWithClientInfo returns a new context with the given ClientInfo. +func ContextWithClientInfo(ctx context.Context, info ClientInfo) context.Context { + return context.WithValue(ctx, clientInfoKey{}, info) +} + +// ClientInfoFromContext retrieves the ClientInfo from the context. +func ClientInfoFromContext(ctx context.Context) (ClientInfo, bool) { + info, ok := ctx.Value(clientInfoKey{}).(ClientInfo) + return info, ok +} + +// Manager is a ConnectionTracker that manages ClientInfo for connections. +type Manager struct { + logger log.ContextLogger + trackers []adapter.ConnectionTracker + + inboundRule *boundsRule + outboundRule *boundsRule + ruleMu sync.RWMutex +} + +// NewManager creates a new ClientContext Manager. +func NewManager(bounds MatchBounds, logger log.ContextLogger) *Manager { + return &Manager{ + trackers: []adapter.ConnectionTracker{}, + logger: logger, + inboundRule: newBoundsRule(bounds.Inbound), + outboundRule: newBoundsRule(bounds.Outbound), + } +} + +// AppendTracker appends a ConnectionTracker to the Manager. +func (m *Manager) AppendTracker(tracker adapter.ConnectionTracker) { + m.trackers = append(m.trackers, tracker) +} + +func (m *Manager) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn { + if !m.match(metadata.Inbound, matchOutbound.Tag()) { + return conn + } + c := &readConn{ + Conn: conn, + reader: conn, + mgr: m, + } + info, err := c.readInfo() + if err != c.readErr { + m.logger.Error("failed to read client info ", "tag", "clientcontext-tracker", "error", err) + } + if err != nil { + return c + } + if info == nil { + return c + } + ctx = ContextWithClientInfo(ctx, *info) + conn = c + for _, tracker := range m.trackers { + conn = tracker.RoutedConnection(ctx, conn, metadata, matchedRule, matchOutbound) + } + return conn +} + +func (m *Manager) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) N.PacketConn { + if !m.match(metadata.Inbound, matchOutbound.Tag()) { + return conn + } + c := &readPacketConn{ + PacketConn: conn, + mgr: m, + } + info, err := c.readInfo() + if err != c.readErr { + m.logger.Error("failed to read client info ", "tag", "clientcontext-tracker", "error", err) + } + if err != nil { + return c + } + if info == nil { + return c + } + ctx = ContextWithClientInfo(ctx, *info) + conn = c + for _, tracker := range m.trackers { + conn = tracker.RoutedPacketConnection(ctx, conn, metadata, matchedRule, matchOutbound) + } + return conn +} + +func (m *Manager) match(inbound, outbound string) bool { + m.ruleMu.RLock() + defer m.ruleMu.RUnlock() + return m.inboundRule.match(inbound) && m.outboundRule.match(outbound) +} + +func (m *Manager) UpdateBounds(bounds MatchBounds) { + m.ruleMu.Lock() + m.inboundRule = newBoundsRule(bounds.Inbound) + m.outboundRule = newBoundsRule(bounds.Outbound) + m.ruleMu.Unlock() +} + +// readConn reads client info from the connection on creation. +type readConn struct { + net.Conn + mgr *Manager + reader io.Reader + n int + readErr error +} + +func (c *readConn) Read(b []byte) (n int, err error) { + if c.readErr != nil { + return c.n, c.readErr + } + return c.reader.Read(b) +} + +// readInfo reads and decodes client info, then sends an HTTP 200 OK response. +func (c *readConn) readInfo() (*ClientInfo, error) { + var buf [32]byte + n, err := c.Conn.Read(buf[:]) + if err != nil { + c.readErr = err + c.n = n + return nil, err + } + if !bytes.HasPrefix(buf[:n], []byte(packetPrefix)) { + c.reader = io.MultiReader(bytes.NewReader(buf[:n]), c.Conn) + return nil, nil + } + + var info ClientInfo + reader := io.MultiReader(bytes.NewReader(buf[len(packetPrefix):n]), c.Conn) + if err := json.NewDecoder(reader).Decode(&info); err != nil { + return nil, fmt.Errorf("decoding client info: %w", err) + } + + if _, err := c.Write([]byte("OK")); err != nil { + return nil, fmt.Errorf("writing OK response: %w", err) + } + return &info, nil +} + +type readPacketConn struct { + N.PacketConn + mgr *Manager + destination metadata.Socksaddr + readErr error +} + +func (c *readPacketConn) ReadPacket(b *buf.Buffer) (destination metadata.Socksaddr, err error) { + if c.readErr != nil { + return c.destination, c.readErr + } + return c.PacketConn.ReadPacket(b) +} + +// readInfo reads and decodes client info if the first packet is a CLIENTINFO packet, then sends an +// OK response. +func (c *readPacketConn) readInfo() (*ClientInfo, error) { + buffer := buf.NewPacket() + defer buffer.Release() + + destination, err := c.ReadPacket(buffer) + if err != nil { + c.destination = destination + c.readErr = err + return nil, err + } + data := buffer.Bytes() + if !bytes.HasPrefix(data, []byte(packetPrefix)) { + // not a client info packet, wrap with cached packet conn so the packet can be read again + c.PacketConn = bufio.NewCachedPacketConn(c.PacketConn, buffer, destination) + return nil, nil + } + var info ClientInfo + if err := json.Unmarshal(data[len(packetPrefix):], &info); err != nil { + return nil, fmt.Errorf("unmarshaling client info: %w", err) + } + + buffer.Reset() + buffer.WriteString("OK") + if err := c.WritePacket(buffer, destination); err != nil { + return nil, fmt.Errorf("writing OK response: %w", err) + } + return &info, nil +} diff --git a/tracker/clientcontext/tracker.go b/tracker/clientcontext/tracker.go index 33d95ec..2fcf18e 100644 --- a/tracker/clientcontext/tracker.go +++ b/tracker/clientcontext/tracker.go @@ -1,37 +1,41 @@ -// Package clientcontext provides a [adapter.ConnectionTracker] that sends and receives client -// metadata after connection handshake. The metadata is stored in the context for other trackers -// to use. +// Package clientcontext provides [adapter.ConnectionTracker]s that sends and receives client +// metadata after connection handshake between the client and server. The metadata is stored in the +// context for other trackers to use. // -// To use this tracker, create a [ClientContextTracker] with either [NewClientContextTracker], for -// clients, or [NewClientContextReader], for servers, then pass it to router.AppendTracker. The -// metadata can be retrieved from the context using [service.PtrFromContext]. -// Note that both client and server sides must use this tracker for it to work. +// Usage: +// Create a [ClientContextInjector] on the client side to send client info to the server, and a +// [Manager] on the server side to receive and store the info. Both trackers should be added to the +// router using router.AppendTracker. Trackers added to the Manager with [Manager.AppendTracker] can +// access the client info from the connection context using [ClientInfoFromContext]. package clientcontext -import ( - stdbufio "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net" - "net/http" +// since sing-box only wraps inbound connections with trackers, conn on the client is from the user +// (e.g. tun connection), while conn on the server is from an outbound on the client. The connection +// to the server isn't established until after conn is wrapped on the client side and we don't have +// access to it until after the handshake. +// +// Client Server +// ------------- ------------- +// conn ---> tracker(conn) | +// (i.e. tun) | | +// dial server -----------> conn +// | | +// +<-------- handshake ------->+ +// | | +// handshakeSuccess <---------- tracker(conn) +// | | +// send client info ---------> read client info +// | | +// pipe traffic dial upstream +// ... +// pipe traffic +// +// This is why writeConn (client) doesn't send the client info until ConnHandshakeSuccess while +// readConn (server) reads it immediately upon creation. - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/service" -) +const packetPrefix = "CLIENTINFO " -var ( - _ (adapter.ConnectionTracker) = (*ClientContextTracker)(nil) - _ (N.ConnHandshakeSuccess) = (*writeConn)(nil) - _ (N.PacketConnHandshakeSuccess) = (*writePacketConn)(nil) -) +type GetClientInfoFn func() ClientInfo // ClientInfo holds information about the client user/device. type ClientInfo struct { @@ -43,80 +47,12 @@ type ClientInfo struct { } // MatchBounds specifies inbound and outbound matching rules. -// The empty string is treated as a wildcard. +// The empty string and "any" are treated as a wildcard. type MatchBounds struct { Inbound []string Outbound []string } -// ClientContextTracker tracks client context for connections. -type ClientContextTracker struct { - info ClientInfo - inboundRule *boundsRule - outboundRule *boundsRule - logger log.ContextLogger - isReader bool -} - -// NewClientContextTracker creates a tracker for writing client info. -func NewClientContextTracker(info ClientInfo, bounds MatchBounds, logger log.ContextLogger) *ClientContextTracker { - return &ClientContextTracker{ - info: info, - inboundRule: newBoundsRule(bounds.Inbound), - outboundRule: newBoundsRule(bounds.Outbound), - logger: logger, - } -} - -// NewClientContextReader creates a tracker for reading client info. -func NewClientContextReader(bounds MatchBounds, logger log.ContextLogger) *ClientContextTracker { - return &ClientContextTracker{ - inboundRule: newBoundsRule(bounds.Inbound), - outboundRule: newBoundsRule(bounds.Outbound), - logger: logger, - isReader: true, - } -} - -// RoutedConnection wraps the connection for reading or writing client info. -func (t *ClientContextTracker) RoutedConnection( - ctx context.Context, - conn net.Conn, - metadata adapter.InboundContext, - matchedRule adapter.Rule, - matchOutbound adapter.Outbound, -) net.Conn { - if !t.inboundRule.match(metadata.Inbound) || !t.outboundRule.match(matchOutbound.Tag()) { - return conn - } - if t.isReader { - return newReadConn(ctx, conn, t.logger) - } - return newWriteConn(ctx, conn, &t.info, t.logger) -} - -// RoutedPacketConnection wraps the packet connection for reading or writing client info. -func (t *ClientContextTracker) RoutedPacketConnection( - ctx context.Context, - conn N.PacketConn, - metadata adapter.InboundContext, - matchedRule adapter.Rule, - matchOutbound adapter.Outbound, -) N.PacketConn { - if !t.inboundRule.match(metadata.Inbound) || !t.outboundRule.match(matchOutbound.Tag()) { - return conn - } - if t.isReader { - return newReadPacketConn(ctx, conn, t.logger) - } - return newWritePacketConn(ctx, conn, metadata, &t.info, t.logger) -} - -func (t *ClientContextTracker) UpdateBounds(bounds MatchBounds) { - t.inboundRule = newBoundsRule(bounds.Inbound) - t.outboundRule = newBoundsRule(bounds.Outbound) -} - type boundsRule struct { tags []string tagMap map[string]bool @@ -138,267 +74,3 @@ func newBoundsRule(tags []string) *boundsRule { func (b *boundsRule) match(tag string) bool { return (b.matchAny && tag != "") || b.tagMap[tag] } - -// since sing-box only wraps inbound connections with trackers, conn on the client is from the user -// (e.g. tun connection), while conn on the server is from an outbound on the client. The connection -// to the server isn't established until after conn is wrapped on the client side and we don't have -// access to it until after the handshake. -// -// Client Server -// ------------- ------------- -// conn ---> tracker(conn) | -// (i.e. tun) | | -// dial server -----------> conn -// | | -// +<-------- handshake ------->+ -// | | -// handshakeSuccess <---------- tracker(conn) -// | | -// send client info ---------> read client info -// | | -// pipe traffic dial upstream -// ... -// pipe traffic -// -// This is why writeConn (client) doesn't send the client info until ConnHandshakeSuccess while -// readConn (server) reads it immediately upon creation. - -// readConn reads client info from the connection on creation. -type readConn struct { - net.Conn - ctx context.Context - info ClientInfo - logger log.ContextLogger - - reader io.Reader - n int - readErr error -} - -// newReadConn creates a readConn and reads client info from it. If successful, the info is stored -// in the context. -func newReadConn(ctx context.Context, conn net.Conn, logger log.ContextLogger) net.Conn { - c := &readConn{ - Conn: conn, - ctx: ctx, - reader: conn, - logger: logger, - } - if err := c.readInfo(); err != nil { - logger.Warn("reading client info: ", err) - } - return c -} - -func (c *readConn) Read(b []byte) (n int, err error) { - if c.readErr != nil { - return c.n, c.readErr - } - return c.reader.Read(b) -} - -// readInfo reads and decodes client info, then sends an HTTP 200 OK response. -func (c *readConn) readInfo() error { - var buf [32]byte - n, err := c.Conn.Read(buf[:]) - if err != nil { - c.readErr = err - c.n = n - return err - } - reader := io.MultiReader(bytes.NewReader(buf[:n]), c.Conn) - if !bytes.HasPrefix(buf[:n], []byte("POST /clientinfo")) { - c.reader = reader - return nil - } - - var info ClientInfo - req, err := http.ReadRequest(stdbufio.NewReader(reader)) - if err != nil { - return fmt.Errorf("reading HTTP request: %w", err) - } - defer req.Body.Close() - if err := json.NewDecoder(req.Body).Decode(&info); err != nil { - return fmt.Errorf("decoding client info: %w", err) - } - c.info = info - - resp := "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" - if _, err := c.Write([]byte(resp)); err != nil { - return fmt.Errorf("writing HTTP response: %w", err) - } - service.ContextWithPtr(c.ctx, &info) - return nil -} - -// writeConn sends client info after handshake. -type writeConn struct { - net.Conn - ctx context.Context - info *ClientInfo - logger log.ContextLogger -} - -func newWriteConn(ctx context.Context, conn net.Conn, info *ClientInfo, logger log.ContextLogger) net.Conn { - return &writeConn{Conn: conn, ctx: ctx, info: info, logger: logger} -} - -// ConnHandshakeSuccess sends client info upon successful handshake with the server. -func (c *writeConn) ConnHandshakeSuccess(conn net.Conn) error { - if err := c.sendInfo(conn); err != nil { - return fmt.Errorf("sending client info: %w", err) - } - return nil -} - -// sendInfo marshals and sends client info as an HTTP POST, then waits for HTTP 200 OK. -func (c *writeConn) sendInfo(conn net.Conn) error { - buf, err := json.Marshal(c.info) - if err != nil { - return fmt.Errorf("marshaling client info: %w", err) - } - // Write HTTP POST request - req := bytes.NewBuffer(nil) - fmt.Fprintf(req, "POST /clientinfo HTTP/1.1\r\n") - fmt.Fprintf(req, "Host: localhost\r\n") - fmt.Fprintf(req, "Content-Type: application/json\r\n") - fmt.Fprintf(req, "Content-Length: %d\r\n", len(buf)) - fmt.Fprintf(req, "\r\n") - req.Write(buf) - if _, err = conn.Write(req.Bytes()); err != nil { - return fmt.Errorf("writing client info: %w", err) - } - - // wait for HTTP 200 OK response - reader := stdbufio.NewReader(conn) - resp, err := http.ReadResponse(reader, nil) - if err != nil { - return fmt.Errorf("reading HTTP response: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - return fmt.Errorf("invalid server response: %s", resp.Status) - } - return nil -} - -const prefix = "CLIENTINFO " - -type readPacketConn struct { - N.PacketConn - ctx context.Context - info *ClientInfo - logger log.ContextLogger - - reader io.Reader - destination metadata.Socksaddr - readErr error -} - -// newReadPacketConn creates a readPacketConn and reads client info from it. If successful, the -// info is stored in the context. -func newReadPacketConn(ctx context.Context, conn N.PacketConn, logger log.ContextLogger) N.PacketConn { - c := &readPacketConn{ - PacketConn: conn, - ctx: ctx, - logger: logger, - } - if err := c.readInfo(); err != nil { - logger.Warn("reading client info: ", err) - } - return c -} - -func (c *readPacketConn) ReadPacket(b *buf.Buffer) (destination metadata.Socksaddr, err error) { - if c.readErr != nil { - return c.destination, c.readErr - } - return c.PacketConn.ReadPacket(b) -} - -// readInfo reads and decodes client info if the first packet is a CLIENTINFO packet, then sends an -// OK response. -func (c *readPacketConn) readInfo() error { - buffer := buf.NewPacket() - defer buffer.Release() - - destination, err := c.ReadPacket(buffer) - if err != nil { - c.readErr = err - return err - } - data := buffer.Bytes() - if !bytes.HasPrefix(data, []byte(prefix)) { - // not a client info packet, wrap with cached packet conn so the packet can be read again - c.PacketConn = bufio.NewCachedPacketConn(c.PacketConn, buffer, destination) - return nil - } - var info ClientInfo - if err := json.Unmarshal(data[len(prefix):], &info); err != nil { - return fmt.Errorf("unmarshaling client info: %w", err) - } - c.info = &info - - buffer.Reset() - buffer.WriteString("OK") - if err := c.WritePacket(buffer, destination); err != nil { - return fmt.Errorf("writing OK response: %w", err) - } - service.ContextWithPtr(c.ctx, &info) - return nil -} - -type writePacketConn struct { - N.PacketConn - ctx context.Context - metadata adapter.InboundContext - info *ClientInfo - logger log.ContextLogger -} - -func newWritePacketConn( - ctx context.Context, - conn N.PacketConn, - metadata adapter.InboundContext, - info *ClientInfo, - logger log.ContextLogger, -) N.PacketConn { - return &writePacketConn{ - PacketConn: conn, - ctx: ctx, - metadata: metadata, - info: info, - logger: logger, - } -} - -// PacketConnHandshakeSuccess sends client info upon successful handshake. -func (c *writePacketConn) PacketConnHandshakeSuccess(conn net.PacketConn) error { - if err := c.sendInfo(conn); err != nil { - return fmt.Errorf("sending client info: %w", err) - } - return nil -} - -// sendInfo marshals and sends client info as a CLIENTINFO packet, then waits for OK. -func (c *writePacketConn) sendInfo(conn net.PacketConn) error { - buf, err := json.Marshal(c.info) - if err != nil { - return fmt.Errorf("marshaling client info: %w", err) - } - packet := append([]byte(prefix), buf...) - _, err = conn.WriteTo(packet, c.metadata.Destination) - if err != nil { - return fmt.Errorf("writing packet: %w", err) - } - - // wait for `OK` response - resp := make([]byte, 2) - if _, _, err := conn.ReadFrom(resp); err != nil { - return fmt.Errorf("reading response: %w", err) - } - if string(resp) != "OK" { - return fmt.Errorf("invalid response: %s", resp) - } - return nil -} diff --git a/tracker/clientcontext/tracker_test.go b/tracker/clientcontext/tracker_test.go index e733727..5993370 100644 --- a/tracker/clientcontext/tracker_test.go +++ b/tracker/clientcontext/tracker_test.go @@ -13,12 +13,10 @@ import ( sbox "github.com/sagernet/sing-box" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/json" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/service" "github.com/stretchr/testify/require" box "github.com/getlantern/lantern-box" @@ -27,20 +25,17 @@ import ( const testOptionsPath = "../../testdata/options" func TestIntegration(t *testing.T) { - cInfo := ClientInfo{ - DeviceID: "lantern-box", - Platform: "linux", - IsPro: false, - CountryCode: "US", - Version: "9.0", - } ctx := box.BaseContext() logger := log.NewNOPFactory().NewLogger("") - serverTracker := NewClientContextReader(MatchBounds{[]string{"any"}, []string{"any"}}, logger) - _, serverBox := newTestBox(ctx, t, testOptionsPath+"/http_server.json", serverTracker) + mgr := NewManager(MatchBounds{[]string{"any"}, []string{"any"}}, logger) + serverOpts := getOptions(ctx, t, testOptionsPath+"/http_server.json") + serverBox, err := sbox.New(sbox.Options{ + Context: ctx, + Options: serverOpts, + }) + require.NoError(t, err) - mTracker := &mockTracker{} - serverBox.Router().AppendTracker(mTracker) + serverBox.Router().AppendTracker(mgr) require.NoError(t, serverBox.Start()) defer serverBox.Close() @@ -48,17 +43,9 @@ func TestIntegration(t *testing.T) { httpServer := startHTTPServer() defer httpServer.Close() - clientOpts, clientBox := newTestBox(ctx, t, testOptionsPath+"/http_client.json", nil) - - httpInbound, exists := clientBox.Inbound().Get("http-client") - require.True(t, exists, "http-client inbound should exist") - require.Equal(t, constant.TypeHTTP, httpInbound.Type(), "http-client should be a HTTP inbound") - - // this cannot actually be empty or we would have failed to create the box instance + clientOpts := getOptions(ctx, t, testOptionsPath+"/http_client.json") proxyAddr := getProxyAddress(clientOpts.Inbounds) - - require.NoError(t, clientBox.Start()) - defer clientBox.Close() + require.NotEmpty(t, proxyAddr, "http-client inbound not found in client options") proxyURL, _ := url.Parse("http://" + proxyAddr) httpClient := &http.Client{ @@ -68,28 +55,63 @@ func TestIntegration(t *testing.T) { } addr := httpServer.URL + mTracker := &mockTracker{} + mgr.AppendTracker(mTracker) + cInfo := ClientInfo{ + DeviceID: "lantern-box", + Platform: "linux", + IsPro: false, + CountryCode: "US", + Version: "9.0", + } + infoFn := func() ClientInfo { return cInfo } + t.Run("with ClientContext tracker", func(t *testing.T) { + mTracker.info = nil + tracker := NewClientContextInjector(infoFn, MatchBounds{[]string{"any"}, []string{"any"}}) + runTrackerTest(ctx, t, clientOpts, tracker, httpClient, addr) + require.Equal(t, &cInfo, mTracker.info) + }) t.Run("without ClientContext tracker", func(t *testing.T) { - req, err := http.NewRequest("GET", addr+"/ip", nil) - require.NoError(t, err) - - _, err = httpClient.Do(req) - require.NoError(t, err) - + mTracker.info = nil + runTrackerTest(ctx, t, clientOpts, nil, httpClient, addr) require.Nil(t, mTracker.info) }) - t.Run("with ClientContext tracker", func(t *testing.T) { - clientTracker := NewClientContextTracker(cInfo, MatchBounds{[]string{"any"}, []string{"any"}}, logger) - clientBox.Router().AppendTracker(clientTracker) - req, err := http.NewRequest("GET", addr+"/ip", nil) - require.NoError(t, err) - - _, err = httpClient.Do(req) - require.NoError(t, err) +} - info := mTracker.info - require.NotNil(t, info) - require.Equal(t, cInfo, *info) +func runTrackerTest( + ctx context.Context, + t *testing.T, + opts option.Options, + tracker *ClientContextInjector, + client *http.Client, + addr string, +) { + instance, err := sbox.New(sbox.Options{ + Context: ctx, + Options: opts, }) + require.NoError(t, err) + if tracker != nil { + instance.Router().AppendTracker(tracker) + } + + require.NoError(t, instance.Start()) + defer instance.Close() + + req, err := http.NewRequest("GET", addr, nil) + require.NoError(t, err) + + _, err = client.Do(req) + require.NoError(t, err) +} + +func getOptions(ctx context.Context, t *testing.T, configPath string) option.Options { + buf, err := os.ReadFile(configPath) + require.NoError(t, err) + + options, err := json.UnmarshalExtendedContext[option.Options](ctx, buf) + require.NoError(t, err) + return options } func getProxyAddress(inbounds []option.Inbound) string { @@ -110,25 +132,6 @@ func startHTTPServer() *httptest.Server { return httptest.NewServer(handler) } -func newTestBox(ctx context.Context, t *testing.T, configPath string, tracker *ClientContextTracker) (option.Options, *sbox.Box) { - buf, err := os.ReadFile(configPath) - require.NoError(t, err) - - options, err := json.UnmarshalExtendedContext[option.Options](ctx, buf) - require.NoError(t, err) - - instance, err := sbox.New(sbox.Options{ - Context: ctx, - Options: options, - }) - require.NoError(t, err) - - if tracker != nil { - instance.Router().AppendTracker(tracker) - } - return options, instance -} - var _ (adapter.ConnectionTracker) = (*mockTracker)(nil) type mockTracker struct { @@ -136,7 +139,10 @@ type mockTracker struct { } func (t *mockTracker) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn { - t.info = service.PtrFromContext[ClientInfo](ctx) + info, ok := ClientInfoFromContext(ctx) + if ok { + t.info = &info + } return conn } func (t *mockTracker) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) N.PacketConn {