diff --git a/CHANGELOG.md b/CHANGELOG.md index 1268abb3..bd533c02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # [Unreleased] - Update go version to 1.23.0 (#628) +- Do not restart proxies when using hostnames to specify listen address when updating a proxy + and populating a collection (#631, @robinbrandt) # [2.11.0] - 2024-10-16 diff --git a/api_test.go b/api_test.go index c1a91a69..82d58ce2 100644 --- a/api_test.go +++ b/api_test.go @@ -256,6 +256,7 @@ func TestPopulateExistingProxy(t *testing.T) { if err != nil { t.Fatal("Unable to create proxy:", err) } + _, err = client.CreateProxy("two", "localhost:7373", "localhost:7474") if err != nil { t.Fatal("Unable to create proxy:", err) @@ -270,7 +271,7 @@ func TestPopulateExistingProxy(t *testing.T) { testProxies, err := client.Populate([]tclient.Proxy{ { Name: "one", - Listen: "127.0.0.1:7070", + Listen: "localhost:7070", // intentional: this should be resolved to 127.0.0.1:7070 Upstream: "localhost:7171", Enabled: true, }, diff --git a/proxy.go b/proxy.go index 25155f4a..9bcda05e 100644 --- a/proxy.go +++ b/proxy.go @@ -81,7 +81,12 @@ func (proxy *Proxy) Update(input *Proxy) error { proxy.Lock() defer proxy.Unlock() - if input.Listen != proxy.Listen || input.Upstream != proxy.Upstream { + differs, err := proxy.Differs(input) + if err != nil { + return err + } + + if differs { stop(proxy) proxy.Listen = input.Listen proxy.Upstream = input.Upstream @@ -131,6 +136,19 @@ func (proxy *Proxy) close() { } } +func (proxy *Proxy) Differs(other *Proxy) (bool, error) { + newResolvedListen, err := net.ResolveTCPAddr("tcp", other.Listen) + if err != nil { + return false, err + } + + if proxy.Listen != newResolvedListen.String() || proxy.Upstream != other.Upstream { + return true, nil + } + + return false, nil +} + // This channel is to kill the blocking Accept() call below by closing the // net.Listener. func (proxy *Proxy) freeBlocker(acceptTomb *tomb.Tomb) { diff --git a/proxy_collection.go b/proxy_collection.go index a1cf9750..e359f0e6 100644 --- a/proxy_collection.go +++ b/proxy_collection.go @@ -43,13 +43,18 @@ func (collection *ProxyCollection) Add(proxy *Proxy, start bool) error { return nil } -func (collection *ProxyCollection) AddOrReplace(proxy *Proxy, start bool) error { +func (collection *ProxyCollection) AddOrReplace(proxy *Proxy, start bool) (*Proxy, error) { collection.Lock() defer collection.Unlock() if existing, exists := collection.proxies[proxy.Name]; exists { - if existing.Listen == proxy.Listen && existing.Upstream == proxy.Upstream { - return nil + differs, err := existing.Differs(proxy) + if err != nil { + return nil, err + } + + if !differs { + return existing, nil } existing.Stop() } @@ -57,13 +62,13 @@ func (collection *ProxyCollection) AddOrReplace(proxy *Proxy, start bool) error if start { err := proxy.Start() if err != nil { - return err + return nil, err } } collection.proxies[proxy.Name] = proxy - return nil + return proxy, nil } func (collection *ProxyCollection) PopulateJson( @@ -98,12 +103,12 @@ func (collection *ProxyCollection) PopulateJson( for i := range input { proxy := NewProxy(server, input[i].Name, input[i].Listen, input[i].Upstream) - err = collection.AddOrReplace(proxy, *input[i].Enabled) + addedOrReplaced, err := collection.AddOrReplace(proxy, *input[i].Enabled) if err != nil { return proxies, err } - proxies = append(proxies, proxy) + proxies = append(proxies, addedOrReplaced) } return proxies, err } diff --git a/proxy_test.go b/proxy_test.go index c50f355b..76aaa3ad 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -3,8 +3,10 @@ package toxiproxy_test import ( "bytes" "encoding/hex" + "errors" "io" "net" + "os" "testing" "time" @@ -177,6 +179,62 @@ func TestProxyUpdate(t *testing.T) { }) } +func TestProxyUpdateWithHostname(t *testing.T) { + testhelper.WithTCPServer(t, func(upstream string, response chan []byte) { + proxy := NewTestProxy("test", upstream) + err := proxy.Start() + if err != nil { + t.Error("Proxy failed to start", err) + } + AssertProxyUp(t, proxy.Listen, true) + + connectionLost := make(chan bool) + + // Start a goroutine to check if connection is maintained + go func() { + conn, err := net.Dial("tcp", proxy.Listen) + if err != nil { + t.Error("Failed to connect to proxy", err) + } + defer conn.Close() + + // Try to read from the connection + buf := make([]byte, 1024) + conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + _, err = conn.Read(buf) + if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { + connectionLost <- true + return + } + + connectionLost <- false + }() + + _, port, err := net.SplitHostPort(proxy.Listen) + if err != nil { + t.Error("Failed to split host and port", err) + } + + input := &toxiproxy.Proxy{ + Listen: net.JoinHostPort("localhost", port), + Upstream: proxy.Upstream, + Enabled: true, + } + err = proxy.Update(input) + if err != nil { + t.Error("Failed to update proxy", err) + } + + // Check if the connection was lost during the update + if lost := <-connectionLost; lost { + t.Error("Connection was lost during proxy update") + } + + // Verify proxy is still up after the update + AssertProxyUp(t, proxy.Listen, true) + }) +} + func TestRestartFailedToStartProxy(t *testing.T) { testhelper.WithTCPServer(t, func(upstream string, response chan []byte) { proxy := NewTestProxy("test", upstream) @@ -207,3 +265,28 @@ func TestRestartFailedToStartProxy(t *testing.T) { AssertProxyUp(t, proxy.Listen, false) }) } + +func TestProxyDiffers(t *testing.T) { + testhelper.WithTCPServer(t, func(upstream string, response chan []byte) { + proxy := NewTestProxy("test", upstream) + proxy.Start() + _, port, err := net.SplitHostPort(proxy.Listen) + if err != nil { + t.Error("Failed to split host and port", err) + } + otherProxy := &toxiproxy.Proxy{ + Name: "other", + Listen: net.JoinHostPort("localhost", port), + Upstream: upstream, + Enabled: true, + } + + differs, err := proxy.Differs(otherProxy) + if err != nil { + t.Error("Failed to check if proxy differs", err) + } + if differs { + t.Error("Proxy should not differ ") + } + }) +}