diff --git a/client/nginx.go b/client/nginx.go index 47c2fb3e..4317ad95 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -808,6 +808,7 @@ func (client *NginxClient) DeleteHTTPServer(ctx context.Context, upstream string // Servers that are in the slice, but don't exist in NGINX will be added to NGINX. // Servers that aren't in the slice, but exist in NGINX, will be removed from NGINX. // Servers that are in the slice and exist in NGINX, but have different parameters, will be updated. +// The client will attempt to update all servers, returning all the errors that occurred. func (client *NginxClient) UpdateHTTPServers(ctx context.Context, upstream string, servers []UpstreamServer) (added []UpstreamServer, deleted []UpstreamServer, updated []UpstreamServer, err error) { serversInNginx, err := client.GetHTTPServers(ctx, upstream) if err != nil { @@ -824,27 +825,37 @@ func (client *NginxClient) UpdateHTTPServers(ctx context.Context, upstream strin toAdd, toDelete, toUpdate := determineUpdates(formattedServers, serversInNginx) for _, server := range toAdd { - err := client.AddHTTPServer(ctx, upstream, server) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to update servers of %v upstream: %w", upstream, err) + addErr := client.AddHTTPServer(ctx, upstream, server) + if addErr != nil { + err = errors.Join(err, addErr) + continue } + added = append(added, server) } for _, server := range toDelete { - err := client.DeleteHTTPServer(ctx, upstream, server.Server) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to update servers of %v upstream: %w", upstream, err) + deleteErr := client.DeleteHTTPServer(ctx, upstream, server.Server) + if deleteErr != nil { + err = errors.Join(err, deleteErr) + continue } + deleted = append(deleted, server) } for _, server := range toUpdate { - err := client.UpdateHTTPServer(ctx, upstream, server) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to update servers of %v upstream: %w", upstream, err) + updateErr := client.UpdateHTTPServer(ctx, upstream, server) + if updateErr != nil { + err = errors.Join(err, updateErr) + continue } + updated = append(updated, server) + } + + if err != nil { + err = fmt.Errorf("failed to update servers of %s upstream: %w", upstream, err) } - return toAdd, toDelete, toUpdate, nil + return added, deleted, updated, err } // haveSameParameters checks if a given server has the same parameters as a server already present in NGINX. Order matters. @@ -1108,6 +1119,7 @@ func (client *NginxClient) DeleteStreamServer(ctx context.Context, upstream stri // Servers that are in the slice, but don't exist in NGINX will be added to NGINX. // Servers that aren't in the slice, but exist in NGINX, will be removed from NGINX. // Servers that are in the slice and exist in NGINX, but have different parameters, will be updated. +// The client will attempt to update all servers, returning all the errors that occurred. func (client *NginxClient) UpdateStreamServers(ctx context.Context, upstream string, servers []StreamUpstreamServer) (added []StreamUpstreamServer, deleted []StreamUpstreamServer, updated []StreamUpstreamServer, err error) { serversInNginx, err := client.GetStreamServers(ctx, upstream) if err != nil { @@ -1123,27 +1135,37 @@ func (client *NginxClient) UpdateStreamServers(ctx context.Context, upstream str toAdd, toDelete, toUpdate := determineStreamUpdates(formattedServers, serversInNginx) for _, server := range toAdd { - err := client.AddStreamServer(ctx, upstream, server) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to update stream servers of %v upstream: %w", upstream, err) + addErr := client.AddStreamServer(ctx, upstream, server) + if addErr != nil { + err = errors.Join(err, addErr) + continue } + added = append(added, server) } for _, server := range toDelete { - err := client.DeleteStreamServer(ctx, upstream, server.Server) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to update stream servers of %v upstream: %w", upstream, err) + deleteErr := client.DeleteStreamServer(ctx, upstream, server.Server) + if deleteErr != nil { + err = errors.Join(err, deleteErr) + continue } + deleted = append(deleted, server) } for _, server := range toUpdate { - err := client.UpdateStreamServer(ctx, upstream, server) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to update stream servers of %v upstream: %w", upstream, err) + updateErr := client.UpdateStreamServer(ctx, upstream, server) + if updateErr != nil { + err = errors.Join(err, updateErr) + continue } + updated = append(updated, server) + } + + if err != nil { + err = fmt.Errorf("failed to update stream servers of %s upstream: %w", upstream, err) } - return toAdd, toDelete, toUpdate, nil + return added, deleted, updated, err } func (client *NginxClient) getIDOfStreamServer(ctx context.Context, upstream string, name string) (int, error) { diff --git a/client/nginx_test.go b/client/nginx_test.go index 36cb87cf..920e9c72 100644 --- a/client/nginx_test.go +++ b/client/nginx_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "encoding/json" "net/http" "net/http/httptest" "reflect" @@ -980,3 +981,182 @@ func TestExtractPlusVersionNegativeCase(t *testing.T) { }) } } + +func TestClientHTTPUpdateServers(t *testing.T) { + t.Parallel() + + responses := []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for AddHTTPServer GET servers for http server + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for AddHTTPServer POST server for http server + { + statusCode: http.StatusInternalServerError, + servers: []UpstreamServer{}, + }, + // response for AddHTTPServer GET servers for https server + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for AddHTTPServer POST server for https server + { + statusCode: http.StatusCreated, + servers: []UpstreamServer{}, + }, + } + + handler := &fakeHandler{ + func(w http.ResponseWriter, _ *http.Request) { + if len(responses) == 0 { + t.Fatal("ran out of responses") + } + + re := responses[0] + responses = responses[1:] + + w.WriteHeader(re.statusCode) + + resp, err := json.Marshal(re.servers) + if err != nil { + t.Fatal(err) + } + _, err = w.Write(resp) + if err != nil { + t.Fatal(err) + } + }, + } + + server := httptest.NewServer(handler) + defer server.Close() + + client, err := NewNginxClient(server.URL, WithHTTPClient(&http.Client{})) + if err != nil { + t.Fatal(err) + } + + httpServer := UpstreamServer{Server: "127.0.0.1:80"} + httpsServer := UpstreamServer{Server: "127.0.0.1:443"} + + // we expect that we will get an error for the 500 error encountered when putting the http server + // but we also expect that we have the https server added + added, _, _, err := client.UpdateHTTPServers(context.TODO(), "fakeUpstream", []UpstreamServer{ + httpServer, + httpsServer, + }) + if err == nil { + t.Fatal("expected to receive an error for 500 response when adding first server") + } + + if len(added) != 1 { + t.Fatalf("expected to get one added server, instead got %d", len(added)) + } + + if !reflect.DeepEqual(httpsServer, added[0]) { + t.Errorf("expected: %v got: %v", httpsServer, added[0]) + } +} + +func TestClientStreamUpdateServers(t *testing.T) { + t.Parallel() + + responses := []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for AddStreamServer GET servers for streamServer1 + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for AddStreamServer POST server for streamServer1 + { + statusCode: http.StatusInternalServerError, + servers: []UpstreamServer{}, + }, + // response for AddStreamServer GET servers for streamServer2 + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for AddStreamServer POST server for streamServer2 + { + statusCode: http.StatusCreated, + servers: []UpstreamServer{}, + }, + } + + handler := &fakeHandler{ + func(w http.ResponseWriter, _ *http.Request) { + if len(responses) == 0 { + t.Fatal("ran out of responses") + } + + re := responses[0] + responses = responses[1:] + + w.WriteHeader(re.statusCode) + + resp, err := json.Marshal(re.servers) + if err != nil { + t.Fatal(err) + } + _, err = w.Write(resp) + if err != nil { + t.Fatal(err) + } + }, + } + + server := httptest.NewServer(handler) + defer server.Close() + + client, err := NewNginxClient(server.URL, WithHTTPClient(&http.Client{})) + if err != nil { + t.Fatal(err) + } + + streamServer1 := StreamUpstreamServer{Server: "127.0.0.1:2000"} + streamServer2 := StreamUpstreamServer{Server: "127.0.0.1:3000"} + + // we expect that we will get an error for the 500 error encountered when putting server1 + // but we also expect that we get the second server added + added, _, _, err := client.UpdateStreamServers(context.TODO(), "fakeUpstream", []StreamUpstreamServer{ + streamServer1, + streamServer2, + }) + if err == nil { + t.Fatal("expected to receive an error for 500 response when adding first server") + } + + if len(added) != 1 { + t.Fatalf("expected to get one added server, instead got %d", len(added)) + } + + if !reflect.DeepEqual(streamServer2, added[0]) { + t.Errorf("expected: %v got: %v", streamServer2, added[0]) + } +} + +type response struct { + servers []UpstreamServer + statusCode int +} + +type fakeHandler struct { + handler func(w http.ResponseWriter, r *http.Request) +} + +func (h *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.handler(w, r) +}