diff --git a/pkg/clusteragent/api/leader_forwarder.go b/pkg/clusteragent/api/leader_forwarder.go index 8440bf609526de..5ecf24e8073782 100644 --- a/pkg/clusteragent/api/leader_forwarder.go +++ b/pkg/clusteragent/api/leader_forwarder.go @@ -87,6 +87,7 @@ func (lf *LeaderForwarder) Forward(rw http.ResponseWriter, req *http.Request) { if req.Header.Get(forwardHeader) != "" { http.Error(rw, "Query was already forwarded from: "+req.RemoteAddr, http.StatusLoopDetected) + return } var currentProxy *httputil.ReverseProxy diff --git a/pkg/clusteragent/api/leader_forwarder_test.go b/pkg/clusteragent/api/leader_forwarder_test.go new file mode 100644 index 00000000000000..70edface61a265 --- /dev/null +++ b/pkg/clusteragent/api/leader_forwarder_test.go @@ -0,0 +1,104 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +//go:build test + +package api + +import ( + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLeaderForwarder_SetLeaderIP(t *testing.T) { + lf := NewLeaderForwarder(5005, 10) + + // Initially no leader IP + assert.Equal(t, "", lf.GetLeaderIP()) + assert.Nil(t, lf.proxy) + + // Set leader IP + lf.SetLeaderIP("1.1.1.1") + assert.Equal(t, "1.1.1.1", lf.GetLeaderIP()) + assert.NotNil(t, lf.proxy) + + // Update leader IP + lf.SetLeaderIP("2.2.2.2") + assert.Equal(t, "2.2.2.2", lf.GetLeaderIP()) + assert.NotNil(t, lf.proxy) + + // Clear proxy with empty string - note: leaderIP is NOT cleared (returns early) + lf.SetLeaderIP("") + assert.Equal(t, "2.2.2.2", lf.GetLeaderIP()) // leaderIP unchanged + assert.Nil(t, lf.proxy) // but proxy is cleared +} + +func TestLeaderForwarder_Forward_NilProxy(t *testing.T) { + lf := NewLeaderForwarder(5005, 10) + + // No leader set, proxy is nil + rw := httptest.NewRecorder() + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + + lf.Forward(rw, req) + + assert.Equal(t, http.StatusServiceUnavailable, rw.Code) + assert.Equal(t, "true", rw.Header().Get("X-DCA-Forwarded")) +} + +func TestLeaderForwarder_Forward_LoopDetection(t *testing.T) { + // Track if leader server was called + leaderCalled := false + leaderServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + leaderCalled = true + w.WriteHeader(http.StatusOK) + })) + defer leaderServer.Close() + + port := leaderServer.Listener.Addr().(*net.TCPAddr).Port + lf := NewLeaderForwarder(port, 10) + lf.SetLeaderIP("127.0.0.1") + + // Request already has forward header (loop detection) + rw := httptest.NewRecorder() + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + req.Header.Set("X-DCA-Follower-Forwarded", "true") + + lf.Forward(rw, req) + + // Loop detection should return 508 and NOT forward to leader + assert.Equal(t, http.StatusLoopDetected, rw.Code) + assert.Equal(t, "true", rw.Header().Get("X-DCA-Forwarded")) + assert.False(t, leaderCalled, "Request should not be forwarded to leader when loop is detected") +} + +func TestLeaderForwarder_Forward_WithLeader(t *testing.T) { + // Create a test server to act as the leader + leaderServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the forward header was added + assert.Equal(t, "true", r.Header.Get("X-DCA-Follower-Forwarded")) + w.WriteHeader(http.StatusOK) + w.Write([]byte("leader response")) + })) + defer leaderServer.Close() + + // Extract port from test server + port := leaderServer.Listener.Addr().(*net.TCPAddr).Port + lf := NewLeaderForwarder(port, 10) + lf.SetLeaderIP("127.0.0.1") + + rw := httptest.NewRecorder() + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + + lf.Forward(rw, req) + + assert.Equal(t, http.StatusOK, rw.Code) + assert.Equal(t, "true", rw.Header().Get("X-DCA-Forwarded")) + assert.Equal(t, "leader response", rw.Body.String()) +} diff --git a/pkg/clusteragent/api/leader_handler_test.go b/pkg/clusteragent/api/leader_handler_test.go index f6efb638bdff40..f07e65ded68b0e 100644 --- a/pkg/clusteragent/api/leader_handler_test.go +++ b/pkg/clusteragent/api/leader_handler_test.go @@ -29,19 +29,25 @@ func (m *mockLeaderEngine) GetLeaderIP() (string, error) { return m.leaderIP, nil } -// fakeLeaderForwarder is a fake implementation of the forwarder for testing purposes -type fakeLeaderForwarder struct{} +// fakeLeaderForwarder is a fake implementation of the forwarder for testing purposes. +// It tracks leader IP changes and forward calls for verifying leadership transition behavior. +type fakeLeaderForwarder struct { + currentLeaderIP string + leaderIPChangeCount int + forwardCallCount int +} -// SetLeaderIP does nothing -func (f *fakeLeaderForwarder) SetLeaderIP(_ string) {} +func (f *fakeLeaderForwarder) SetLeaderIP(ip string) { + f.currentLeaderIP = ip + f.leaderIPChangeCount++ +} -// GetLeaderIP does nothing func (f *fakeLeaderForwarder) GetLeaderIP() string { - return "" + return f.currentLeaderIP } -// Forward returns ok func (f *fakeLeaderForwarder) Forward(w http.ResponseWriter, _ *http.Request) { + f.forwardCallCount++ w.WriteHeader(http.StatusOK) } @@ -92,3 +98,89 @@ func TestRejectOrForwardLeaderQuery_AsLeader(t *testing.T) { assert.False(t, lph.rejectOrForwardLeaderQuery(rw, req)) } + +// TestRejectOrForwardLeaderQuery_LeadershipTransition tests the behavior when +// leadership changes between requests (leader to follower and back). +func TestRejectOrForwardLeaderQuery_LeadershipTransition(t *testing.T) { + mockEngine := &mockLeaderEngine{ + isLeader: true, + leaderIP: "1.1.1.1", + } + forwarder := &fakeLeaderForwarder{} + + lph := &LeaderProxyHandler{ + leaderElectionEnabled: true, + le: mockEngine, + leaderForwarder: forwarder, + } + + // First request: we are the leader, should handle locally + rw1 := httptest.NewRecorder() + req1 := httptest.NewRequest("GET", "http://example.com/foo", nil) + assert.False(t, lph.rejectOrForwardLeaderQuery(rw1, req1), "Should handle locally as leader") + assert.Equal(t, 0, forwarder.forwardCallCount, "Should not forward when leader") + + // Simulate leadership loss + mockEngine.isLeader = false + mockEngine.leaderIP = "2.2.2.2" + + // Second request: we lost leadership, should forward to new leader + rw2 := httptest.NewRecorder() + req2 := httptest.NewRequest("GET", "http://example.com/foo", nil) + assert.True(t, lph.rejectOrForwardLeaderQuery(rw2, req2), "Should forward as follower") + assert.Equal(t, 1, forwarder.forwardCallCount, "Should forward once") + assert.Equal(t, "2.2.2.2", forwarder.currentLeaderIP, "Should update to new leader IP") + + // Simulate regaining leadership + mockEngine.isLeader = true + + // Third request: we became the leader again, should handle locally + rw3 := httptest.NewRecorder() + req3 := httptest.NewRequest("GET", "http://example.com/foo", nil) + assert.False(t, lph.rejectOrForwardLeaderQuery(rw3, req3), "Should handle locally as new leader") + assert.Equal(t, 1, forwarder.forwardCallCount, "Should not forward additional requests") +} + +// TestRejectOrForwardLeaderQuery_LeaderIPChange tests that the forwarder is updated +// when the leader IP changes while we remain a follower. +func TestRejectOrForwardLeaderQuery_LeaderIPChange(t *testing.T) { + mockEngine := &mockLeaderEngine{ + isLeader: false, + leaderIP: "1.1.1.1", + } + forwarder := &fakeLeaderForwarder{ + currentLeaderIP: "1.1.1.1", // Already knows old leader + } + + lph := &LeaderProxyHandler{ + leaderElectionEnabled: true, + le: mockEngine, + leaderForwarder: forwarder, + } + + // First request: forward to current leader + rw1 := httptest.NewRecorder() + req1 := httptest.NewRequest("GET", "http://example.com/foo", nil) + assert.True(t, lph.rejectOrForwardLeaderQuery(rw1, req1)) + assert.Equal(t, 1, forwarder.forwardCallCount) + // IP didn't change, so SetLeaderIP should not have been called + assert.Equal(t, 0, forwarder.leaderIPChangeCount, "Should not update IP when unchanged") + + // Simulate leader failover - new leader elected + mockEngine.leaderIP = "2.2.2.2" + + // Second request: should detect IP change and update forwarder + rw2 := httptest.NewRecorder() + req2 := httptest.NewRequest("GET", "http://example.com/foo", nil) + assert.True(t, lph.rejectOrForwardLeaderQuery(rw2, req2)) + assert.Equal(t, 2, forwarder.forwardCallCount) + assert.Equal(t, 1, forwarder.leaderIPChangeCount, "Should update IP once") + assert.Equal(t, "2.2.2.2", forwarder.currentLeaderIP, "Should have new leader IP") + + // Third request: IP hasn't changed again + rw3 := httptest.NewRecorder() + req3 := httptest.NewRequest("GET", "http://example.com/foo", nil) + assert.True(t, lph.rejectOrForwardLeaderQuery(rw3, req3)) + assert.Equal(t, 3, forwarder.forwardCallCount) + assert.Equal(t, 1, forwarder.leaderIPChangeCount, "Should not update IP when unchanged") +}