Skip to content

Commit 0eea674

Browse files
Refactor loadBalanceRandom (#100)
* Refactor loadballanceRandom * Fix variable naming for consistency in load balancing logic --------- Co-authored-by: Balazs Zomborszki <balazs.zomborszki@nokia.com>
1 parent e435ae8 commit 0eea674

File tree

2 files changed

+40
-34
lines changed

2 files changed

+40
-34
lines changed

client.go

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ func (c *Client) doSetAuth(ctx context.Context, req *http.Request) error {
577577
return nil
578578
}
579579

580-
func (c *Client) doPre(req *http.Request) (*http.Response, error) {
580+
func (c *Client) doMonitorPre(req *http.Request) (*http.Response, error) {
581581
for i := len(c.monitor) - 1; i >= 0; i-- {
582582
if c.monitor[i].pre != nil {
583583
resp, err := c.monitor[i].pre(req)
@@ -589,7 +589,7 @@ func (c *Client) doPre(req *http.Request) (*http.Response, error) {
589589
return nil, nil
590590
}
591591

592-
func (c *Client) doPost(req *http.Request, resp *http.Response, err error) (*http.Response, error) {
592+
func (c *Client) doMonitorPost(req *http.Request, resp *http.Response, err error) (*http.Response, error) {
593593
for i := range c.monitor {
594594
if c.monitor[i].post != nil {
595595
newResp := c.monitor[i].post(req, resp, err)
@@ -615,7 +615,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
615615
req.Header = make(http.Header)
616616
}
617617

618-
target, err := c.setReqTarget(req)
618+
targetForLog, err := c.setReqTarget(req)
619619
if err != nil {
620620
return nil, err
621621
}
@@ -626,16 +626,15 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
626626
return nil, err
627627
}
628628

629-
if resp, err := c.doPre(req); resp != nil || err != nil {
629+
if resp, err := c.doMonitorPre(req); resp != nil || err != nil {
630630
return resp, err
631631
}
632632

633-
target = c.setLoadBalanceTarget(req, target)
634633
req, spanStr, spanEndFunc := doSpan(req)
635634

636-
resp, err := c.doLog(spanStr, req, target)
635+
resp, err := c.doLog(spanStr, req, targetForLog)
637636

638-
resp, err = c.doPost(req, resp, err)
637+
resp, err = c.doMonitorPost(req, resp, err)
639638

640639
if spanEndFunc != nil {
641640
spanEndFunc()
@@ -644,7 +643,12 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
644643
return resp, err
645644
}
646645

647-
func (c *Client) doWithRetry(req *http.Request, spanStr, target string) (*http.Response, error) {
646+
func (c *Client) doWithRetry(req *http.Request, spanStr, targetForLog string) (*http.Response, error) {
647+
originalHost := req.URL.Hostname()
648+
targetForLog = c.setLoadBalanceTarget(req, targetForLog, originalHost)
649+
650+
log.Debugf("[%s] Sent req: %s %s", spanStr, req.Method, targetForLog)
651+
648652
clonedBody := c.cloneBody(req)
649653
resp, err := c.do(req)
650654

@@ -653,22 +657,23 @@ func (c *Client) doWithRetry(req *http.Request, spanStr, target string) (*http.R
653657
_ = resp.Body.Close()
654658
}
655659

660+
targetForLog = c.setLoadBalanceTarget(req, targetForLog, originalHost) // Set target again
661+
656662
req.Body = clonedBody
657663
clonedBody = c.cloneBody(req)
658664

659665
time.Sleep(c.calcBackoff(retries))
660-
log.Debugf("[%s] Send rty(%d): %s %s: err=%v", spanStr, retries, req.Method, target, err)
666+
log.Debugf("[%s] Send rty(%d): %s %s: err=%v", spanStr, retries, req.Method, targetForLog, err)
661667
resp, err = c.do(req)
662668
}
663669

664670
return resp, err
665671
}
666672

667-
func (c *Client) doLog(spanStr string, req *http.Request, target string) (*http.Response, error) {
668-
log.Debugf("[%s] Sent req: %s %s", spanStr, req.Method, target)
669-
resp, err := c.doWithRetry(req, spanStr, target)
673+
func (c *Client) doLog(spanStr string, req *http.Request, targetForLog string) (*http.Response, error) {
674+
resp, err := c.doWithRetry(req, spanStr, targetForLog)
670675
if err != nil {
671-
log.Debugf("[%s] Fail req: %s %s", spanStr, req.Method, target)
676+
log.Debugf("[%s] Fail req: %s %s", spanStr, req.Method, targetForLog)
672677
} else {
673678
log.Debugf("[%s] Recv rsp: %s", spanStr, resp.Status)
674679
}
@@ -1072,23 +1077,23 @@ var netLookupHost = func(ctx context.Context, host string) ([]string, error) {
10721077
return net.DefaultResolver.LookupHost(ctx, host)
10731078
}
10741079

1075-
func (c *Client) setLoadBalanceTarget(req *http.Request, target string) (targetOut string) {
1080+
func (c *Client) setLoadBalanceTarget(req *http.Request, target, originalHost string) (targetOut string) {
10761081
targetOut = target
10771082
if !c.LoadBalanceRandom {
10781083
return
10791084
}
1080-
if net.ParseIP(req.URL.Hostname()) != nil {
1085+
if net.ParseIP(originalHost) != nil {
10811086
log.Debugf("Host %s is an IP address, not a hostname. Load balancing is not applied.", req.URL.Hostname())
10821087
return // Do not apply load balancing if Host is an IP address.
10831088
}
10841089

1085-
IPs, err := netLookupHost(req.Context(), req.URL.Hostname())
1090+
IPs, err := netLookupHost(req.Context(), originalHost)
10861091
if err != nil {
1087-
log.Debugf("Failed to resolve host %s: %v", req.URL.Hostname(), err)
1092+
log.Debugf("Failed to resolve host %s: %v", originalHost, err)
10881093
return
10891094
}
10901095
if len(IPs) > 1 {
1091-
log.Debugf("Multiple IPs for %s: %v", req.URL.Hostname(), IPs)
1096+
log.Debugf("Multiple IPs for %s: %v", originalHost, IPs)
10921097
if req.Host == "" { // MonitorPre maybe already change req.URL.Host. And set req.Host to the original Host.
10931098
req.Host = req.URL.Host // Set Host header to original Host. This is used for TLS SNI and other purposes.
10941099
}
@@ -1103,9 +1108,9 @@ func chooseIPFromList(IPs []string) string {
11031108
return IPs[index] // Return the randomly chosen IP
11041109
}
11051110

1106-
// EnableLoadBalanceRandom enables load balancing by randomly choosing one of the IP addresses
1107-
// returned by net.LookupHost for the target hostname.
1108-
func (c *Client) EnableLoadBalanceRandom() *Client {
1109-
c.LoadBalanceRandom = true
1111+
// EnableLoadBalanceRandom enables or disables load balancing by random IP address.
1112+
// If enabled, the client will resolve the hostname and choose a random IP address from the list
1113+
func (c *Client) EnableLoadBalanceRandom(enable bool) *Client {
1114+
c.LoadBalanceRandom = enable
11101115
return c
11111116
}

client_test.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -963,27 +963,27 @@ func TestEnableLoadBalanceRandom(t *testing.T) {
963963
client := NewClient()
964964
assert.False(t, client.LoadBalanceRandom, "LoadBalanceRandom should be false by default")
965965

966-
client.EnableLoadBalanceRandom()
966+
client.EnableLoadBalanceRandom(true)
967967
assert.True(t, client.LoadBalanceRandom, "LoadBalanceRandom should be true after calling EnableLoadBalanceRandom")
968968
}
969969
func TestSetLoadBalanceTarget_NoLoadBalance(t *testing.T) {
970970
client := NewClient()
971971
req, _ := http.NewRequest("GET", "http://example.com/resource", nil)
972972
target := "http://example.com/resource"
973-
out := client.setLoadBalanceTarget(req, target)
973+
out := client.setLoadBalanceTarget(req, target, req.URL.Hostname())
974974
assert.Equal(t, target, out)
975975
}
976976

977977
func TestSetLoadBalanceTarget_IPAddressHost(t *testing.T) {
978-
client := NewClient().EnableLoadBalanceRandom()
978+
client := NewClient().EnableLoadBalanceRandom(true)
979979
req, _ := http.NewRequest("GET", "http://127.0.0.1/resource", nil)
980980
target := "http://127.0.0.1/resource"
981-
out := client.setLoadBalanceTarget(req, target)
981+
out := client.setLoadBalanceTarget(req, target, req.URL.Hostname())
982982
assert.Equal(t, target, out)
983983
}
984984

985985
func TestSetLoadBalanceTarget_ResolveError(t *testing.T) {
986-
client := NewClient().EnableLoadBalanceRandom()
986+
client := NewClient().EnableLoadBalanceRandom(true)
987987
req, _ := http.NewRequest("GET", "http://nonexistent.invalid/resource", nil)
988988
target := "http://nonexistent.invalid/resource"
989989

@@ -994,12 +994,12 @@ func TestSetLoadBalanceTarget_ResolveError(t *testing.T) {
994994
}
995995
defer func() { netLookupHost = origLookupHost }()
996996

997-
out := client.setLoadBalanceTarget(req, target)
997+
out := client.setLoadBalanceTarget(req, target, req.URL.Hostname())
998998
assert.Equal(t, target, out)
999999
}
10001000

10011001
func TestSetLoadBalanceTarget_SingleIP(t *testing.T) {
1002-
client := NewClient().EnableLoadBalanceRandom()
1002+
client := NewClient().EnableLoadBalanceRandom(true)
10031003
req, _ := http.NewRequest("GET", "http://example.com/resource", nil)
10041004
target := "http://example.com/resource"
10051005

@@ -1010,15 +1010,16 @@ func TestSetLoadBalanceTarget_SingleIP(t *testing.T) {
10101010
}
10111011
defer func() { netLookupHost = origLookupHost }()
10121012

1013-
out := client.setLoadBalanceTarget(req, target)
1013+
out := client.setLoadBalanceTarget(req, target, req.URL.Hostname())
10141014
assert.NotContains(t, out, "->")
10151015
assert.Contains(t, req.URL.Host, "example.com")
10161016
assert.Equal(t, "example.com", req.Host)
10171017
}
10181018

10191019
func TestSetLoadBalanceTarget_DoubleIP(t *testing.T) {
1020-
client := NewClient().EnableLoadBalanceRandom()
1021-
req, _ := http.NewRequest("GET", "http://example.com/resource", nil)
1020+
URL := "http://example-headless.com/resource"
1021+
client := NewClient().EnableLoadBalanceRandom(strings.Contains(URL, "headless"))
1022+
req, _ := http.NewRequest("GET", URL, nil)
10221023
target := "http://example.com/resource"
10231024

10241025
// Patch net.LookupHost to return a single IP
@@ -1028,8 +1029,8 @@ func TestSetLoadBalanceTarget_DoubleIP(t *testing.T) {
10281029
}
10291030
defer func() { netLookupHost = origLookupHost }()
10301031

1031-
out := client.setLoadBalanceTarget(req, target)
1032+
out := client.setLoadBalanceTarget(req, target, req.URL.Hostname())
10321033
assert.Contains(t, req.URL.Host, "192.0.2.")
1033-
assert.Equal(t, "example.com", req.Host)
1034+
assert.Equal(t, "example-headless.com", req.Host)
10341035
assert.Regexp(t, `\[192\.0\.2\.\d+]`, out, "Expected output to contain one IP in the format '[IP]'")
10351036
}

0 commit comments

Comments
 (0)