diff --git a/go.mod b/go.mod index a91acac5..72ce2de9 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/julienschmidt/httprouter v1.3.0 github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f + github.com/prometheus/client_golang v1.20.4 github.com/prometheus/client_model v0.6.2 github.com/stretchr/testify v1.11.1 go.yaml.in/yaml/v2 v2.4.2 @@ -25,7 +26,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/jpillora/backoff v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_golang v1.20.4 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect github.com/xhit/go-str2duration/v2 v2.1.0 // indirect diff --git a/secrets/README.md b/secrets/README.md new file mode 100644 index 00000000..b97bc56a --- /dev/null +++ b/secrets/README.md @@ -0,0 +1,267 @@ +# Secret Management + +The `secrets` package provides a unified way to handle secrets within configuration files for Prometheus and its ecosystem components. It allows secrets to be specified inline, loaded from files, or fetched from other sources through a pluggable provider mechanism. + +## Concepts + +The package is built around a few core concepts: + + * `SecretField`: A type used in configuration structs to represent a field that holds a secret. It handles the logic for unmarshaling from different secret sources, and the API for accessing secrets. + * `Provider`: An interface for fetching secrets from a specific source (e.g., inline string, file on disk). The package comes with built-in providers, and new ones can be registered. + * `Manager`: A component that discovers all `SecretField` instances within a configuration struct, manages their lifecycle, and handles periodic refreshing of secrets. + +## How to Use + +Using the `secrets` package involves three main steps: defining your configuration struct, initializing the secret manager, and accessing the secret values. + +### 1. Define Your Configuration Struct + +In your configuration struct, use the `secrets.SecretField` type for any fields that should contain secrets. + +```go +package main + +import "github.com/prometheus/common/secrets" + +type MyConfig struct { + APIKey secrets.SecretField `yaml:"api_key"` + Password secrets.SecretField `yaml:"password"` + // ... other config fields +} +``` + +### 2. Configure Secrets in YAML + +Users can then provide secrets in their YAML configuration file. + +For simple secrets, an inline string can be used: + +```yaml +api_key: "my_super_secret_api_key" +``` + +To load a secret from a file, use the `file` provider: + +```yaml +password: + file: /path/to/password.txt +``` + +### 3. Initialize the Secret Manager + +After unmarshaling your configuration file into your struct, you must create a `secrets.Manager` to manage the lifecycle of the secrets. The manager is initialized with a pointer to your configuration struct. + +```go +import ( + "context" + "log" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/common/secrets" + "go.yaml.in/yaml/v2" +) + +func main() { + // A Prometheus registry is needed to register the secret manager's metrics. + promRegisterer := prometheus.NewRegistry() + + // Load config from file + configData := []byte(` +api_key: "my_super_secret_api_key" +password: + file: /path/to/password.txt +`) + var cfg MyConfig + if err := yaml.Unmarshal(configData, &cfg); err != nil { + log.Fatalf("Error unmarshaling config: %v", err) + } + + // Create a secret manager. This discovers and manages all SecretFields in cfg. + // The manager will handle refreshing secrets in the background. + manager, err := secrets.NewManager(promRegisterer, &cfg) + if err != nil { + log.Fatalf("Error creating secret manager: %v", err) + } + // Start the manager's background refresh loop. + manager.Start(context.Background()) + defer manager.Stop() + + + // ... your application logic ... + + // Wait for the secrets in cfg to be ready. + for { + if ready, err := manager.SecretsReady(&cfg); err != nil { + log.Fatalf("Error checking secret readiness: %v", err) + } else if ready { + break + } + } + + // Access the secret value when needed. + apiKey := cfg.APIKey.Get() + password := cfg.Password.Get() + + log.Printf("API Key: %s", apiKey) + log.Printf("Password: %s", password) +} +``` + +### 4. Accessing Secrets + +To get the string value of a secret, simply call the `Get()` method on the `SecretField`. + +```go +secretValue := myConfig.APIKey.Get() +``` + +The manager handles caching and refreshing, so `Get()` will always return the current valid secret. + +## Built-in Providers + +The `secrets` package comes with two built-in providers: + + * `inline`: For secrets that are specified directly as a string in the configuration file. This is the default if a plain string is provided. + ```yaml + api_key: "my_inline_secret" + ``` + * `file`: For secrets that are loaded from a file on disk. + ```yaml + password: + file: + path: /etc/prometheus/secrets/password + ``` + +## Custom Providers + +You can extend the functionality by creating your own custom secret providers. A custom provider must implement the `Provider` interface: + +```go +type Provider interface { + // FetchSecret retrieves the secret value. + FetchSecret(ctx context.Context) (string, error) + + // Name returns the provider's name (e.g., "inline"). + Name() string +} +``` + +Once you have implemented the interface, you need to register a factory function for your provider with the global `ProviderRegistry`. This is typically done in an `init()` function. + +```go +package myprovider + +import ( + "context" + "github.com/prometheus/common/secrets" +) + +type MyCustomProvider struct { + // ... fields for your provider +} + +func (p *MyCustomProvider) FetchSecret(ctx context.Context) (string, error) { + // ... logic to fetch your secret +} + +func (p *MyCustomProvider) Name() string { + return "my_custom_provider" +} + +func init() { + secrets.Providers.Register(func() secrets.Provider { + return &MyCustomProvider{} + }) +} +``` + +## Secret Validation + +For secrets that can be rotated (e.g., loaded from a file that gets updated), you can provide an optional validation function. This prevents a broken or partially written secret from being loaded into your application after a rotation. The manager will use the new secret only after your validation function returns `true`. + +A common use case is to verify that a new authentication token can successfully access a protected endpoint before it is put into active use. This avoids causing monitoring gaps if, for example, a new bearer token is invalid. + +To use this feature, implement the `SecretValidator` interface and attach it to a `SecretField` instance. + +Here is an example of a validator that checks if an HTTP endpoint can be reached using the new secret as a bearer token. It performs an `HEAD` request and considers the secret valid if the server responds with any status code other than `401 Unauthorized` or `403 Forbidden`. + +```go +import ( + "context" + "fmt" + "net/http" + + "github.com/prometheus/common/secrets" +) + +// HTTPBearerTokenValidator checks if a secret is a valid bearer token for a given URL. +type HTTPBearerTokenValidator struct { + EndpointURL string + client *http.Client +} + +func NewHTTPBearerTokenValidator(url string) *HTTPBearerTokenValidator { + return &HTTPBearerTokenValidator{ + EndpointURL: url, + client: &http.Client{}, + } +} + +func (v *HTTPBearerTokenValidator) Validate(ctx context.Context, secret string) bool { + req, err := http.NewRequestWithContext(ctx, "HEAD", v.EndpointURL, nil) + if err != nil { + // Could not create the request, so we cannot validate. + return false + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", secret)) + + resp, err := v.client.Do(req) + if err != nil { + // The request failed, so we cannot consider this valid. + return false + } + defer resp.Body.Close() + + // If the status is Unauthorized or Forbidden, the token is invalid. + // Any other status code (e.g., 200 OK, 404 Not Found) means the token + // was accepted for authentication, so we consider it valid for rotation. + return resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusForbidden +} + +func (v *HTTPBearerTokenValidator) Settings() secrets.ValidationSettings { + // Return custom settings or use the default. + return secrets.DefaultValidationSettings() +} + +// In your application code, after unmarshaling the config: +validator := NewHTTPBearerTokenValidator("https://my-protected-api.com/v1/status") +cfg.APIKey.SetSecretValidation(validator) +``` + +The `ValidationSettings` allow you to configure timeouts, backoff, and retry attempts for the validation logic, making the process resilient to temporary network issues. + +## Prometheus Metrics + +The `Manager` exposes several Prometheus metrics to monitor the state of the secrets it manages. These metrics are registered with the `prometheus.Registerer` that is passed to `NewManager`. + +The following metrics are available, all labeled with `provider` and `secret_id`: + + * `prometheus_remote_secret_last_successful_fetch_seconds`: (Gauge) The Unix timestamp of the last successful secret fetch. + * `prometheus_remote_secret_state`: (Gauge) Describes the current state of a remotely fetched secret (0=success, 1=stale, 2=error, 3=initializing). + * `prometheus_remote_secret_fetch_success_total`: (Counter) Total number of successful secret fetches. + * `prometheus_remote_secret_fetch_failures_total`: (Counter) Total number of failed secret fetches. + * `prometheus_remote_secret_fetch_duration_seconds`: (Histogram) Duration of secret fetch attempts. + * `prometheus_remote_secret_validation_failures_total`: (Counter) Total number of failed secret validations. + +## Error Handling and Panics + +The `secrets` package is designed to be robust, but there is one critical error condition that will cause a panic: using a `SecretField` before the `Manager` has been initialized. + +If you call `Get()` or `TriggerRefresh()` on a `SecretField` that has not been discovered by a `Manager`, your program will panic with the message: + +``` +secret field has not been discovered by a manager; was NewManager(&cfg) called? +``` + +This is a safeguard to prevent the use of unmanaged and potentially empty secrets. To avoid this panic, ensure that you always create a `Manager` by passing a pointer to your configuration struct to `secrets.NewManager` immediately after you unmarshal your configuration. \ No newline at end of file diff --git a/secrets/field.go b/secrets/field.go new file mode 100644 index 00000000..4f06b3c0 --- /dev/null +++ b/secrets/field.go @@ -0,0 +1,181 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import ( + "encoding/json" + "fmt" + "time" + + "go.yaml.in/yaml/v2" +) + +// SecretField is a field containing a secret. +type SecretField struct { + provider Provider + manager *Manager + validator SecretValidator + settings SecretFieldSettings +} + +type SecretFieldSettings struct { + RefreshInterval time.Duration `yaml:"refreshInterval,omitempty"` + Default string `yaml:"default,omitempty"` +} + +func (s SecretField) String() string { + return fmt.Sprintf("SecretField{Provider: %s}", s.provider.Name()) +} + +// MarshalYAML implements the yaml.Marshaler interface for SecretField. +func (s SecretField) MarshalYAML() (interface{}, error) { + if s.provider.Name() == "inline" && s.manager != nil && s.manager.MarshalInlineSecrets { + return s.Get(), nil + } + + // Marshal settings to a map to merge them with the provider config. + settingsBytes, err := yaml.Marshal(s.settings) + if err != nil { + return nil, fmt.Errorf("failed to marshal secret field settings: %w", err) + } + + out := make(map[string]interface{}) + if err := yaml.Unmarshal(settingsBytes, &out); err != nil { + return nil, fmt.Errorf("failed to unmarshal marshaled settings: %w", err) + } + + // Add the provider configuration. + out[s.provider.Name()] = s.provider + return out, nil +} + +// MarshalJSON implements the json.Marshaler interface for SecretField. +func (s SecretField) MarshalJSON() ([]byte, error) { + data, err := s.MarshalYAML() + if err != nil { + return nil, err + } + return json.Marshal(data) +} + +type mapType = map[string]interface{} + +// splitProviderAndSettings separates provider-specific configuration from the generic SecretField settings. +func splitProviderAndSettings(baseMap mapType) (baseProvider Provider, providerData interface{}, settingsMap mapType, err error) { + settingsMap = make(mapType) + providerName := "" + + for k, v := range baseMap { + // Check if the key corresponds to a registered provider. + if p, _ := Providers.Get(k); p != nil { + if providerName != "" { + // A provider has already been found, which is an error. + return nil, nil, nil, fmt.Errorf("secret must contain exactly one provider type, but multiple were found: %s, %s", providerName, k) + } + baseProvider = p + providerName = k + providerData = v + } else { + // If it's not a provider key, treat it as a setting. + settingsMap[k] = v + } + } + + if providerName == "" { + // Marshal the map back to YAML for a readable error message. + yamlBytes, err := yaml.Marshal(baseMap) + if err != nil { + // Fallback to the original format if marshalling fails for some reason. + return nil, nil, nil, fmt.Errorf("no valid secret provider found in configuration: %v", baseMap) + } + return nil, nil, nil, fmt.Errorf("no valid secret provider found in configuration:\n%s", string(yamlBytes)) + } + + return baseProvider, providerData, settingsMap, nil +} + +// convertConfig takes a map-like structure and unmarshals it into a typed struct. +// It achieves this by first marshalling the input to YAML and then unmarshalling +// it into the target struct. +func convertConfig[T any](source interface{}, target T) error { + bytes, err := yaml.Marshal(source) + if err != nil { + return fmt.Errorf("failed to re-marshal config: %w", err) + } + if err := yaml.Unmarshal(bytes, target); err != nil { + return fmt.Errorf("failed to unmarshal config: %w", err) + } + return nil +} + +func (s *SecretField) UnmarshalYAML(unmarshal func(interface{}) error) error { + var plainSecret string + if err := unmarshal(&plainSecret); err == nil { + s.provider = &InlineProvider{ + secret: plainSecret, + } + s.validator = DefaultValidator{} + return nil + } + + var baseMap mapType + if err := unmarshal(&baseMap); err != nil { + return err + } + + concreteProvider, providerConfig, settingsMap, err := splitProviderAndSettings(baseMap) + if err != nil { + return err + } + + if err := convertConfig(providerConfig, concreteProvider); err != nil { + return fmt.Errorf("failed to unmarshal into %s provider: %w", concreteProvider.Name(), err) + } + var settings SecretFieldSettings + if err := convertConfig(settingsMap, &settings); err != nil { + return fmt.Errorf("failed to unmarshal secret field settings: %w", err) + } + s.provider = concreteProvider + s.validator = DefaultValidator{} + s.settings = settings + return nil +} + +// SetSecretValidation registers an optional validation function for the secret. +// +// When the secret manager fetches a new version of the secret, it will not +// be used immediately if there is a validator. Instead, the manager will +// hold the new secret in a pending state and call the provided Validate +// with it until it returns true, there is an explicit refresh request, +// there is a time out, or the old secret was never valid. +func (s *SecretField) SetSecretValidation(validator SecretValidator) { + s.validator = validator + if s.manager != nil { + s.manager.setSecretValidation(s, validator) + } +} + +func (s *SecretField) Get() string { + if s.manager == nil { + panic("secret field has not been discovered by a manager; was NewManager(&cfg) called?") + } + return s.manager.get(s) +} + +func (s *SecretField) TriggerRefresh() { + if s.manager == nil { + panic("secret field has not been discovered by a manager; was NewManager(&cfg) called?") + } + s.manager.triggerRefresh(s) +} diff --git a/secrets/field_test.go b/secrets/field_test.go new file mode 100644 index 00000000..566a1dbc --- /dev/null +++ b/secrets/field_test.go @@ -0,0 +1,185 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.yaml.in/yaml/v2" +) + +func TestSecretField_UnmarshalYAML(t *testing.T) { + tests := []struct { + name string + yaml string + expectProvider Provider + expectSettings SecretFieldSettings + expectErr string + }{ + { + name: "Unmarshal plain string into InlineProvider", + yaml: `my_secret_value`, + expectProvider: &InlineProvider{ + secret: "my_secret_value", + }, + }, + { + name: "Unmarshal file provider", + yaml: ` +file: + path: /path/to/secret +`, + expectProvider: &FileProvider{ + Path: "/path/to/secret", + }, + }, + { + name: "Unmarshal file provider with settings", + yaml: ` +file: + path: /path/to/secret +refreshInterval: 5m +`, + expectProvider: &FileProvider{ + Path: "/path/to/secret", + }, + expectSettings: SecretFieldSettings{ + RefreshInterval: 5 * time.Minute, + }, + }, + { + name: "Error on multiple providers", + yaml: ` +file: + path: /path/to/secret +inline: another_secret +`, + expectErr: "secret must contain exactly one provider type, but multiple were found: ", + }, + { + name: "Error on unknown provider", + yaml: ` +moon_secret_manager: + moon_phase: full +`, + expectErr: `no valid secret provider found in configuration:`, + }, + { + name: "Error on invalid provider config", + yaml: ` +file: + path: [ "this", "should", "be", "a", "string" ] +`, + expectErr: "failed to unmarshal into file provider", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var sf SecretField + err := yaml.Unmarshal([]byte(tt.yaml), &sf) + + if tt.expectErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectErr) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectProvider.Name(), sf.provider.Name()) + assert.Equal(t, tt.expectProvider, sf.provider) + assert.Equal(t, tt.expectSettings, sf.settings) + } + }) + } +} + +func TestSecretField_MarshalYAML(t *testing.T) { + t.Run("Marshal non-inline provider", func(t *testing.T) { + sf := SecretField{ + provider: &FileProvider{Path: "/path/to/token"}, + } + b, err := yaml.Marshal(sf) + require.NoError(t, err) + expected := "file:\n path: /path/to/token\n" + assert.Equal(t, expected, string(b)) + }) + + t.Run("Marshal non-inline provider with settings", func(t *testing.T) { + sf := SecretField{ + provider: &FileProvider{Path: "/path/to/token"}, + settings: SecretFieldSettings{ + RefreshInterval: 10 * time.Minute, + }, + } + b, err := yaml.Marshal(sf) + require.NoError(t, err) + expected := "file:\n path: /path/to/token\nrefreshInterval: 10m0s\n" + assert.Equal(t, expected, string(b)) + }) + + t.Run("Marshal inline provider without manager", func(t *testing.T) { + sf := SecretField{ + provider: &InlineProvider{secret: "my-password"}, + } + b, err := yaml.Marshal(sf) + require.NoError(t, err) + expected := "inline: \n" + assert.Equal(t, expected, string(b)) + }) + + t.Run("Marshal inline provider with manager and MarshalInlineSecrets=false", func(t *testing.T) { + m := &Manager{MarshalInlineSecrets: false} + sf := SecretField{ + manager: m, + provider: &InlineProvider{secret: "my-password"}, + } + b, err := yaml.Marshal(sf) + require.NoError(t, err) + expected := "inline: \n" + assert.Equal(t, expected, string(b)) + }) + + t.Run("Marshal inline provider with manager and MarshalInlineSecrets=true", func(t *testing.T) { + m := &Manager{MarshalInlineSecrets: true} + sf := SecretField{ + manager: m, + provider: &InlineProvider{secret: "my-password"}, + } + b, err := yaml.Marshal(sf) + require.NoError(t, err) + expected := "my-password\n" // Marshals as a plain string + assert.Equal(t, expected, string(b)) + }) +} + +func TestSecretField_MarshalJSON(t *testing.T) { + // JSON marshaling is just a wrapper around YAML marshaling, so a simple test is sufficient. + sf := SecretField{ + provider: &FileProvider{Path: "/path/to/token"}, + } + b, err := json.Marshal(sf) + require.NoError(t, err) + expected := `{"file":{"path":"/path/to/token"}}` + assert.JSONEq(t, expected, string(b)) +} + +func TestSecretField_ManagerPanics(t *testing.T) { + sf := SecretField{} // No manager attached + + assert.PanicsWithValuef(t, "secret field has not been discovered by a manager; was NewManager(&cfg) called?", func() { sf.Get() }, "Get should panic without a manager") + assert.PanicsWithValuef(t, "secret field has not been discovered by a manager; was NewManager(&cfg) called?", func() { sf.TriggerRefresh() }, "TriggerRefresh should panic without a manager") +} diff --git a/secrets/internal_providers.go b/secrets/internal_providers.go new file mode 100644 index 00000000..a44ac2b7 --- /dev/null +++ b/secrets/internal_providers.go @@ -0,0 +1,72 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import ( + "context" + "os" +) + +// FileProvider fetches secrets from a file. +type FileProvider struct { + Path string `yaml:"path" json:"path"` +} + +func (fp *FileProvider) FetchSecret(_ context.Context) (string, error) { + content, err := os.ReadFile(fp.Path) + if err != nil { + return "", err + } + return string(content), nil +} + +func (*FileProvider) Name() string { + return "file" +} + +func (fp *FileProvider) Key() string { + return fp.Path +} + +func (fp *FileProvider) MarshalYAML() (interface{}, error) { + return map[string]interface{}{ + "path": fp.Path, + }, nil +} + +// InlineProvider reads an config secret. +type InlineProvider struct { + secret string +} + +func (ip *InlineProvider) FetchSecret(_ context.Context) (string, error) { + return ip.secret, nil +} + +func (*InlineProvider) Name() string { + return "inline" +} + +func (ip *InlineProvider) Key() string { + return ip.secret +} + +func (*InlineProvider) MarshalYAML() (interface{}, error) { + return "", nil +} + +func init() { + Providers.Register(func() Provider { return &InlineProvider{} }) + Providers.Register(func() Provider { return &FileProvider{} }) +} diff --git a/secrets/internal_providers_test.go b/secrets/internal_providers_test.go new file mode 100644 index 00000000..d3d39abb --- /dev/null +++ b/secrets/internal_providers_test.go @@ -0,0 +1,96 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.yaml.in/yaml/v2" +) + +func TestFileProvider(t *testing.T) { + ctx := context.Background() + secretContent := "my-super-secret-password" + tempDir := t.TempDir() + secretFile := filepath.Join(tempDir, "secret.txt") + + err := os.WriteFile(secretFile, []byte(secretContent), 0o600) + require.NoError(t, err) + + fp := &FileProvider{Path: secretFile} + + t.Run("FetchSecret_Success", func(t *testing.T) { + content, err := fp.FetchSecret(ctx) + require.NoError(t, err) + assert.Equal(t, secretContent, content) + }) + + t.Run("FetchSecret_NotFound", func(t *testing.T) { + badFP := &FileProvider{Path: filepath.Join(tempDir, "non-existant.txt")} + _, err := badFP.FetchSecret(ctx) + require.Error(t, err) + assert.True(t, os.IsNotExist(err)) + }) + + t.Run("Name", func(t *testing.T) { + assert.Equal(t, "file", fp.Name()) + }) + + t.Run("Key", func(t *testing.T) { + assert.Equal(t, secretFile, fp.Key()) + }) + + t.Run("MarshalYAML", func(t *testing.T) { + data, err := fp.MarshalYAML() + require.NoError(t, err) + expected := map[string]interface{}{"path": secretFile} + assert.Equal(t, expected, data) + }) +} + +func TestInlineProvider(t *testing.T) { + ctx := context.Background() + secretContent := "my-inline-secret" + ip := &InlineProvider{secret: secretContent} + + t.Run("FetchSecret", func(t *testing.T) { + content, err := ip.FetchSecret(ctx) + require.NoError(t, err) + assert.Equal(t, secretContent, content) + }) + + t.Run("Name", func(t *testing.T) { + assert.Equal(t, "inline", ip.Name()) + }) + + t.Run("Key", func(t *testing.T) { + assert.Equal(t, secretContent, ip.Key()) + }) + + t.Run("MarshalYAML", func(t *testing.T) { + data, err := ip.MarshalYAML() + require.NoError(t, err) + assert.Equal(t, "", data) + + // Also check the output of a full yaml marshal + out, err := yaml.Marshal(ip) + require.NoError(t, err) + assert.Equal(t, "\n", string(out)) + }) +} diff --git a/secrets/manager.go b/secrets/manager.go new file mode 100644 index 00000000..848b5195 --- /dev/null +++ b/secrets/manager.go @@ -0,0 +1,413 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/prometheus/client_golang/prometheus" +) + +const ( + // fetchTimeout governs the maximum time a single fetch attempt can take. + fetchTimeout = 5 * time.Minute + // fetchInitialBackoff is the initial backoff duration for refetching a secret after a failure. + fetchInitialBackoff = 1 * time.Second + // fetchMaxBackoff is the maximum backoff duration for retrying a failed fetch. + fetchMaxBackoff = 2 * time.Minute + + // the default refresh interval for secrets. + defaultRefreshInterval = time.Hour + + // Prometheus secret states. + stateSuccess float64 = 0 + stateStale float64 = 1 + stateError float64 = 2 + stateInitializing float64 = 3 +) + +type Manager struct { + MarshalInlineSecrets bool + mtx sync.RWMutex + secrets map[*SecretField]*managedSecret + refreshC chan struct{} + allReady atomic.Bool + cancel context.CancelFunc + wg sync.WaitGroup + // Prometheus metrics + lastSuccessfulFetch *prometheus.GaugeVec + secretState *prometheus.GaugeVec + fetchSuccessTotal *prometheus.CounterVec + fetchFailuresTotal *prometheus.CounterVec + fetchDuration *prometheus.HistogramVec + validationFailuresTotal *prometheus.CounterVec +} + +type managedSecret struct { + mtx sync.RWMutex + pendingSecret string + secret string + provider Provider + fetched time.Time + fetchInProgress bool + refreshInterval time.Duration + refreshRequested bool + validator SecretValidator + verified bool + metricLabels prometheus.Labels +} + +// NewManager discovers all SecretField instances within the provided config +// structure using reflection and registers them with this manager. +func NewManager(r prometheus.Registerer, config interface{}) (*Manager, error) { + paths, err := getSecretFields(config) + if err != nil { + return nil, err + } + manager := &Manager{ + secrets: make(map[*SecretField]*managedSecret), + } + manager.registerMetrics(r) + for path, field := range paths { + manager.registerSecret(path, field) + } + return manager, nil +} + +func (m *Manager) registerMetrics(r prometheus.Registerer) { + labels := []string{"provider", "secret_id"} + + m.lastSuccessfulFetch = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "prometheus_remote_secret_last_successful_fetch_seconds", + Help: "The unix timestamp of the last successful secret fetch.", + }, + labels, + ) + m.secretState = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "prometheus_remote_secret_state", + Help: "Describes the current state of a remotely fetched secret (0=success, 1=stale, 2=error, 3=initializing).", + }, + labels, + ) + m.fetchSuccessTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "prometheus_remote_secret_fetch_success_total", + Help: "Total number of successful secret fetches.", + }, + labels, + ) + m.fetchFailuresTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "prometheus_remote_secret_fetch_failures_total", + Help: "Total number of failed secret fetches.", + }, + labels, + ) + + m.fetchDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "prometheus_remote_secret_fetch_duration_seconds", + Help: "Duration of secret fetch attempts.", + Buckets: prometheus.DefBuckets, + }, + labels, + ) + m.validationFailuresTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "prometheus_remote_secret_validation_failures_total", + Help: "Total number of failed secret validations.", + }, + labels, + ) + + // Register all metrics with the provided registry + r.MustRegister( + m.lastSuccessfulFetch, + m.secretState, + m.fetchSuccessTotal, + m.fetchFailuresTotal, + m.fetchDuration, + m.validationFailuresTotal, + ) +} + +func (m *Manager) registerSecret(path string, s *SecretField) { + s.manager = m + + m.mtx.Lock() + defer m.mtx.Unlock() + + labels := prometheus.Labels{ + "provider": s.provider.Name(), + "secret_id": path, + } + + refreshInterval := s.settings.RefreshInterval + if refreshInterval == 0 { + refreshInterval = defaultRefreshInterval + } + + ms := &managedSecret{ + provider: s.provider, + validator: s.validator, + refreshInterval: refreshInterval, + metricLabels: labels, + } + m.secrets[s] = ms + m.secretState.With(labels).Set(stateInitializing) + m.fetchSuccessTotal.With(labels).Add(0) + m.fetchFailuresTotal.With(labels).Add(0) + m.validationFailuresTotal.With(labels).Add(0) +} + +func (m *Manager) secretReady(s *SecretField) bool { + m.mtx.RLock() + defer m.mtx.RUnlock() + return !m.secrets[s].fetched.IsZero() +} + +func (m *Manager) SecretsReady(config interface{}) (bool, error) { + if m.allReady.Load() { + return true, nil + } + paths, err := getSecretFields(config) + if err != nil { + return false, err + } + for _, field := range paths { + if !m.secretReady(field) { + return false, nil + } + } + return true, nil +} + +func (m *Manager) Start(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) + + m.wg.Add(1) + go func() { + defer m.wg.Done() + m.fetchSecretsLoop(ctx) + }() + + m.cancel = cancel +} + +func (m *Manager) Stop() { + m.cancel() + m.wg.Wait() +} + +func (m *Manager) setSecretValidation(s *SecretField, validator SecretValidator) { + m.mtx.RLock() + secret := m.secrets[s] + m.mtx.RUnlock() + secret.mtx.Lock() + secret.validator = validator + secret.mtx.Unlock() +} + +func (m *Manager) get(s *SecretField) string { + if inline, ok := s.provider.(*InlineProvider); ok { + return inline.secret + } + m.mtx.RLock() + secret := m.secrets[s] + m.mtx.RUnlock() + secret.mtx.RLock() + defer secret.mtx.RUnlock() + return secret.secret +} + +func (m *Manager) triggerRefresh(s *SecretField) { + m.mtx.RLock() + secret := m.secrets[s] + m.mtx.RUnlock() + secret.mtx.Lock() + defer secret.mtx.Unlock() + secret.refreshRequested = true + secret.verified = false + secret.secret = secret.pendingSecret + select { + case m.refreshC <- struct{}{}: + default: + // a refresh is already pending, do nothing + } +} + +// fetchSecretsLoop is a long-running goroutine that periodically fetches secrets. +func (m *Manager) fetchSecretsLoop(ctx context.Context) { + timer := time.NewTimer(time.Duration(0)) + defer timer.Stop() + for { + select { + case <-ctx.Done(): + return + case <-timer.C: + case <-m.refreshC: + if !timer.Stop() { + <-timer.C + } + } + m.mtx.RLock() + // Create a list of secrets to check to avoid holding the lock during fetch operations. + secretsToCheck := make([]*managedSecret, 0, len(m.secrets)) + for _, secret := range m.secrets { + secretsToCheck = append(secretsToCheck, secret) + } + m.mtx.RUnlock() + + waitTime := 5 * time.Minute + secretsReady := true + + for _, ms := range secretsToCheck { + ms.mtx.Lock() + + timeToRefresh := time.Until(ms.fetched.Add(ms.refreshInterval)) + refreshNeeded := ms.refreshRequested || timeToRefresh < 0 + waitTime = min(waitTime, ms.refreshInterval) + if ms.fetched.IsZero() { + secretsReady = false + } + + if ms.fetchInProgress { + ms.mtx.Unlock() + continue + } + + if !refreshNeeded { + ms.mtx.Unlock() + if timeToRefresh > 0 { + waitTime = min(waitTime, timeToRefresh) + } + continue + } + ms.fetchInProgress = true + ms.mtx.Unlock() + + go m.fetchAndStoreSecret(ctx, ms) + } + m.allReady.Store(secretsReady) + timer.Reset(waitTime) + } +} + +// fetchAndStoreSecret performs a single secret fetch, including retry logic with exponential backoff. +// It is robust against hangs in the underlying provider's FetchSecret method. +func (m *Manager) fetchAndStoreSecret(ctx context.Context, ms *managedSecret) { + var newSecret string + var err error + ms.mtx.RLock() + provider := ms.provider + labels := ms.metricLabels + hasBeenFetchedBefore := !ms.fetched.IsZero() + ms.mtx.RUnlock() + + backoff := fetchInitialBackoff + for { + fetchCtx, cancel := context.WithTimeout(ctx, fetchTimeout) + + newSecret, err = provider.FetchSecret(fetchCtx) + cancel() + + if err == nil { + break // Success + } + + m.fetchFailuresTotal.With(labels).Inc() + if hasBeenFetchedBefore { + m.secretState.With(labels).Set(stateStale) + } else { + m.secretState.With(labels).Set(stateError) + } + + select { + case <-time.After(backoff): + backoff = min(fetchMaxBackoff, backoff*2) + case <-ctx.Done(): + return + } + } + ms.mtx.Lock() + + m.fetchSuccessTotal.With(labels).Inc() + m.lastSuccessfulFetch.With(labels).SetToCurrentTime() + m.secretState.With(labels).Set(stateSuccess) + + ms.pendingSecret = newSecret + ms.fetched = time.Now() + ms.fetchInProgress = false + ms.refreshRequested = false + + // If a was not verified before, we can swap it immediately + if !ms.verified { + ms.secret = newSecret + } + ms.mtx.Unlock() + m.validateAndStoreField(ctx, ms, newSecret) +} + +// validateAndStoreField performs validation for a single field, including retry logic. +func (m *Manager) validateAndStoreField(ctx context.Context, ms *managedSecret, pendingSecret string) { + var isValid bool + + ms.mtx.RLock() + var validator SecretValidator = DefaultValidator{} + if ms.validator != nil { + validator = ms.validator + } + labels := ms.metricLabels + vs := validator.Settings() + ms.mtx.RUnlock() + + backoff := vs.InitialBackoff + for i := 0; i < vs.MaxRetries; i++ { + ms.mtx.RLock() + shouldRun := ms.pendingSecret == pendingSecret + ms.mtx.RUnlock() + if !shouldRun { + return + } + validateCtx, cancel := context.WithTimeout(ctx, vs.Timeout) + isValid = validator.Validate(validateCtx, pendingSecret) + cancel() + + if isValid { + break // Success + } + m.validationFailuresTotal.With(labels).Inc() + if i < vs.MaxRetries-1 { + select { + case <-time.After(backoff): + backoff = min(vs.MaxBackoff, backoff*2) + case <-ctx.Done(): + return + } + } + } + + ms.mtx.Lock() + defer ms.mtx.Unlock() + + if ms.pendingSecret == pendingSecret { + ms.secret = pendingSecret + ms.verified = true + } +} diff --git a/secrets/manager_test.go b/secrets/manager_test.go new file mode 100644 index 00000000..a12b265f --- /dev/null +++ b/secrets/manager_test.go @@ -0,0 +1,285 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockProvider allows controlling the secret value and simulating errors for tests. +type mockProvider struct { + mtx sync.RWMutex + secret string + fetchErr error + fetchedLatest bool + blockChan chan struct{} + releaseChan chan struct{} +} + +func newMockProvider(secret string) *mockProvider { + return &mockProvider{secret: secret} +} + +func (mp *mockProvider) FetchSecret(ctx context.Context) (string, error) { + // Block if the test requires it, to simulate fetch latency. + if mp.blockChan != nil { + select { + case <-mp.blockChan: + case <-ctx.Done(): + return "", ctx.Err() + } + } + + // Release if the test requires it, to signal fetch has started. + if mp.releaseChan != nil { + close(mp.releaseChan) + } + + mp.mtx.RLock() + defer mp.mtx.RUnlock() + mp.fetchedLatest = true + return mp.secret, mp.fetchErr +} + +func (*mockProvider) Name() string { return "mock" } + +func (mp *mockProvider) setSecret(s string) { + mp.mtx.Lock() + defer mp.mtx.Unlock() + mp.fetchedLatest = false + mp.secret = s +} + +func (mp *mockProvider) setFetchError(err error) { + mp.mtx.Lock() + defer mp.mtx.Unlock() + mp.fetchedLatest = false + mp.fetchErr = err +} + +func (mp *mockProvider) hasFetchedLatest() bool { + mp.mtx.Lock() + defer mp.mtx.Unlock() + return mp.fetchedLatest +} + +// mockValidator allows controlling validation logic for tests. +type mockValidator struct { + mtx sync.RWMutex + secrets map[string]bool + settings ValidationSettings + verifiedLatest string +} + +func newMockValidator() *mockValidator { + return &mockValidator{ + secrets: make(map[string]bool), + settings: DefaultValidationSettings(), + } +} + +func (mv *mockValidator) Validate(_ context.Context, secret string) bool { + mv.mtx.Lock() + defer mv.mtx.Unlock() + m, e := mv.secrets[secret] + mv.verifiedLatest = secret + return m && e +} + +func (mv *mockValidator) Settings() ValidationSettings { + return mv.settings +} + +func (mv *mockValidator) setValid(secret string, isValid bool) { + mv.mtx.Lock() + defer mv.mtx.Unlock() + mv.secrets[secret] = isValid +} + +// testConfig is a struct used for discovering SecretFields in tests. +type testConfig struct { + APIKeys []SecretField `yaml:"api_keys"` +} + +func setupManagerTest(t *testing.T, cfg *testConfig) (*Manager, *prometheus.Registry) { + // Register the mock provider for tests. + originalProviders := Providers + Providers = &ProviderRegistry{} + Providers.Register(func() Provider { return &InlineProvider{} }) + Providers.Register(func() Provider { return &FileProvider{} }) + Providers.Register(func() Provider { return &mockProvider{} }) + + t.Cleanup(func() { + Providers = originalProviders + }) + + reg := prometheus.NewRegistry() + m, err := NewManager(reg, cfg) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + m.Start(ctx) + t.Cleanup(cancel) + t.Cleanup(m.Stop) + + return m, reg +} + +func TestNewManager(t *testing.T) { + provider1 := newMockProvider("secret1") + provider2 := newMockProvider("secret2") + + cfg := &testConfig{ + APIKeys: []SecretField{ + {provider: provider1}, + {provider: provider2}, + {provider: &InlineProvider{secret: "inline_secret"}}, + }, + } + + reg := prometheus.NewRegistry() + m, err := NewManager(reg, cfg) + require.NoError(t, err) + + require.Lenf(t, m.secrets, 3, "Manager should discover 3 secrets") + assert.NotNil(t, m.secrets[&cfg.APIKeys[0]]) + assert.NotNil(t, m.secrets[&cfg.APIKeys[1]]) + assert.NotNil(t, m.secrets[&cfg.APIKeys[2]]) +} + +func TestManager_SecretLifecycle(t *testing.T) { + provider := newMockProvider("initial_secret") + cfg := &testConfig{ + APIKeys: []SecretField{ + { + provider: provider, + settings: SecretFieldSettings{RefreshInterval: 50 * time.Millisecond}, + }, + }, + } + + m, _ := setupManagerTest(t, cfg) + + // 1. Initial fetch + require.Eventuallyf(t, provider.hasFetchedLatest, time.Second, 10*time.Millisecond, "Initial fetch should occur") + assert.Equal(t, "initial_secret", cfg.APIKeys[0].Get()) + + ready, err := m.SecretsReady(cfg) + require.NoError(t, err) + assert.Truef(t, ready, "Secrets should be ready after initial fetch") + + // 2. Scheduled refresh + provider.setSecret("refreshed_secret") + require.Eventuallyf(t, provider.hasFetchedLatest, time.Second, 10*time.Millisecond, "Scheduled refresh should occur") + assert.Equal(t, "refreshed_secret", cfg.APIKeys[0].Get()) + + // 3. Triggered refresh + provider.setSecret("triggered_secret") + cfg.APIKeys[0].TriggerRefresh() + require.Eventuallyf(t, provider.hasFetchedLatest, time.Second, 10*time.Millisecond, "Triggered refresh should occur") + assert.Equal(t, "triggered_secret", cfg.APIKeys[0].Get()) +} + +func TestManager_FetchErrorAndRecovery(t *testing.T) { + provider := newMockProvider("") + provider.setFetchError(errors.New("fetch failed")) + cfg := &testConfig{ + APIKeys: []SecretField{ + { + provider: provider, + }, + }, + } + m, _ := setupManagerTest(t, cfg) + + // Initial fetch fails. + require.Eventuallyf(t, provider.hasFetchedLatest, time.Second, 10*time.Millisecond, "A fetch should be attempted") + assert.Emptyf(t, cfg.APIKeys[0].Get(), "Secret should be empty after failed fetch") + + ready, err := m.SecretsReady(cfg) + require.NoError(t, err) + assert.Falsef(t, ready, "Secrets should not be ready after failed fetch") + + // Recovery. + provider.setFetchError(nil) + provider.setSecret("recovered_secret") + require.Eventuallyf(t, func() bool { return cfg.APIKeys[0].Get() == "recovered_secret" }, 2*time.Second, 50*time.Millisecond, "Manager should recover after error") + + ready, err = m.SecretsReady(cfg) + require.NoError(t, err) + assert.Truef(t, ready, "Secrets should be ready after recovery") +} + +func TestManager_Validation(t *testing.T) { + provider := newMockProvider("initial_valid") + cfg := &testConfig{ + APIKeys: []SecretField{ + { + provider: provider, + settings: SecretFieldSettings{RefreshInterval: 10 * time.Millisecond}, + }, + }, + } + m, _ := setupManagerTest(t, cfg) + validator := newMockValidator() + validator.setValid("initial_valid", true) + validator.setValid("finally_valid", true) + // Make validation super fast for the test. + validator.settings.InitialBackoff = 5 * time.Millisecond + + cfg.APIKeys[0].SetSecretValidation(validator) + + // 1. Initial fetch with successful validation. + require.Eventuallyf(t, provider.hasFetchedLatest, time.Second, 10*time.Millisecond, "A fetch should be attempted") + assert.Equal(t, "initial_valid", cfg.APIKeys[0].Get()) + require.Eventuallyf(t, func() bool { + m.secrets[&cfg.APIKeys[0]].mtx.RLock() + defer m.secrets[&cfg.APIKeys[0]].mtx.RUnlock() + return m.secrets[&cfg.APIKeys[0]].verified + }, + time.Second, 10*time.Millisecond, "should be eventually valid") + + // 2. Refresh with an invalid secret. + provider.setSecret("next_invalid") + require.Eventuallyf(t, provider.hasFetchedLatest, time.Second, 10*time.Millisecond, "A refresh should be attempted") + + // Wait a bit to ensure validation is attempted and fails. + time.Sleep(100 * time.Millisecond) + assert.Equalf(t, "initial_valid", cfg.APIKeys[0].Get(), "Old secret should be kept after validation failure") + + // 3. Refresh again with a now-valid secret. + provider.setSecret("finally_valid") + require.Eventuallyf(t, provider.hasFetchedLatest, time.Second, 10*time.Millisecond, "Another refresh should be attempted") + require.Eventuallyf(t, func() bool { return cfg.APIKeys[0].Get() == "finally_valid" }, time.Second, 10*time.Millisecond, "New secret should be adopted after validation succeeds") +} + +func TestManager_InlineSecret(t *testing.T) { + inlineSecret := "this-is-inline" + cfg := &testConfig{ + APIKeys: []SecretField{ + {provider: &InlineProvider{secret: inlineSecret}}, + }, + } + setupManagerTest(t, cfg) + + assert.Equal(t, inlineSecret, cfg.APIKeys[0].Get()) +} diff --git a/secrets/provider.go b/secrets/provider.go new file mode 100644 index 00000000..8304ea98 --- /dev/null +++ b/secrets/provider.go @@ -0,0 +1,69 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import ( + "context" + "time" +) + +type Provider interface { + // FetchSecret retrieves the secret value. + FetchSecret(ctx context.Context) (string, error) + + // Name returns the provider's name (e.g., "inline"). + Name() string +} + +// SecretValidator allows for validating a new secret before it is +// rotated into active use. If invalid, the old secret will be used +// while it is still considered valid and has not expired. +// This interface is optional, to prevent monitoring gaps if the +// system being scraped hasn't had its secret refreshed yet. +type SecretValidator interface { + Validate(ctx context.Context, secret string) bool + Settings() ValidationSettings +} + +// ValidationSettings holds configurable parameters for secret validation. +type ValidationSettings struct { + // Timeout governs the maximum time a single validation attempt can take. + Timeout time.Duration + // InitialBackoff is the initial backoff duration for re-validating a secret after a failure. + InitialBackoff time.Duration + // MaxBackoff is the maximum backoff duration for retrying a failed validation. + MaxBackoff time.Duration + // MaxRetries is the maximum number of retries for a failed validation. + MaxRetries int +} + +type DefaultValidator struct{} + +func (DefaultValidator) Validate(_ context.Context, _ string) bool { + return true +} + +func (DefaultValidator) Settings() ValidationSettings { + return DefaultValidationSettings() +} + +// DefaultValidationSettings returns a ValidationSettings struct with default values. +func DefaultValidationSettings() ValidationSettings { + return ValidationSettings{ + Timeout: 30 * time.Second, + InitialBackoff: 1 * time.Second, + MaxBackoff: 30 * time.Second, + MaxRetries: 10, + } +} diff --git a/secrets/registry.go b/secrets/registry.go new file mode 100644 index 00000000..771deab8 --- /dev/null +++ b/secrets/registry.go @@ -0,0 +1,42 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import "fmt" + +type providerFactory = func() Provider + +type ProviderRegistry struct { + factoryMap map[string]providerFactory +} + +func (r *ProviderRegistry) Get(name string) (Provider, error) { + if constructor, ok := r.factoryMap[name]; ok { + return constructor(), nil + } + return nil, fmt.Errorf("unknown provider type: %q", name) +} + +func (r *ProviderRegistry) Register(constructor func() Provider) { + name := constructor().Name() + if _, ok := r.factoryMap[name]; ok { + panic(fmt.Sprintf("attempt to register duplicate type: %q", name)) + } + if r.factoryMap == nil { + r.factoryMap = make(map[string]providerFactory) + } + r.factoryMap[name] = constructor +} + +var Providers = &ProviderRegistry{} diff --git a/secrets/registry_test.go b/secrets/registry_test.go new file mode 100644 index 00000000..73bf2c8d --- /dev/null +++ b/secrets/registry_test.go @@ -0,0 +1,75 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testProvider struct { + name string +} + +func (*testProvider) FetchSecret(_ context.Context) (string, error) { + return "test_secret", nil +} + +func (tp *testProvider) Name() string { + return tp.name +} + +func TestProviderRegistry(t *testing.T) { + t.Run("GetInitialProviders", func(t *testing.T) { + // Test that providers from init() are registered in the global registry. + p, err := Providers.Get("inline") + require.NoError(t, err) + assert.IsType(t, &InlineProvider{}, p) + + p, err = Providers.Get("file") + require.NoError(t, err) + assert.IsType(t, &FileProvider{}, p) + }) + + t.Run("GetUnknownProvider", func(t *testing.T) { + _, err := Providers.Get("unknown-provider") + require.Error(t, err) + assert.Contains(t, err.Error(), `unknown provider type: "unknown-provider"`) + }) + + t.Run("RegisterAndGet", func(t *testing.T) { + reg := &ProviderRegistry{} + constructor := func() Provider { return &testProvider{name: "test"} } + + reg.Register(constructor) + p, err := reg.Get("test") + require.NoError(t, err) + assert.IsType(t, &testProvider{}, p) + assert.Equal(t, "test", p.Name()) + }) + + t.Run("RegisterDuplicate", func(t *testing.T) { + reg := &ProviderRegistry{} + constructor1 := func() Provider { return &testProvider{name: "duplicate"} } + constructor2 := func() Provider { return &testProvider{name: "duplicate"} } + + reg.Register(constructor1) + assert.PanicsWithValue(t, `attempt to register duplicate type: "duplicate"`, func() { + reg.Register(constructor2) + }) + }) +} diff --git a/secrets/resolve.go b/secrets/resolve.go new file mode 100644 index 00000000..a5de5127 --- /dev/null +++ b/secrets/resolve.go @@ -0,0 +1,112 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import ( + "fmt" + "reflect" + "sort" + "strings" +) + +const ( + // maxRecursionDepth is the maximum depth for path traversals + // in the config. + maxRecursionDepth = 50 +) + +var secretFieldType = reflect.TypeOf(SecretField{}) + +type secretPaths map[string]*SecretField + +type walkItem struct { + path []string + val reflect.Value +} + +func getSecretFields(v interface{}) (secretPaths, error) { + results := make(secretPaths) + if v == nil { + return results, nil + } + visited := make(map[uintptr]bool) + queue := []walkItem{{path: nil, val: reflect.ValueOf(v)}} + + for len(queue) > 0 { + currentItem := queue[0] + queue = queue[1:] + + path := currentItem.path + val := currentItem.val + if len(path) > maxRecursionDepth { + return nil, fmt.Errorf("path traversal exceeded maximum depth (current depth: %d):\n%v", len(path), path) + } + + if val.Type() == secretFieldType { + path := strings.Join(path, ".") + if !val.CanAddr() { + return nil, fmt.Errorf("path '%s': found SecretField type that is not addressable", path) + } + secret, ok := val.Addr().Interface().(*SecretField) + if !ok { + return nil, fmt.Errorf("path '%s': internal error: matched SecretField type but failed type assertion", path) + } + results[path] = secret + continue + } + queue = process(path, val, visited, queue) + } + return results, nil +} + +func process(path []string, val reflect.Value, visited map[uintptr]bool, queue []walkItem) []walkItem { + if !val.IsValid() { + return queue + } + switch val.Kind() { + case reflect.Ptr: + if val.IsNil() || visited[val.Pointer()] { + return queue + } + visited[val.Pointer()] = true + return append(queue, walkItem{path: path, val: val.Elem()}) + case reflect.Interface: + return append(queue, walkItem{path: path, val: val.Elem()}) + case reflect.Struct: + for i := 0; i < val.NumField(); i++ { + newPath := append(path, val.Type().Field(i).Name) + field := val.Field(i) + if field.CanInterface() { + queue = append(queue, walkItem{path: newPath, val: field}) + } + } + case reflect.Slice, reflect.Array: + for i := 0; i < val.Len(); i++ { + newPath := append(path, fmt.Sprintf("[%d]", i)) + queue = append(queue, walkItem{path: newPath, val: val.Index(i)}) + } + case reflect.Map: + keys := val.MapKeys() + sort.Slice(keys, func(i, j int) bool { + return fmt.Sprintf("%v", keys[i].Interface()) < fmt.Sprintf("%v", keys[j].Interface()) + }) + for _, key := range keys { + keyPath := append(path, fmt.Sprintf("[%v:key]", key.Interface())) + queue = append(queue, walkItem{path: keyPath, val: key}) + valPath := append(path, fmt.Sprintf("[%v]", key.Interface())) + queue = append(queue, walkItem{path: valPath, val: val.MapIndex(key)}) + } + } + return queue +} diff --git a/secrets/resolve_test.go b/secrets/resolve_test.go new file mode 100644 index 00000000..afdcd48b --- /dev/null +++ b/secrets/resolve_test.go @@ -0,0 +1,384 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import ( + "reflect" + "strings" + "testing" +) + +type TestCase struct { + name string + input interface{} + want map[string]string + errContains string +} + +func newSF(secret string) SecretField { + return SecretField{ + provider: &InlineProvider{ + secret: secret, + }, + } +} + +func newSFRef(secret string) *SecretField { + val := newSF(secret) + return &val +} + +func normalizeSecretPaths(sp secretPaths) map[string]string { + normalized := make(map[string]string) + for path, ptr := range sp { + normalized[path] = ptr.provider.(*InlineProvider).secret + } + return normalized +} + +type SimpleStruct struct { + Secret SecretField + Day string +} + +type ManyStruct struct { + Birthday SecretField + MothersMaidenName **SecretField + FavoriteColors []SecretField + BookReviews map[string]*SecretField + FavoriteMaterial SecretField +} + +type NestedStruct struct { + Nested SimpleStruct + TopSecret SecretField +} + +type NestedInterfaceStruct struct { + NestedInterface interface{} + TopSecretI SecretField +} + +type PtrNestedStruct struct { + Indirect *NestedStruct + Number int +} + +type PrivateField struct { + Exported SecretField + secret SecretField +} +type PrivateNestedField struct { + Nested PrivateField +} + +type DeeplyNestedStruct struct { + S SecretField + D *DeeplyNestedStruct +} + +func TestGetSecretFields(t *testing.T) { + pointer := newSF("pointer") + + tests := []TestCase{ + { + name: "Direct SecretField", + input: newSF("direct"), + errContains: "not addressable", + }, + { + name: "Plain SecretField", + input: &pointer, + want: map[string]string{ + "": "pointer", + }, + }, + { + name: "Simple struct with one SecretField", + input: &SimpleStruct{ + Secret: newSF("secret"), + Day: "Monday", + }, + want: map[string]string{ + "Secret": "secret", + }, + }, + { + name: "Struct with multiple SecretFields and nested pointers", + input: &ManyStruct{ + Birthday: newSF("happy_birthday"), + MothersMaidenName: func() **SecretField { + s := newSF("maiden_name") + p := &s + return &p + }(), + FavoriteColors: []SecretField{ + newSF("red"), + newSF("blue"), + newSF("green"), + }, + BookReviews: map[string]*SecretField{ + "The Hitchhiker's Guide to the Galaxy": newSFRef("hitchhiker_secret"), + "The Great Gatsby": newSFRef("gatsby_secret"), + }, + FavoriteMaterial: newSF("oak"), + }, + want: map[string]string{ + "Birthday": "happy_birthday", + "MothersMaidenName": "maiden_name", + "FavoriteColors.[0]": "red", + "FavoriteColors.[1]": "blue", + "FavoriteColors.[2]": "green", + "BookReviews.[The Great Gatsby]": "gatsby_secret", + "BookReviews.[The Hitchhiker's Guide to the Galaxy]": "hitchhiker_secret", + "FavoriteMaterial": "oak", + }, + }, + { + name: "Nested struct with SecretFields at different levels", + input: &NestedStruct{ + Nested: SimpleStruct{ + Secret: newSF("inner_secret"), + Day: "Tuesday", + }, + TopSecret: newSF("outer_secret"), + }, + want: map[string]string{ + "Nested.Secret": "inner_secret", + "TopSecret": "outer_secret", + }, + }, + { + name: "Struct with nil pointer to nested struct", + input: &PtrNestedStruct{ + Indirect: nil, + Number: 10, + }, + want: map[string]string{}, + }, + { + name: "Struct with populated pointer to nested struct", + input: &PtrNestedStruct{ + Indirect: &NestedStruct{ + Nested: SimpleStruct{ + Secret: newSF("pointed_inner_secret"), + Day: "Wednesday", + }, + TopSecret: newSF("pointed_outer_secret"), + }, + Number: 20, + }, + want: map[string]string{ + "Indirect.Nested.Secret": "pointed_inner_secret", + "Indirect.TopSecret": "pointed_outer_secret", + }, + }, + { + name: "Struct with private secret field", + input: &PrivateField{ + Exported: newSF("exported_secret"), + secret: newSF("unexported_secret"), + }, + want: map[string]string{ + "Exported": "exported_secret", + }, + }, + { + name: "Nested struct with private secret field", + input: &PrivateNestedField{ + Nested: PrivateField{ + Exported: newSF("exported_secret"), + secret: newSF("unexported_secret"), + }, + }, + want: map[string]string{ + "Nested.Exported": "exported_secret", + }, + }, + { + name: "Nil input", + input: nil, + want: map[string]string{}, + }, + { + name: "Pointer to nil input", + input: func() *int { + var x *int + return x + }(), + want: map[string]string{}, + }, + { + name: "Empty struct", + input: &struct{}{}, + want: map[string]string{}, + }, + { + name: "Struct with no SecretFields", + input: &struct { + Name string + Age int + }{ + Name: "John Doe", + Age: 30, + }, + want: map[string]string{}, + }, + { + name: "Deeply nested struct (should handle depth)", + input: &DeeplyNestedStruct{ + S: newSF("level_1"), + D: &DeeplyNestedStruct{ + S: newSF("level_2"), + D: &DeeplyNestedStruct{ + S: newSF("level_3"), + D: nil, + }, + }, + }, + want: map[string]string{ + "S": "level_1", + "D.S": "level_2", + "D.D.S": "level_3", + }, + }, + { + name: "Interface holding a SimpleStruct", + input: &NestedInterfaceStruct{ + NestedInterface: &SimpleStruct{ + Secret: newSF("interface_secret"), + Day: "Friday", + }, + TopSecretI: newSF("interface_top_secret"), + }, + want: map[string]string{ + "NestedInterface.Secret": "interface_secret", + "TopSecretI": "interface_top_secret", + }, + }, + { + name: "Interface holding a pointer to SimpleStruct", + input: &NestedInterfaceStruct{ + NestedInterface: &SimpleStruct{ + Secret: newSF("interface_ptr_secret"), + Day: "Saturday", + }, + TopSecretI: newSF("interface_ptr_top_secret"), + }, + want: map[string]string{ + "NestedInterface.Secret": "interface_ptr_secret", + "TopSecretI": "interface_ptr_top_secret", + }, + }, + { + name: "Interface holding a primitive type (no secrets)", + input: &NestedInterfaceStruct{ + NestedInterface: "hello world", + TopSecretI: newSF("primitive_interface_top_secret"), + }, + want: map[string]string{ + "TopSecretI": "primitive_interface_top_secret", + }, + }, + { + name: "Slice of SecretFields", + input: &[]SecretField{ + newSF("slice_secret_1"), + newSF("slice_secret_2"), + }, + want: map[string]string{ + "[0]": "slice_secret_1", + "[1]": "slice_secret_2", + }, + }, + { + name: "Map with SecretField values", + input: &map[string]SecretField{ + "key1": newSF("map_secret_1"), + "key2": newSF("map_secret_2"), + }, + errContains: "not addressable", + }, + { + name: "Map with SecretField values references", + input: &map[string]*SecretField{ + "key1": newSFRef("map_secret_1"), + "key2": newSFRef("map_secret_2"), + }, + want: map[string]string{ + "[key1]": "map_secret_1", + "[key2]": "map_secret_2", + }, + }, + { + name: "Empty slice", + input: &[]SimpleStruct{}, + want: map[string]string{}, + }, + { + name: "Empty map", + input: &map[string]SimpleStruct{}, + want: map[string]string{}, + }, + { + name: "Deeply nested struct exceeding max depth (should return error)", + input: func() *DeeplyNestedStruct { + head := &DeeplyNestedStruct{S: newSF("head_secret")} + current := head + for i := 0; i < 51; i++ { // Create a chain longer than 50 + current.D = &DeeplyNestedStruct{S: newSF("level_" + (string)(rune('a'+i)))} + current = current.D + } + return head + }(), + errContains: "path traversal exceeded maximum depth", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + gotPaths, gotErr := getSecretFields(tc.input) + + // Check for expected error. + if tc.errContains != "" { + if gotErr == nil { + t.Fatalf("Expected error containing '%s', but got no error", tc.errContains) + } + if !strings.Contains(gotErr.Error(), tc.errContains) { + t.Errorf("Expected error containing '%s', but got: %v", tc.errContains, gotErr) + } + return + } else if gotErr != nil { + t.Fatalf("Did not expect an error, but got: %v", gotErr) + } + + normalizedGotPaths := normalizeSecretPaths(gotPaths) + + if !reflect.DeepEqual(normalizedGotPaths, tc.want) { + t.Errorf("GetSecretFields() got = %v, want %v", normalizedGotPaths, tc.want) + for k, v := range tc.want { + if actualVal, ok := normalizedGotPaths[k]; !ok { + t.Errorf("Missing %v = %q", k, v) + } else if actualVal != v { + t.Errorf("Mimatch %v = %q (want %q)", k, actualVal, v) + } + } + for k, v := range normalizedGotPaths { + if _, ok := tc.want[k]; !ok { + t.Errorf("Unexpected %v = %q", k, v) + } + } + } + }) + } +}