Skip to content

Commit 1c00987

Browse files
committed
update JWT interface
1 parent 6649c4f commit 1c00987

File tree

13 files changed

+126
-66
lines changed

13 files changed

+126
-66
lines changed

clickhouse.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,3 @@ func (ch *clickhouse) Close() error {
375375
}
376376
}
377377
}
378-
379-
func (ch *clickhouse) UpdateJWT(jwt string) error {
380-
ch.opt.Auth.JWT = jwt
381-
return nil
382-
}

clickhouse_options.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,6 @@ type Auth struct { // has_control_character
8080

8181
Username string
8282
Password string
83-
84-
// JWT for ClickHouse Cloud. Use this instead of Username and Password if you're using JWT auth.
85-
JWT string
8683
}
8784

8885
type Compression struct {
@@ -160,6 +157,11 @@ type Options struct {
160157
// HTTPProxy specifies an HTTP proxy URL to use for requests made by the client.
161158
HTTPProxyURL *url.URL
162159

160+
// GetJWT should return a JWT for authentication with ClickHouse Cloud.
161+
// This is called per connection/request, so you may cache the token in your app if needed.
162+
// Use this instead of Auth.Username and Auth.Password if you're using JWT auth.
163+
GetJWT GetJWTFunc
164+
163165
scheme string
164166
ReadTimeout time.Duration
165167
}

clickhouse_std.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -386,11 +386,6 @@ func (std *stdDriver) Close() error {
386386
return err
387387
}
388388

389-
func (std *stdDriver) UpdateJWT(jwt string) error {
390-
std.opt.Auth.JWT = jwt
391-
return nil
392-
}
393-
394389
type stdBatch struct {
395390
batch ldriver.Batch
396391
debugf func(format string, v ...any)

conn.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,18 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
110110
}
111111
)
112112

113-
if err := connect.handshake(opt.Auth); err != nil {
113+
auth := opt.Auth
114+
if useJWTAuth(opt) {
115+
jwt, err := opt.GetJWT(ctx)
116+
if err != nil {
117+
return nil, fmt.Errorf("failed to get JWT: %w", err)
118+
}
119+
120+
auth.Username = jwtAuthMarker
121+
auth.Password = jwt
122+
}
123+
124+
if err := connect.handshake(auth); err != nil {
114125
return nil, err
115126
}
116127

conn_handshake.go

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@ import (
2525
"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
2626
)
2727

28-
// jwtAuthMarker is the marker for JSON Web Token authentication in ClickHouse Cloud.
29-
// At the protocol level this is used in place of a username.
30-
const jwtAuthMarker = " JWT AUTHENTICATION "
31-
3228
func (c *connect) handshake(auth Auth) error {
3329
defer c.buffer.Reset()
3430
c.debugf("[handshake] -> %s", proto.ClientHandshake{})
@@ -48,14 +44,8 @@ func (c *connect) handshake(auth Auth) error {
4844
handshake.Encode(c.buffer)
4945
{
5046
c.buffer.PutString(auth.Database)
51-
52-
if auth.JWT != "" {
53-
c.buffer.PutString(jwtAuthMarker)
54-
c.buffer.PutString(auth.JWT)
55-
} else {
56-
c.buffer.PutString(auth.Username)
57-
c.buffer.PutString(auth.Password)
58-
}
47+
c.buffer.PutString(auth.Username)
48+
c.buffer.PutString(auth.Password)
5949
}
6050
if err := c.flush(); err != nil {
6151
return err

conn_http.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,20 @@ func (rw *HTTPReaderWriter) reset(pw *io.PipeWriter) io.WriteCloser {
120120
}
121121

122122
// applyOptionsToRequest applies the client Options (such as auth, headers, client info) to the given http.Request
123-
func applyOptionsToRequest(req *http.Request, opt *Options) {
124-
if opt.TLS != nil && len(opt.Auth.JWT) > 0 {
125-
req.Header.Set("Authorization", "Bearer "+opt.Auth.JWT)
123+
func applyOptionsToRequest(ctx context.Context, req *http.Request, opt *Options) error {
124+
jwt := queryOptionsJWT(ctx)
125+
useJWT := jwt != "" || useJWTAuth(opt)
126+
127+
if opt.TLS != nil && useJWT {
128+
if jwt == "" {
129+
var err error
130+
jwt, err = opt.GetJWT(ctx)
131+
if err != nil {
132+
return fmt.Errorf("failed to get JWT: %w", err)
133+
}
134+
}
135+
136+
req.Header.Set("Authorization", "Bearer "+jwt)
126137
} else if opt.TLS != nil && len(opt.Auth.Username) > 0 {
127138
req.Header.Set("X-ClickHouse-User", opt.Auth.Username)
128139
if len(opt.Auth.Password) > 0 {
@@ -145,6 +156,8 @@ func applyOptionsToRequest(req *http.Request, opt *Options) {
145156
for k, v := range opt.HttpHeaders {
146157
req.Header.Set(k, v)
147158
}
159+
160+
return nil
148161
}
149162

150163
func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpConnect, error) {
@@ -462,7 +475,11 @@ func (h *httpConnect) createRequest(ctx context.Context, requestUrl string, read
462475
return nil, err
463476
}
464477

465-
applyOptionsToRequest(req, h.opt)
478+
err = applyOptionsToRequest(ctx, req, h.opt)
479+
if err != nil {
480+
return nil, err
481+
}
482+
466483
for k, v := range headers {
467484
req.Header.Add(k, v)
468485
}

context.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ type (
5353
async AsyncOptions
5454
queryID string
5555
quotaKey string
56+
jwt string
5657
events struct {
5758
logs func(*Log)
5859
progress func(*Progress)
@@ -95,6 +96,15 @@ func WithQuotaKey(quotaKey string) QueryOption {
9596
}
9697
}
9798

99+
// WithJWT overrides the existing authentication with the given JWT.
100+
// This only applies for clients connected with HTTPS to ClickHouse Cloud.
101+
func WithJWT(jwt string) QueryOption {
102+
return func(o *QueryOptions) error {
103+
o.jwt = jwt
104+
return nil
105+
}
106+
}
107+
98108
func WithSettings(settings Settings) QueryOption {
99109
return func(o *QueryOptions) error {
100110
o.settings = settings
@@ -211,6 +221,16 @@ func queryOptions(ctx context.Context) QueryOptions {
211221
return opt
212222
}
213223

224+
// queryOptionsJWT returns the JWT within the given context's QueryOptions.
225+
// Empty string if not present.
226+
func queryOptionsJWT(ctx context.Context) string {
227+
if opt, ok := ctx.Value(_contextOptionKey).(QueryOptions); ok {
228+
return opt.jwt
229+
}
230+
231+
return ""
232+
}
233+
214234
// queryOptionsAsync returns the AsyncOptions struct within the given context's QueryOptions.
215235
func queryOptionsAsync(ctx context.Context) AsyncOptions {
216236
if opt, ok := ctx.Value(_contextOptionKey).(QueryOptions); ok {

jwt.go

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,16 @@
1818
package clickhouse
1919

2020
import (
21-
"database/sql"
22-
"fmt"
21+
"context"
2322
)
2423

25-
type jwtUpdater interface {
26-
UpdateJWT(jwt string) error
27-
}
28-
29-
// UpdateSqlJWT is a helper function that updates the JWT within the given sql.DB instance, useful for
30-
// updating expired tokens.
31-
// For the Native interface, the JWT is only updated for new connections.
32-
// For the HTTP interface, the JWT is updated immediately for subsequent requests.
33-
// Existing Native connections are unaffected, but may be forcibly closed by the server upon token expiry.
34-
// For a completely fresh set of connections you should open a new instance.
35-
func UpdateSqlJWT(db *sql.DB, jwt string) error {
36-
if db == nil {
37-
return nil
38-
}
24+
// jwtAuthMarker is the marker for JSON Web Token authentication in ClickHouse Cloud.
25+
// At the protocol level this is used in place of a username.
26+
const jwtAuthMarker = " JWT AUTHENTICATION "
3927

40-
chDriver, ok := db.Driver().(jwtUpdater)
41-
if !ok {
42-
return fmt.Errorf("failed to update JWT: db instance must be ClickHouse")
43-
}
28+
type GetJWTFunc = func(ctx context.Context) (string, error)
4429

45-
return chDriver.UpdateJWT(jwt)
30+
// useJWTAuth returns true if the client should use JWT auth
31+
func useJWTAuth(opt *Options) bool {
32+
return opt.GetJWT != nil
4633
}

lib/driver/driver.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,6 @@ type (
6161
Ping(context.Context) error
6262
Stats() Stats
6363
Close() error
64-
65-
// UpdateJWT updates the JWT used for new connections, useful for swapping expired tokens.
66-
// Existing connections are unaffected, but may be forcibly closed by the server upon token expiry.
67-
// For a completely fresh set of connections you should create a new Conn instance.
68-
UpdateJWT(jwt string) error
6964
}
7065
Row interface {
7166
Err() error

tests/conn_test.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,10 +456,25 @@ func TestFreeBufOnConnRelease(t *testing.T) {
456456
require.NoError(t, err)
457457
}
458458

459+
func TestJWTError(t *testing.T) {
460+
getJWT := func(ctx context.Context) (string, error) {
461+
return "", fmt.Errorf("test error")
462+
}
463+
464+
conn, err := GetJWTConnection(testSet, nil, nil, 1000*time.Millisecond, getJWT)
465+
require.NoError(t, err)
466+
require.ErrorContains(t, conn.Ping(context.Background()), "test error")
467+
}
468+
459469
func TestNativeJWTAuth(t *testing.T) {
460470
SkipNotCloud(t)
461471

462-
conn, err := GetJWTConnection(testSet, nil, &tls.Config{}, 1000*time.Millisecond)
472+
jwt := GetEnv("CLICKHOUSE_JWT", "")
473+
getJWT := func(ctx context.Context) (string, error) {
474+
return jwt, nil
475+
}
476+
477+
conn, err := GetJWTConnection(testSet, nil, &tls.Config{}, 1000*time.Millisecond, getJWT)
463478
require.NoError(t, err)
464479

465480
// Token works
@@ -469,7 +484,7 @@ func TestNativeJWTAuth(t *testing.T) {
469484
time.Sleep(1500 * time.Millisecond)
470485

471486
// Break the token
472-
require.NoError(t, conn.UpdateJWT("broken_jwt"))
487+
jwt = "broken_jwt"
473488

474489
// Next ping should fail
475490
require.Error(t, conn.Ping(context.Background()))

0 commit comments

Comments
 (0)