diff --git a/go.mod b/go.mod index 2d9186f658..a8eea74d34 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,9 @@ module github.com/nginx/nginx-gateway-fabric go 1.24.2 require ( + github.com/aws/aws-sdk-go-v2 v1.36.5 + github.com/aws/aws-sdk-go-v2/config v1.29.17 + github.com/aws/aws-sdk-go-v2/service/s3 v1.82.0 github.com/fsnotify/fsnotify v1.9.0 github.com/go-logr/logr v1.4.3 github.com/google/go-cmp v0.7.0 @@ -31,6 +34,21 @@ require ( require ( buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.4-20250130201111-63bb56e20495.1 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.11 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.70 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.32 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.36 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.36 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.36 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.17 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.17 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.34.0 // indirect + github.com/aws/smithy-go v1.22.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v5 v5.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 4737cccc9f..dba21a0073 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,42 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25 github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/aws/aws-sdk-go-v2 v1.36.5 h1:0OF9RiEMEdDdZEMqF9MRjevyxAQcf6gY+E7vwBILFj0= +github.com/aws/aws-sdk-go-v2 v1.36.5/go.mod h1:EYrzvCCN9CMUTa5+6lf6MM4tq3Zjp8UhSGR/cBsjai0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.11 h1:12SpdwU8Djs+YGklkinSSlcrPyj3H4VifVsKf78KbwA= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.11/go.mod h1:dd+Lkp6YmMryke+qxW/VnKyhMBDTYP41Q2Bb+6gNZgY= +github.com/aws/aws-sdk-go-v2/config v1.29.17 h1:jSuiQ5jEe4SAMH6lLRMY9OVC+TqJLP5655pBGjmnjr0= +github.com/aws/aws-sdk-go-v2/config v1.29.17/go.mod h1:9P4wwACpbeXs9Pm9w1QTh6BwWwJjwYvJ1iCt5QbCXh8= +github.com/aws/aws-sdk-go-v2/credentials v1.17.70 h1:ONnH5CM16RTXRkS8Z1qg7/s2eDOhHhaXVd72mmyv4/0= +github.com/aws/aws-sdk-go-v2/credentials v1.17.70/go.mod h1:M+lWhhmomVGgtuPOhO85u4pEa3SmssPTdcYpP/5J/xc= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.32 h1:KAXP9JSHO1vKGCr5f4O6WmlVKLFFXgWYAGoJosorxzU= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.32/go.mod h1:h4Sg6FQdexC1yYG9RDnOvLbW1a/P986++/Y/a+GyEM8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.36 h1:SsytQyTMHMDPspp+spo7XwXTP44aJZZAC7fBV2C5+5s= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.36/go.mod h1:Q1lnJArKRXkenyog6+Y+zr7WDpk4e6XlR6gs20bbeNo= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.36 h1:i2vNHQiXUvKhs3quBR6aqlgJaiaexz/aNvdCktW/kAM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.36/go.mod h1:UdyGa7Q91id/sdyHPwth+043HhmP6yP9MBHgbZM0xo8= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.36 h1:GMYy2EOWfzdP3wfVAGXBNKY5vK4K8vMET4sYOYltmqs= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.36/go.mod h1:gDhdAV6wL3PmPqBhiPbnlS447GoWs8HTTOYef9/9Inw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.4 h1:CXV68E2dNqhuynZJPB80bhPQwAKqBWVer887figW6Jc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.4/go.mod h1:/xFi9KtvBXP97ppCz1TAEvU1Uf66qvid89rbem3wCzQ= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.4 h1:nAP2GYbfh8dd2zGZqFRSMlq+/F6cMPBUuCsGAMkN074= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.4/go.mod h1:LT10DsiGjLWh4GbjInf9LQejkYEhBgBCjLG5+lvk4EE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.17 h1:t0E6FzREdtCsiLIoLCWsYliNsRBgyGD/MCK571qk4MI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.17/go.mod h1:ygpklyoaypuyDvOM5ujWGrYWpAK3h7ugnmKCU/76Ys4= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.17 h1:qcLWgdhq45sDM9na4cvXax9dyLitn8EYBRl8Ak4XtG4= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.17/go.mod h1:M+jkjBFZ2J6DJrjMv2+vkBbuht6kxJYtJiwoVgX4p4U= +github.com/aws/aws-sdk-go-v2/service/s3 v1.82.0 h1:JubM8CGDDFaAOmBrd8CRYNr49ZNgEAiLwGwgNMdS0nw= +github.com/aws/aws-sdk-go-v2/service/s3 v1.82.0/go.mod h1:kUklwasNoCn5YpyAqC/97r6dzTA1SRKJfKq16SXeoDU= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.5 h1:AIRJ3lfb2w/1/8wOOSqYb9fUKGwQbtysJ2H1MofRUPg= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.5/go.mod h1:b7SiVprpU+iGazDUqvRSLf5XmCdn+JtT1on7uNL6Ipc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.3 h1:BpOxT3yhLwSJ77qIY3DoHAQjZsc4HEGfMCE4NGy3uFg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.3/go.mod h1:vq/GQR1gOFLquZMSrxUK/cpvKCNVYibNyJ1m7JrU88E= +github.com/aws/aws-sdk-go-v2/service/sts v1.34.0 h1:NFOJ/NXEGV4Rq//71Hs1jC/NvPs1ezajK+yQmkwnPV0= +github.com/aws/aws-sdk-go-v2/service/sts v1.34.0/go.mod h1:7ph2tGpfQvwzgistp2+zga9f+bCjlQJPkPUmMgDSD7w= +github.com/aws/smithy-go v1.22.4 h1:uqXzVZNuNexwc/xrh6Tb56u89WDlJY6HS+KC0S4QSjw= +github.com/aws/smithy-go v1.22.4/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= diff --git a/internal/framework/fetch/fetch.go b/internal/framework/fetch/fetch.go new file mode 100644 index 0000000000..14a0079d5d --- /dev/null +++ b/internal/framework/fetch/fetch.go @@ -0,0 +1,569 @@ +package fetch + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" + "k8s.io/apimachinery/pkg/util/wait" +) + +const ( + // Default configuration values. + defaultTimeout = 30 * time.Second + defaultRetryAttempts = 3 + defaultRetryMaxDelay = 5 * time.Minute + defaultRetryInitialDuration = 200 * time.Millisecond + defaultRetryJitter = 0.1 + defaultRetryLinearFactor = 1.0 + exponentialBackoffFactor = 2.0 + + // HTTP configuration. + userAgent = "nginx-gateway-fabric" + + // Checksum configuration. + checksumFileSuffix = ".sha256" +) + +// ChecksumMismatchError represents an error when the calculated checksum doesn't match the expected checksum. +// This type of error should not trigger retries as it indicates data corruption or tampering. +type ChecksumMismatchError struct { + Expected string + Actual string +} + +func (e *ChecksumMismatchError) Error() string { + return fmt.Sprintf("checksum mismatch: expected %s, got %s", e.Expected, e.Actual) +} + +// S3Error represents an error when fetching from S3 fails. +type S3Error struct { + Err error + Bucket string + Key string +} + +func (e *S3Error) Error() string { + return fmt.Sprintf("S3 error for s3://%s/%s: %v", e.Bucket, e.Key, e.Err) +} + +func (e *S3Error) Unwrap() error { + return e.Err +} + +// HTTPStatusError represents an error for an unexpected HTTP status code. +type HTTPStatusError struct { + StatusCode int +} + +func (e *HTTPStatusError) Error() string { + return fmt.Sprintf("unexpected status code: %d", e.StatusCode) +} + +// HTTPError represents an error when fetching via HTTP fails. +type HTTPError struct { + Err error + URL string +} + +func (e *HTTPError) Error() string { + return fmt.Sprintf("HTTP error for %s: %v", e.URL, e.Err) +} + +func (e *HTTPError) Unwrap() error { + return e.Err +} + +// RetryBackoffType defines supported backoff strategies. +type RetryBackoffType string + +const ( + RetryBackoffExponential RetryBackoffType = "exponential" + RetryBackoffLinear RetryBackoffType = "linear" +) + +// options contains the configuration for fetching remote files. +type options struct { + checksumLocation string + retryBackoff RetryBackoffType + timeout time.Duration + retryMaxDelay time.Duration + retryAttempts int32 + checksumEnabled bool +} + +// defaults returns options with sensible default values. +func defaults() options { + return options{ + timeout: defaultTimeout, + retryAttempts: defaultRetryAttempts, + retryMaxDelay: defaultRetryMaxDelay, + retryBackoff: RetryBackoffExponential, + } +} + +// Option defines a function that modifies fetch options. +type Option func(*options) + +// WithTimeout sets the HTTP request timeout. +func WithTimeout(timeout time.Duration) Option { + return func(o *options) { + o.timeout = timeout + } +} + +// WithRetryAttempts sets the number of retry attempts (total attempts = 1 + retries). +func WithRetryAttempts(attempts int32) Option { + return func(o *options) { + o.retryAttempts = attempts + } +} + +// WithRetryBackoff sets the retry backoff strategy. +func WithRetryBackoff(backoff RetryBackoffType) Option { + return func(o *options) { + o.retryBackoff = backoff + } +} + +// WithMaxRetryDelay sets the maximum delay between retries. +func WithMaxRetryDelay(delay time.Duration) Option { + return func(o *options) { + o.retryMaxDelay = delay + } +} + +// WithChecksum enables checksum validation with an optional custom checksum location. +// For HTTP URLs: if no location is provided, defaults to .sha256 +// For S3 URLs: if no location is provided, defaults to .sha256 in the same bucket. +func WithChecksum(checksumLocation ...string) Option { + return func(o *options) { + o.checksumEnabled = true + if len(checksumLocation) > 0 { + o.checksumLocation = checksumLocation[0] + } + } +} + +// S3Client defines the interface for S3 operations. +// +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 . S3Client +type S3Client interface { + GetObject(ctx context.Context, input *s3.GetObjectInput, opts ...func(*s3.Options)) (*s3.GetObjectOutput, error) +} + +// Fetcher defines the interface for fetching remote files. +// +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate +type Fetcher interface { + GetRemoteFile(targetURL string, opts ...Option) ([]byte, error) +} + +// DefaultFetcher is the default implementation of Fetcher. +// It supports both HTTP(S) and S3 URLs with automatic protocol detection. +type DefaultFetcher struct { + s3Client S3Client + httpClient *http.Client +} + +// NewDefaultFetcher creates a new DefaultFetcher with AWS and HTTP clients configured. +// If AWS credentials are not available, S3 functionality will be disabled but HTTP will still work. +func NewDefaultFetcher() (*DefaultFetcher, error) { + // Try to load AWS config + // Note: We don't return an error if AWS config fails - HTTP fetching should still work + var s3Client S3Client + cfg, err := config.LoadDefaultConfig(context.TODO()) + if err == nil { + s3Client = s3.NewFromConfig(cfg) + } + + httpClient := &http.Client{ + Timeout: defaultTimeout, + } + + return &DefaultFetcher{ + s3Client: s3Client, + httpClient: httpClient, + }, nil +} + +// NewDefaultFetcherWithS3Client creates a new DefaultFetcher with a custom S3 client. +// This is primarily used for testing with fake S3 clients. +func NewDefaultFetcherWithS3Client(s3Client S3Client) *DefaultFetcher { + httpClient := &http.Client{ + Timeout: defaultTimeout, + } + + return &DefaultFetcher{ + s3Client: s3Client, + httpClient: httpClient, + } +} + +// GetRemoteFile fetches a remote file with retry logic and optional validation. +// Supports both HTTP(S) and S3 URLs with automatic protocol detection. +func (f *DefaultFetcher) GetRemoteFile(targetURL string, opts ...Option) ([]byte, error) { + ctx := context.Background() + + // Apply options to defaults + options := defaults() + for _, opt := range opts { + opt(&options) + } + + // Route to appropriate fetcher based on URL scheme + if strings.HasPrefix(targetURL, "s3://") { + return f.fetchS3File(ctx, targetURL, options) + } + + if strings.HasPrefix(targetURL, "http://") || strings.HasPrefix(targetURL, "https://") { + return f.fetchHTTPFile(ctx, targetURL, options) + } + + return nil, fmt.Errorf("unsupported URL scheme: %s (supported: http://, https://, s3://)", targetURL) +} + +// fetchS3File fetches a file from S3 using the AWS SDK. +func (f *DefaultFetcher) fetchS3File(ctx context.Context, s3URL string, opts options) ([]byte, error) { + if f.s3Client == nil { + return nil, fmt.Errorf("S3 client not available - AWS credentials may not be configured") + } + + bucket, key, err := parseS3URL(s3URL) + if err != nil { + return nil, fmt.Errorf("invalid S3 URL %s: %w", s3URL, err) + } + + backoff := createBackoffConfig(opts.retryBackoff, opts.retryAttempts, opts.retryMaxDelay) + var lastErr error + var result []byte + + err = wait.ExponentialBackoffWithContext(ctx, backoff, func(ctx context.Context) (bool, error) { + data, err := f.getS3Object(ctx, bucket, key, opts.timeout) + if err != nil { + lastErr = &S3Error{Bucket: bucket, Key: key, Err: err} + // Intentionally return nil error to signal retry mechanism to continue + return false, nil //nolint:nilerr // Retry on S3 errors + } + + if opts.checksumEnabled { + if err := f.validateS3FileContent(ctx, data, bucket, key, opts); err != nil { + lastErr = err + // Don't retry on checksum mismatches + var checksumErr *ChecksumMismatchError + if errors.As(err, &checksumErr) { + return false, err // Stop retrying + } + return false, nil // Retry on other checksum errors + } + } + + result = data + return true, nil + }) + + if result != nil { + return result, nil + } + + // Return the most meaningful error + if lastErr != nil { + return nil, lastErr + } + + if err != nil { + return nil, fmt.Errorf("failed to fetch S3 file after retries: %w", err) + } + + return nil, fmt.Errorf("failed to fetch S3 file %s: unknown error", s3URL) +} + +// fetchHTTPFile fetches a file using HTTP(S). +func (f *DefaultFetcher) fetchHTTPFile(ctx context.Context, targetURL string, opts options) ([]byte, error) { + backoff := createBackoffConfig(opts.retryBackoff, opts.retryAttempts, opts.retryMaxDelay) + var lastErr error + var result []byte + + err := wait.ExponentialBackoffWithContext(ctx, backoff, func(ctx context.Context) (bool, error) { + data, err := f.getHTTPContent(ctx, targetURL, opts.timeout) + if err != nil { + lastErr = &HTTPError{URL: targetURL, Err: err} + + var statusErr *HTTPStatusError + if errors.As(err, &statusErr) { + switch statusErr.StatusCode { + case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + return false, nil // Retry on retryable status codes + default: + return false, err // Stop retrying on non-retryable status codes + } + } + + return false, nil + } + + if opts.checksumEnabled { + if err := f.validateHTTPFileContent(ctx, data, targetURL, opts); err != nil { + lastErr = err + // Don't retry on checksum mismatches + var checksumErr *ChecksumMismatchError + if errors.As(err, &checksumErr) { + return false, err // Stop retrying + } + return false, nil // Retry on other checksum errors + } + } + + result = data + return true, nil + }) + if err != nil { + // If the backoff timed out or was aborted by a non-retryable error, + // return the last recorded error for better context. + if lastErr != nil { + return nil, lastErr + } + return nil, fmt.Errorf("failed to fetch HTTP file after retries: %w", err) + } + + if result != nil { + return result, nil + } + + // This case should ideally not be reached, but as a fallback, return the last known error. + if lastErr != nil { + return nil, lastErr + } + + return nil, fmt.Errorf("failed to fetch HTTP file %s: unknown error", targetURL) +} + +// getS3Object fetches an object from S3. +func (f *DefaultFetcher) getS3Object( + ctx context.Context, + bucket, key string, + timeout time.Duration, +) ([]byte, error) { + ctxWithTimeout, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + input := &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + } + + result, err := f.s3Client.GetObject(ctxWithTimeout, input) + if err != nil { + return nil, fmt.Errorf("failed to get S3 object: %w", err) + } + defer result.Body.Close() + + data, err := io.ReadAll(result.Body) + if err != nil { + return nil, fmt.Errorf("failed to read S3 object body: %w", err) + } + + return data, nil +} + +// getHTTPContent fetches content via HTTP(S). +func (f *DefaultFetcher) getHTTPContent( + ctx context.Context, + targetURL string, + timeout time.Duration, +) ([]byte, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("User-Agent", userAgent) + + resp, err := f.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, &HTTPStatusError{StatusCode: resp.StatusCode} + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read HTTP response body: %w", err) + } + + return body, nil +} + +// validateS3FileContent validates the fetched S3 file content using the enabled validation methods. +func (f *DefaultFetcher) validateS3FileContent( + ctx context.Context, + data []byte, + bucket, key string, + opts options, +) error { + if opts.checksumEnabled { + if err := f.validateS3Checksum(ctx, data, bucket, key, opts); err != nil { + return fmt.Errorf("checksum validation failed: %w", err) + } + } + return nil +} + +// validateS3Checksum validates S3 file content against a SHA256 checksum. +func (f *DefaultFetcher) validateS3Checksum( + ctx context.Context, + data []byte, + bucket, key string, + opts options, +) error { + checksumBucket := bucket + checksumKey := key + checksumFileSuffix + + if opts.checksumLocation != "" { + if strings.HasPrefix(opts.checksumLocation, "s3://") { + // Parse full S3 URL + var err error + checksumBucket, checksumKey, err = parseS3URL(opts.checksumLocation) + if err != nil { + return fmt.Errorf("invalid checksum S3 URL: %w", err) + } + } else { + checksumKey = opts.checksumLocation + } + } + + checksumData, err := f.getS3Object(ctx, checksumBucket, checksumKey, opts.timeout) + if err != nil { + checksumURL := fmt.Sprintf("s3://%s/%s", checksumBucket, checksumKey) + return &S3Error{ + Bucket: checksumBucket, + Key: checksumKey, + Err: fmt.Errorf("failed to fetch checksum from %s: %w", checksumURL, err), + } + } + + return validateChecksum(data, checksumData) +} + +// validateHTTPFileContent validates the fetched HTTP file content using the enabled validation methods. +func (f *DefaultFetcher) validateHTTPFileContent( + ctx context.Context, + data []byte, + targetURL string, + opts options, +) error { + if opts.checksumEnabled { + if err := f.validateHTTPChecksum(ctx, data, targetURL, opts); err != nil { + return fmt.Errorf("checksum validation failed: %w", err) + } + } + return nil +} + +// validateHTTPChecksum validates HTTP file content against a SHA256 checksum. +func (f *DefaultFetcher) validateHTTPChecksum( + ctx context.Context, + data []byte, + targetURL string, + opts options, +) error { + // Determine checksum URL + checksumURL := opts.checksumLocation + if checksumURL == "" { + checksumURL = targetURL + checksumFileSuffix + } + + // Fetch checksum file + checksumData, err := f.getHTTPContent(ctx, checksumURL, opts.timeout) + if err != nil { + return &HTTPError{URL: checksumURL, Err: fmt.Errorf("failed to fetch checksum: %w", err)} + } + + return validateChecksum(data, checksumData) +} + +// validateChecksum validates data against checksum content. +func validateChecksum(data, checksumData []byte) error { + // Parse checksum (format: "hash filename" or just "hash") + checksumStr := strings.TrimSpace(string(checksumData)) + checksumFields := strings.Fields(checksumStr) + + if len(checksumFields) == 0 { + return fmt.Errorf("checksum file is empty or contains only whitespace") + } + + expectedChecksum := checksumFields[0] + + // Calculate actual checksum + hasher := sha256.New() + hasher.Write(data) + actualChecksum := hex.EncodeToString(hasher.Sum(nil)) + + if actualChecksum != expectedChecksum { + return &ChecksumMismatchError{Expected: expectedChecksum, Actual: actualChecksum} + } + + return nil +} + +// parseS3URL parses an S3 URL and returns bucket and key. +func parseS3URL(s3URL string) (bucket, key string, err error) { + parsedURL, err := url.Parse(s3URL) + if err != nil { + return "", "", fmt.Errorf("failed to parse S3 URL: %w", err) + } + + bucket = parsedURL.Host + if bucket == "" { + return "", "", fmt.Errorf("S3 bucket name cannot be empty") + } + + key = strings.TrimPrefix(parsedURL.Path, "/") + if key == "" { + return "", "", fmt.Errorf("S3 object key cannot be empty") + } + + // URL decode the key to handle encoded characters + key, err = url.QueryUnescape(key) + if err != nil { + return "", "", fmt.Errorf("failed to decode S3 object key: %w", err) + } + + return bucket, key, nil +} + +// createBackoffConfig creates a backoff configuration for retries. +func createBackoffConfig( + backoffType RetryBackoffType, + attempts int32, + maxDelay time.Duration, +) wait.Backoff { + backoff := wait.Backoff{ + Duration: defaultRetryInitialDuration, + Factor: defaultRetryLinearFactor, + Jitter: defaultRetryJitter, + Steps: int(attempts + 1), + Cap: maxDelay, + } + + if backoffType == RetryBackoffExponential { + backoff.Factor = exponentialBackoffFactor + } + + return backoff +} diff --git a/internal/framework/fetch/fetch_test.go b/internal/framework/fetch/fetch_test.go new file mode 100644 index 0000000000..bb252c87cb --- /dev/null +++ b/internal/framework/fetch/fetch_test.go @@ -0,0 +1,720 @@ +package fetch + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +func TestGetRemoteFile(t *testing.T) { + fetcher, err := NewDefaultFetcher() + if err != nil { + t.Fatalf("NewDefaultFetcher() failed: %v", err) + } + + fileContent := "test file content" + hasher := sha256.New() + hasher.Write([]byte(fileContent)) + expectedChecksum := hex.EncodeToString(hasher.Sum(nil)) + + tests := []struct { + setupServer func() *httptest.Server + setupFetcher func() Fetcher + validateFunc func(t *testing.T, data []byte, err error) + name string + url string + expectedErr string + options []Option + expectErr bool + }{ + // HTTP Checksum validation scenarios + { + name: "valid checksum with filename", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, ".sha256") { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(expectedChecksum + " filename.txt")) + } else { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(fileContent)) + } + })) + }, + url: "/file.txt", + options: []Option{WithChecksum()}, + expectErr: false, + }, + { + name: "checksum mismatch", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, ".sha256") { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("0000000000000000000000000000000000000000000000000000000000000000")) + } else { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(fileContent)) + } + })) + }, + url: "/file.txt", + options: []Option{WithChecksum()}, + expectErr: true, + }, + { + name: "empty checksum file", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, ".sha256") { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(" \n\t ")) + } else { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(fileContent)) + } + })) + }, + url: "/file.txt", + options: []Option{WithChecksum()}, + expectErr: true, + }, + // URL validation error cases + { + name: "S3 missing bucket and key", + url: "s3://", + options: []Option{WithTimeout(1 * time.Second), WithRetryAttempts(0)}, + expectErr: true, + expectedErr: "S3 bucket name cannot be empty", + }, + { + name: "S3 missing key", + url: "s3://bucket", + options: []Option{WithTimeout(1 * time.Second), WithRetryAttempts(0)}, + expectErr: true, + expectedErr: "S3 object key cannot be empty", + }, + { + name: "S3 empty bucket with key", + url: "s3:///key", + options: []Option{WithTimeout(1 * time.Second), WithRetryAttempts(0)}, + expectErr: true, + expectedErr: "S3 bucket name cannot be empty", + }, + { + name: "FTP scheme", + url: "ftp://example.com/file.txt", + options: []Option{WithTimeout(1 * time.Second), WithRetryAttempts(0)}, + expectErr: true, + expectedErr: "unsupported URL scheme", + }, + { + name: "File scheme", + url: "file:///local/path", + options: []Option{WithTimeout(1 * time.Second), WithRetryAttempts(0)}, + expectErr: true, + expectedErr: "unsupported URL scheme", + }, + { + name: "Invalid URL", + url: "invalid-url", + options: []Option{WithTimeout(1 * time.Second), WithRetryAttempts(0)}, + expectErr: true, + expectedErr: "unsupported URL scheme", + }, + { + name: "Empty URL", + url: "", + options: []Option{WithTimeout(1 * time.Second), WithRetryAttempts(0)}, + expectErr: true, + expectedErr: "unsupported URL scheme", + }, + { + name: "S3 client unavailable", + setupFetcher: func() Fetcher { + return &DefaultFetcher{ + s3Client: nil, + httpClient: &http.Client{}, + } + }, + url: "s3://bucket/key", + options: []Option{WithTimeout(1 * time.Second), WithRetryAttempts(0)}, + expectErr: true, + expectedErr: "S3 client not available", + }, + // Options testing + { + name: "timeout option", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("fast response")) + })) + }, + url: "/", + options: []Option{WithTimeout(5 * time.Second)}, + expectErr: false, + }, + { + name: "multiple options", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + })) + }, + url: "/", + options: []Option{ + WithRetryAttempts(1), + WithRetryBackoff(RetryBackoffExponential), + WithMaxRetryDelay(50 * time.Millisecond), + }, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var testFetcher Fetcher = fetcher + if tt.setupFetcher != nil { + testFetcher = tt.setupFetcher() + } + + testURL := tt.url + if tt.setupServer != nil { + server := tt.setupServer() + defer server.Close() + testURL = server.URL + tt.url + } + + data, err := testFetcher.GetRemoteFile(testURL, tt.options...) + + if tt.expectErr { + if err == nil { + t.Error("Expected error, got nil") + } + if tt.expectedErr != "" && !strings.Contains(err.Error(), tt.expectedErr) { + t.Errorf("Expected error containing %q, got: %v", tt.expectedErr, err) + } + } else if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if tt.validateFunc != nil { + tt.validateFunc(t, data, err) + } + }) + } +} + +func TestGetRemoteFileError(t *testing.T) { + fetcher, err := NewDefaultFetcher() + if err != nil { + t.Fatalf("NewDefaultFetcher() failed: %v", err) + } + + tests := []struct { + setupServer func() *httptest.Server + name string + url string + expectErrType string + options []Option + }{ + { + name: "HTTP error response", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + }, + options: []Option{WithRetryAttempts(0)}, + expectErrType: "HTTPError", + }, + { + name: "network connection error", + url: "http://127.0.0.1:1", + options: []Option{WithRetryAttempts(0), WithTimeout(10 * time.Millisecond)}, + expectErrType: "HTTPError", + }, + { + name: "timeout during request", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + time.Sleep(20 * time.Millisecond) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("delayed response")) + })) + }, + options: []Option{WithTimeout(10 * time.Millisecond), WithRetryAttempts(0)}, + expectErrType: "HTTPError", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testURL := tt.url + if tt.setupServer != nil { + server := tt.setupServer() + defer server.Close() + testURL = server.URL + } + + _, err := fetcher.GetRemoteFile(testURL, tt.options...) + if err == nil { + t.Error("Expected error, got nil") + } + + if tt.expectErrType != "" { + switch tt.expectErrType { + case "HTTPError": + var httpErr *HTTPError + if !errors.As(err, &httpErr) { + t.Errorf("Expected HTTPError, got %T: %v", err, err) + } + default: + t.Errorf("Unknown expected error type: %s", tt.expectErrType) + } + } + }) + } +} + +func TestParseS3URL(t *testing.T) { + tests := []struct { + name string + url string + bucket string + key string + expectErr bool + }{ + // Error cases specific to parsing logic (not covered in integration tests) + { + name: "missing key with trailing slash", + url: "s3://bucket/", + expectErr: true, + }, + { + name: "invalid URL encoding", + url: "s3://bucket/invalid%gg", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bucket, key, err := parseS3URL(tt.url) + + switch { + case tt.expectErr: + if err == nil { + t.Error("Expected error, got nil") + } + case err != nil: + t.Errorf("Unexpected error: %v", err) + default: + if bucket != tt.bucket { + t.Errorf("Expected bucket %q, got %q", tt.bucket, bucket) + } + if key != tt.key { + t.Errorf("Expected key %q, got %q", tt.key, key) + } + } + }) + } +} + +func TestErrorTypes(t *testing.T) { + tests := []struct { + err error + unwraps error + name string + expected string + }{ + { + name: "ChecksumMismatchError", + err: &ChecksumMismatchError{ + Expected: "abc123", + Actual: "def456", + }, + expected: "checksum mismatch: expected abc123, got def456", + }, + { + name: "S3Error", + err: &S3Error{ + Bucket: "my-bucket", + Key: "my-key", + Err: errors.New("access denied"), + }, + expected: "S3 error for s3://my-bucket/my-key: access denied", + unwraps: errors.New("access denied"), + }, + { + name: "HTTPError", + err: &HTTPError{ + URL: "http://example.com", + Err: errors.New("connection refused"), + }, + expected: "HTTP error for http://example.com: connection refused", + unwraps: errors.New("connection refused"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err.Error() != tt.expected { + t.Errorf("Expected error message %q, got %q", tt.expected, tt.err.Error()) + } + + if tt.unwraps != nil { + if unwrapper, ok := tt.err.(interface{ Unwrap() error }); ok { + if unwrapper.Unwrap().Error() != tt.unwraps.Error() { + t.Errorf("Expected unwrapped error %q, got %q", tt.unwraps.Error(), unwrapper.Unwrap().Error()) + } + } else { + t.Error("Expected error to implement Unwrap()") + } + } + }) + } +} + +func TestGetRemoteFileRetry(t *testing.T) { + fetcher, err := NewDefaultFetcher() + if err != nil { + t.Fatalf("NewDefaultFetcher() failed: %v", err) + } + + t.Run("retry with linear backoff", func(t *testing.T) { + var attemptCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + count := attemptCount.Add(1) + if count < 4 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + })) + defer server.Close() + + data, err := fetcher.GetRemoteFile(server.URL, + WithRetryAttempts(3), + WithRetryBackoff(RetryBackoffLinear)) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if string(data) != "success" { + t.Errorf("Expected 'success', got %q", string(data)) + } + if attemptCount.Load() != 4 { + t.Errorf("Expected 4 attempts, got %d", attemptCount.Load()) + } + }) + + t.Run("max retries exceeded", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + _, err := fetcher.GetRemoteFile(server.URL, + WithRetryAttempts(1), + WithTimeout(10*time.Millisecond)) + + if err == nil { + t.Error("Expected error, got nil") + } + + var statusErr *HTTPStatusError + if !errors.As(err, &statusErr) { + t.Errorf("Expected HTTPStatusError, got %T: %v", err, err) + } + }) + + t.Run("no retries", func(t *testing.T) { + var attemptCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + _, err := fetcher.GetRemoteFile(server.URL, WithRetryAttempts(0)) + + if err == nil { + t.Error("Expected error, got nil") + } + if attemptCount.Load() != 1 { + t.Errorf("Expected 1 attempt, got %d", attemptCount.Load()) + } + }) +} + +func TestChecksumMismatch(t *testing.T) { + fetcher, err := NewDefaultFetcher() + if err != nil { + t.Fatalf("NewDefaultFetcher() failed: %v", err) + } + + fileContent := "mismatch test" + invalidChecksum := "0000000000000000000000000000000000000000000000000000000000000000" + + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, ".sha256") { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(invalidChecksum)) + } else { + attempts.Add(1) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(fileContent)) + } + })) + defer server.Close() + + _, err = fetcher.GetRemoteFile(server.URL+"/file.txt", + WithChecksum(), + WithRetryAttempts(3)) + + if err == nil { + t.Error("Expected checksum mismatch error, got nil") + } + var checksumErr *ChecksumMismatchError + if !errors.As(err, &checksumErr) { + t.Errorf("Expected ChecksumMismatchError, got %T: %v", err, err) + } + if attempts.Load() != 1 { + t.Errorf("Expected 1 attempt (no retries on checksum mismatch), got %d", attempts.Load()) + } +} + +type mockS3Client struct { + objects map[string][]byte + errors map[string]error + calls []s3GetObjectCall +} + +type s3GetObjectCall struct { + bucket string + key string +} + +func (m *mockS3Client) GetObject( + _ context.Context, + input *s3.GetObjectInput, + _ ...func(*s3.Options), +) (*s3.GetObjectOutput, error) { + m.calls = append(m.calls, s3GetObjectCall{bucket: *input.Bucket, key: *input.Key}) + + key := *input.Bucket + "/" + *input.Key + if err, exists := m.errors[key]; exists { + return nil, err + } + if data, exists := m.objects[key]; exists { + return &s3.GetObjectOutput{ + Body: io.NopCloser(bytes.NewReader(data)), + }, nil + } + return nil, fmt.Errorf("NoSuchKey: key not found") +} + +func (m *mockS3Client) getCallCount() int { + return len(m.calls) +} + +// TestGetRemoteFileS3 tests S3 scenarios. +func TestGetRemoteFileS3(t *testing.T) { + type testCase struct { + expectErrType interface{} + mockObjects map[string][]byte + mockErrors map[string]error + validate func(t *testing.T, data []byte, s3Client *mockS3Client, content []byte) + name string + url string + options []Option + content []byte + expectErr bool + } + + // Common content for tests + content1 := []byte("s3 test content") + hasher1 := sha256.New() + hasher1.Write(content1) + checksum1 := hex.EncodeToString(hasher1.Sum(nil)) + + content2 := []byte("relative checksum test") + hasher2 := sha256.New() + hasher2.Write(content2) + checksum2 := hex.EncodeToString(hasher2.Sum(nil)) + + tests := []testCase{ + { + name: "success", + url: "s3://test-bucket/test-key.txt", + content: content1, + mockObjects: map[string][]byte{ + "test-bucket/test-key.txt": content1, + }, + options: []Option{WithRetryAttempts(0)}, + validate: func(t *testing.T, data []byte, s3Client *mockS3Client, content []byte) { + t.Helper() + if !bytes.Equal(data, content) { + t.Errorf("Expected %q, got %q", content, data) + } + if s3Client.getCallCount() != 1 { + t.Errorf("Expected 1 S3 call, got %d", s3Client.getCallCount()) + } + }, + }, + { + name: "checksum validation success", + url: "s3://test-bucket/file.txt", + content: content1, + mockObjects: map[string][]byte{ + "test-bucket/file.txt": content1, + "test-bucket/file.txt.sha256": []byte(checksum1 + " file.txt"), + }, + options: []Option{WithChecksum(), WithRetryAttempts(0)}, + validate: func(t *testing.T, data []byte, s3Client *mockS3Client, content []byte) { + t.Helper() + if !bytes.Equal(data, content) { + t.Errorf("Expected %q, got %q", content, data) + } + if s3Client.getCallCount() != 2 { + t.Errorf("Expected 2 S3 calls, got %d", s3Client.getCallCount()) + } + }, + }, + { + name: "checksum mismatch", + url: "s3://test-bucket/file.txt", + content: content1, + mockObjects: map[string][]byte{ + "test-bucket/file.txt": content1, + "test-bucket/file.txt.sha256": []byte("badchecksum"), + }, + options: []Option{WithChecksum(), WithRetryAttempts(0)}, + expectErr: true, + expectErrType: &ChecksumMismatchError{}, + }, + { + name: "S3 access error", + url: "s3://test-bucket/error-key", + options: []Option{WithRetryAttempts(0)}, + expectErr: true, + mockErrors: map[string]error{ + "test-bucket/error-key": fmt.Errorf("access denied"), + }, + expectErrType: &S3Error{}, + }, + { + name: "checksum file error", + url: "s3://test-bucket/file.txt", + options: []Option{WithChecksum(), WithRetryAttempts(0)}, + content: content1, + mockObjects: map[string][]byte{ + "test-bucket/file.txt": content1, + }, + mockErrors: map[string]error{ + "test-bucket/file.txt.sha256": fmt.Errorf("checksum file not found"), + }, + expectErr: true, + expectErrType: &S3Error{}, + }, + { + name: "full S3 URL checksum location", + url: "s3://test-bucket/file.txt", + content: content1, + mockObjects: map[string][]byte{ + "test-bucket/file.txt": content1, + "checksum-bucket/custom.sha256": []byte(checksum1), + }, + options: []Option{WithChecksum("s3://checksum-bucket/custom.sha256"), WithRetryAttempts(0)}, + validate: func(t *testing.T, data []byte, s3Client *mockS3Client, content []byte) { + t.Helper() + if !bytes.Equal(data, content) { + t.Errorf("Expected %q, got %q", content, data) + } + if s3Client.getCallCount() != 2 { + t.Errorf("Expected 2 S3 calls, got %d", s3Client.getCallCount()) + } + }, + }, + { + name: "relative checksum location", + url: "s3://test-bucket/file.txt", + content: content2, + mockObjects: map[string][]byte{ + "test-bucket/file.txt": content2, + "test-bucket/custom-checksum.sha256": []byte(checksum2), + }, + options: []Option{WithChecksum("custom-checksum.sha256"), WithRetryAttempts(0)}, + validate: func(t *testing.T, data []byte, _ *mockS3Client, content []byte) { + t.Helper() + if !bytes.Equal(data, content) { + t.Errorf("Expected %q, got %q", content, data) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeS3Client := &mockS3Client{ + objects: tt.mockObjects, + errors: tt.mockErrors, + } + + fetcher := NewDefaultFetcherWithS3Client(fakeS3Client) + + data, err := fetcher.GetRemoteFile(tt.url, tt.options...) + + if tt.expectErr { + assertErrorType(t, err, tt.expectErrType) + } else if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if tt.validate != nil { + tt.validate(t, data, fakeS3Client, tt.content) + } + }) + } +} + +func assertErrorType(t *testing.T, err error, expectedType interface{}) { + t.Helper() + + if err == nil { + t.Fatalf("Expected error of type %T, got nil", expectedType) + } + + switch expectedType.(type) { + case *ChecksumMismatchError: + var target *ChecksumMismatchError + if !errors.As(err, &target) { + t.Errorf("Expected ChecksumMismatchError, got %T", err) + } + case *S3Error: + var target *S3Error + if !errors.As(err, &target) { + t.Errorf("Expected S3Error, got %T", err) + } + default: + if expectedType != nil { + t.Fatalf("unhandled expected error type: %T", expectedType) + } + } +} diff --git a/internal/framework/fetch/fetchfakes/fake_fetcher.go b/internal/framework/fetch/fetchfakes/fake_fetcher.go new file mode 100644 index 0000000000..93ef4083ee --- /dev/null +++ b/internal/framework/fetch/fetchfakes/fake_fetcher.go @@ -0,0 +1,118 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fetchfakes + +import ( + "sync" + + "github.com/nginx/nginx-gateway-fabric/internal/framework/fetch" +) + +type FakeFetcher struct { + GetRemoteFileStub func(string, ...fetch.Option) ([]byte, error) + getRemoteFileMutex sync.RWMutex + getRemoteFileArgsForCall []struct { + arg1 string + arg2 []fetch.Option + } + getRemoteFileReturns struct { + result1 []byte + result2 error + } + getRemoteFileReturnsOnCall map[int]struct { + result1 []byte + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeFetcher) GetRemoteFile(arg1 string, arg2 ...fetch.Option) ([]byte, error) { + fake.getRemoteFileMutex.Lock() + ret, specificReturn := fake.getRemoteFileReturnsOnCall[len(fake.getRemoteFileArgsForCall)] + fake.getRemoteFileArgsForCall = append(fake.getRemoteFileArgsForCall, struct { + arg1 string + arg2 []fetch.Option + }{arg1, arg2}) + stub := fake.GetRemoteFileStub + fakeReturns := fake.getRemoteFileReturns + fake.recordInvocation("GetRemoteFile", []interface{}{arg1, arg2}) + fake.getRemoteFileMutex.Unlock() + if stub != nil { + return stub(arg1, arg2...) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeFetcher) GetRemoteFileCallCount() int { + fake.getRemoteFileMutex.RLock() + defer fake.getRemoteFileMutex.RUnlock() + return len(fake.getRemoteFileArgsForCall) +} + +func (fake *FakeFetcher) GetRemoteFileCalls(stub func(string, ...fetch.Option) ([]byte, error)) { + fake.getRemoteFileMutex.Lock() + defer fake.getRemoteFileMutex.Unlock() + fake.GetRemoteFileStub = stub +} + +func (fake *FakeFetcher) GetRemoteFileArgsForCall(i int) (string, []fetch.Option) { + fake.getRemoteFileMutex.RLock() + defer fake.getRemoteFileMutex.RUnlock() + argsForCall := fake.getRemoteFileArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeFetcher) GetRemoteFileReturns(result1 []byte, result2 error) { + fake.getRemoteFileMutex.Lock() + defer fake.getRemoteFileMutex.Unlock() + fake.GetRemoteFileStub = nil + fake.getRemoteFileReturns = struct { + result1 []byte + result2 error + }{result1, result2} +} + +func (fake *FakeFetcher) GetRemoteFileReturnsOnCall(i int, result1 []byte, result2 error) { + fake.getRemoteFileMutex.Lock() + defer fake.getRemoteFileMutex.Unlock() + fake.GetRemoteFileStub = nil + if fake.getRemoteFileReturnsOnCall == nil { + fake.getRemoteFileReturnsOnCall = make(map[int]struct { + result1 []byte + result2 error + }) + } + fake.getRemoteFileReturnsOnCall[i] = struct { + result1 []byte + result2 error + }{result1, result2} +} + +func (fake *FakeFetcher) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.getRemoteFileMutex.RLock() + defer fake.getRemoteFileMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeFetcher) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ fetch.Fetcher = new(FakeFetcher) diff --git a/internal/framework/fetch/fetchfakes/fake_s3client.go b/internal/framework/fetch/fetchfakes/fake_s3client.go new file mode 100644 index 0000000000..85c0c068ab --- /dev/null +++ b/internal/framework/fetch/fetchfakes/fake_s3client.go @@ -0,0 +1,122 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fetchfakes + +import ( + "context" + "sync" + + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/nginx/nginx-gateway-fabric/internal/framework/fetch" +) + +type FakeS3Client struct { + GetObjectStub func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error) + getObjectMutex sync.RWMutex + getObjectArgsForCall []struct { + arg1 context.Context + arg2 *s3.GetObjectInput + arg3 []func(*s3.Options) + } + getObjectReturns struct { + result1 *s3.GetObjectOutput + result2 error + } + getObjectReturnsOnCall map[int]struct { + result1 *s3.GetObjectOutput + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeS3Client) GetObject(arg1 context.Context, arg2 *s3.GetObjectInput, arg3 ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + fake.getObjectMutex.Lock() + ret, specificReturn := fake.getObjectReturnsOnCall[len(fake.getObjectArgsForCall)] + fake.getObjectArgsForCall = append(fake.getObjectArgsForCall, struct { + arg1 context.Context + arg2 *s3.GetObjectInput + arg3 []func(*s3.Options) + }{arg1, arg2, arg3}) + stub := fake.GetObjectStub + fakeReturns := fake.getObjectReturns + fake.recordInvocation("GetObject", []interface{}{arg1, arg2, arg3}) + fake.getObjectMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3...) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeS3Client) GetObjectCallCount() int { + fake.getObjectMutex.RLock() + defer fake.getObjectMutex.RUnlock() + return len(fake.getObjectArgsForCall) +} + +func (fake *FakeS3Client) GetObjectCalls(stub func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)) { + fake.getObjectMutex.Lock() + defer fake.getObjectMutex.Unlock() + fake.GetObjectStub = stub +} + +func (fake *FakeS3Client) GetObjectArgsForCall(i int) (context.Context, *s3.GetObjectInput, []func(*s3.Options)) { + fake.getObjectMutex.RLock() + defer fake.getObjectMutex.RUnlock() + argsForCall := fake.getObjectArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeS3Client) GetObjectReturns(result1 *s3.GetObjectOutput, result2 error) { + fake.getObjectMutex.Lock() + defer fake.getObjectMutex.Unlock() + fake.GetObjectStub = nil + fake.getObjectReturns = struct { + result1 *s3.GetObjectOutput + result2 error + }{result1, result2} +} + +func (fake *FakeS3Client) GetObjectReturnsOnCall(i int, result1 *s3.GetObjectOutput, result2 error) { + fake.getObjectMutex.Lock() + defer fake.getObjectMutex.Unlock() + fake.GetObjectStub = nil + if fake.getObjectReturnsOnCall == nil { + fake.getObjectReturnsOnCall = make(map[int]struct { + result1 *s3.GetObjectOutput + result2 error + }) + } + fake.getObjectReturnsOnCall[i] = struct { + result1 *s3.GetObjectOutput + result2 error + }{result1, result2} +} + +func (fake *FakeS3Client) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.getObjectMutex.RLock() + defer fake.getObjectMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeS3Client) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ fetch.S3Client = new(FakeS3Client)