diff --git a/cfg/kafka_client.go b/cfg/kafka_client.go index 9b9087f6a..00f83191d 100644 --- a/cfg/kafka_client.go +++ b/cfg/kafka_client.go @@ -1,10 +1,14 @@ package cfg import ( + "context" + "errors" "os" + "github.com/ozontech/file.d/xoauth" "github.com/twmb/franz-go/pkg/kgo" "github.com/twmb/franz-go/pkg/sasl/aws" + "github.com/twmb/franz-go/pkg/sasl/oauth" "github.com/twmb/franz-go/pkg/sasl/plain" "github.com/twmb/franz-go/pkg/sasl/scram" "github.com/twmb/franz-go/plugin/kzap" @@ -27,6 +31,8 @@ type KafkaClientSaslConfig struct { SaslMechanism string SaslUsername string SaslPassword string + + SaslOAuth KafkaClientOAuthConfig } type KafkaClientSslConfig struct { @@ -36,7 +42,63 @@ type KafkaClientSslConfig struct { SslSkipVerify bool } -func GetKafkaClientOptions(c KafkaClientConfig, l *zap.Logger) []kgo.Opt { +type KafkaClientOAuthConfig struct { + // static + Token string `json:"token"` + + // dynamic + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + TokenURL string `json:"token_url"` + Scopes []string `json:"scopes" slice:"true"` + AuthStyle string `json:"auth_style" default:"params" options:"params|header"` +} + +func (c *KafkaClientOAuthConfig) isStatic() bool { + return c.Token != "" +} + +func (c *KafkaClientOAuthConfig) isDynamic() bool { + return c.ClientID != "" && c.ClientSecret != "" && c.TokenURL != "" +} + +func (c *KafkaClientOAuthConfig) isValid() bool { + return c.isStatic() || c.isDynamic() +} + +func GetKafkaClientOAuthTokenSource(ctx context.Context, cfg KafkaClientConfig) (xoauth.TokenSource, error) { + saslCfg := cfg.GetSaslConfig() + + if !cfg.IsSaslEnabled() || saslCfg.SaslMechanism != "OAUTHBEARER" { + return nil, nil + } + + saslOAuth := saslCfg.SaslOAuth + if !saslOAuth.isValid() { + return nil, errors.New("invalid SASL OAUTHBEARER config") + } + + if saslOAuth.isDynamic() { + authStyle := xoauth.AuthStyleInParams + if saslOAuth.AuthStyle == "header" { + authStyle = xoauth.AuthStyleInHeader + } + + return xoauth.NewReuseTokenSource(ctx, &xoauth.Config{ + ClientID: saslOAuth.ClientID, + ClientSecret: saslOAuth.ClientSecret, + TokenURL: saslOAuth.TokenURL, + Scopes: saslOAuth.Scopes, + AuthStyle: authStyle, + }) + } + + return xoauth.NewStaticTokenSource(&xoauth.Token{ + AccessToken: saslOAuth.Token, + }), nil +} + +func GetKafkaClientOptions(c KafkaClientConfig, l *zap.Logger, tokenSource xoauth.TokenSource) []kgo.Opt { opts := []kgo.Opt{ kgo.SeedBrokers(c.GetBrokers()...), kgo.ClientID(c.GetClientID()), @@ -66,6 +128,18 @@ func GetKafkaClientOptions(c KafkaClientConfig, l *zap.Logger) []kgo.Opt { AccessKey: saslConfig.SaslUsername, SecretKey: saslConfig.SaslPassword, }.AsManagedStreamingIAMMechanism())) + case "OAUTHBEARER": + authFn := func(ctx context.Context) (oauth.Auth, error) { + if tokenSource == nil { + return oauth.Auth{}, errors.New("uninitialized token source") + } + t := tokenSource.Token(ctx) + if t == nil { + return oauth.Auth{}, errors.New("empty token from token source") + } + return oauth.Auth{Token: t.AccessToken}, nil + } + opts = append(opts, kgo.SASL(oauth.Oauth(authFn))) } } diff --git a/e2e/kafka_auth/kafka_auth.go b/e2e/kafka_auth/kafka_auth.go index 43fe5409f..db2839aa7 100644 --- a/e2e/kafka_auth/kafka_auth.go +++ b/e2e/kafka_auth/kafka_auth.go @@ -124,8 +124,9 @@ func (c *Config) Configure(t *testing.T, _ *cfg.Config, _ string) { config.ClientCert = "./kafka_auth/certs/client_cert.pem" } - kafka_out.NewClient(config, + kafka_out.NewClient(context.Background(), config, zap.NewNop().WithOptions(zap.WithFatalHook(zapcore.WriteThenPanic)), + nil, ) }, func() { @@ -154,9 +155,9 @@ func (c *Config) Configure(t *testing.T, _ *cfg.Config, _ string) { config.ClientCert = "./kafka_auth/certs/client_cert.pem" } - kafka_in.NewClient(config, + kafka_in.NewClient(context.Background(), config, zap.NewNop().WithOptions(zap.WithFatalHook(zapcore.WriteThenPanic)), - Consumer{}, + Consumer{}, nil, ) }, } diff --git a/e2e/kafka_file/kafka_file.go b/e2e/kafka_file/kafka_file.go index 3a32e906d..60ac47d1d 100644 --- a/e2e/kafka_file/kafka_file.go +++ b/e2e/kafka_file/kafka_file.go @@ -52,8 +52,9 @@ func (c *Config) Send(t *testing.T) { BatchSize_: c.Count, } - client := kafka_out.NewClient(config, + client := kafka_out.NewClient(context.Background(), config, zap.NewNop().WithOptions(zap.WithFatalHook(zapcore.WriteThenPanic)), + nil, ) adminClient := kadm.NewClient(client) _, err := adminClient.CreateTopic(context.TODO(), 1, 1, nil, c.Topics[0]) diff --git a/e2e/split_join/split_join.go b/e2e/split_join/split_join.go index cbd119bb0..652223629 100644 --- a/e2e/split_join/split_join.go +++ b/e2e/split_join/split_join.go @@ -63,9 +63,9 @@ func (c *Config) Configure(t *testing.T, conf *cfg.Config, pipelineName string) HeartbeatInterval_: 10 * time.Second, } - c.client = kafka_in.NewClient(config, + c.client = kafka_in.NewClient(context.Background(), config, zap.NewNop().WithOptions(zap.WithFatalHook(zapcore.WriteThenPanic)), - Consumer{}, + Consumer{}, nil, ) adminClient := kadm.NewClient(c.client) diff --git a/go.mod b/go.mod index 01f77cf23..a16ab139c 100644 --- a/go.mod +++ b/go.mod @@ -45,11 +45,11 @@ require ( github.com/twmb/franz-go/plugin/kzap v1.1.2 github.com/twmb/tlscfg v1.2.1 github.com/valyala/fasthttp v1.48.0 - github.com/xdg-go/scram v1.1.2 go.uber.org/atomic v1.11.0 go.uber.org/automaxprocs v1.5.3 go.uber.org/zap v1.27.0 golang.org/x/net v0.47.0 + golang.org/x/oauth2 v0.27.0 google.golang.org/protobuf v1.36.5 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 @@ -131,8 +131,6 @@ require ( github.com/twmb/franz-go/pkg/kmsg v1.12.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect - github.com/xdg-go/pbkdf2 v1.0.0 // indirect - github.com/xdg-go/stringprep v1.0.4 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel v1.34.0 // indirect @@ -143,7 +141,6 @@ require ( go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/mod v0.29.0 // indirect - golang.org/x/oauth2 v0.27.0 // indirect golang.org/x/sync v0.18.0 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/term v0.37.0 // indirect diff --git a/go.sum b/go.sum index feff277bf..efebc884c 100644 --- a/go.sum +++ b/go.sum @@ -363,12 +363,6 @@ github.com/valyala/fasthttp v1.48.0 h1:oJWvHb9BIZToTQS3MuQ2R3bJZiNSa2KiNdeI8A+79 github.com/valyala/fasthttp v1.48.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= -github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= -github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= -github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= -github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= -github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= @@ -499,7 +493,6 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/plugin/input/kafka/README.md b/plugin/input/kafka/README.md index 56885602c..2350a456e 100755 --- a/plugin/input/kafka/README.md +++ b/plugin/input/kafka/README.md @@ -140,7 +140,7 @@ If set, the plugin will use SASL authentications mechanism.
-**`sasl_mechanism`** *`string`* *`default=SCRAM-SHA-512`* *`options=PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM`* +**`sasl_mechanism`** *`string`* *`default=SCRAM-SHA-512`* *`options=PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM|OAUTHBEARER`* SASL mechanism to use. @@ -158,6 +158,22 @@ SASL password.
+**`sasl_oauth`** *`cfg.KafkaClientOAuthConfig`* + +SASL OAUTHBEARER config. It works only if `sasl_mechanism:"OAUTHBEARER"`. +> There are 2 options - a static token or a dynamically updated. + +`OAuthConfig` params: +* **`token`** *`string`* - static token +--- +* **`client_id`** *`string`* - client ID +* **`client_secret`** *`string`* - client secret +* **`token_url`** *`string`* - resource server's token endpoint URL +* **`scopes`** *`string`* - optional requested permissions +* **`auth_style`** *`string`* - specifies how the endpoint wants the client ID & client secret sent + +
+ **`is_ssl_enabled`** *`bool`* *`default=false`* If set, the plugin will use SSL/TLS connections method. diff --git a/plugin/input/kafka/client.go b/plugin/input/kafka/client.go index f316ad2ce..a570aeca8 100644 --- a/plugin/input/kafka/client.go +++ b/plugin/input/kafka/client.go @@ -5,12 +5,13 @@ import ( "time" "github.com/ozontech/file.d/cfg" + "github.com/ozontech/file.d/xoauth" "github.com/twmb/franz-go/pkg/kgo" "go.uber.org/zap" ) -func NewClient(c *Config, l *zap.Logger, s Consumer) *kgo.Client { - opts := cfg.GetKafkaClientOptions(c, l) +func NewClient(ctx context.Context, c *Config, l *zap.Logger, s Consumer, tokenSource xoauth.TokenSource) *kgo.Client { + opts := cfg.GetKafkaClientOptions(c, l, tokenSource) opts = append(opts, []kgo.Opt{ kgo.ConsumerGroup(c.ConsumerGroup), kgo.ConsumeTopics(c.Topics...), @@ -53,10 +54,10 @@ func NewClient(c *Config, l *zap.Logger, s Consumer) *kgo.Client { l.Fatal("can't create kafka client", zap.Error(err)) } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + pingCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - err = client.Ping(ctx) + err = client.Ping(pingCtx) if err != nil { l.Fatal("can't connect to kafka", zap.Error(err)) } diff --git a/plugin/input/kafka/kafka.go b/plugin/input/kafka/kafka.go index 6e7c0e068..15ccfc21b 100644 --- a/plugin/input/kafka/kafka.go +++ b/plugin/input/kafka/kafka.go @@ -9,6 +9,7 @@ import ( "github.com/ozontech/file.d/metric" "github.com/ozontech/file.d/pipeline" "github.com/ozontech/file.d/pipeline/metadata" + "github.com/ozontech/file.d/xoauth" "github.com/prometheus/client_golang/prometheus" "github.com/twmb/franz-go/pkg/kgo" "go.uber.org/zap" @@ -47,19 +48,18 @@ pipelines: type Plugin struct { config *Config - logger *zap.SugaredLogger - client *kgo.Client + logger *zap.Logger cancel context.CancelFunc controller pipeline.InputPluginController - idByTopic map[string]int + + client *kgo.Client + s *splitConsume + tokenSource xoauth.TokenSource + metaTemplater *metadata.MetaTemplater // plugin metrics commitErrorsMetric prometheus.Counter consumeErrorsMetric prometheus.Counter - - metaTemplater *metadata.MetaTemplater - - s *splitConsume } type OffsetType byte @@ -177,7 +177,7 @@ type Config struct { // > @3@4@5@6 // > // > SASL mechanism to use. - SaslMechanism string `json:"sasl_mechanism" default:"SCRAM-SHA-512" options:"PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM"` // * + SaslMechanism string `json:"sasl_mechanism" default:"SCRAM-SHA-512" options:"PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM|OAUTHBEARER"` // * // > @3@4@5@6 // > @@ -189,6 +189,21 @@ type Config struct { // > SASL password. SaslPassword string `json:"sasl_password" default:"password"` // * + // > @3@4@5@6 + // > + // > SASL OAUTHBEARER config. It works only if `sasl_mechanism:"OAUTHBEARER"`. + // >> There are 2 options - a static token or a dynamically updated. + // > + // > `OAuthConfig` params: + // > * **`token`** *`string`* - static token + // > --- + // > * **`client_id`** *`string`* - client ID + // > * **`client_secret`** *`string`* - client secret + // > * **`token_url`** *`string`* - resource server's token endpoint URL + // > * **`scopes`** *`string`* - optional requested permissions + // > * **`auth_style`** *`string`* - specifies how the endpoint wants the client ID & client secret sent + SaslOAuth cfg.KafkaClientOAuthConfig `json:"sasl_oauth" child:"true"` // * + // > @3@4@5@6 // > // > If set, the plugin will use SSL/TLS connections method. @@ -242,6 +257,7 @@ func (c *Config) GetSaslConfig() cfg.KafkaClientSaslConfig { SaslMechanism: c.SaslMechanism, SaslUsername: c.SaslUsername, SaslPassword: c.SaslPassword, + SaslOAuth: c.SaslOAuth, } } @@ -271,36 +287,45 @@ func Factory() (pipeline.AnyPlugin, pipeline.AnyConfig) { func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.InputPluginParams) { p.controller = params.Controller - p.logger = params.Logger + p.logger = params.Logger.Desugar() p.config = config.(*Config) p.registerMetrics(params.MetricCtl) if len(p.config.Meta) > 0 { p.metaTemplater = metadata.NewMetaTemplater( p.config.Meta, - p.logger.Desugar(), + p.logger, params.PipelineSettings.MetaCacheSize, ) } - p.idByTopic = make(map[string]int, len(p.config.Topics)) + idByTopic := make(map[string]int, len(p.config.Topics)) for i, topic := range p.config.Topics { - p.idByTopic[topic] = i + idByTopic[topic] = i } ctx, cancel := context.WithCancel(context.Background()) p.cancel = cancel + p.s = &splitConsume{ consumers: make(map[tp]*pconsumer), bufferSize: p.config.ChannelBufferSize, maxConcurrentConsumers: p.config.MaxConcurrentConsumers, - idByTopic: p.idByTopic, + idByTopic: idByTopic, controller: p.controller, - logger: p.logger.Desugar(), + logger: p.logger, metaTemplater: p.metaTemplater, consumeErrorsMetric: p.consumeErrorsMetric, } - p.client = NewClient(p.config, p.logger.Desugar(), p.s) + + var err error + p.tokenSource, err = cfg.GetKafkaClientOAuthTokenSource(ctx, p.config) + if err != nil { + p.logger.Fatal(err.Error()) + } + + p.client = NewClient(ctx, p.config, p.logger, p.s, p.tokenSource) + p.controller.UseSpread() p.controller.DisableStreams() @@ -313,13 +338,16 @@ func (p *Plugin) registerMetrics(ctl *metric.Ctl) { } func (p *Plugin) Stop() { - p.logger.Infof("Stopping") + p.logger.Info("Stopping") err := p.client.CommitMarkedOffsets(context.Background()) if err != nil { p.commitErrorsMetric.Inc() - p.logger.Errorf("can't commit marked offsets: %s", err.Error()) + p.logger.Error("can't commit marked offsets", zap.Error(err)) } p.client.Close() + if p.tokenSource != nil { + p.tokenSource.Stop() + } p.cancel() } diff --git a/plugin/output/kafka/README.md b/plugin/output/kafka/README.md index d1ef51b20..84cff1058 100755 --- a/plugin/output/kafka/README.md +++ b/plugin/output/kafka/README.md @@ -115,7 +115,7 @@ If set, the plugin will use SASL authentications mechanism.
-**`sasl_mechanism`** *`string`* *`default=SCRAM-SHA-512`* *`options=PLAIN|SCRAM-SHA-256|SCRAM-SHA-512`* +**`sasl_mechanism`** *`string`* *`default=SCRAM-SHA-512`* *`options=PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM|OAUTHBEARER`* SASL mechanism to use. @@ -133,6 +133,22 @@ SASL password.
+**`sasl_oauth`** *`cfg.KafkaClientOAuthConfig`* + +SASL OAUTHBEARER config. It works only if `sasl_mechanism:"OAUTHBEARER"`. +> There are 2 options - a static token or a dynamically updated. + +`OAuthConfig` params: +* **`token`** *`string`* - static token +--- +* **`client_id`** *`string`* - client ID +* **`client_secret`** *`string`* - client secret +* **`token_url`** *`string`* - resource server's token endpoint URL +* **`scopes`** *`[]string`* - optional requested permissions +* **`auth_style`** *`string`* - specifies how the endpoint wants the client ID & client secret sent + +
+ **`is_ssl_enabled`** *`bool`* *`default=false`* If set, the plugin will use SSL/TLS connections method. diff --git a/plugin/output/kafka/client.go b/plugin/output/kafka/client.go index 9f2853537..4869e3f86 100644 --- a/plugin/output/kafka/client.go +++ b/plugin/output/kafka/client.go @@ -5,6 +5,7 @@ import ( "time" "github.com/ozontech/file.d/cfg" + "github.com/ozontech/file.d/xoauth" "github.com/twmb/franz-go/pkg/kgo" "go.uber.org/zap" ) @@ -15,8 +16,8 @@ type KafkaClient interface { ForceMetadataRefresh() } -func NewClient(c *Config, l *zap.Logger) *kgo.Client { - opts := cfg.GetKafkaClientOptions(c, l) +func NewClient(ctx context.Context, c *Config, l *zap.Logger, tokenSource xoauth.TokenSource) *kgo.Client { + opts := cfg.GetKafkaClientOptions(c, l, tokenSource) opts = append(opts, []kgo.Opt{ kgo.DefaultProduceTopic(c.DefaultTopic), kgo.MaxBufferedRecords(c.BatchSize_), @@ -54,10 +55,10 @@ func NewClient(c *Config, l *zap.Logger) *kgo.Client { l.Fatal("can't create kafka client", zap.Error(err)) } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + pingCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - err = client.Ping(ctx) + err = client.Ping(pingCtx) if err != nil { l.Fatal("can't connect to kafka", zap.Error(err)) } diff --git a/plugin/output/kafka/kafka.go b/plugin/output/kafka/kafka.go index 213ff57ec..69245eed1 100644 --- a/plugin/output/kafka/kafka.go +++ b/plugin/output/kafka/kafka.go @@ -3,12 +3,14 @@ package kafka import ( "context" "errors" + "fmt" "time" "github.com/ozontech/file.d/cfg" "github.com/ozontech/file.d/fd" "github.com/ozontech/file.d/metric" "github.com/ozontech/file.d/pipeline" + "github.com/ozontech/file.d/xoauth" "github.com/prometheus/client_golang/prometheus" "github.com/twmb/franz-go/pkg/kerr" "github.com/twmb/franz-go/pkg/kgo" @@ -32,13 +34,15 @@ type data struct { } type Plugin struct { - logger *zap.SugaredLogger + logger *zap.Logger config *Config avgEventSize int controller pipeline.OutputPluginController + cancel context.CancelFunc - client KafkaClient - batcher *pipeline.RetriableBatcher + client KafkaClient + batcher *pipeline.RetriableBatcher + tokenSource xoauth.TokenSource // plugin metrics sendErrorMetric prometheus.Counter @@ -151,7 +155,7 @@ type Config struct { // > @3@4@5@6 // > // > SASL mechanism to use. - SaslMechanism string `json:"sasl_mechanism" default:"SCRAM-SHA-512" options:"PLAIN|SCRAM-SHA-256|SCRAM-SHA-512"` // * + SaslMechanism string `json:"sasl_mechanism" default:"SCRAM-SHA-512" options:"PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM|OAUTHBEARER"` // * // > @3@4@5@6 // > @@ -163,6 +167,21 @@ type Config struct { // > SASL password. SaslPassword string `json:"sasl_password" default:"password"` // * + // > @3@4@5@6 + // > + // > SASL OAUTHBEARER config. It works only if `sasl_mechanism:"OAUTHBEARER"`. + // >> There are 2 options - a static token or a dynamically updated. + // > + // > `OAuthConfig` params: + // > * **`token`** *`string`* - static token + // > --- + // > * **`client_id`** *`string`* - client ID + // > * **`client_secret`** *`string`* - client secret + // > * **`token_url`** *`string`* - resource server's token endpoint URL + // > * **`scopes`** *`[]string`* - optional requested permissions + // > * **`auth_style`** *`string`* - specifies how the endpoint wants the client ID & client secret sent + SaslOAuth cfg.KafkaClientOAuthConfig `json:"sasl_oauth" child:"true"` // * + // > @3@4@5@6 // > // > If set, the plugin will use SSL/TLS connections method. @@ -206,6 +225,7 @@ func (c *Config) GetSaslConfig() cfg.KafkaClientSaslConfig { SaslMechanism: c.SaslMechanism, SaslUsername: c.SaslUsername, SaslPassword: c.SaslPassword, + SaslOAuth: c.SaslOAuth, } } @@ -235,7 +255,7 @@ func Factory() (pipeline.AnyPlugin, pipeline.AnyConfig) { func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginParams) { p.config = config.(*Config) - p.logger = params.Logger + p.logger = params.Logger.Desugar() p.avgEventSize = params.PipelineSettings.AvgEventSize p.controller = params.Controller p.registerMetrics(params.MetricCtl) @@ -244,9 +264,18 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP p.logger.Fatal("'retention' can't be <1") } - p.logger.Infof("workers count=%d, batch size=%d", p.config.WorkersCount_, p.config.BatchSize_) + p.logger.Info(fmt.Sprintf("workers count=%d, batch size=%d", p.config.WorkersCount_, p.config.BatchSize_)) - p.client = NewClient(p.config, p.logger.Desugar()) + ctx, cancel := context.WithCancel(context.Background()) + p.cancel = cancel + + var err error + p.tokenSource, err = cfg.GetKafkaClientOAuthTokenSource(ctx, p.config) + if err != nil { + p.logger.Fatal(err.Error()) + } + + p.client = NewClient(ctx, p.config, p.logger, p.tokenSource) batcherOpts := pipeline.BatcherOptions{ PipelineName: params.PipelineName, @@ -275,7 +304,7 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP level = zapcore.ErrorLevel } - p.logger.Desugar().Log(level, "can't write batch", + p.logger.Log(level, "can't write batch", zap.Int("retries", p.config.Retry), ) @@ -343,7 +372,7 @@ func (p *Plugin) out(workerData *pipeline.WorkerData, batch *pipeline.Batch) err if errors.Is(err, kerr.LeaderNotAvailable) || errors.Is(err, kerr.NotLeaderForPartition) { p.client.ForceMetadataRefresh() } - p.logger.Errorf("can't write batch: %v", err) + p.logger.Error("can't write batch", zap.Error(err)) p.sendErrorMetric.Inc() return err } @@ -354,4 +383,8 @@ func (p *Plugin) out(workerData *pipeline.WorkerData, batch *pipeline.Batch) err func (p *Plugin) Stop() { p.batcher.Stop() p.client.Close() + if p.tokenSource != nil { + p.tokenSource.Stop() + } + p.cancel() } diff --git a/plugin/output/kafka/kafka_test.go b/plugin/output/kafka/kafka_test.go index a1a5e9b5a..7061b750b 100644 --- a/plugin/output/kafka/kafka_test.go +++ b/plugin/output/kafka/kafka_test.go @@ -62,7 +62,7 @@ func FuzzKafka(f *testing.F) { } worker := pipeline.WorkerData(nil) - logger := zaptest.NewLogger(f).Sugar() + logger := zaptest.NewLogger(f) p := Plugin{ logger: logger, config: &config, diff --git a/xoauth/backoff.go b/xoauth/backoff.go new file mode 100644 index 000000000..d0a7c017c --- /dev/null +++ b/xoauth/backoff.go @@ -0,0 +1,72 @@ +package xoauth + +import ( + "math" + "math/rand" + "time" +) + +// exponentialJitterBackoff provides a callback which will +// perform en exponential backoff based on the attempt number and with jitter to +// prevent a thundering herd. +// +// min and max here are *not* absolute values. The number to be multiplied by +// the attempt number will be chosen at random from between them, thus they are +// bounding the jitter. +func exponentialJitterBackoff() func(min, max time.Duration, attemptNum int) time.Duration { + rnd := rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) + + return func(min, max time.Duration, attemptNum int) time.Duration { + minf := float64(min) + mult := math.Pow(2, float64(attemptNum)) * minf + + jitter := rnd.Float64() * (mult - minf) + + sleepf := mult + jitter + maxf := float64(max) + if sleepf > maxf { + sleepf = maxf + } + + return time.Duration(sleepf) + } +} + +// linearJitterBackoff will perform linear backoff based on the attempt number and with jitter to +// prevent a thundering herd. +// +// min and max here are *not* absolute values. The number to be multiplied by +// the attempt number will be chosen at random from between them, thus they are +// bounding the jitter. +// +// For instance: +// - To get strictly linear backoff of one second increasing each retry, set +// both to one second (1s, 2s, 3s, 4s, ...) +// - To get a small amount of jitter centered around one second increasing each +// retry, set to around one second, such as a min of 800ms and max of 1200ms +// (892ms, 2102ms, 2945ms, 4312ms, ...) +// - To get extreme jitter, set to a very wide spread, such as a min of 100ms +// and a max of 20s (15382ms, 292ms, 51321ms, 35234ms, ...) +func linearJitterBackoff() func(min, max time.Duration, attemptNum int) time.Duration { + rnd := rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) + + return func(min, max time.Duration, attemptNum int) time.Duration { + // attemptNum always starts at zero but we want to start at 1 for multiplication + attemptNum++ + + if max <= min { + // Unclear what to do here, or they are the same, so return min * + // attemptNum + return min * time.Duration(attemptNum) + } + + // Pick a random number that lies somewhere between the min and max and + // multiply by the attemptNum. attemptNum starts at zero so we always + // increment here. We first get a random percentage, then apply that to the + // difference between min and max, and add to min. + jitter := rnd.Float64() * float64(max-min) + + jitterMin := int64(jitter) + int64(min) + return time.Duration(jitterMin * int64(attemptNum)) + } +} diff --git a/xoauth/backoff_test.go b/xoauth/backoff_test.go new file mode 100644 index 000000000..c0091e1fe --- /dev/null +++ b/xoauth/backoff_test.go @@ -0,0 +1,94 @@ +package xoauth + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestExponentialJitterBackoff(t *testing.T) { + cases := []struct { + name string + min time.Duration + max time.Duration + attempts int + }{ + { + name: "1s_1m_10", + min: time.Second, + max: time.Minute, + attempts: 10, + }, + { + name: "1m_1h_10", + min: time.Minute, + max: time.Hour, + attempts: 10, + }, + { + name: "1s_1m_1000", + min: time.Second, + max: time.Minute, + attempts: 1000, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + backoff := exponentialJitterBackoff() + + for attempt := range tt.attempts { + got := backoff(tt.min, tt.max, attempt) + require.GreaterOrEqual(t, got, tt.min) + require.LessOrEqual(t, got, tt.max) + } + }) + } +} + +func TestLinearJitterBackoff(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + min time.Duration + max time.Duration + attempts int + }{ + { + name: "1s_1m_60", + min: time.Second, + max: time.Minute, + attempts: 60, + }, + { + name: "1s_1s_60", + min: time.Second, + max: time.Second, + attempts: 60, + }, + { + name: "1s_2s_1000", + min: time.Second, + max: time.Second, + attempts: 1000, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + backoff := linearJitterBackoff() + + for attempt := range tt.attempts { + got := backoff(tt.min, tt.max, attempt) + require.GreaterOrEqual(t, got, tt.min) + require.LessOrEqual(t, got, int64(tt.max)*int64(attempt+1)) + } + }) + } +} diff --git a/xoauth/error.go b/xoauth/error.go new file mode 100644 index 000000000..ddcada7c0 --- /dev/null +++ b/xoauth/error.go @@ -0,0 +1,90 @@ +package xoauth + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "mime" + "net" + "net/http" + + "golang.org/x/oauth2" +) + +// error fields: https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 +type errorJSON struct { + ErrorCode string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` +} + +const ( + // issuer errors + ecInvalidRequest = "invalid_request" + ecInvalidClient = "invalid_client" + ecInvalidGrant = "invalid_grant" + ecInvalidScope = "invalid_scope" + ecUnauthorizedClient = "unauthorized_client" + ecUnsupportedGrantType = "unsupported_grant_type" + + ecTimeout = "timeout" + ecNetwork = "network_error" + ecRateLimit = "rate_limit" + + ecUnknownServer = "unknown_server" + ecUnknown = "unknown" +) + +type errorAuth struct { + code string + message string + cause error +} + +func (e *errorAuth) Error() string { return e.message } +func (e *errorAuth) Code() string { return e.code } +func (e *errorAuth) Unwrap() error { return e.cause } +func (e *errorAuth) Is(target error) bool { return errors.Is(target, e.cause) } + +func wrapErr(err error, code, prefix string) *errorAuth { + return &errorAuth{ + code: code, + message: fmt.Sprintf("%s: %v", prefix, err), + + cause: err, + } +} + +func parseError(err error) *errorAuth { + var retrieveErr *oauth2.RetrieveError + if errors.As(err, &retrieveErr) { + return parseRetrieveError(err, retrieveErr) + } + + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return wrapErr(err, ecTimeout, "timeout") + } + + var netError net.Error + if errors.As(err, &netError) { + return wrapErr(netError, ecNetwork, "network error") + } + + return wrapErr(err, ecUnknown, "unknown error") +} + +func parseRetrieveError(err error, retrieveErr *oauth2.RetrieveError) *errorAuth { + content, _, _ := mime.ParseMediaType(retrieveErr.Response.Header.Get("Content-Type")) + if content == "application/json" { + var errJson errorJSON + if err = json.Unmarshal(retrieveErr.Body, &errJson); err == nil { + return wrapErr(retrieveErr, errJson.ErrorCode, errJson.ErrorDescription) + } + } + + if retrieveErr.Response.StatusCode == http.StatusTooManyRequests { + return wrapErr(err, ecRateLimit, "rate limit") + } + + return wrapErr(err, ecUnknownServer, err.Error()) +} diff --git a/xoauth/error_test.go b/xoauth/error_test.go new file mode 100644 index 000000000..cad275634 --- /dev/null +++ b/xoauth/error_test.go @@ -0,0 +1,104 @@ +package xoauth + +import ( + "context" + "errors" + "net" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestParseError(t *testing.T) { + cases := []struct { + name string + in []error + want *errorAuth + }{ + { + name: "retrieve_err_json", + in: []error{ + &oauth2.RetrieveError{ + Response: &http.Response{ + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + }, + Body: []byte(`{"error":"invalid_client","error_description":"some err description"}`), + }, + }, + want: &errorAuth{ + code: ecInvalidClient, + }, + }, + { + name: "retrieve_err_json_unmarshal_err", + in: []error{ + &oauth2.RetrieveError{ + Response: &http.Response{ + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + }, + Body: []byte(`invalid json`), + }, + &oauth2.RetrieveError{ + Response: &http.Response{ + StatusCode: http.StatusBadRequest, + }, + }, + }, + want: &errorAuth{ + code: ecUnknownServer, + }, + }, + { + name: "retrieve_err_rate_limit", + in: []error{ + &oauth2.RetrieveError{ + Response: &http.Response{ + StatusCode: http.StatusTooManyRequests, + }, + }, + }, + want: &errorAuth{ + code: ecRateLimit, + }, + }, + { + name: "timeout", + in: []error{context.DeadlineExceeded, context.Canceled}, + want: &errorAuth{ + code: ecTimeout, + }, + }, + { + name: "network", + in: []error{&net.AddrError{}}, + want: &errorAuth{ + code: ecNetwork, + }, + }, + { + name: "unknown", + in: []error{errors.New("some err")}, + want: &errorAuth{ + code: ecUnknown, + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + for _, err := range tt.in { + got := parseError(err) + + require.Equal(t, tt.want.Code(), got.Code()) + } + }) + } +} diff --git a/xoauth/token.go b/xoauth/token.go new file mode 100644 index 000000000..b555ca7b2 --- /dev/null +++ b/xoauth/token.go @@ -0,0 +1,10 @@ +package xoauth + +import "time" + +type Token struct { + AccessToken string + TokenType string + Expiry time.Time + Scope string +} diff --git a/xoauth/tokenissuer.go b/xoauth/tokenissuer.go new file mode 100644 index 000000000..f0da47c02 --- /dev/null +++ b/xoauth/tokenissuer.go @@ -0,0 +1,113 @@ +package xoauth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/ozontech/file.d/xtime" + "golang.org/x/oauth2" +) + +const ( + defaultHttpTimeout = 1 * time.Minute + + grantTypeClientCreds = "client_credentials" +) + +type httpTokenIssuer struct { + client *http.Client + cfg *Config +} + +func newHTTPTokenIssuer(cfg *Config) *httpTokenIssuer { + return &httpTokenIssuer{ + client: &http.Client{ + Timeout: defaultHttpTimeout, + }, + cfg: cfg, + } +} + +type tokenJSON struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + Scope string `json:"scope"` +} + +func (ti *httpTokenIssuer) issueToken(ctx context.Context) (*Token, error) { + req, err := newTokenRequest(ctx, ti.cfg) + if err != nil { + return nil, err + } + + resp, err := ti.client.Do(req) + if err != nil { + return nil, err + } + defer func() { + _ = resp.Body.Close() + }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("cannot fetch token: %w", err) + } + + if c := resp.StatusCode; c < http.StatusOK || c >= http.StatusMultipleChoices { + return nil, &oauth2.RetrieveError{ + Response: resp, + Body: body, + } + } + + var data tokenJSON + if err := json.Unmarshal(body, &data); err != nil { + return nil, err + } + + var expiry time.Time + if secs := data.ExpiresIn; secs > 0 { + expiry = xtime.GetInaccurateTime().Add(time.Duration(secs) * time.Second) + } + return &Token{ + AccessToken: data.AccessToken, + TokenType: data.TokenType, + Expiry: expiry, + Scope: data.Scope, + }, nil +} + +func newTokenRequest(ctx context.Context, cfg *Config) (*http.Request, error) { + v := url.Values{} + v.Set("grant_type", grantTypeClientCreds) + if len(cfg.Scopes) > 0 { + v.Set("scope", strings.Join(cfg.Scopes, " ")) + } + + if cfg.AuthStyle == AuthStyleInParams { + v.Set("client_id", cfg.ClientID) + if cfg.ClientSecret != "" { + v.Set("client_secret", cfg.ClientSecret) + } + } + + reqBody := io.NopCloser(strings.NewReader(v.Encode())) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.TokenURL, reqBody) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if cfg.AuthStyle == AuthStyleInHeader { + req.SetBasicAuth(url.QueryEscape(cfg.ClientID), url.QueryEscape(cfg.ClientSecret)) + } + + return req, nil +} diff --git a/xoauth/tokenissuer_test.go b/xoauth/tokenissuer_test.go new file mode 100644 index 000000000..93c4cc2d5 --- /dev/null +++ b/xoauth/tokenissuer_test.go @@ -0,0 +1,122 @@ +package xoauth + +import ( + "context" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewTokenRequest(t *testing.T) { + cases := []struct { + name string + cfg *Config + + wantBodyForm map[string]string + wantHeader map[string]string + }{ + { + name: "auth_in_params", + cfg: &Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + TokenURL: "http://example.com", + AuthStyle: AuthStyleInParams, + }, + wantBodyForm: map[string]string{ + "grant_type": "client_credentials", + "client_id": "test-client", + "client_secret": "test-secret", + }, + wantHeader: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + }, + { + name: "auth_in_params_no_secret", + cfg: &Config{ + ClientID: "test-client", + TokenURL: "http://example.com", + AuthStyle: AuthStyleInParams, + }, + wantBodyForm: map[string]string{ + "grant_type": "client_credentials", + "client_id": "test-client", + }, + wantHeader: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + }, + { + name: "auth_in_header", + cfg: &Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + TokenURL: "http://example.com", + AuthStyle: AuthStyleInHeader, + }, + wantBodyForm: map[string]string{ + "grant_type": "client_credentials", + }, + wantHeader: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": basicAuth("test-client", "test-secret"), + }, + }, + { + name: "auth_in_header_no_secret", + cfg: &Config{ + ClientID: "test-client", + TokenURL: "http://example.com", + AuthStyle: AuthStyleInHeader, + }, + wantBodyForm: map[string]string{ + "grant_type": "client_credentials", + }, + wantHeader: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": basicAuth("test-client", ""), + }, + }, + { + name: "scopes", + cfg: &Config{ + ClientID: "test-client", + TokenURL: "http://example.com", + Scopes: []string{"scp1", "scp2"}, + AuthStyle: AuthStyleInParams, + }, + wantBodyForm: map[string]string{ + "grant_type": "client_credentials", + "client_id": "test-client", + "scope": "scp1 scp2", + }, + wantHeader: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := newTokenRequest(context.Background(), tt.cfg) + require.NoError(t, err) + + for k, v := range tt.wantBodyForm { + require.Equal(t, v, got.PostFormValue(k)) + } + + for k, v := range tt.wantHeader { + require.Equal(t, v, got.Header.Get(k)) + } + }) + } +} + +func basicAuth(user, pass string) string { + auth := user + ":" + pass + return "Basic " + base64.StdEncoding.EncodeToString([]byte(auth)) +} diff --git a/xoauth/tokensource.go b/xoauth/tokensource.go new file mode 100644 index 000000000..b496c1936 --- /dev/null +++ b/xoauth/tokensource.go @@ -0,0 +1,194 @@ +package xoauth + +import ( + "context" + "fmt" + "time" + + "github.com/ozontech/file.d/logger" + "go.uber.org/atomic" +) + +type AuthStyle int + +const ( + AuthStyleUnknown AuthStyle = iota + AuthStyleInParams + AuthStyleInHeader +) + +type Config struct { + ClientID string + ClientSecret string + TokenURL string + Scopes []string + AuthStyle AuthStyle +} + +func (c *Config) validate() error { + if c.ClientID == "" { + return fmt.Errorf("client id must be non-empty") + } + if c.TokenURL == "" { + return fmt.Errorf("token url must be non-empty") + } + if c.AuthStyle == AuthStyleUnknown { + return fmt.Errorf("auth style must be specified") + } + return nil +} + +// tokenIssuer issues token for specific token grant type flow +type tokenIssuer interface { + issueToken(ctx context.Context) (*Token, error) +} + +type TokenSource interface { + Token(ctx context.Context) *Token // read-only + Stop() +} + +type staticTokenSource struct { + t *Token +} + +func NewStaticTokenSource(t *Token) TokenSource { + return &staticTokenSource{ + t: t, + } +} + +func (ts *staticTokenSource) Token(_ context.Context) *Token { + return ts.t +} + +func (ts *staticTokenSource) Stop() {} + +// reuseTokenSource implements lifecycle of auth token refreshing. +// - Once reuseTokenSource is created, first token issuance happens +// - Further Token() calls must be non-blocking +// - Token is updated in the background depending on the expidation date +// - After Close() reuseTokenSource is irreversibly stops all background work +type reuseTokenSource struct { + tokenHolder atomic.Pointer[Token] + tokenIssuer tokenIssuer + + stopCh chan struct{} +} + +func NewReuseTokenSource(ctx context.Context, cfg *Config) (TokenSource, error) { + if err := cfg.validate(); err != nil { + return nil, err + } + + ti := newHTTPTokenIssuer(cfg) + return newReuseTokenSource(ctx, ti) +} + +func newReuseTokenSource(ctx context.Context, ti tokenIssuer) (*reuseTokenSource, error) { + ts := &reuseTokenSource{ + tokenHolder: atomic.Pointer[Token]{}, + tokenIssuer: ti, + + stopCh: make(chan struct{}), + } + + // get first token during initialization to verify provided data + t, err := ti.issueToken(ctx) + if err != nil { + return nil, fmt.Errorf("failed to init token source: %w", err) + } + + ts.tokenHolder.Store(t) + go ts.maintenance(ctx, time.Until(t.Expiry)/2) + + return ts, nil +} + +func (ts *reuseTokenSource) Token(ctx context.Context) *Token { + select { + case <-ctx.Done(): + return nil + default: + } + + return ts.tokenHolder.Load() +} + +func (ts *reuseTokenSource) Stop() { + close(ts.stopCh) +} + +// maintenance runs a token update loop +func (ts *reuseTokenSource) maintenance(ctx context.Context, firstDelay time.Duration) { + scheduler := time.NewTimer(firstDelay) + defer scheduler.Stop() + + // paths to calculate next scheduler delay + success, fail := ts.newDelayer() + + updateToken := func() time.Duration { + t, err := ts.tokenIssuer.issueToken(ctx) + if err != nil { + return fail(parseError(err)) + } + + ts.tokenHolder.Store(t) + return success(time.Until(t.Expiry)) + } + + for { + select { + case <-ts.stopCh: + return + case <-ctx.Done(): + return + case <-scheduler.C: + } + + delay := updateToken() + resetTimer(scheduler, delay) + } +} + +func resetTimer(t *time.Timer, d time.Duration) { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + t.Reset(d) +} + +// newDelayer returns success and failure paths that will be applied to refresh scheduler +func (ts *reuseTokenSource) newDelayer() ( + func(ttl time.Duration) time.Duration, // success + func(err *errorAuth) time.Duration, // failure +) { + expBackoff := exponentialJitterBackoff() + linBackoff := linearJitterBackoff() + attempt := 0 + + success := func(ttl time.Duration) time.Duration { + attempt = 0 + return linBackoff(ttl/3, ttl/2, 0) + } + + failure := func(err *errorAuth) time.Duration { + attempt++ + code := err.Code() + logger.Errorf("error occurred while updating oauth token: attempt=%d, code=%s, error=%s", + attempt, code, err.Error()) + + switch code { + case ecInvalidRequest, ecInvalidClient, ecInvalidGrant, ecInvalidScope, + ecUnauthorizedClient, ecUnsupportedGrantType: + return linBackoff(time.Minute, 10*time.Minute, attempt) + default: + return expBackoff(time.Second, time.Minute, attempt) + } + } + + return success, failure +} diff --git a/xoauth/tokensource_test.go b/xoauth/tokensource_test.go new file mode 100644 index 000000000..78dd2f118 --- /dev/null +++ b/xoauth/tokensource_test.go @@ -0,0 +1,74 @@ +package xoauth + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestReuseTokenSource(t *testing.T) { + ctx := context.Background() + now := time.Now() + + issuer := &mockTokenIssuer{} + + // first token fail + issuer.expect(nil, errors.New("some error")) + _, err := newReuseTokenSource(ctx, issuer) + require.Error(t, err) + + // first token success, save and start maintenance + tok := &Token{AccessToken: "test-token", Expiry: now.Add(time.Second)} + issuer.expect(tok, nil) + source, err := newReuseTokenSource(ctx, issuer) + require.NoError(t, err) + + // check token + require.Equal(t, tok.AccessToken, source.Token(ctx).AccessToken) + + // first update will be after first token expire / 2, + // set next token and wait + tok = &Token{AccessToken: "test-token-1", Expiry: now.Add(2 * time.Second)} + issuer.expect(tok, nil) + time.Sleep(time.Second) + + // check token + require.Equal(t, tok.AccessToken, source.Token(ctx).AccessToken) + + // stop token source and wait next potential update, + // set issuer token which we won't have to get + issuer.expect(&Token{AccessToken: "test-token-2", Expiry: now.Add(3 * time.Second)}, nil) + source.Stop() + time.Sleep(time.Second) + + // check that token is one that was saved before the stop + require.Equal(t, tok.AccessToken, source.Token(ctx).AccessToken) + + // check that Token() with canceled context returns nil + ctx2, cancel := context.WithCancel(ctx) + cancel() + require.Nil(t, source.Token(ctx2)) +} + +type mockTokenIssuer struct { + token *Token + err error + + m sync.RWMutex +} + +func (ti *mockTokenIssuer) issueToken(ctx context.Context) (*Token, error) { + ti.m.RLock() + defer ti.m.RUnlock() + return ti.token, ti.err +} + +func (ti *mockTokenIssuer) expect(t *Token, err error) { + ti.m.Lock() + defer ti.m.Unlock() + ti.token, ti.err = t, err +} diff --git a/xscram/client.go b/xscram/client.go deleted file mode 100644 index 4004e2fd7..000000000 --- a/xscram/client.go +++ /dev/null @@ -1,43 +0,0 @@ -package xscram - -import ( - "crypto/sha256" - "crypto/sha512" - - "github.com/xdg-go/scram" -) - -var ( - SHA256 scram.HashGeneratorFcn = sha256.New - SHA512 scram.HashGeneratorFcn = sha512.New -) - -type Client struct { - *scram.Client - *scram.ClientConversation - scram.HashGeneratorFcn -} - -func NewClient(hashFn scram.HashGeneratorFcn) *Client { - return &Client{ - HashGeneratorFcn: hashFn, - } -} - -func (x *Client) Begin(userName, password, authzID string) error { - var err error - x.Client, err = x.HashGeneratorFcn.NewClient(userName, password, authzID) - if err != nil { - return err - } - x.ClientConversation = x.Client.NewConversation() - return nil -} - -func (x *Client) Step(challenge string) (string, error) { - return x.ClientConversation.Step(challenge) -} - -func (x *Client) Done() bool { - return x.ClientConversation.Done() -}