diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 10498c9ad..af97301a1 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -36,9 +36,11 @@ func main() { flags.AddFlagSet(o.Flags()) local := flag.NewFlagSet(os.Args[0], flag.ExitOnError) klog.InitFlags(local) - err := local.Set("v", "4") - if err != nil { - fmt.Fprintf(os.Stderr, "error setting klog flags: %v", err) + if local.Lookup("v") == nil { + err := local.Set("v", "4") + if err != nil { + fmt.Fprintf(os.Stderr, "error setting klog flags: %v", err) + } } local.VisitAll(func(fl *flag.Flag) { fl.Name = util.Normalize(fl.Name) diff --git a/cmd/server/main.go b/cmd/server/main.go index fdb1a736b..3bad77eec 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -36,9 +36,11 @@ func main() { flags.AddFlagSet(o.Flags()) local := flag.NewFlagSet(os.Args[0], flag.ExitOnError) klog.InitFlags(local) - err := local.Set("v", "4") - if err != nil { - fmt.Fprintf(os.Stderr, "error setting klog flags: %v", err) + if local.Lookup("v") == nil { + err := local.Set("v", "4") + if err != nil { + fmt.Fprintf(os.Stderr, "error setting klog flags: %v", err) + } } local.VisitAll(func(fl *flag.Flag) { fl.Name = util.Normalize(fl.Name) diff --git a/pkg/agent/client.go b/pkg/agent/client.go index 5a19ec9ee..1469d1b33 100644 --- a/pkg/agent/client.go +++ b/pkg/agent/client.go @@ -354,7 +354,13 @@ func (a *Client) Serve() { if status.Code(err) == codes.Canceled { klog.V(2).InfoS("stream canceled", "serverID", a.serverID, "agentID", a.agentID) } else { - klog.ErrorS(err, "could not read stream", "serverID", a.serverID, "agentID", a.agentID) + select { + case <-a.stopCh: + klog.V(5).InfoS("could not read stream because agent client is shutting down", "serverID", a.serverID, "agentID", a.agentID, "err", err) + default: + // If stopCh is not closed, this is a legitimate, unexpected error. + klog.ErrorS(err, "could not read stream", "serverID", a.serverID, "agentID", a.agentID) + } } return } @@ -407,7 +413,13 @@ func (a *Client) Serve() { closePkt.GetCloseResponse().ConnectID = connID } if err := a.Send(closePkt); err != nil { - klog.ErrorS(err, "close response failure", "") + if err == io.EOF { + klog.V(4).InfoS("received EOF; connection already closed", "connectionID", connID, "dialID", dialReq.Random, "err", err) + } else if _, ok := a.connManager.Get(connID); !ok { + klog.V(5).InfoS("connection already closed", "connectionID", connID, "dialID", dialReq.Random, "err", err) + } else { + klog.ErrorS(err, "close response failure", "connectionID", connID, "dialID", dialReq.Random) + } } close(dataCh) a.connManager.Delete(connID) diff --git a/pkg/agent/clientset.go b/pkg/agent/clientset.go index a17948b99..7fa2610fa 100644 --- a/pkg/agent/clientset.go +++ b/pkg/agent/clientset.go @@ -96,15 +96,12 @@ func (cs *ClientSet) HealthyClientsCount() int { } -func (cs *ClientSet) hasIDLocked(serverID string) bool { - _, ok := cs.clients[serverID] - return ok -} - +// HasID returns true if the ClientSet has a client to the specified serverID. func (cs *ClientSet) HasID(serverID string) bool { cs.mu.Lock() defer cs.mu.Unlock() - return cs.hasIDLocked(serverID) + _, exists := cs.clients[serverID] + return exists } type DuplicateServerError struct { @@ -115,20 +112,19 @@ func (dse *DuplicateServerError) Error() string { return "duplicate server: " + dse.ServerID } -func (cs *ClientSet) addClientLocked(serverID string, c *Client) error { - if cs.hasIDLocked(serverID) { +// AddClient adds the specified client to our set of clients. +// If we already have a connection with the same serverID, we will return *DuplicateServerError. +func (cs *ClientSet) AddClient(serverID string, c *Client) error { + cs.mu.Lock() + defer cs.mu.Unlock() + + _, exists := cs.clients[serverID] + if exists { return &DuplicateServerError{ServerID: serverID} } cs.clients[serverID] = c metrics.Metrics.SetServerConnectionsCount(len(cs.clients)) return nil - -} - -func (cs *ClientSet) AddClient(serverID string, c *Client) error { - cs.mu.Lock() - defer cs.mu.Unlock() - return cs.addClientLocked(serverID, c) } func (cs *ClientSet) RemoveClient(serverID string) { diff --git a/pkg/server/backend_manager.go b/pkg/server/backend_manager.go index f42631d28..1bd9bf8ab 100644 --- a/pkg/server/backend_manager.go +++ b/pkg/server/backend_manager.go @@ -299,10 +299,10 @@ func containIDType(idTypes []header.IdentifierType, idType header.IdentifierType // addBackend adds a backend. func (s *DefaultBackendStorage) addBackend(identifier string, idType header.IdentifierType, backend *Backend) { if !containIDType(s.idTypes, idType) { - klog.V(4).InfoS("fail to add backend", "backend", identifier, "error", &ErrWrongIDType{idType, s.idTypes}) + klog.V(3).InfoS("fail to add backend", "backend", identifier, "error", &ErrWrongIDType{idType, s.idTypes}) return } - klog.V(5).InfoS("Register backend for agent", "agentID", identifier) + klog.V(2).InfoS("Register backend for agent", "agentID", identifier) s.mu.Lock() defer s.mu.Unlock() _, ok := s.backends[identifier] @@ -327,7 +327,7 @@ func (s *DefaultBackendStorage) removeBackend(identifier string, idType header.I klog.ErrorS(&ErrWrongIDType{idType, s.idTypes}, "fail to remove backend") return } - klog.V(5).InfoS("Remove connection for agent", "agentID", identifier) + klog.V(2).InfoS("Remove connection for agent", "agentID", identifier) s.mu.Lock() defer s.mu.Unlock() backends, ok := s.backends[identifier] @@ -400,7 +400,7 @@ func (s *DefaultBackendStorage) GetRandomBackend() (*Backend, error) { return nil, &ErrNotFound{} } agentID := s.agentIDs[s.random.Intn(len(s.agentIDs))] - klog.V(5).InfoS("Pick agent as backend", "agentID", agentID) + klog.V(3).InfoS("Pick agent as backend", "agentID", agentID) // always return the first connection to an agent, because the agent // will close later connections if there are multiple. return s.backends[agentID][0], nil diff --git a/pkg/server/server.go b/pkg/server/server.go index 9122c48f8..eb46e8352 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -107,6 +107,53 @@ const ( destHostKey key = iota ) +// mapDialErrorToHTTPStatus maps common TCP/network error strings to appropriate HTTP status codes +func mapDialErrorToHTTPStatus(errStr string) int { + // Convert to lowercase for case-insensitive matching + errLower := strings.ToLower(errStr) + + // Check each error pattern and return appropriate status code + switch { + // Timeouts - backend didn't respond in time -> 504 Gateway Timeout + case strings.Contains(errLower, "i/o timeout"), + strings.Contains(errLower, "deadline exceeded"), + strings.Contains(errLower, "context deadline exceeded"), + strings.Contains(errLower, "timeout"): + return 504 + + // Resource exhaustion errors -> 503 Service Unavailable + case strings.Contains(errLower, "too many open files"), + strings.Contains(errLower, "socket: too many open files"): + return 503 + + // Connection errors -> 502 Bad Gateway + case strings.Contains(errLower, "connection refused"), + strings.Contains(errLower, "connection reset by peer"), + strings.Contains(errLower, "broken pipe"), + strings.Contains(errLower, "network is unreachable"), + strings.Contains(errLower, "no route to host"), + strings.Contains(errLower, "host is unreachable"), + strings.Contains(errLower, "network is down"): + return 502 + + // DNS resolution failures -> 502 Bad Gateway + case strings.Contains(errLower, "no such host"), + strings.Contains(errLower, "name resolution"), + strings.Contains(errLower, "lookup") && strings.Contains(errLower, "no such host"): + return 502 + + // TLS/SSL errors -> 502 Bad Gateway + case strings.Contains(errLower, "tls"), + strings.Contains(errLower, "ssl"), + strings.Contains(errLower, "certificate"): + return 502 + + // Default to 502 Bad Gateway for unknown proxy errors + default: + return 502 + } +} + func (c *ProxyClientConnection) send(pkt *client.Packet) error { defer func(start time.Time) { metrics.Metrics.ObserveFrontendWriteLatency(time.Since(start)) }(time.Now()) if c.Mode == ModeGRPC { @@ -121,11 +168,21 @@ func (c *ProxyClientConnection) send(pkt *client.Packet) error { _, err := c.HTTP.Write(pkt.GetData().Data) return err } else if pkt.Type == client.PacketType_DIAL_RSP { - if pkt.GetDialResponse().Error != "" { - body := bytes.NewBufferString(pkt.GetDialResponse().Error) + dialErr := pkt.GetDialResponse().Error + if dialErr != "" { + // // Map the error to appropriate HTTP status code + statusCode := mapDialErrorToHTTPStatus(dialErr) + statusText := http.StatusText(statusCode) + body := bytes.NewBufferString(dialErr) t := http.Response{ - StatusCode: 503, + StatusCode: statusCode, + Status: fmt.Sprintf("%d %s", statusCode, statusText), Body: io.NopCloser(body), + Header: http.Header{ + "Content-Type": []string{"text/plain; charset=utf-8"}, + }, + Proto: "HTTP/1.1", + ProtoMinor: 1, } t.Write(c.HTTP) @@ -718,7 +775,7 @@ func (s *ProxyServer) Connect(stream agent.AgentService_ConnectServer) error { } agentID := backend.GetAgentID() - klog.V(5).InfoS("Connect request from agent", "agentID", agentID, "serverID", s.serverID) + klog.V(2).InfoS("Connect request from agent", "agentID", agentID, "serverID", s.serverID) labels := runpprof.Labels( "serverCount", strconv.Itoa(s.serverCount), "agentID", agentID, @@ -945,7 +1002,7 @@ func (s *ProxyServer) serveRecvBackend(backend *Backend, agentID string, recvCh klog.V(5).InfoS("Ignoring unrecognized packet from backend", "packet", pkt, "agentID", agentID) } } - klog.V(5).InfoS("Close backend of agent", "agentID", agentID) + klog.V(3).InfoS("Close backend of agent", "agentID", agentID) } func (s *ProxyServer) sendBackendClose(backend *Backend, connectID int64, random int64, reason string) { diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index 3d3ec18d2..ad81179ca 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -40,7 +40,7 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer metrics.Metrics.HTTPConnectionDec() klog.V(2).InfoS("Received request for host", "method", r.Method, "host", r.Host, "userAgent", r.UserAgent()) - if r.TLS != nil { + if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { klog.V(2).InfoS("TLS", "commonName", r.TLS.PeerCertificates[0].Subject.CommonName) } if r.Method != http.MethodConnect { @@ -60,14 +60,6 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Send the HTTP 200 OK status after a successful hijack - _, err = conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) - if err != nil { - klog.ErrorS(err, "failed to send 200 connection established") - conn.Close() - return - } - var closeOnce sync.Once defer closeOnce.Do(func() { conn.Close() }) @@ -104,15 +96,23 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { connected: connected, start: time.Now(), backend: backend, + dialID: random, + agentID: backend.GetAgentID(), } t.Server.PendingDial.Add(random, connection) if err := backend.Send(dialRequest); err != nil { - klog.ErrorS(err, "failed to tunnel dial request") + klog.ErrorS(err, "failed to tunnel dial request", "host", r.Host, "dialID", connection.dialID, "agentID", connection.agentID) + // Send proper HTTP error response + conn.Write([]byte(fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nFailed to tunnel dial request: %v\r\n", err))) + conn.Close() return } ctxt := backend.Context() if ctxt.Err() != nil { - klog.ErrorS(err, "context reports failure") + klog.ErrorS(ctxt.Err(), "context reports failure") + conn.Write([]byte(fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nBackend context error: %v\r\n", ctxt.Err()))) + conn.Close() + return } select { @@ -123,6 +123,15 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { select { case <-connection.connected: // Waiting for response before we begin full communication. + // Now that connection is established, send 200 OK to switch to tunnel mode + _, err = conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) + if err != nil { + klog.ErrorS(err, "failed to send 200 connection established", "host", r.Host, "agentID", connection.agentID) + conn.Close() + return + } + klog.V(3).InfoS("Connection established, sent 200 OK", "host", r.Host, "agentID", connection.agentID, "connectionID", connection.connectID) + case <-closed: // Connection was closed before being established } @@ -142,22 +151,22 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn.Close() }() - klog.V(3).InfoS("Starting proxy to host", "host", r.Host) - pkt := make([]byte, 1<<15) // Match GRPC Window size - connID := connection.connectID agentID := connection.agentID + klog.V(3).InfoS("Starting proxy to host", "host", r.Host, "agentID", agentID, "connectionID", connID) + + pkt := make([]byte, 1<<15) // Match GRPC Window size var acc int for { n, err := bufrw.Read(pkt[:]) acc += n if err == io.EOF { - klog.V(1).InfoS("EOF from host", "host", r.Host) + klog.V(1).InfoS("EOF from host", "host", r.Host, "agentID", agentID, "connectionID", connID) break } if err != nil { - klog.ErrorS(err, "Received failure on connection") + klog.ErrorS(err, "Received failure on connection", "host", r.Host, "agentID", agentID, "connectionID", connID) break } @@ -172,7 +181,7 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { } err = backend.Send(packet) if err != nil { - klog.ErrorS(err, "error sending packet") + klog.ErrorS(err, "error sending packet", "host", r.Host, "agentID", agentID, "connectionID", connID) break } klog.V(5).InfoS("Forwarding data on tunnel to agent", diff --git a/tests/proxy_test.go b/tests/proxy_test.go index aac247006..8c8faa591 100644 --- a/tests/proxy_test.go +++ b/tests/proxy_test.go @@ -689,7 +689,7 @@ func TestFailedDNSLookupProxy_HTTPCONN(t *testing.T) { t.Error(err) } - urlString := "http://thissssssxxxxx.com:80" + urlString := "http://thisdefinitelydoesnotexist.com:80" serverURL, _ := url.Parse(urlString) // Send HTTP-Connect request @@ -705,36 +705,12 @@ func TestFailedDNSLookupProxy_HTTPCONN(t *testing.T) { t.Errorf("reading HTTP response from CONNECT: %v", err) } - if res.StatusCode != 200 { - t.Errorf("expect 200; got %d", res.StatusCode) - } - if br.Buffered() > 0 { - t.Error("unexpected extra buffer") - } - dialer := func(_, _ string) (net.Conn, error) { - return conn, nil - } - - c := &http.Client{ - Transport: &http.Transport{ - Dial: dialer, - }, - } - - resp, err := c.Get(urlString) - if err != nil { - t.Error(err) + if res.StatusCode != 502 { + t.Errorf("expect 502; got %d", res.StatusCode) } - if resp.StatusCode != 503 { - t.Errorf("expect 503; got %d", res.StatusCode) - } - - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - if !strings.Contains(err.Error(), "connection reset by peer") { - t.Error(err) - } + body, err := io.ReadAll(res.Body) + res.Body.Close() if !strings.Contains(string(body), "no such host") { t.Errorf("Unexpected error: %v", err) @@ -779,37 +755,21 @@ func TestFailedDial_HTTPCONN(t *testing.T) { br := bufio.NewReader(conn) res, err := http.ReadResponse(br, nil) if err != nil { - t.Fatalf("reading HTTP response from CONNECT: %v", err) - } - if res.StatusCode != 200 { - t.Fatalf("expect 200; got %d", res.StatusCode) + t.Errorf("reading HTTP response from CONNECT: %v", err) } - dialer := func(_, _ string) (net.Conn, error) { - return conn, nil + if res.StatusCode != 502 { + t.Errorf("expect 502; got %d", res.StatusCode) } - c := &http.Client{ - Transport: &http.Transport{ - Dial: dialer, - }, - } - - resp, err := c.Get(server.URL) + body, err := io.ReadAll(res.Body) + res.Body.Close() if err != nil { - t.Fatal(err) - } - - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err == nil { t.Fatalf("Expected error reading response body; response=%q", body) - } else if !strings.Contains(err.Error(), "connection reset by peer") { - t.Error(err) } if !strings.Contains(string(body), "connection refused") { - t.Errorf("Unexpected error: %v", err) + t.Errorf("Expected 'connection refused' in error body, got: %s", string(body)) } if err := ps.Metrics().ExpectServerDialFailure(metricsserver.DialFailureErrorResponse, 1); err != nil {