Skip to content

Commit e6f2f5b

Browse files
committed
Token renewal in a interceptor
1 parent d0a2fcb commit e6f2f5b

File tree

4 files changed

+90
-79
lines changed

4 files changed

+90
-79
lines changed

generate/go_client.tpl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,16 @@ func New(config *DialConfig) (Client, error) {
5656
authInterceptor := &authInterceptor{config: config}
5757
c.interceptors = append(c.interceptors, authInterceptor)
5858
}
59+
if config.TokenRenewal != nil {
60+
tokenRenewingInterceptor := &tokenRenewingInterceptor{config: config, client: c}
61+
c.interceptors = append(c.interceptors, tokenRenewingInterceptor)
62+
}
5963
if config.Log != nil {
6064
loggingInterceptor := &loggingInterceptor{config: config}
6165
c.interceptors = append(c.interceptors, loggingInterceptor)
6266
}
6367
c.interceptors = append(c.interceptors, config.Interceptors...)
6468

65-
// TODO convert to interceptor
66-
go c.startTokenRenewal()
67-
6869
return c, nil
6970
}
7071

go/client/client-interceptors.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@ package client
22

33
import (
44
"context"
5+
"fmt"
6+
"log/slog"
7+
"sync"
8+
"sync/atomic"
9+
"time"
510

611
"connectrpc.com/connect"
12+
apiv2models "github.com/metal-stack/api/go/metalstack/api/v2"
713
)
814

915
// authinterceptor adds the required auth headers
@@ -65,3 +71,79 @@ func (i *loggingInterceptor) WrapStreamingClient(next connect.StreamingClientFun
6571
func (i *loggingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
6672
return next
6773
}
74+
75+
type tokenRenewingInterceptor struct {
76+
config *DialConfig
77+
client *client
78+
79+
renewing atomic.Bool
80+
81+
sync.Mutex
82+
}
83+
84+
func (i *tokenRenewingInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
85+
return connect.UnaryFunc(func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
86+
err := i.renewTokenIfNeeded()
87+
if err != nil {
88+
return nil, err
89+
}
90+
return next(ctx, request)
91+
})
92+
}
93+
94+
func (i *tokenRenewingInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
95+
return next
96+
}
97+
98+
func (i *tokenRenewingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
99+
return next
100+
}
101+
102+
func (i *tokenRenewingInterceptor) renewTokenIfNeeded() error {
103+
if i.config.expiresAt.IsZero() {
104+
return nil
105+
}
106+
if i.renewing.Load() {
107+
return nil
108+
}
109+
if i.config.Log == nil {
110+
i.config.Log = slog.Default()
111+
}
112+
113+
replaceBefore := i.config.expiresAt.Sub(i.config.issuedAt) / tokenRenewChecksDuringLifetime
114+
115+
if time.Until(i.config.expiresAt) > replaceBefore {
116+
return nil
117+
}
118+
119+
i.renewing.Store(true)
120+
defer i.renewing.Store(false)
121+
122+
i.config.Log.Info("call token refresh, current token expires soon", "expires", i.config.expiresAt.String())
123+
124+
i.Lock()
125+
defer i.Unlock()
126+
127+
resp, err := i.client.Apiv2().Token().Refresh(context.Background(), &apiv2models.TokenServiceRefreshRequest{})
128+
if err != nil {
129+
return fmt.Errorf("unable to refresh token %w", err)
130+
}
131+
132+
i.config.Token = resp.Secret
133+
err = i.config.parse()
134+
if err != nil {
135+
return fmt.Errorf("unable to parse token %w", err)
136+
}
137+
138+
if i.config.TokenRenewal.PersistTokenFn == nil {
139+
return nil
140+
}
141+
142+
err = i.config.TokenRenewal.PersistTokenFn(i.config.Token)
143+
if err != nil {
144+
return fmt.Errorf("unable to persist token %w", err)
145+
}
146+
147+
i.config.Log.Info("token refreshed, new token expires in", "expires", i.config.expiresAt.String())
148+
return nil
149+
}

go/client/client.go

Lines changed: 4 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

go/client/conn.go

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package client
22

33
import (
4-
"context"
54
"errors"
65
"fmt"
76
"log/slog"
@@ -10,7 +9,6 @@ import (
109

1110
"connectrpc.com/connect"
1211
"github.com/golang-jwt/jwt/v5"
13-
api "github.com/metal-stack/api/go/metalstack/api/v2"
1412
)
1513

1614
const tokenRenewChecksDuringLifetime = 4
@@ -84,70 +82,3 @@ func (dc *DialConfig) parse() error {
8482
}
8583
return nil
8684
}
87-
88-
func (c *client) startTokenRenewal() {
89-
if c.config.TokenRenewal == nil {
90-
return
91-
}
92-
if c.config.expiresAt.IsZero() {
93-
return
94-
}
95-
if c.config.Log == nil {
96-
c.config.Log = slog.Default()
97-
}
98-
99-
replaceBefore := c.config.expiresAt.Sub(c.config.issuedAt) / tokenRenewChecksDuringLifetime
100-
101-
err := c.renewTokenIfNeeded(replaceBefore)
102-
if err != nil {
103-
c.config.Log.Error("unable to renew token", "error", err)
104-
}
105-
106-
ticker := time.NewTicker(replaceBefore)
107-
defer ticker.Stop()
108-
done := make(chan bool)
109-
for {
110-
select {
111-
case <-done:
112-
return
113-
case <-ticker.C:
114-
err := c.renewTokenIfNeeded(replaceBefore)
115-
if err != nil {
116-
c.config.Log.Error("unable to renew token", "error", err)
117-
}
118-
}
119-
}
120-
}
121-
122-
func (c *client) renewTokenIfNeeded(replaceBefore time.Duration) error {
123-
if time.Until(c.config.expiresAt) > replaceBefore {
124-
return nil
125-
}
126-
c.config.Log.Info("call token refresh, current token expires soon", "expires", c.config.expiresAt.String())
127-
128-
c.Lock()
129-
defer c.Unlock()
130-
131-
resp, err := c.Apiv2().Token().Refresh(context.Background(), &api.TokenServiceRefreshRequest{})
132-
if err != nil {
133-
return fmt.Errorf("unable to refresh token %w", err)
134-
}
135-
136-
c.config.Token = resp.Secret
137-
err = c.config.parse()
138-
if err != nil {
139-
return fmt.Errorf("unable to parse token %w", err)
140-
}
141-
142-
if c.config.TokenRenewal.PersistTokenFn == nil {
143-
return nil
144-
}
145-
146-
err = c.config.TokenRenewal.PersistTokenFn(c.config.Token)
147-
if err != nil {
148-
return fmt.Errorf("unable to persist token %w", err)
149-
}
150-
151-
c.config.Log.Info("token refreshed, new token expires in", "expires", c.config.expiresAt.String())
152-
return nil
153-
}

0 commit comments

Comments
 (0)