From 5babfa7d6b893cd987a166966d0bf79c03c6f4db Mon Sep 17 00:00:00 2001 From: Anand Rajagopal Date: Tue, 6 May 2025 21:32:41 +0000 Subject: [PATCH] Receivers initial commit Signed-off-by: Anand Rajagopal --- cmd/alertmanager/main.go | 25 ++- config/notifiers.go | 39 ++--- config/receiver/receiver.go | 7 +- go.mod | 14 ++ go.sum | 28 ++++ notify/pagerduty/pagerduty.go | 42 +++-- secrets/generic_secret.go | 32 ++++ secrets/providers/aws_secrets_manager.go | 185 +++++++++++++++++++++++ secrets/secrets_provider.go | 118 +++++++++++++++ 9 files changed, 453 insertions(+), 37 deletions(-) create mode 100644 secrets/generic_secret.go create mode 100644 secrets/providers/aws_secrets_manager.go create mode 100644 secrets/secrets_provider.go diff --git a/cmd/alertmanager/main.go b/cmd/alertmanager/main.go index 2b918061f4..c394673060 100644 --- a/cmd/alertmanager/main.go +++ b/cmd/alertmanager/main.go @@ -17,6 +17,8 @@ import ( "context" "errors" "fmt" + "github.com/prometheus/alertmanager/secrets" + "github.com/prometheus/alertmanager/secrets/providers" "log/slog" "net" "net/http" @@ -158,10 +160,10 @@ func run() int { httpTimeout = kingpin.Flag("web.timeout", "Timeout for HTTP requests. If negative or zero, no timeout is set.").Default("0").Duration() memlimitRatio = kingpin.Flag("auto-gomemlimit.ratio", "The ratio of reserved GOMEMLIMIT memory to the detected maximum container or system memory. The value must be greater than 0 and less than or equal to 1."). - Default("0.9").Float64() + Default("0.9").Float64() clusterBindAddr = kingpin.Flag("cluster.listen-address", "Listen address for cluster. Set to empty string to disable HA mode."). - Default(defaultClusterAddr).String() + Default(defaultClusterAddr).String() clusterAdvertiseAddr = kingpin.Flag("cluster.advertise-address", "Explicit address to advertise in cluster.").String() peers = kingpin.Flag("cluster.peer", "Initial peers (may be repeated).").Strings() peerTimeout = kingpin.Flag("cluster.peer-timeout", "Time to wait between peers to send notifications.").Default("15s").Duration() @@ -402,8 +404,9 @@ func run() int { } var ( - inhibitor *inhibit.Inhibitor - tmpl *template.Template + inhibitor *inhibit.Inhibitor + tmpl *template.Template + secretsProviderRegistry *secrets.SecretsProviderRegistry ) dispMetrics := dispatch.NewDispatcherMetrics(false, prometheus.DefaultRegisterer) @@ -414,6 +417,9 @@ func run() int { prometheus.DefaultRegisterer, configLogger, ) + defer func() { + secretsProviderRegistry.Stop() + }() configCoordinator.Subscribe(func(conf *config.Config) error { tmpl, err = template.FromGlobs(conf.Templates) if err != nil { @@ -428,6 +434,10 @@ func run() int { activeReceivers[r.RouteOpts.Receiver] = struct{}{} }) + spRegistry := secrets.NewSecretsProviderRegistry(logger, prometheus.NewRegistry()) + // currently only one secrets providers is supported + spRegistry.Register(providers.AWSSecretsManagerSecretProviderDiscoveryConfig{}) + spRegistry.Init() // Build the map of receiver to integrations. receivers := make(map[string][]notify.Integration, len(activeReceivers)) var integrationsNum int @@ -437,7 +447,7 @@ func run() int { configLogger.Info("skipping creation of receiver not referenced by any route", "receiver", rcv.Name) continue } - integrations, err := receiver.BuildReceiverIntegrations(rcv, tmpl, logger) + integrations, err := receiver.BuildReceiverIntegrations(rcv, tmpl, logger, spRegistry) if err != nil { return err } @@ -460,10 +470,13 @@ func run() int { inhibitor.Stop() disp.Stop() + if secretsProviderRegistry != nil { + secretsProviderRegistry.Stop() + } inhibitor = inhibit.NewInhibitor(alerts, conf.InhibitRules, marker, logger) silencer := silence.NewSilencer(silences, marker, logger) - + secretsProviderRegistry = spRegistry // An interface value that holds a nil concrete value is non-nil. // Therefore we explicly pass an empty interface, to detect if the // cluster is not enabled in notify. diff --git a/config/notifiers.go b/config/notifiers.go index 87f806aa27..ff3a0d1394 100644 --- a/config/notifiers.go +++ b/config/notifiers.go @@ -16,6 +16,7 @@ package config import ( "errors" "fmt" + "github.com/prometheus/alertmanager/secrets" "net/textproto" "regexp" "strings" @@ -328,22 +329,22 @@ type PagerdutyConfig struct { HTTPConfig *commoncfg.HTTPClientConfig `yaml:"http_config,omitempty" json:"http_config,omitempty"` - ServiceKey Secret `yaml:"service_key,omitempty" json:"service_key,omitempty"` - ServiceKeyFile string `yaml:"service_key_file,omitempty" json:"service_key_file,omitempty"` - RoutingKey Secret `yaml:"routing_key,omitempty" json:"routing_key,omitempty"` - RoutingKeyFile string `yaml:"routing_key_file,omitempty" json:"routing_key_file,omitempty"` - URL *URL `yaml:"url,omitempty" json:"url,omitempty"` - Client string `yaml:"client,omitempty" json:"client,omitempty"` - ClientURL string `yaml:"client_url,omitempty" json:"client_url,omitempty"` - Description string `yaml:"description,omitempty" json:"description,omitempty"` - Details map[string]string `yaml:"details,omitempty" json:"details,omitempty"` - Images []PagerdutyImage `yaml:"images,omitempty" json:"images,omitempty"` - Links []PagerdutyLink `yaml:"links,omitempty" json:"links,omitempty"` - Source string `yaml:"source,omitempty" json:"source,omitempty"` - Severity string `yaml:"severity,omitempty" json:"severity,omitempty"` - Class string `yaml:"class,omitempty" json:"class,omitempty"` - Component string `yaml:"component,omitempty" json:"component,omitempty"` - Group string `yaml:"group,omitempty" json:"group,omitempty"` + ServiceKey *secrets.GenericSecret `yaml:"service_key,omitempty" json:"service_key,omitempty"` + ServiceKeyFile string `yaml:"service_key_file,omitempty" json:"service_key_file,omitempty"` + RoutingKey *secrets.GenericSecret `yaml:"routing_key,omitempty" json:"routing_key,omitempty"` + RoutingKeyFile string `yaml:"routing_key_file,omitempty" json:"routing_key_file,omitempty"` + URL *URL `yaml:"url,omitempty" json:"url,omitempty"` + Client string `yaml:"client,omitempty" json:"client,omitempty"` + ClientURL string `yaml:"client_url,omitempty" json:"client_url,omitempty"` + Description string `yaml:"description,omitempty" json:"description,omitempty"` + Details map[string]string `yaml:"details,omitempty" json:"details,omitempty"` + Images []PagerdutyImage `yaml:"images,omitempty" json:"images,omitempty"` + Links []PagerdutyLink `yaml:"links,omitempty" json:"links,omitempty"` + Source string `yaml:"source,omitempty" json:"source,omitempty"` + Severity string `yaml:"severity,omitempty" json:"severity,omitempty"` + Class string `yaml:"class,omitempty" json:"class,omitempty"` + Component string `yaml:"component,omitempty" json:"component,omitempty"` + Group string `yaml:"group,omitempty" json:"group,omitempty"` } // PagerdutyLink is a link. @@ -366,13 +367,13 @@ func (c *PagerdutyConfig) UnmarshalYAML(unmarshal func(interface{}) error) error if err := unmarshal((*plain)(c)); err != nil { return err } - if c.RoutingKey == "" && c.ServiceKey == "" && c.RoutingKeyFile == "" && c.ServiceKeyFile == "" { + if c.RoutingKey == nil && c.ServiceKey == nil && c.RoutingKeyFile == "" && c.ServiceKeyFile == "" { return errors.New("missing service or routing key in PagerDuty config") } - if len(c.RoutingKey) > 0 && len(c.RoutingKeyFile) > 0 { + if c.RoutingKey != nil && len(c.RoutingKeyFile) > 0 { return errors.New("at most one of routing_key & routing_key_file must be configured") } - if len(c.ServiceKey) > 0 && len(c.ServiceKeyFile) > 0 { + if c.ServiceKey != nil && len(c.ServiceKeyFile) > 0 { return errors.New("at most one of service_key & service_key_file must be configured") } if c.Details == nil { diff --git a/config/receiver/receiver.go b/config/receiver/receiver.go index d92a19a4c5..33a85850a3 100644 --- a/config/receiver/receiver.go +++ b/config/receiver/receiver.go @@ -14,6 +14,7 @@ package receiver import ( + "github.com/prometheus/alertmanager/secrets" "log/slog" commoncfg "github.com/prometheus/common/config" @@ -43,7 +44,7 @@ import ( // BuildReceiverIntegrations builds a list of integration notifiers off of a // receiver config. -func BuildReceiverIntegrations(nc config.Receiver, tmpl *template.Template, logger *slog.Logger, httpOpts ...commoncfg.HTTPClientOption) ([]notify.Integration, error) { +func BuildReceiverIntegrations(nc config.Receiver, tmpl *template.Template, logger *slog.Logger, spRegistry *secrets.SecretsProviderRegistry, httpOpts ...commoncfg.HTTPClientOption) ([]notify.Integration, error) { if logger == nil { logger = promslog.NewNopLogger() } @@ -68,7 +69,9 @@ func BuildReceiverIntegrations(nc config.Receiver, tmpl *template.Template, logg add("email", i, c, func(l *slog.Logger) (notify.Notifier, error) { return email.New(c, tmpl, l), nil }) } for i, c := range nc.PagerdutyConfigs { - add("pagerduty", i, c, func(l *slog.Logger) (notify.Notifier, error) { return pagerduty.New(c, tmpl, l, httpOpts...) }) + add("pagerduty", i, c, func(l *slog.Logger) (notify.Notifier, error) { + return pagerduty.New(c, tmpl, l, spRegistry, httpOpts...) + }) } for i, c := range nc.OpsGenieConfigs { add("opsgenie", i, c, func(l *slog.Logger) (notify.Notifier, error) { return opsgenie.New(c, tmpl, l, httpOpts...) }) diff --git a/go.mod b/go.mod index 886f16f31d..3b2fcf7f74 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,9 @@ require ( github.com/alecthomas/kingpin/v2 v2.4.0 github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b github.com/aws/aws-sdk-go v1.55.5 + github.com/aws/aws-sdk-go-v2 v1.36.3 + github.com/aws/aws-sdk-go-v2/config v1.29.14 + github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.35.4 github.com/cenkalti/backoff/v4 v4.3.0 github.com/cespare/xxhash/v2 v2.3.0 github.com/coder/quartz v0.1.2 @@ -53,6 +56,17 @@ require ( require ( github.com/armon/go-metrics v0.3.10 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect + github.com/aws/smithy-go v1.22.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index 8b90e15442..60511aaa2b 100644 --- a/go.sum +++ b/go.sum @@ -80,6 +80,34 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3d github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.35.4 h1:EKXYJ8kgz4fiqef8xApu7eH0eae2SrVG+oHCLFybMRI= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.35.4/go.mod h1:yGhDiLKguA3iFJYxbrQkQiNzuy+ddxesSZYWVeeEH5Q= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/notify/pagerduty/pagerduty.go b/notify/pagerduty/pagerduty.go index abab5a70be..6a2ba01be2 100644 --- a/notify/pagerduty/pagerduty.go +++ b/notify/pagerduty/pagerduty.go @@ -19,6 +19,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/prometheus/alertmanager/secrets" "io" "log/slog" "net/http" @@ -45,27 +46,30 @@ const ( // Notifier implements a Notifier for PagerDuty notifications. type Notifier struct { - conf *config.PagerdutyConfig - tmpl *template.Template - logger *slog.Logger - apiV1 string // for tests. - client *http.Client - retrier *notify.Retrier + conf *config.PagerdutyConfig + tmpl *template.Template + logger *slog.Logger + apiV1 string // for tests. + client *http.Client + retrier *notify.Retrier + secretsFetcher secrets.SecretsFetcher } // New returns a new PagerDuty notifier. -func New(c *config.PagerdutyConfig, t *template.Template, l *slog.Logger, httpOpts ...commoncfg.HTTPClientOption) (*Notifier, error) { +func New(c *config.PagerdutyConfig, t *template.Template, l *slog.Logger, spRegistry *secrets.SecretsProviderRegistry, httpOpts ...commoncfg.HTTPClientOption) (*Notifier, error) { client, err := commoncfg.NewClientFromConfig(*c.HTTPConfig, "pagerduty", httpOpts...) if err != nil { return nil, err } n := &Notifier{conf: c, tmpl: t, logger: l, client: client} - if c.ServiceKey != "" || c.ServiceKeyFile != "" { + if c.ServiceKey != nil || c.ServiceKeyFile != "" { + n.secretsFetcher, err = spRegistry.RegisterSecret(c.ServiceKey) n.apiV1 = "https://events.pagerduty.com/generic/2010-04-15/create_event.json" // Retrying can solve the issue on 403 (rate limiting) and 5xx response codes. // https://v2.developer.pagerduty.com/docs/trigger-events n.retrier = ¬ify.Retrier{RetryCodes: []int{http.StatusForbidden}, CustomDetailsFunc: errDetails} } else { + n.secretsFetcher, err = spRegistry.RegisterSecret(c.RoutingKey) // Retrying can solve the issue on 429 (rate limiting) and 5xx response codes. // https://v2.developer.pagerduty.com/docs/events-api-v2#api-response-codes--retry-logic n.retrier = ¬ify.Retrier{RetryCodes: []int{http.StatusTooManyRequests}, CustomDetailsFunc: errDetails} @@ -143,6 +147,22 @@ func (n *Notifier) encodeMessage(msg *pagerDutyMessage) (bytes.Buffer, error) { return buf, nil } +func (n *Notifier) getSecret(ctx context.Context) string { + var secret *secrets.GenericSecret + if n.conf.ServiceKey != nil { + secret = n.conf.ServiceKey + } else { + secret = n.conf.RoutingKey + } + + if sec, err := n.secretsFetcher.FetchSecret(ctx, secret); err != nil { + n.logger.Error("unable to fetch secret", err) + return "" + } else { + return sec + } +} + func (n *Notifier) notifyV1( ctx context.Context, eventType string, @@ -159,7 +179,8 @@ func (n *Notifier) notifyV1( n.logger.Warn("Truncated description", "key", key, "max_runes", maxV1DescriptionLenRunes) } - serviceKey := string(n.conf.ServiceKey) + //serviceKey := string(n.conf.ServiceKey) + serviceKey := n.getSecret(ctx) if serviceKey == "" { content, fileErr := os.ReadFile(n.conf.ServiceKeyFile) if fileErr != nil { @@ -224,7 +245,8 @@ func (n *Notifier) notifyV2( n.logger.Warn("Truncated summary", "key", key, "max_runes", maxV2SummaryLenRunes) } - routingKey := string(n.conf.RoutingKey) + //routingKey := string(n.conf.RoutingKey) + routingKey := n.getSecret(ctx) if routingKey == "" { content, fileErr := os.ReadFile(n.conf.RoutingKeyFile) if fileErr != nil { diff --git a/secrets/generic_secret.go b/secrets/generic_secret.go new file mode 100644 index 0000000000..8742a705af --- /dev/null +++ b/secrets/generic_secret.go @@ -0,0 +1,32 @@ +package secrets + +import ( + "errors" + "time" +) + +type GenericSecret struct { + AWSSecretsManagerConfig *AWSSecretsManagerConfig `yaml:"aws_secrets_manager" json:"aws_secrets_manager_config"` +} + +// TODO implement this correctly +func (gs *GenericSecret) String() string { + return "" +} + +// TODO implement Marshal and JSON equivalent methods +func (gs *GenericSecret) UnmarshalYAML(unmarshalFn func(any) error) error { + var inlineForm string + if err := unmarshalFn(&inlineForm); err == nil { + return errors.New("inline form is not supported") + } + type plain GenericSecret + // We need to do this to avoid infinite recursion. + return unmarshalFn((*plain)(gs)) +} + +type AWSSecretsManagerConfig struct { + SecretARN string `yaml:"secret_arn"` + SecretKey string `yaml:"secret_key"` + RefreshInterval time.Duration `yaml:"refresh_interval"` +} diff --git a/secrets/providers/aws_secrets_manager.go b/secrets/providers/aws_secrets_manager.go new file mode 100644 index 0000000000..d569ad8405 --- /dev/null +++ b/secrets/providers/aws_secrets_manager.go @@ -0,0 +1,185 @@ +package providers + +import ( + "context" + "encoding/json" + "errors" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + "github.com/prometheus/alertmanager/secrets" + "github.com/prometheus/client_golang/prometheus" + "log/slog" + "sync" + "time" +) + +//TODO metrics + +type AWSSecretsManagerProvider struct { + mtx sync.RWMutex + fetchers map[string]*secretFetcher + logger *slog.Logger + reg prometheus.Registerer + ctx context.Context +} + +func (a *AWSSecretsManagerProvider) Register(secret *secrets.GenericSecret) secrets.SecretsFetcher { + s := secret.AWSSecretsManagerConfig + if s == nil { + a.logger.Error("secret is nil. nothing to register") + return nil + } + a.logger.Info("registering secret") + a.mtx.Lock() + defer a.mtx.Unlock() + if f, OK := a.fetchers[s.SecretARN]; OK { + a.logger.Info("found an existing secret fetcher") + f.update(s.RefreshInterval) + return f + } + a.logger.Info("no secret fetcher found. creating a new one") + a.fetchers[s.SecretARN] = newSecretFetcher(a.ctx, a.logger, a.reg, s.SecretARN, s.RefreshInterval) + return a.fetchers[s.SecretARN] +} + +func (a *AWSSecretsManagerProvider) Stop() { + a.mtx.Lock() + defer a.mtx.Unlock() + for name, fetcher := range a.fetchers { + a.logger.Info("stopping secrets fetcher", "name", name) + fetcher.Stop() + } + a.logger.Info("aws secrets manager providers stopped") +} + +type secretFetcher struct { + secrets map[string]string + mtx sync.RWMutex + logger *slog.Logger + reg prometheus.Registerer + arn string + interval time.Duration + ctx context.Context + client *secretsmanager.Client + done chan struct{} + ticker *time.Ticker + initialFetch bool +} + +func newSecretFetcher(ctx context.Context, logger *slog.Logger, reg prometheus.Registerer, arn string, interval time.Duration) *secretFetcher { + sf := &secretFetcher{ + secrets: make(map[string]string), + logger: logger, + reg: reg, + arn: arn, + interval: interval, + ctx: ctx, + done: make(chan struct{}), + ticker: time.NewTicker(interval), + } + sf.createSecretsManagerClient() + go sf.run() + return sf +} + +func (s *secretFetcher) createSecretsManagerClient() { + parsedARN, err := arn.Parse(s.arn) + if err != nil { + s.logger.Error("unable to create secret manager client", err) + return + } + config, err := awsconfig.LoadDefaultConfig(s.ctx, awsconfig.WithRegion(parsedARN.Region)) + if err != nil { + s.logger.Error("unable to load config", err) + return + } + s.client = secretsmanager.NewFromConfig(config) +} + +func (s *secretFetcher) Stop() { + <-s.done + s.logger.Info("secret fetcher stopped") +} + +func (s *secretFetcher) run() { + defer close(s.done) + defer s.ticker.Stop() + input := &secretsmanager.GetSecretValueInput{ + SecretId: aws.String(s.arn), + } + s.logger.Debug("fetch secret", "reason", "initial") + s.retrieveSecret(input) + s.initialFetch = true + for { + select { + case <-s.ticker.C: + s.logger.Debug("fetching secret", "reason", "periodic") + s.retrieveSecret(input) + s.initialFetch = true + case <-s.ctx.Done(): + s.logger.Info("stopping secrets fetcher") + return + } + } +} + +func (s *secretFetcher) retrieveSecret(input *secretsmanager.GetSecretValueInput) { + result, err := s.client.GetSecretValue(s.ctx, input) + if err != nil { + s.logger.Error("unable to fetch secret for ARN", "arn", s.arn, "error", err) + return + } + secretString := *result.SecretString + var m map[string]string + if err = json.Unmarshal([]byte(secretString), &m); err != nil { + s.logger.Error("unable to unmarshal payload", "arn", s.arn, "error", err) + return + } + s.logger.Debug("retrieved keys", "key count", len(m)) + s.mtx.Lock() + defer s.mtx.Unlock() + s.secrets = nil + s.secrets = m +} + +func (s *secretFetcher) update(interval time.Duration) { + s.mtx.Lock() + defer s.mtx.Unlock() + if s.interval > interval { + s.interval = interval + s.ticker.Reset(s.interval) + } +} + +func (s *secretFetcher) FetchSecret(_ context.Context, secret *secrets.GenericSecret) (string, error) { + sec := secret.AWSSecretsManagerConfig + if sec == nil { + return "", errors.New("cannot fetch empty secret") + } + + s.mtx.RLock() + value, exists := s.secrets[sec.SecretKey] + s.mtx.RUnlock() + if !exists { + return "", errors.New("secret not found") + } + return value, nil +} + +type AWSSecretsManagerSecretProviderDiscoveryConfig struct { +} + +func (a AWSSecretsManagerSecretProviderDiscoveryConfig) Name() string { + return "aws_secrets_manager" +} + +func (a AWSSecretsManagerSecretProviderDiscoveryConfig) NewSecretsProvider(options secrets.SecretProviderOptions) (secrets.SecretsProvider, error) { + return &AWSSecretsManagerProvider{ + fetchers: make(map[string]*secretFetcher), + logger: options.Logger, + reg: options.Registerer, + ctx: options.Context, + }, nil +} diff --git a/secrets/secrets_provider.go b/secrets/secrets_provider.go new file mode 100644 index 0000000000..407a9505e0 --- /dev/null +++ b/secrets/secrets_provider.go @@ -0,0 +1,118 @@ +package secrets + +import ( + "context" + "errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/common/config" + "log/slog" + "sync" +) + +var ( + AWS_SECRETS_MANAGER_PROVIDER = "aws_secrets_manager" +) + +type SecretsFetcher interface { + FetchSecret(ctx context.Context, secret *GenericSecret) (string, error) + Stop() +} + +type SecretsProvider interface { + Register(secret *GenericSecret) SecretsFetcher + Stop() +} + +type SecretsProviderRegistry struct { + mtx sync.RWMutex + providers map[string]SecretsProvider + logger *slog.Logger + reg prometheus.Registerer + configs map[string]SecretProviderDiscoveryConfig + ctx context.Context + cancel context.CancelFunc +} + +func NewSecretsProviderRegistry(logger *slog.Logger, reg prometheus.Registerer) *SecretsProviderRegistry { + registry := &SecretsProviderRegistry{ + providers: make(map[string]SecretsProvider), + configs: make(map[string]SecretProviderDiscoveryConfig), + logger: logger, + reg: reg, + } + return registry +} + +func (s *SecretsProviderRegistry) Register(config SecretProviderDiscoveryConfig) { + s.mtx.Lock() + defer s.mtx.Unlock() + s.logger.Info("registering secret providers", "name", config.Name()) + s.configs[config.Name()] = config +} + +func (s *SecretsProviderRegistry) Init() { + s.mtx.Lock() + defer s.mtx.Unlock() + s.ctx, s.cancel = context.WithCancel(context.Background()) + for name, providerConfig := range s.configs { + s.logger.Info("initializing secret providers", "name", name) + provider, err := providerConfig.NewSecretsProvider(SecretProviderOptions{ + Logger: s.logger, + Registerer: s.reg, + Context: s.ctx, + }) + if err != nil { + s.logger.Error("unable to initialize secrets provider", "name", name, "error", err.Error()) + continue + } + s.providers[name] = provider + } +} + +func (s *SecretsProviderRegistry) Stop() { + if s == nil { + return + } + s.mtx.Lock() + defer s.mtx.Unlock() + if s.cancel == nil { + return + } + s.cancel() + s.cancel = nil + for name, provider := range s.providers { + s.logger.Info("stopping secrets providers", "name", name) + provider.Stop() + } + s.logger.Info("stopped secrets providers registry") +} + +func (s *SecretsProviderRegistry) RegisterSecret(secret *GenericSecret) (SecretsFetcher, error) { + s.mtx.RLock() + defer s.mtx.RUnlock() + + s.logger.Info("registering secret") + if secret.AWSSecretsManagerConfig != nil { + s.logger.Info("registering aws_secret_manager secret") + return s.providers[AWS_SECRETS_MANAGER_PROVIDER].Register(secret), nil + } + return nil, errors.New("no secrets fetcher found for the given secret") +} + +type SecretProviderDiscoveryConfig interface { + // Name returns the name of the discovery mechanism. + Name() string + + NewSecretsProvider(SecretProviderOptions) (SecretsProvider, error) +} + +type SecretProviderOptions struct { + Logger *slog.Logger + + // A registerer for the SecretProvider's metrics. + Registerer prometheus.Registerer + + HTTPClientOptions []config.HTTPClientOption + + Context context.Context +}