Skip to content

Commit c24ae57

Browse files
committed
JWT Auth
1 parent 3ab92e4 commit c24ae57

17 files changed

+297
-41
lines changed

clickhouse.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,8 @@ func (ch *clickhouse) Close() error {
375375
}
376376
}
377377
}
378+
379+
func (ch *clickhouse) UpdateJWT(jwt string) error {
380+
ch.opt.Auth.JWT = jwt
381+
return nil
382+
}

clickhouse_options.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,12 @@ var compressionMap = map[string]CompressionMethod{
7777

7878
type Auth struct { // has_control_character
7979
Database string
80+
8081
Username string
8182
Password string
83+
84+
// JWT for ClickHouse Cloud. Use this instead of Username and Password if you're using JWT auth.
85+
JWT string
8286
}
8387

8488
type Compression struct {

clickhouse_std.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ func (o *stdConnOpener) Driver() driver.Driver {
5454
debugf = log.New(os.Stdout, "[clickhouse-std] ", 0).Printf
5555
}
5656
}
57-
return &stdDriver{debugf: debugf}
57+
return &stdDriver{
58+
opt: o.opt,
59+
debugf: debugf,
60+
}
5861
}
5962

6063
func (o *stdConnOpener) Connect(ctx context.Context) (_ driver.Conn, err error) {
@@ -201,6 +204,7 @@ type stdConnect interface {
201204
}
202205

203206
type stdDriver struct {
207+
opt *Options
204208
conn stdConnect
205209
commit func() error
206210
debugf func(format string, v ...any)
@@ -382,6 +386,11 @@ func (std *stdDriver) Close() error {
382386
return err
383387
}
384388

389+
func (std *stdDriver) UpdateJWT(jwt string) error {
390+
std.opt.Auth.JWT = jwt
391+
return nil
392+
}
393+
385394
type stdBatch struct {
386395
batch ldriver.Batch
387396
debugf func(format string, v ...any)

conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
110110
}
111111
)
112112

113-
if err := connect.handshake(opt.Auth.Database, opt.Auth.Username, opt.Auth.Password); err != nil {
113+
if err := connect.handshake(opt.Auth); err != nil {
114114
return nil, err
115115
}
116116

conn_handshake.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ import (
2525
"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
2626
)
2727

28-
func (c *connect) handshake(database, username, password string) error {
28+
// jwtAuthMarker is the marker for JSON Web Token authentication in ClickHouse Cloud.
29+
// At the protocol level this is used in place of a username.
30+
const jwtAuthMarker = " JWT AUTHENTICATION "
31+
32+
func (c *connect) handshake(auth Auth) error {
2933
defer c.buffer.Reset()
3034
c.debugf("[handshake] -> %s", proto.ClientHandshake{})
3135
// set a read deadline - alternative to context.Read operation will fail if no data is received after deadline.
@@ -43,9 +47,15 @@ func (c *connect) handshake(database, username, password string) error {
4347
}
4448
handshake.Encode(c.buffer)
4549
{
46-
c.buffer.PutString(database)
47-
c.buffer.PutString(username)
48-
c.buffer.PutString(password)
50+
c.buffer.PutString(auth.Database)
51+
52+
if auth.JWT != "" {
53+
c.buffer.PutString(jwtAuthMarker)
54+
c.buffer.PutString(auth.JWT)
55+
} else {
56+
c.buffer.PutString(auth.Username)
57+
c.buffer.PutString(auth.Password)
58+
}
4959
}
5060
if err := c.flush(); err != nil {
5161
return err

conn_http.go

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,34 @@ func (rw *HTTPReaderWriter) reset(pw *io.PipeWriter) io.WriteCloser {
119119
}
120120
}
121121

122+
// applyOptionsToRequest applies the client Options (such as auth, headers, client info) to the given http.Request
123+
func applyOptionsToRequest(req *http.Request, opt *Options) {
124+
if opt.TLS != nil && len(opt.Auth.JWT) > 0 {
125+
req.Header.Set("Authorization", "Bearer "+opt.Auth.JWT)
126+
} else if opt.TLS != nil && len(opt.Auth.Username) > 0 {
127+
req.Header.Set("X-ClickHouse-User", opt.Auth.Username)
128+
if len(opt.Auth.Password) > 0 {
129+
req.Header.Set("X-ClickHouse-Key", opt.Auth.Password)
130+
req.Header.Set("X-ClickHouse-SSL-Certificate-Auth", "off")
131+
} else {
132+
req.Header.Set("X-ClickHouse-SSL-Certificate-Auth", "on")
133+
}
134+
} else if opt.TLS == nil && len(opt.Auth.Username) > 0 {
135+
if len(opt.Auth.Password) > 0 {
136+
req.URL.User = url.UserPassword(opt.Auth.Username, opt.Auth.Password)
137+
138+
} else {
139+
req.URL.User = url.User(opt.Auth.Username)
140+
}
141+
}
142+
143+
req.Header.Set("User-Agent", opt.ClientInfo.String())
144+
145+
for k, v := range opt.HttpHeaders {
146+
req.Header.Set(k, v)
147+
}
148+
}
149+
122150
func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpConnect, error) {
123151
var debugf = func(format string, v ...any) {}
124152
if opt.Debug {
@@ -151,29 +179,6 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon
151179
Path: opt.HttpUrlPath,
152180
}
153181

154-
headers := make(map[string]string)
155-
for k, v := range opt.HttpHeaders {
156-
headers[k] = v
157-
}
158-
159-
if opt.TLS == nil && len(opt.Auth.Username) > 0 {
160-
if len(opt.Auth.Password) > 0 {
161-
u.User = url.UserPassword(opt.Auth.Username, opt.Auth.Password)
162-
} else {
163-
u.User = url.User(opt.Auth.Username)
164-
}
165-
} else if opt.TLS != nil && len(opt.Auth.Username) > 0 {
166-
headers["X-ClickHouse-User"] = opt.Auth.Username
167-
if len(opt.Auth.Password) > 0 {
168-
headers["X-ClickHouse-Key"] = opt.Auth.Password
169-
headers["X-ClickHouse-SSL-Certificate-Auth"] = "off"
170-
} else {
171-
headers["X-ClickHouse-SSL-Certificate-Auth"] = "on"
172-
}
173-
}
174-
175-
headers["User-Agent"] = opt.ClientInfo.String()
176-
177182
query := u.Query()
178183
if len(opt.Auth.Database) > 0 {
179184
query.Set("database", opt.Auth.Database)
@@ -225,6 +230,7 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon
225230
}
226231

227232
conn := &httpConnect{
233+
opt: opt,
228234
client: &http.Client{
229235
Transport: t,
230236
},
@@ -234,7 +240,6 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon
234240
blockCompressor: compress.NewWriter(compress.Level(opt.Compression.Level), compress.Method(opt.Compression.Method)),
235241
compressionPool: compressionPool,
236242
blockBufferSize: opt.BlockBufferSize,
237-
headers: headers,
238243
}
239244
location, err := conn.readTimeZone(ctx)
240245
if err != nil {
@@ -251,6 +256,7 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon
251256
}
252257

253258
return &httpConnect{
259+
opt: opt,
254260
client: &http.Client{
255261
Transport: t,
256262
},
@@ -261,11 +267,11 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon
261267
compressionPool: compressionPool,
262268
location: location,
263269
blockBufferSize: opt.BlockBufferSize,
264-
headers: headers,
265270
}, nil
266271
}
267272

268273
type httpConnect struct {
274+
opt *Options
269275
url *url.URL
270276
client *http.Client
271277
location *time.Location
@@ -274,7 +280,6 @@ type httpConnect struct {
274280
blockCompressor *compress.Writer
275281
compressionPool Pool[HTTPReaderWriter]
276282
blockBufferSize uint8
277-
headers map[string]string
278283
}
279284

280285
func (h *httpConnect) isBad() bool {
@@ -456,9 +461,12 @@ func (h *httpConnect) createRequest(ctx context.Context, requestUrl string, read
456461
if err != nil {
457462
return nil, err
458463
}
464+
465+
applyOptionsToRequest(req, h.opt)
459466
for k, v := range headers {
460467
req.Header.Add(k, v)
461468
}
469+
462470
var query url.Values
463471
if options != nil {
464472
query = req.URL.Query()

conn_http_async_insert.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func (h *httpConnect) asyncInsert(ctx context.Context, query string, wait bool,
3838
}
3939
}
4040

41-
res, err := h.sendQuery(ctx, query, &options, h.headers)
41+
res, err := h.sendQuery(ctx, query, &options, nil)
4242
if res != nil {
4343
defer res.Body.Close()
4444
// we don't care about result, so just discard it to reuse connection

conn_http_batch.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,7 @@ func (b *httpBatch) Send() (err error) {
207207

208208
options.settings["query"] = b.query
209209
headers["Content-Type"] = "application/octet-stream"
210-
for k, v := range b.conn.headers {
211-
headers[k] = v
212-
}
210+
213211
res, err := b.conn.sendStreamQuery(b.ctx, r, &options, headers)
214212

215213
if res != nil {

conn_http_exec.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func (h *httpConnect) exec(ctx context.Context, query string, args ...any) error
2929
return err
3030
}
3131

32-
res, err := h.sendQuery(ctx, query, &options, h.headers)
32+
res, err := h.sendQuery(ctx, query, &options, nil)
3333
if res != nil {
3434
defer res.Body.Close()
3535
// we don't care about result, so just discard it to reuse connection

conn_http_query.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ func (h *httpConnect) query(ctx context.Context, release func(*connect, error),
4242
headers["Accept-Encoding"] = h.compression.String()
4343
}
4444

45-
for k, v := range h.headers {
46-
headers[k] = v
47-
}
48-
4945
res, err := h.sendQuery(ctx, query, &options, headers)
5046
if err != nil {
5147
return nil, err

0 commit comments

Comments
 (0)