diff --git a/build/config.yaml.example b/build/config.yaml.example index 37e870adf..d194bf9db 100644 --- a/build/config.yaml.example +++ b/build/config.yaml.example @@ -3,6 +3,8 @@ # cloud_provider: AWS # region: us-west-2 # api_endpoint: http://127.0.0.1:8080/api +# optional dataplane_api_key +# dataplane_api_key: your-api-key # sync_interval: 5s # upstreams: # - name: backend-one @@ -24,6 +26,8 @@ # subscription_id: my_subscription_id # resource_group_name: my_resource_group # api_endpoint: http://127.0.0.1:8080/api +# optional dataplane_api_key +# dataplane_api_key: your-api-key # sync_interval: 5s # upstreams: # - name: backend-one diff --git a/cmd/sync/config.go b/cmd/sync/config.go index 97fa5fba5..b1515ae62 100644 --- a/cmd/sync/config.go +++ b/cmd/sync/config.go @@ -10,9 +10,10 @@ import ( // commonConfig stores the configuration parameters common to all providers. type commonConfig struct { - APIEndpoint string `yaml:"api_endpoint"` - CloudProvider string `yaml:"cloud_provider"` - SyncInterval time.Duration `yaml:"sync_interval"` + APIEndpoint string `yaml:"api_endpoint"` + CloudProvider string `yaml:"cloud_provider"` + DataplaneAPIKey string `yaml:"dataplane_api_key,omitempty"` + SyncInterval time.Duration `yaml:"sync_interval"` } func parseCommonConfig(data []byte) (*commonConfig, error) { diff --git a/cmd/sync/main.go b/cmd/sync/main.go index 263382d6f..205b202c5 100644 --- a/cmd/sync/main.go +++ b/cmd/sync/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/base64" "flag" "fmt" "io" @@ -21,7 +22,10 @@ var ( version string ) -const connTimeoutInSecs = 10 +const ( + connTimeoutInSecs = 10 + maxHeaders = 100 +) func main() { flag.Parse() @@ -63,7 +67,7 @@ func main() { os.Exit(10) } - httpClient := &http.Client{Timeout: connTimeoutInSecs * time.Second} + httpClient := NewHTTPClient(commonConfig) nginxClient, err := nginx.NewNginxClient(commonConfig.APIEndpoint, nginx.WithHTTPClient(httpClient)) if err != nil { log.Printf("Couldn't create NGINX client: %v", err) @@ -189,3 +193,57 @@ func getStreamUpstreamServerAddresses(server []nginx.StreamUpstreamServer) []str } return streamUpstreamServerAddr } + +// headerTransport wraps an http.RoundTripper and adds custom headers to all requests. +type headerTransport struct { + headers http.Header + transport http.RoundTripper +} + +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + clonedReq := req.Clone(req.Context()) + + for key, values := range t.headers { + for _, value := range values { + clonedReq.Header.Add(key, value) + } + } + + if len(clonedReq.Header) > maxHeaders { + return nil, fmt.Errorf("number of headers in request exceeds the maximum allowed (%d)", maxHeaders) + } + + resp, err := t.transport.RoundTrip(clonedReq) + if err != nil { + return nil, fmt.Errorf("headerTransport RoundTrip failed: %w", err) + } + + return resp, nil +} + +func NewHTTPClient(cfg *commonConfig) *http.Client { + headers := NewHeaders(cfg) + baseTransport := &http.Transport{} + + return &http.Client{ + Transport: &headerTransport{ + headers: headers, + transport: baseTransport, + }, + Timeout: connTimeoutInSecs * time.Second, + } +} + +func NewHeaders(cfg *commonConfig) http.Header { + headers := http.Header{} + headers.Set("Content-Type", "application/json") + + if cfg.DataplaneAPIKey != "" { + authValue := "ApiKey " + base64.StdEncoding.EncodeToString([]byte(cfg.DataplaneAPIKey)) + headers.Set("Authorization", authValue) + } else { + log.Printf("[optional] DataplaneAPIKey not configured") + } + + return headers +} diff --git a/cmd/sync/main_test.go b/cmd/sync/main_test.go new file mode 100644 index 000000000..154d63f32 --- /dev/null +++ b/cmd/sync/main_test.go @@ -0,0 +1,401 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + nginx "github.com/nginx/nginx-plus-go-client/v2/client" +) + +func TestGetUpstreamServerAddresses(t *testing.T) { + t.Parallel() + tests := []struct { + name string + servers []nginx.UpstreamServer + expected []string + }{ + { + name: "empty server list", + servers: []nginx.UpstreamServer{}, + expected: []string{}, + }, + { + name: "single server", + servers: []nginx.UpstreamServer{ + {Server: "10.0.0.1:80"}, + }, + expected: []string{"10.0.0.1:80"}, + }, + { + name: "multiple servers", + servers: []nginx.UpstreamServer{ + {Server: "10.0.0.1:80"}, + {Server: "10.0.0.2:80"}, + {Server: "10.0.0.3:8080"}, + }, + expected: []string{"10.0.0.1:80", "10.0.0.2:80", "10.0.0.3:8080"}, + }, + { + name: "servers with additional fields", + servers: []nginx.UpstreamServer{ + {Server: "192.168.1.1:443", MaxConns: intPtr(100), Weight: intPtr(5)}, + {Server: "192.168.1.2:443", MaxFails: intPtr(3)}, + }, + expected: []string{"192.168.1.1:443", "192.168.1.2:443"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := getUpstreamServerAddresses(tt.servers) + if len(result) != len(tt.expected) { + t.Errorf("expected %d addresses, got %d", len(tt.expected), len(result)) + return + } + for i, addr := range result { + if addr != tt.expected[i] { + t.Errorf("expected address[%d] = %s, got %s", i, tt.expected[i], addr) + } + } + }) + } +} + +func TestGetStreamUpstreamServerAddresses(t *testing.T) { + t.Parallel() + tests := []struct { + name string + servers []nginx.StreamUpstreamServer + expected []string + }{ + { + name: "empty server list", + servers: []nginx.StreamUpstreamServer{}, + expected: []string{}, + }, + { + name: "single stream server", + servers: []nginx.StreamUpstreamServer{ + {Server: "10.0.0.1:3306"}, + }, + expected: []string{"10.0.0.1:3306"}, + }, + { + name: "multiple stream servers", + servers: []nginx.StreamUpstreamServer{ + {Server: "10.0.0.1:3306"}, + {Server: "10.0.0.2:3306"}, + {Server: "10.0.0.3:5432"}, + }, + expected: []string{"10.0.0.1:3306", "10.0.0.2:3306", "10.0.0.3:5432"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := getStreamUpstreamServerAddresses(tt.servers) + if len(result) != len(tt.expected) { + t.Errorf("expected %d addresses, got %d", len(tt.expected), len(result)) + return + } + for i, addr := range result { + if addr != tt.expected[i] { + t.Errorf("expected address[%d] = %s, got %s", i, tt.expected[i], addr) + } + } + }) + } +} + +func TestNewHeaders(t *testing.T) { + t.Parallel() + tests := []struct { + config *commonConfig + name string + expectedContentType string + authPrefix string + hasAuthorization bool + }{ + { + name: "with API key", + config: &commonConfig{ + DataplaneAPIKey: "test-api-key-123", + }, + expectedContentType: "application/json", + hasAuthorization: true, + authPrefix: "ApiKey ", + }, + { + name: "without API key", + config: &commonConfig{}, + expectedContentType: "application/json", + hasAuthorization: false, + }, + { + name: "with empty API key", + config: &commonConfig{ + DataplaneAPIKey: "", + }, + expectedContentType: "application/json", + hasAuthorization: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + headers := NewHeaders(tt.config) + + contentType := headers.Get("Content-Type") + if contentType != tt.expectedContentType { + t.Errorf("expected Content-Type %s, got %s", tt.expectedContentType, contentType) + } + + auth := headers.Get("Authorization") + if tt.hasAuthorization { + if auth == "" { + t.Error("expected Authorization header, got empty") + } + if len(tt.authPrefix) > 0 && len(auth) < len(tt.authPrefix) { + t.Errorf("expected Authorization to start with %s", tt.authPrefix) + } + } else if auth != "" { + t.Errorf("expected no Authorization header, got %s", auth) + } + }) + } +} + +func TestHeaderTransport_RoundTrip(t *testing.T) { + t.Parallel() + tests := []struct { + headers http.Header + name string + expectedStatusCode int + expectError bool + }{ + { + name: "successful request with headers", + headers: http.Header{ + "Content-Type": []string{"application/json"}, + "Authorization": []string{"ApiKey test123"}, + }, + expectedStatusCode: http.StatusOK, + expectError: false, + }, + { + name: "request without custom headers", + headers: http.Header{ + "Content-Type": []string{"application/json"}, + }, + expectedStatusCode: http.StatusOK, + expectError: false, + }, + { + name: "request with empty headers", + headers: http.Header{}, + expectedStatusCode: http.StatusOK, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify custom headers were added + for key := range tt.headers { + if r.Header.Get(key) == "" { + t.Errorf("expected header %s to be present", key) + } + } + w.WriteHeader(tt.expectedStatusCode) + if _, err := w.Write([]byte(`{"status":"ok"}`)); err != nil { + t.Errorf("failed to write response: %v", err) + } + })) + defer server.Close() + + transport := &headerTransport{ + headers: tt.headers, + transport: http.DefaultTransport, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := transport.RoundTrip(req) + + if tt.expectError { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != tt.expectedStatusCode { + t.Errorf("expected status %d, got %d", tt.expectedStatusCode, resp.StatusCode) + } + }) + } +} + +func TestHeaderTransport_RoundTrip_HeaderLimits(t *testing.T) { + t.Parallel() + tests := []struct { + name string + errorMessage string + numHeaders int + expectError bool + }{ + { + name: "well below header limit", + numHeaders: 10, + expectError: false, + }, + { + name: "at half header limit", + numHeaders: maxHeaders / 2, + expectError: false, + }, + { + name: "exactly at header limit", + numHeaders: maxHeaders, + expectError: false, + }, + { + name: "one over header limit", + numHeaders: maxHeaders + 1, + expectError: true, + errorMessage: "number of headers in request exceeds the maximum allowed", + }, + { + name: "zero headers", + numHeaders: 0, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(`{"status":"ok"}`)); err != nil { + t.Errorf("failed to write response: %v", err) + } + })) + defer server.Close() + + headers := http.Header{} + for i := range tt.numHeaders { + headerName := fmt.Sprintf("X-Custom-Header-%d", i) + headers.Add(headerName, "value") + } + + transport := &headerTransport{ + headers: headers, + transport: http.DefaultTransport, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := transport.RoundTrip(req) + + if tt.expectError { + if err == nil { + t.Errorf("expected error for %d headers, got nil", tt.numHeaders) + } else if tt.errorMessage != "" && !containsString(err.Error(), tt.errorMessage) { + t.Errorf("expected error message to contain %q, got %q", tt.errorMessage, err.Error()) + } + } else { + if err != nil { + t.Errorf("unexpected error for %d headers: %v", tt.numHeaders, err) + } + if resp != nil { + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } + } + } + }) + } +} + +// Helper function to check if a string contains a substring. +func containsString(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && + (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || + func() bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false + }())) +} + +func TestNewHTTPClient(t *testing.T) { + t.Parallel() + tests := []struct { + config *commonConfig + name string + }{ + { + name: "with API key", + config: &commonConfig{ + DataplaneAPIKey: "test-key-123", + }, + }, + { + name: "without API key", + config: &commonConfig{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + client := NewHTTPClient(tt.config) + + if client == nil { + t.Fatal("expected non-nil client") + } + + if client.Timeout != connTimeoutInSecs*1000000000 { // nanoseconds + t.Errorf("expected timeout %v, got %v", connTimeoutInSecs*1000000000, client.Timeout) + } + + if client.Transport == nil { + t.Error("expected non-nil transport") + } + + if _, ok := client.Transport.(*headerTransport); !ok { + t.Error("expected transport to be *headerTransport") + } + }) + } +} + +// Helper function. +func intPtr(i int) *int { + return &i +} diff --git a/examples/aws.md b/examples/aws.md index bd6c8210c..46f589c37 100644 --- a/examples/aws.md +++ b/examples/aws.md @@ -25,6 +25,7 @@ nginx-asg-sync is configured in **/etc/nginx/config.yaml**. ```yaml region: us-west-2 api_endpoint: http://127.0.0.1:8080/api +dataplane_api_key: your_api_key_here (optional) sync_interval: 5s cloud_provider: AWS profile: default @@ -49,6 +50,7 @@ upstreams: ``` - The `api_endpoint` key defines the NGINX Plus API endpoint. +- The `dataplane_api_key` key (optional) defines the API key for authenticating with the [Dataplane API](https://docs.nginx.com/nginxaas/azure/loadbalancer-kubernetes/#view-nginxaas-data-plane-api-endpoint-using-the-azure-portal) - The `sync_interval` key defines the synchronization interval: nginx-asg-sync checks for scaling updates every 5 seconds. The value is a string that represents a duration (e.g., `5s`). The maximum unit is hours. - The `cloud_provider` key defines a cloud provider that will be used. The default is `AWS`. This means the key can be diff --git a/examples/azure.md b/examples/azure.md index 229dd5b0a..436e3784c 100644 --- a/examples/azure.md +++ b/examples/azure.md @@ -26,6 +26,7 @@ nginx-asg-sync is configured in **/etc/nginx/config.yaml**. ```yaml api_endpoint: http://127.0.0.1:8080/api +dataplane_api_key: your_api_key_here (optional) sync_interval: 5s cloud_provider: Azure subscription_id: my_subscription_id @@ -50,6 +51,8 @@ upstreams: ``` - The `api_endpoint` key defines the NGINX Plus API endpoint. +- The `dataplane_api_key` key (optional) defines the API key for authenticating with the [Dataplane API](https://docs.nginx.com/nginxaas/azure/loadbalancer-kubernetes/#view-nginxaas-data-plane-api-endpoint-using-the-azure-portal) + of [NGINXaaS for Azure](https://docs.nginx.com/nginxaas/azure). - The `sync_interval` key defines the synchronization interval: nginx-asg-sync checks for scaling updates every 5 seconds. The value is a string that represents a duration (e.g., `5s`). The maximum unit is hours. - The `cloud_provider` key defines a Cloud Provider that will be used. The default is `AWS`. This means the key can be