diff --git a/cfg/kafka_client.go b/cfg/kafka_client.go
index 9b9087f6a..00f83191d 100644
--- a/cfg/kafka_client.go
+++ b/cfg/kafka_client.go
@@ -1,10 +1,14 @@
package cfg
import (
+ "context"
+ "errors"
"os"
+ "github.com/ozontech/file.d/xoauth"
"github.com/twmb/franz-go/pkg/kgo"
"github.com/twmb/franz-go/pkg/sasl/aws"
+ "github.com/twmb/franz-go/pkg/sasl/oauth"
"github.com/twmb/franz-go/pkg/sasl/plain"
"github.com/twmb/franz-go/pkg/sasl/scram"
"github.com/twmb/franz-go/plugin/kzap"
@@ -27,6 +31,8 @@ type KafkaClientSaslConfig struct {
SaslMechanism string
SaslUsername string
SaslPassword string
+
+ SaslOAuth KafkaClientOAuthConfig
}
type KafkaClientSslConfig struct {
@@ -36,7 +42,63 @@ type KafkaClientSslConfig struct {
SslSkipVerify bool
}
-func GetKafkaClientOptions(c KafkaClientConfig, l *zap.Logger) []kgo.Opt {
+type KafkaClientOAuthConfig struct {
+ // static
+ Token string `json:"token"`
+
+ // dynamic
+ ClientID string `json:"client_id"`
+ ClientSecret string `json:"client_secret"`
+ TokenURL string `json:"token_url"`
+ Scopes []string `json:"scopes" slice:"true"`
+ AuthStyle string `json:"auth_style" default:"params" options:"params|header"`
+}
+
+func (c *KafkaClientOAuthConfig) isStatic() bool {
+ return c.Token != ""
+}
+
+func (c *KafkaClientOAuthConfig) isDynamic() bool {
+ return c.ClientID != "" && c.ClientSecret != "" && c.TokenURL != ""
+}
+
+func (c *KafkaClientOAuthConfig) isValid() bool {
+ return c.isStatic() || c.isDynamic()
+}
+
+func GetKafkaClientOAuthTokenSource(ctx context.Context, cfg KafkaClientConfig) (xoauth.TokenSource, error) {
+ saslCfg := cfg.GetSaslConfig()
+
+ if !cfg.IsSaslEnabled() || saslCfg.SaslMechanism != "OAUTHBEARER" {
+ return nil, nil
+ }
+
+ saslOAuth := saslCfg.SaslOAuth
+ if !saslOAuth.isValid() {
+ return nil, errors.New("invalid SASL OAUTHBEARER config")
+ }
+
+ if saslOAuth.isDynamic() {
+ authStyle := xoauth.AuthStyleInParams
+ if saslOAuth.AuthStyle == "header" {
+ authStyle = xoauth.AuthStyleInHeader
+ }
+
+ return xoauth.NewReuseTokenSource(ctx, &xoauth.Config{
+ ClientID: saslOAuth.ClientID,
+ ClientSecret: saslOAuth.ClientSecret,
+ TokenURL: saslOAuth.TokenURL,
+ Scopes: saslOAuth.Scopes,
+ AuthStyle: authStyle,
+ })
+ }
+
+ return xoauth.NewStaticTokenSource(&xoauth.Token{
+ AccessToken: saslOAuth.Token,
+ }), nil
+}
+
+func GetKafkaClientOptions(c KafkaClientConfig, l *zap.Logger, tokenSource xoauth.TokenSource) []kgo.Opt {
opts := []kgo.Opt{
kgo.SeedBrokers(c.GetBrokers()...),
kgo.ClientID(c.GetClientID()),
@@ -66,6 +128,18 @@ func GetKafkaClientOptions(c KafkaClientConfig, l *zap.Logger) []kgo.Opt {
AccessKey: saslConfig.SaslUsername,
SecretKey: saslConfig.SaslPassword,
}.AsManagedStreamingIAMMechanism()))
+ case "OAUTHBEARER":
+ authFn := func(ctx context.Context) (oauth.Auth, error) {
+ if tokenSource == nil {
+ return oauth.Auth{}, errors.New("uninitialized token source")
+ }
+ t := tokenSource.Token(ctx)
+ if t == nil {
+ return oauth.Auth{}, errors.New("empty token from token source")
+ }
+ return oauth.Auth{Token: t.AccessToken}, nil
+ }
+ opts = append(opts, kgo.SASL(oauth.Oauth(authFn)))
}
}
diff --git a/e2e/kafka_auth/kafka_auth.go b/e2e/kafka_auth/kafka_auth.go
index 43fe5409f..db2839aa7 100644
--- a/e2e/kafka_auth/kafka_auth.go
+++ b/e2e/kafka_auth/kafka_auth.go
@@ -124,8 +124,9 @@ func (c *Config) Configure(t *testing.T, _ *cfg.Config, _ string) {
config.ClientCert = "./kafka_auth/certs/client_cert.pem"
}
- kafka_out.NewClient(config,
+ kafka_out.NewClient(context.Background(), config,
zap.NewNop().WithOptions(zap.WithFatalHook(zapcore.WriteThenPanic)),
+ nil,
)
},
func() {
@@ -154,9 +155,9 @@ func (c *Config) Configure(t *testing.T, _ *cfg.Config, _ string) {
config.ClientCert = "./kafka_auth/certs/client_cert.pem"
}
- kafka_in.NewClient(config,
+ kafka_in.NewClient(context.Background(), config,
zap.NewNop().WithOptions(zap.WithFatalHook(zapcore.WriteThenPanic)),
- Consumer{},
+ Consumer{}, nil,
)
},
}
diff --git a/e2e/kafka_file/kafka_file.go b/e2e/kafka_file/kafka_file.go
index 3a32e906d..60ac47d1d 100644
--- a/e2e/kafka_file/kafka_file.go
+++ b/e2e/kafka_file/kafka_file.go
@@ -52,8 +52,9 @@ func (c *Config) Send(t *testing.T) {
BatchSize_: c.Count,
}
- client := kafka_out.NewClient(config,
+ client := kafka_out.NewClient(context.Background(), config,
zap.NewNop().WithOptions(zap.WithFatalHook(zapcore.WriteThenPanic)),
+ nil,
)
adminClient := kadm.NewClient(client)
_, err := adminClient.CreateTopic(context.TODO(), 1, 1, nil, c.Topics[0])
diff --git a/e2e/split_join/split_join.go b/e2e/split_join/split_join.go
index cbd119bb0..652223629 100644
--- a/e2e/split_join/split_join.go
+++ b/e2e/split_join/split_join.go
@@ -63,9 +63,9 @@ func (c *Config) Configure(t *testing.T, conf *cfg.Config, pipelineName string)
HeartbeatInterval_: 10 * time.Second,
}
- c.client = kafka_in.NewClient(config,
+ c.client = kafka_in.NewClient(context.Background(), config,
zap.NewNop().WithOptions(zap.WithFatalHook(zapcore.WriteThenPanic)),
- Consumer{},
+ Consumer{}, nil,
)
adminClient := kadm.NewClient(c.client)
diff --git a/go.mod b/go.mod
index 01f77cf23..a16ab139c 100644
--- a/go.mod
+++ b/go.mod
@@ -45,11 +45,11 @@ require (
github.com/twmb/franz-go/plugin/kzap v1.1.2
github.com/twmb/tlscfg v1.2.1
github.com/valyala/fasthttp v1.48.0
- github.com/xdg-go/scram v1.1.2
go.uber.org/atomic v1.11.0
go.uber.org/automaxprocs v1.5.3
go.uber.org/zap v1.27.0
golang.org/x/net v0.47.0
+ golang.org/x/oauth2 v0.27.0
google.golang.org/protobuf v1.36.5
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.1
@@ -131,8 +131,6 @@ require (
github.com/twmb/franz-go/pkg/kmsg v1.12.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
- github.com/xdg-go/pbkdf2 v1.0.0 // indirect
- github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel v1.34.0 // indirect
@@ -143,7 +141,6 @@ require (
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/mod v0.29.0 // indirect
- golang.org/x/oauth2 v0.27.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/term v0.37.0 // indirect
diff --git a/go.sum b/go.sum
index feff277bf..efebc884c 100644
--- a/go.sum
+++ b/go.sum
@@ -363,12 +363,6 @@ github.com/valyala/fasthttp v1.48.0 h1:oJWvHb9BIZToTQS3MuQ2R3bJZiNSa2KiNdeI8A+79
github.com/valyala/fasthttp v1.48.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
-github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
-github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
-github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
-github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
-github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
-github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -499,7 +493,6 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
-golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
diff --git a/plugin/input/kafka/README.md b/plugin/input/kafka/README.md
index 56885602c..2350a456e 100755
--- a/plugin/input/kafka/README.md
+++ b/plugin/input/kafka/README.md
@@ -140,7 +140,7 @@ If set, the plugin will use SASL authentications mechanism.
-**`sasl_mechanism`** *`string`* *`default=SCRAM-SHA-512`* *`options=PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM`*
+**`sasl_mechanism`** *`string`* *`default=SCRAM-SHA-512`* *`options=PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM|OAUTHBEARER`*
SASL mechanism to use.
@@ -158,6 +158,22 @@ SASL password.
+**`sasl_oauth`** *`cfg.KafkaClientOAuthConfig`*
+
+SASL OAUTHBEARER config. It works only if `sasl_mechanism:"OAUTHBEARER"`.
+> There are 2 options - a static token or a dynamically updated.
+
+`OAuthConfig` params:
+* **`token`** *`string`* - static token
+---
+* **`client_id`** *`string`* - client ID
+* **`client_secret`** *`string`* - client secret
+* **`token_url`** *`string`* - resource server's token endpoint URL
+* **`scopes`** *`string`* - optional requested permissions
+* **`auth_style`** *`string`* - specifies how the endpoint wants the client ID & client secret sent
+
+
+
**`is_ssl_enabled`** *`bool`* *`default=false`*
If set, the plugin will use SSL/TLS connections method.
diff --git a/plugin/input/kafka/client.go b/plugin/input/kafka/client.go
index f316ad2ce..a570aeca8 100644
--- a/plugin/input/kafka/client.go
+++ b/plugin/input/kafka/client.go
@@ -5,12 +5,13 @@ import (
"time"
"github.com/ozontech/file.d/cfg"
+ "github.com/ozontech/file.d/xoauth"
"github.com/twmb/franz-go/pkg/kgo"
"go.uber.org/zap"
)
-func NewClient(c *Config, l *zap.Logger, s Consumer) *kgo.Client {
- opts := cfg.GetKafkaClientOptions(c, l)
+func NewClient(ctx context.Context, c *Config, l *zap.Logger, s Consumer, tokenSource xoauth.TokenSource) *kgo.Client {
+ opts := cfg.GetKafkaClientOptions(c, l, tokenSource)
opts = append(opts, []kgo.Opt{
kgo.ConsumerGroup(c.ConsumerGroup),
kgo.ConsumeTopics(c.Topics...),
@@ -53,10 +54,10 @@ func NewClient(c *Config, l *zap.Logger, s Consumer) *kgo.Client {
l.Fatal("can't create kafka client", zap.Error(err))
}
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ pingCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
- err = client.Ping(ctx)
+ err = client.Ping(pingCtx)
if err != nil {
l.Fatal("can't connect to kafka", zap.Error(err))
}
diff --git a/plugin/input/kafka/kafka.go b/plugin/input/kafka/kafka.go
index 6e7c0e068..15ccfc21b 100644
--- a/plugin/input/kafka/kafka.go
+++ b/plugin/input/kafka/kafka.go
@@ -9,6 +9,7 @@ import (
"github.com/ozontech/file.d/metric"
"github.com/ozontech/file.d/pipeline"
"github.com/ozontech/file.d/pipeline/metadata"
+ "github.com/ozontech/file.d/xoauth"
"github.com/prometheus/client_golang/prometheus"
"github.com/twmb/franz-go/pkg/kgo"
"go.uber.org/zap"
@@ -47,19 +48,18 @@ pipelines:
type Plugin struct {
config *Config
- logger *zap.SugaredLogger
- client *kgo.Client
+ logger *zap.Logger
cancel context.CancelFunc
controller pipeline.InputPluginController
- idByTopic map[string]int
+
+ client *kgo.Client
+ s *splitConsume
+ tokenSource xoauth.TokenSource
+ metaTemplater *metadata.MetaTemplater
// plugin metrics
commitErrorsMetric prometheus.Counter
consumeErrorsMetric prometheus.Counter
-
- metaTemplater *metadata.MetaTemplater
-
- s *splitConsume
}
type OffsetType byte
@@ -177,7 +177,7 @@ type Config struct {
// > @3@4@5@6
// >
// > SASL mechanism to use.
- SaslMechanism string `json:"sasl_mechanism" default:"SCRAM-SHA-512" options:"PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM"` // *
+ SaslMechanism string `json:"sasl_mechanism" default:"SCRAM-SHA-512" options:"PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM|OAUTHBEARER"` // *
// > @3@4@5@6
// >
@@ -189,6 +189,21 @@ type Config struct {
// > SASL password.
SaslPassword string `json:"sasl_password" default:"password"` // *
+ // > @3@4@5@6
+ // >
+ // > SASL OAUTHBEARER config. It works only if `sasl_mechanism:"OAUTHBEARER"`.
+ // >> There are 2 options - a static token or a dynamically updated.
+ // >
+ // > `OAuthConfig` params:
+ // > * **`token`** *`string`* - static token
+ // > ---
+ // > * **`client_id`** *`string`* - client ID
+ // > * **`client_secret`** *`string`* - client secret
+ // > * **`token_url`** *`string`* - resource server's token endpoint URL
+ // > * **`scopes`** *`string`* - optional requested permissions
+ // > * **`auth_style`** *`string`* - specifies how the endpoint wants the client ID & client secret sent
+ SaslOAuth cfg.KafkaClientOAuthConfig `json:"sasl_oauth" child:"true"` // *
+
// > @3@4@5@6
// >
// > If set, the plugin will use SSL/TLS connections method.
@@ -242,6 +257,7 @@ func (c *Config) GetSaslConfig() cfg.KafkaClientSaslConfig {
SaslMechanism: c.SaslMechanism,
SaslUsername: c.SaslUsername,
SaslPassword: c.SaslPassword,
+ SaslOAuth: c.SaslOAuth,
}
}
@@ -271,36 +287,45 @@ func Factory() (pipeline.AnyPlugin, pipeline.AnyConfig) {
func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.InputPluginParams) {
p.controller = params.Controller
- p.logger = params.Logger
+ p.logger = params.Logger.Desugar()
p.config = config.(*Config)
p.registerMetrics(params.MetricCtl)
if len(p.config.Meta) > 0 {
p.metaTemplater = metadata.NewMetaTemplater(
p.config.Meta,
- p.logger.Desugar(),
+ p.logger,
params.PipelineSettings.MetaCacheSize,
)
}
- p.idByTopic = make(map[string]int, len(p.config.Topics))
+ idByTopic := make(map[string]int, len(p.config.Topics))
for i, topic := range p.config.Topics {
- p.idByTopic[topic] = i
+ idByTopic[topic] = i
}
ctx, cancel := context.WithCancel(context.Background())
p.cancel = cancel
+
p.s = &splitConsume{
consumers: make(map[tp]*pconsumer),
bufferSize: p.config.ChannelBufferSize,
maxConcurrentConsumers: p.config.MaxConcurrentConsumers,
- idByTopic: p.idByTopic,
+ idByTopic: idByTopic,
controller: p.controller,
- logger: p.logger.Desugar(),
+ logger: p.logger,
metaTemplater: p.metaTemplater,
consumeErrorsMetric: p.consumeErrorsMetric,
}
- p.client = NewClient(p.config, p.logger.Desugar(), p.s)
+
+ var err error
+ p.tokenSource, err = cfg.GetKafkaClientOAuthTokenSource(ctx, p.config)
+ if err != nil {
+ p.logger.Fatal(err.Error())
+ }
+
+ p.client = NewClient(ctx, p.config, p.logger, p.s, p.tokenSource)
+
p.controller.UseSpread()
p.controller.DisableStreams()
@@ -313,13 +338,16 @@ func (p *Plugin) registerMetrics(ctl *metric.Ctl) {
}
func (p *Plugin) Stop() {
- p.logger.Infof("Stopping")
+ p.logger.Info("Stopping")
err := p.client.CommitMarkedOffsets(context.Background())
if err != nil {
p.commitErrorsMetric.Inc()
- p.logger.Errorf("can't commit marked offsets: %s", err.Error())
+ p.logger.Error("can't commit marked offsets", zap.Error(err))
}
p.client.Close()
+ if p.tokenSource != nil {
+ p.tokenSource.Stop()
+ }
p.cancel()
}
diff --git a/plugin/output/kafka/README.md b/plugin/output/kafka/README.md
index d1ef51b20..84cff1058 100755
--- a/plugin/output/kafka/README.md
+++ b/plugin/output/kafka/README.md
@@ -115,7 +115,7 @@ If set, the plugin will use SASL authentications mechanism.
-**`sasl_mechanism`** *`string`* *`default=SCRAM-SHA-512`* *`options=PLAIN|SCRAM-SHA-256|SCRAM-SHA-512`*
+**`sasl_mechanism`** *`string`* *`default=SCRAM-SHA-512`* *`options=PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM|OAUTHBEARER`*
SASL mechanism to use.
@@ -133,6 +133,22 @@ SASL password.
+**`sasl_oauth`** *`cfg.KafkaClientOAuthConfig`*
+
+SASL OAUTHBEARER config. It works only if `sasl_mechanism:"OAUTHBEARER"`.
+> There are 2 options - a static token or a dynamically updated.
+
+`OAuthConfig` params:
+* **`token`** *`string`* - static token
+---
+* **`client_id`** *`string`* - client ID
+* **`client_secret`** *`string`* - client secret
+* **`token_url`** *`string`* - resource server's token endpoint URL
+* **`scopes`** *`[]string`* - optional requested permissions
+* **`auth_style`** *`string`* - specifies how the endpoint wants the client ID & client secret sent
+
+
+
**`is_ssl_enabled`** *`bool`* *`default=false`*
If set, the plugin will use SSL/TLS connections method.
diff --git a/plugin/output/kafka/client.go b/plugin/output/kafka/client.go
index 9f2853537..4869e3f86 100644
--- a/plugin/output/kafka/client.go
+++ b/plugin/output/kafka/client.go
@@ -5,6 +5,7 @@ import (
"time"
"github.com/ozontech/file.d/cfg"
+ "github.com/ozontech/file.d/xoauth"
"github.com/twmb/franz-go/pkg/kgo"
"go.uber.org/zap"
)
@@ -15,8 +16,8 @@ type KafkaClient interface {
ForceMetadataRefresh()
}
-func NewClient(c *Config, l *zap.Logger) *kgo.Client {
- opts := cfg.GetKafkaClientOptions(c, l)
+func NewClient(ctx context.Context, c *Config, l *zap.Logger, tokenSource xoauth.TokenSource) *kgo.Client {
+ opts := cfg.GetKafkaClientOptions(c, l, tokenSource)
opts = append(opts, []kgo.Opt{
kgo.DefaultProduceTopic(c.DefaultTopic),
kgo.MaxBufferedRecords(c.BatchSize_),
@@ -54,10 +55,10 @@ func NewClient(c *Config, l *zap.Logger) *kgo.Client {
l.Fatal("can't create kafka client", zap.Error(err))
}
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ pingCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
- err = client.Ping(ctx)
+ err = client.Ping(pingCtx)
if err != nil {
l.Fatal("can't connect to kafka", zap.Error(err))
}
diff --git a/plugin/output/kafka/kafka.go b/plugin/output/kafka/kafka.go
index 213ff57ec..69245eed1 100644
--- a/plugin/output/kafka/kafka.go
+++ b/plugin/output/kafka/kafka.go
@@ -3,12 +3,14 @@ package kafka
import (
"context"
"errors"
+ "fmt"
"time"
"github.com/ozontech/file.d/cfg"
"github.com/ozontech/file.d/fd"
"github.com/ozontech/file.d/metric"
"github.com/ozontech/file.d/pipeline"
+ "github.com/ozontech/file.d/xoauth"
"github.com/prometheus/client_golang/prometheus"
"github.com/twmb/franz-go/pkg/kerr"
"github.com/twmb/franz-go/pkg/kgo"
@@ -32,13 +34,15 @@ type data struct {
}
type Plugin struct {
- logger *zap.SugaredLogger
+ logger *zap.Logger
config *Config
avgEventSize int
controller pipeline.OutputPluginController
+ cancel context.CancelFunc
- client KafkaClient
- batcher *pipeline.RetriableBatcher
+ client KafkaClient
+ batcher *pipeline.RetriableBatcher
+ tokenSource xoauth.TokenSource
// plugin metrics
sendErrorMetric prometheus.Counter
@@ -151,7 +155,7 @@ type Config struct {
// > @3@4@5@6
// >
// > SASL mechanism to use.
- SaslMechanism string `json:"sasl_mechanism" default:"SCRAM-SHA-512" options:"PLAIN|SCRAM-SHA-256|SCRAM-SHA-512"` // *
+ SaslMechanism string `json:"sasl_mechanism" default:"SCRAM-SHA-512" options:"PLAIN|SCRAM-SHA-256|SCRAM-SHA-512|AWS_MSK_IAM|OAUTHBEARER"` // *
// > @3@4@5@6
// >
@@ -163,6 +167,21 @@ type Config struct {
// > SASL password.
SaslPassword string `json:"sasl_password" default:"password"` // *
+ // > @3@4@5@6
+ // >
+ // > SASL OAUTHBEARER config. It works only if `sasl_mechanism:"OAUTHBEARER"`.
+ // >> There are 2 options - a static token or a dynamically updated.
+ // >
+ // > `OAuthConfig` params:
+ // > * **`token`** *`string`* - static token
+ // > ---
+ // > * **`client_id`** *`string`* - client ID
+ // > * **`client_secret`** *`string`* - client secret
+ // > * **`token_url`** *`string`* - resource server's token endpoint URL
+ // > * **`scopes`** *`[]string`* - optional requested permissions
+ // > * **`auth_style`** *`string`* - specifies how the endpoint wants the client ID & client secret sent
+ SaslOAuth cfg.KafkaClientOAuthConfig `json:"sasl_oauth" child:"true"` // *
+
// > @3@4@5@6
// >
// > If set, the plugin will use SSL/TLS connections method.
@@ -206,6 +225,7 @@ func (c *Config) GetSaslConfig() cfg.KafkaClientSaslConfig {
SaslMechanism: c.SaslMechanism,
SaslUsername: c.SaslUsername,
SaslPassword: c.SaslPassword,
+ SaslOAuth: c.SaslOAuth,
}
}
@@ -235,7 +255,7 @@ func Factory() (pipeline.AnyPlugin, pipeline.AnyConfig) {
func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginParams) {
p.config = config.(*Config)
- p.logger = params.Logger
+ p.logger = params.Logger.Desugar()
p.avgEventSize = params.PipelineSettings.AvgEventSize
p.controller = params.Controller
p.registerMetrics(params.MetricCtl)
@@ -244,9 +264,18 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP
p.logger.Fatal("'retention' can't be <1")
}
- p.logger.Infof("workers count=%d, batch size=%d", p.config.WorkersCount_, p.config.BatchSize_)
+ p.logger.Info(fmt.Sprintf("workers count=%d, batch size=%d", p.config.WorkersCount_, p.config.BatchSize_))
- p.client = NewClient(p.config, p.logger.Desugar())
+ ctx, cancel := context.WithCancel(context.Background())
+ p.cancel = cancel
+
+ var err error
+ p.tokenSource, err = cfg.GetKafkaClientOAuthTokenSource(ctx, p.config)
+ if err != nil {
+ p.logger.Fatal(err.Error())
+ }
+
+ p.client = NewClient(ctx, p.config, p.logger, p.tokenSource)
batcherOpts := pipeline.BatcherOptions{
PipelineName: params.PipelineName,
@@ -275,7 +304,7 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP
level = zapcore.ErrorLevel
}
- p.logger.Desugar().Log(level, "can't write batch",
+ p.logger.Log(level, "can't write batch",
zap.Int("retries", p.config.Retry),
)
@@ -343,7 +372,7 @@ func (p *Plugin) out(workerData *pipeline.WorkerData, batch *pipeline.Batch) err
if errors.Is(err, kerr.LeaderNotAvailable) || errors.Is(err, kerr.NotLeaderForPartition) {
p.client.ForceMetadataRefresh()
}
- p.logger.Errorf("can't write batch: %v", err)
+ p.logger.Error("can't write batch", zap.Error(err))
p.sendErrorMetric.Inc()
return err
}
@@ -354,4 +383,8 @@ func (p *Plugin) out(workerData *pipeline.WorkerData, batch *pipeline.Batch) err
func (p *Plugin) Stop() {
p.batcher.Stop()
p.client.Close()
+ if p.tokenSource != nil {
+ p.tokenSource.Stop()
+ }
+ p.cancel()
}
diff --git a/plugin/output/kafka/kafka_test.go b/plugin/output/kafka/kafka_test.go
index a1a5e9b5a..7061b750b 100644
--- a/plugin/output/kafka/kafka_test.go
+++ b/plugin/output/kafka/kafka_test.go
@@ -62,7 +62,7 @@ func FuzzKafka(f *testing.F) {
}
worker := pipeline.WorkerData(nil)
- logger := zaptest.NewLogger(f).Sugar()
+ logger := zaptest.NewLogger(f)
p := Plugin{
logger: logger,
config: &config,
diff --git a/xoauth/backoff.go b/xoauth/backoff.go
new file mode 100644
index 000000000..d0a7c017c
--- /dev/null
+++ b/xoauth/backoff.go
@@ -0,0 +1,72 @@
+package xoauth
+
+import (
+ "math"
+ "math/rand"
+ "time"
+)
+
+// exponentialJitterBackoff provides a callback which will
+// perform en exponential backoff based on the attempt number and with jitter to
+// prevent a thundering herd.
+//
+// min and max here are *not* absolute values. The number to be multiplied by
+// the attempt number will be chosen at random from between them, thus they are
+// bounding the jitter.
+func exponentialJitterBackoff() func(min, max time.Duration, attemptNum int) time.Duration {
+ rnd := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
+
+ return func(min, max time.Duration, attemptNum int) time.Duration {
+ minf := float64(min)
+ mult := math.Pow(2, float64(attemptNum)) * minf
+
+ jitter := rnd.Float64() * (mult - minf)
+
+ sleepf := mult + jitter
+ maxf := float64(max)
+ if sleepf > maxf {
+ sleepf = maxf
+ }
+
+ return time.Duration(sleepf)
+ }
+}
+
+// linearJitterBackoff will perform linear backoff based on the attempt number and with jitter to
+// prevent a thundering herd.
+//
+// min and max here are *not* absolute values. The number to be multiplied by
+// the attempt number will be chosen at random from between them, thus they are
+// bounding the jitter.
+//
+// For instance:
+// - To get strictly linear backoff of one second increasing each retry, set
+// both to one second (1s, 2s, 3s, 4s, ...)
+// - To get a small amount of jitter centered around one second increasing each
+// retry, set to around one second, such as a min of 800ms and max of 1200ms
+// (892ms, 2102ms, 2945ms, 4312ms, ...)
+// - To get extreme jitter, set to a very wide spread, such as a min of 100ms
+// and a max of 20s (15382ms, 292ms, 51321ms, 35234ms, ...)
+func linearJitterBackoff() func(min, max time.Duration, attemptNum int) time.Duration {
+ rnd := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
+
+ return func(min, max time.Duration, attemptNum int) time.Duration {
+ // attemptNum always starts at zero but we want to start at 1 for multiplication
+ attemptNum++
+
+ if max <= min {
+ // Unclear what to do here, or they are the same, so return min *
+ // attemptNum
+ return min * time.Duration(attemptNum)
+ }
+
+ // Pick a random number that lies somewhere between the min and max and
+ // multiply by the attemptNum. attemptNum starts at zero so we always
+ // increment here. We first get a random percentage, then apply that to the
+ // difference between min and max, and add to min.
+ jitter := rnd.Float64() * float64(max-min)
+
+ jitterMin := int64(jitter) + int64(min)
+ return time.Duration(jitterMin * int64(attemptNum))
+ }
+}
diff --git a/xoauth/backoff_test.go b/xoauth/backoff_test.go
new file mode 100644
index 000000000..c0091e1fe
--- /dev/null
+++ b/xoauth/backoff_test.go
@@ -0,0 +1,94 @@
+package xoauth
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestExponentialJitterBackoff(t *testing.T) {
+ cases := []struct {
+ name string
+ min time.Duration
+ max time.Duration
+ attempts int
+ }{
+ {
+ name: "1s_1m_10",
+ min: time.Second,
+ max: time.Minute,
+ attempts: 10,
+ },
+ {
+ name: "1m_1h_10",
+ min: time.Minute,
+ max: time.Hour,
+ attempts: 10,
+ },
+ {
+ name: "1s_1m_1000",
+ min: time.Second,
+ max: time.Minute,
+ attempts: 1000,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ backoff := exponentialJitterBackoff()
+
+ for attempt := range tt.attempts {
+ got := backoff(tt.min, tt.max, attempt)
+ require.GreaterOrEqual(t, got, tt.min)
+ require.LessOrEqual(t, got, tt.max)
+ }
+ })
+ }
+}
+
+func TestLinearJitterBackoff(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ name string
+ min time.Duration
+ max time.Duration
+ attempts int
+ }{
+ {
+ name: "1s_1m_60",
+ min: time.Second,
+ max: time.Minute,
+ attempts: 60,
+ },
+ {
+ name: "1s_1s_60",
+ min: time.Second,
+ max: time.Second,
+ attempts: 60,
+ },
+ {
+ name: "1s_2s_1000",
+ min: time.Second,
+ max: time.Second,
+ attempts: 1000,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ backoff := linearJitterBackoff()
+
+ for attempt := range tt.attempts {
+ got := backoff(tt.min, tt.max, attempt)
+ require.GreaterOrEqual(t, got, tt.min)
+ require.LessOrEqual(t, got, int64(tt.max)*int64(attempt+1))
+ }
+ })
+ }
+}
diff --git a/xoauth/error.go b/xoauth/error.go
new file mode 100644
index 000000000..ddcada7c0
--- /dev/null
+++ b/xoauth/error.go
@@ -0,0 +1,90 @@
+package xoauth
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "mime"
+ "net"
+ "net/http"
+
+ "golang.org/x/oauth2"
+)
+
+// error fields: https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
+type errorJSON struct {
+ ErrorCode string `json:"error"`
+ ErrorDescription string `json:"error_description,omitempty"`
+}
+
+const (
+ // issuer errors
+ ecInvalidRequest = "invalid_request"
+ ecInvalidClient = "invalid_client"
+ ecInvalidGrant = "invalid_grant"
+ ecInvalidScope = "invalid_scope"
+ ecUnauthorizedClient = "unauthorized_client"
+ ecUnsupportedGrantType = "unsupported_grant_type"
+
+ ecTimeout = "timeout"
+ ecNetwork = "network_error"
+ ecRateLimit = "rate_limit"
+
+ ecUnknownServer = "unknown_server"
+ ecUnknown = "unknown"
+)
+
+type errorAuth struct {
+ code string
+ message string
+ cause error
+}
+
+func (e *errorAuth) Error() string { return e.message }
+func (e *errorAuth) Code() string { return e.code }
+func (e *errorAuth) Unwrap() error { return e.cause }
+func (e *errorAuth) Is(target error) bool { return errors.Is(target, e.cause) }
+
+func wrapErr(err error, code, prefix string) *errorAuth {
+ return &errorAuth{
+ code: code,
+ message: fmt.Sprintf("%s: %v", prefix, err),
+
+ cause: err,
+ }
+}
+
+func parseError(err error) *errorAuth {
+ var retrieveErr *oauth2.RetrieveError
+ if errors.As(err, &retrieveErr) {
+ return parseRetrieveError(err, retrieveErr)
+ }
+
+ if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
+ return wrapErr(err, ecTimeout, "timeout")
+ }
+
+ var netError net.Error
+ if errors.As(err, &netError) {
+ return wrapErr(netError, ecNetwork, "network error")
+ }
+
+ return wrapErr(err, ecUnknown, "unknown error")
+}
+
+func parseRetrieveError(err error, retrieveErr *oauth2.RetrieveError) *errorAuth {
+ content, _, _ := mime.ParseMediaType(retrieveErr.Response.Header.Get("Content-Type"))
+ if content == "application/json" {
+ var errJson errorJSON
+ if err = json.Unmarshal(retrieveErr.Body, &errJson); err == nil {
+ return wrapErr(retrieveErr, errJson.ErrorCode, errJson.ErrorDescription)
+ }
+ }
+
+ if retrieveErr.Response.StatusCode == http.StatusTooManyRequests {
+ return wrapErr(err, ecRateLimit, "rate limit")
+ }
+
+ return wrapErr(err, ecUnknownServer, err.Error())
+}
diff --git a/xoauth/error_test.go b/xoauth/error_test.go
new file mode 100644
index 000000000..cad275634
--- /dev/null
+++ b/xoauth/error_test.go
@@ -0,0 +1,104 @@
+package xoauth
+
+import (
+ "context"
+ "errors"
+ "net"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/oauth2"
+)
+
+func TestParseError(t *testing.T) {
+ cases := []struct {
+ name string
+ in []error
+ want *errorAuth
+ }{
+ {
+ name: "retrieve_err_json",
+ in: []error{
+ &oauth2.RetrieveError{
+ Response: &http.Response{
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ },
+ },
+ Body: []byte(`{"error":"invalid_client","error_description":"some err description"}`),
+ },
+ },
+ want: &errorAuth{
+ code: ecInvalidClient,
+ },
+ },
+ {
+ name: "retrieve_err_json_unmarshal_err",
+ in: []error{
+ &oauth2.RetrieveError{
+ Response: &http.Response{
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ },
+ },
+ Body: []byte(`invalid json`),
+ },
+ &oauth2.RetrieveError{
+ Response: &http.Response{
+ StatusCode: http.StatusBadRequest,
+ },
+ },
+ },
+ want: &errorAuth{
+ code: ecUnknownServer,
+ },
+ },
+ {
+ name: "retrieve_err_rate_limit",
+ in: []error{
+ &oauth2.RetrieveError{
+ Response: &http.Response{
+ StatusCode: http.StatusTooManyRequests,
+ },
+ },
+ },
+ want: &errorAuth{
+ code: ecRateLimit,
+ },
+ },
+ {
+ name: "timeout",
+ in: []error{context.DeadlineExceeded, context.Canceled},
+ want: &errorAuth{
+ code: ecTimeout,
+ },
+ },
+ {
+ name: "network",
+ in: []error{&net.AddrError{}},
+ want: &errorAuth{
+ code: ecNetwork,
+ },
+ },
+ {
+ name: "unknown",
+ in: []error{errors.New("some err")},
+ want: &errorAuth{
+ code: ecUnknown,
+ },
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ for _, err := range tt.in {
+ got := parseError(err)
+
+ require.Equal(t, tt.want.Code(), got.Code())
+ }
+ })
+ }
+}
diff --git a/xoauth/token.go b/xoauth/token.go
new file mode 100644
index 000000000..b555ca7b2
--- /dev/null
+++ b/xoauth/token.go
@@ -0,0 +1,10 @@
+package xoauth
+
+import "time"
+
+type Token struct {
+ AccessToken string
+ TokenType string
+ Expiry time.Time
+ Scope string
+}
diff --git a/xoauth/tokenissuer.go b/xoauth/tokenissuer.go
new file mode 100644
index 000000000..f0da47c02
--- /dev/null
+++ b/xoauth/tokenissuer.go
@@ -0,0 +1,113 @@
+package xoauth
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/ozontech/file.d/xtime"
+ "golang.org/x/oauth2"
+)
+
+const (
+ defaultHttpTimeout = 1 * time.Minute
+
+ grantTypeClientCreds = "client_credentials"
+)
+
+type httpTokenIssuer struct {
+ client *http.Client
+ cfg *Config
+}
+
+func newHTTPTokenIssuer(cfg *Config) *httpTokenIssuer {
+ return &httpTokenIssuer{
+ client: &http.Client{
+ Timeout: defaultHttpTimeout,
+ },
+ cfg: cfg,
+ }
+}
+
+type tokenJSON struct {
+ AccessToken string `json:"access_token"`
+ TokenType string `json:"token_type"`
+ ExpiresIn int64 `json:"expires_in"`
+ Scope string `json:"scope"`
+}
+
+func (ti *httpTokenIssuer) issueToken(ctx context.Context) (*Token, error) {
+ req, err := newTokenRequest(ctx, ti.cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := ti.client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ _ = resp.Body.Close()
+ }()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ if err != nil {
+ return nil, fmt.Errorf("cannot fetch token: %w", err)
+ }
+
+ if c := resp.StatusCode; c < http.StatusOK || c >= http.StatusMultipleChoices {
+ return nil, &oauth2.RetrieveError{
+ Response: resp,
+ Body: body,
+ }
+ }
+
+ var data tokenJSON
+ if err := json.Unmarshal(body, &data); err != nil {
+ return nil, err
+ }
+
+ var expiry time.Time
+ if secs := data.ExpiresIn; secs > 0 {
+ expiry = xtime.GetInaccurateTime().Add(time.Duration(secs) * time.Second)
+ }
+ return &Token{
+ AccessToken: data.AccessToken,
+ TokenType: data.TokenType,
+ Expiry: expiry,
+ Scope: data.Scope,
+ }, nil
+}
+
+func newTokenRequest(ctx context.Context, cfg *Config) (*http.Request, error) {
+ v := url.Values{}
+ v.Set("grant_type", grantTypeClientCreds)
+ if len(cfg.Scopes) > 0 {
+ v.Set("scope", strings.Join(cfg.Scopes, " "))
+ }
+
+ if cfg.AuthStyle == AuthStyleInParams {
+ v.Set("client_id", cfg.ClientID)
+ if cfg.ClientSecret != "" {
+ v.Set("client_secret", cfg.ClientSecret)
+ }
+ }
+
+ reqBody := io.NopCloser(strings.NewReader(v.Encode()))
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.TokenURL, reqBody)
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ if cfg.AuthStyle == AuthStyleInHeader {
+ req.SetBasicAuth(url.QueryEscape(cfg.ClientID), url.QueryEscape(cfg.ClientSecret))
+ }
+
+ return req, nil
+}
diff --git a/xoauth/tokenissuer_test.go b/xoauth/tokenissuer_test.go
new file mode 100644
index 000000000..93c4cc2d5
--- /dev/null
+++ b/xoauth/tokenissuer_test.go
@@ -0,0 +1,122 @@
+package xoauth
+
+import (
+ "context"
+ "encoding/base64"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewTokenRequest(t *testing.T) {
+ cases := []struct {
+ name string
+ cfg *Config
+
+ wantBodyForm map[string]string
+ wantHeader map[string]string
+ }{
+ {
+ name: "auth_in_params",
+ cfg: &Config{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ TokenURL: "http://example.com",
+ AuthStyle: AuthStyleInParams,
+ },
+ wantBodyForm: map[string]string{
+ "grant_type": "client_credentials",
+ "client_id": "test-client",
+ "client_secret": "test-secret",
+ },
+ wantHeader: map[string]string{
+ "Content-Type": "application/x-www-form-urlencoded",
+ },
+ },
+ {
+ name: "auth_in_params_no_secret",
+ cfg: &Config{
+ ClientID: "test-client",
+ TokenURL: "http://example.com",
+ AuthStyle: AuthStyleInParams,
+ },
+ wantBodyForm: map[string]string{
+ "grant_type": "client_credentials",
+ "client_id": "test-client",
+ },
+ wantHeader: map[string]string{
+ "Content-Type": "application/x-www-form-urlencoded",
+ },
+ },
+ {
+ name: "auth_in_header",
+ cfg: &Config{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ TokenURL: "http://example.com",
+ AuthStyle: AuthStyleInHeader,
+ },
+ wantBodyForm: map[string]string{
+ "grant_type": "client_credentials",
+ },
+ wantHeader: map[string]string{
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": basicAuth("test-client", "test-secret"),
+ },
+ },
+ {
+ name: "auth_in_header_no_secret",
+ cfg: &Config{
+ ClientID: "test-client",
+ TokenURL: "http://example.com",
+ AuthStyle: AuthStyleInHeader,
+ },
+ wantBodyForm: map[string]string{
+ "grant_type": "client_credentials",
+ },
+ wantHeader: map[string]string{
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": basicAuth("test-client", ""),
+ },
+ },
+ {
+ name: "scopes",
+ cfg: &Config{
+ ClientID: "test-client",
+ TokenURL: "http://example.com",
+ Scopes: []string{"scp1", "scp2"},
+ AuthStyle: AuthStyleInParams,
+ },
+ wantBodyForm: map[string]string{
+ "grant_type": "client_credentials",
+ "client_id": "test-client",
+ "scope": "scp1 scp2",
+ },
+ wantHeader: map[string]string{
+ "Content-Type": "application/x-www-form-urlencoded",
+ },
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ got, err := newTokenRequest(context.Background(), tt.cfg)
+ require.NoError(t, err)
+
+ for k, v := range tt.wantBodyForm {
+ require.Equal(t, v, got.PostFormValue(k))
+ }
+
+ for k, v := range tt.wantHeader {
+ require.Equal(t, v, got.Header.Get(k))
+ }
+ })
+ }
+}
+
+func basicAuth(user, pass string) string {
+ auth := user + ":" + pass
+ return "Basic " + base64.StdEncoding.EncodeToString([]byte(auth))
+}
diff --git a/xoauth/tokensource.go b/xoauth/tokensource.go
new file mode 100644
index 000000000..b496c1936
--- /dev/null
+++ b/xoauth/tokensource.go
@@ -0,0 +1,194 @@
+package xoauth
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/ozontech/file.d/logger"
+ "go.uber.org/atomic"
+)
+
+type AuthStyle int
+
+const (
+ AuthStyleUnknown AuthStyle = iota
+ AuthStyleInParams
+ AuthStyleInHeader
+)
+
+type Config struct {
+ ClientID string
+ ClientSecret string
+ TokenURL string
+ Scopes []string
+ AuthStyle AuthStyle
+}
+
+func (c *Config) validate() error {
+ if c.ClientID == "" {
+ return fmt.Errorf("client id must be non-empty")
+ }
+ if c.TokenURL == "" {
+ return fmt.Errorf("token url must be non-empty")
+ }
+ if c.AuthStyle == AuthStyleUnknown {
+ return fmt.Errorf("auth style must be specified")
+ }
+ return nil
+}
+
+// tokenIssuer issues token for specific token grant type flow
+type tokenIssuer interface {
+ issueToken(ctx context.Context) (*Token, error)
+}
+
+type TokenSource interface {
+ Token(ctx context.Context) *Token // read-only
+ Stop()
+}
+
+type staticTokenSource struct {
+ t *Token
+}
+
+func NewStaticTokenSource(t *Token) TokenSource {
+ return &staticTokenSource{
+ t: t,
+ }
+}
+
+func (ts *staticTokenSource) Token(_ context.Context) *Token {
+ return ts.t
+}
+
+func (ts *staticTokenSource) Stop() {}
+
+// reuseTokenSource implements lifecycle of auth token refreshing.
+// - Once reuseTokenSource is created, first token issuance happens
+// - Further Token() calls must be non-blocking
+// - Token is updated in the background depending on the expidation date
+// - After Close() reuseTokenSource is irreversibly stops all background work
+type reuseTokenSource struct {
+ tokenHolder atomic.Pointer[Token]
+ tokenIssuer tokenIssuer
+
+ stopCh chan struct{}
+}
+
+func NewReuseTokenSource(ctx context.Context, cfg *Config) (TokenSource, error) {
+ if err := cfg.validate(); err != nil {
+ return nil, err
+ }
+
+ ti := newHTTPTokenIssuer(cfg)
+ return newReuseTokenSource(ctx, ti)
+}
+
+func newReuseTokenSource(ctx context.Context, ti tokenIssuer) (*reuseTokenSource, error) {
+ ts := &reuseTokenSource{
+ tokenHolder: atomic.Pointer[Token]{},
+ tokenIssuer: ti,
+
+ stopCh: make(chan struct{}),
+ }
+
+ // get first token during initialization to verify provided data
+ t, err := ti.issueToken(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to init token source: %w", err)
+ }
+
+ ts.tokenHolder.Store(t)
+ go ts.maintenance(ctx, time.Until(t.Expiry)/2)
+
+ return ts, nil
+}
+
+func (ts *reuseTokenSource) Token(ctx context.Context) *Token {
+ select {
+ case <-ctx.Done():
+ return nil
+ default:
+ }
+
+ return ts.tokenHolder.Load()
+}
+
+func (ts *reuseTokenSource) Stop() {
+ close(ts.stopCh)
+}
+
+// maintenance runs a token update loop
+func (ts *reuseTokenSource) maintenance(ctx context.Context, firstDelay time.Duration) {
+ scheduler := time.NewTimer(firstDelay)
+ defer scheduler.Stop()
+
+ // paths to calculate next scheduler delay
+ success, fail := ts.newDelayer()
+
+ updateToken := func() time.Duration {
+ t, err := ts.tokenIssuer.issueToken(ctx)
+ if err != nil {
+ return fail(parseError(err))
+ }
+
+ ts.tokenHolder.Store(t)
+ return success(time.Until(t.Expiry))
+ }
+
+ for {
+ select {
+ case <-ts.stopCh:
+ return
+ case <-ctx.Done():
+ return
+ case <-scheduler.C:
+ }
+
+ delay := updateToken()
+ resetTimer(scheduler, delay)
+ }
+}
+
+func resetTimer(t *time.Timer, d time.Duration) {
+ if !t.Stop() {
+ select {
+ case <-t.C:
+ default:
+ }
+ }
+ t.Reset(d)
+}
+
+// newDelayer returns success and failure paths that will be applied to refresh scheduler
+func (ts *reuseTokenSource) newDelayer() (
+ func(ttl time.Duration) time.Duration, // success
+ func(err *errorAuth) time.Duration, // failure
+) {
+ expBackoff := exponentialJitterBackoff()
+ linBackoff := linearJitterBackoff()
+ attempt := 0
+
+ success := func(ttl time.Duration) time.Duration {
+ attempt = 0
+ return linBackoff(ttl/3, ttl/2, 0)
+ }
+
+ failure := func(err *errorAuth) time.Duration {
+ attempt++
+ code := err.Code()
+ logger.Errorf("error occurred while updating oauth token: attempt=%d, code=%s, error=%s",
+ attempt, code, err.Error())
+
+ switch code {
+ case ecInvalidRequest, ecInvalidClient, ecInvalidGrant, ecInvalidScope,
+ ecUnauthorizedClient, ecUnsupportedGrantType:
+ return linBackoff(time.Minute, 10*time.Minute, attempt)
+ default:
+ return expBackoff(time.Second, time.Minute, attempt)
+ }
+ }
+
+ return success, failure
+}
diff --git a/xoauth/tokensource_test.go b/xoauth/tokensource_test.go
new file mode 100644
index 000000000..78dd2f118
--- /dev/null
+++ b/xoauth/tokensource_test.go
@@ -0,0 +1,74 @@
+package xoauth
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestReuseTokenSource(t *testing.T) {
+ ctx := context.Background()
+ now := time.Now()
+
+ issuer := &mockTokenIssuer{}
+
+ // first token fail
+ issuer.expect(nil, errors.New("some error"))
+ _, err := newReuseTokenSource(ctx, issuer)
+ require.Error(t, err)
+
+ // first token success, save and start maintenance
+ tok := &Token{AccessToken: "test-token", Expiry: now.Add(time.Second)}
+ issuer.expect(tok, nil)
+ source, err := newReuseTokenSource(ctx, issuer)
+ require.NoError(t, err)
+
+ // check token
+ require.Equal(t, tok.AccessToken, source.Token(ctx).AccessToken)
+
+ // first update will be after first token expire / 2,
+ // set next token and wait
+ tok = &Token{AccessToken: "test-token-1", Expiry: now.Add(2 * time.Second)}
+ issuer.expect(tok, nil)
+ time.Sleep(time.Second)
+
+ // check token
+ require.Equal(t, tok.AccessToken, source.Token(ctx).AccessToken)
+
+ // stop token source and wait next potential update,
+ // set issuer token which we won't have to get
+ issuer.expect(&Token{AccessToken: "test-token-2", Expiry: now.Add(3 * time.Second)}, nil)
+ source.Stop()
+ time.Sleep(time.Second)
+
+ // check that token is one that was saved before the stop
+ require.Equal(t, tok.AccessToken, source.Token(ctx).AccessToken)
+
+ // check that Token() with canceled context returns nil
+ ctx2, cancel := context.WithCancel(ctx)
+ cancel()
+ require.Nil(t, source.Token(ctx2))
+}
+
+type mockTokenIssuer struct {
+ token *Token
+ err error
+
+ m sync.RWMutex
+}
+
+func (ti *mockTokenIssuer) issueToken(ctx context.Context) (*Token, error) {
+ ti.m.RLock()
+ defer ti.m.RUnlock()
+ return ti.token, ti.err
+}
+
+func (ti *mockTokenIssuer) expect(t *Token, err error) {
+ ti.m.Lock()
+ defer ti.m.Unlock()
+ ti.token, ti.err = t, err
+}
diff --git a/xscram/client.go b/xscram/client.go
deleted file mode 100644
index 4004e2fd7..000000000
--- a/xscram/client.go
+++ /dev/null
@@ -1,43 +0,0 @@
-package xscram
-
-import (
- "crypto/sha256"
- "crypto/sha512"
-
- "github.com/xdg-go/scram"
-)
-
-var (
- SHA256 scram.HashGeneratorFcn = sha256.New
- SHA512 scram.HashGeneratorFcn = sha512.New
-)
-
-type Client struct {
- *scram.Client
- *scram.ClientConversation
- scram.HashGeneratorFcn
-}
-
-func NewClient(hashFn scram.HashGeneratorFcn) *Client {
- return &Client{
- HashGeneratorFcn: hashFn,
- }
-}
-
-func (x *Client) Begin(userName, password, authzID string) error {
- var err error
- x.Client, err = x.HashGeneratorFcn.NewClient(userName, password, authzID)
- if err != nil {
- return err
- }
- x.ClientConversation = x.Client.NewConversation()
- return nil
-}
-
-func (x *Client) Step(challenge string) (string, error) {
- return x.ClientConversation.Step(challenge)
-}
-
-func (x *Client) Done() bool {
- return x.ClientConversation.Done()
-}