Skip to content

Commit b3b56a5

Browse files
authored
Merge branch 'main' into fix-clickhouse-proxy-issue
2 parents d512624 + b38d732 commit b3b56a5

17 files changed

+359
-42
lines changed

clickhouse_options.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ var compressionMap = map[string]CompressionMethod{
7777

7878
type Auth struct { // has_control_character
7979
Database string
80+
8081
Username string
8182
Password string
8283
}
@@ -156,6 +157,11 @@ type Options struct {
156157
// HTTPProxy specifies an HTTP proxy URL to use for requests made by the client.
157158
HTTPProxyURL *url.URL
158159

160+
// GetJWT should return a JWT for authentication with ClickHouse Cloud.
161+
// This is called per connection/request, so you may cache the token in your app if needed.
162+
// Use this instead of Auth.Username and Auth.Password if you're using JWT auth.
163+
GetJWT GetJWTFunc
164+
159165
scheme string
160166
ReadTimeout time.Duration
161167
}

clickhouse_std.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ func (o *stdConnOpener) Driver() driver.Driver {
5454
debugf = log.New(os.Stdout, "[clickhouse-std] ", 0).Printf
5555
}
5656
}
57-
return &stdDriver{debugf: debugf}
57+
return &stdDriver{
58+
opt: o.opt,
59+
debugf: debugf,
60+
}
5861
}
5962

6063
func (o *stdConnOpener) Connect(ctx context.Context) (_ driver.Conn, err error) {
@@ -201,6 +204,7 @@ type stdConnect interface {
201204
}
202205

203206
type stdDriver struct {
207+
opt *Options
204208
conn stdConnect
205209
commit func() error
206210
debugf func(format string, v ...any)

conn.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,18 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
110110
}
111111
)
112112

113-
if err := connect.handshake(opt.Auth.Database, opt.Auth.Username, opt.Auth.Password); err != nil {
113+
auth := opt.Auth
114+
if useJWTAuth(opt) {
115+
jwt, err := opt.GetJWT(ctx)
116+
if err != nil {
117+
return nil, fmt.Errorf("failed to get JWT: %w", err)
118+
}
119+
120+
auth.Username = jwtAuthMarker
121+
auth.Password = jwt
122+
}
123+
124+
if err := connect.handshake(auth); err != nil {
114125
return nil, err
115126
}
116127

conn_handshake.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import (
2525
"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
2626
)
2727

28-
func (c *connect) handshake(database, username, password string) error {
28+
func (c *connect) handshake(auth Auth) error {
2929
defer c.buffer.Reset()
3030
c.debugf("[handshake] -> %s", proto.ClientHandshake{})
3131
// set a read deadline - alternative to context.Read operation will fail if no data is received after deadline.
@@ -43,9 +43,9 @@ func (c *connect) handshake(database, username, password string) error {
4343
}
4444
handshake.Encode(c.buffer)
4545
{
46-
c.buffer.PutString(database)
47-
c.buffer.PutString(username)
48-
c.buffer.PutString(password)
46+
c.buffer.PutString(auth.Database)
47+
c.buffer.PutString(auth.Username)
48+
c.buffer.PutString(auth.Password)
4949
}
5050
if err := c.flush(); err != nil {
5151
return err

conn_http.go

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,47 @@ func (rw *HTTPReaderWriter) reset(pw *io.PipeWriter) io.WriteCloser {
119119
}
120120
}
121121

122+
// applyOptionsToRequest applies the client Options (such as auth, headers, client info) to the given http.Request
123+
func applyOptionsToRequest(ctx context.Context, req *http.Request, opt *Options) error {
124+
jwt := queryOptionsJWT(ctx)
125+
useJWT := jwt != "" || useJWTAuth(opt)
126+
127+
if opt.TLS != nil && useJWT {
128+
if jwt == "" {
129+
var err error
130+
jwt, err = opt.GetJWT(ctx)
131+
if err != nil {
132+
return fmt.Errorf("failed to get JWT: %w", err)
133+
}
134+
}
135+
136+
req.Header.Set("Authorization", "Bearer "+jwt)
137+
} else if opt.TLS != nil && len(opt.Auth.Username) > 0 {
138+
req.Header.Set("X-ClickHouse-User", opt.Auth.Username)
139+
if len(opt.Auth.Password) > 0 {
140+
req.Header.Set("X-ClickHouse-Key", opt.Auth.Password)
141+
req.Header.Set("X-ClickHouse-SSL-Certificate-Auth", "off")
142+
} else {
143+
req.Header.Set("X-ClickHouse-SSL-Certificate-Auth", "on")
144+
}
145+
} else if opt.TLS == nil && len(opt.Auth.Username) > 0 {
146+
if len(opt.Auth.Password) > 0 {
147+
req.URL.User = url.UserPassword(opt.Auth.Username, opt.Auth.Password)
148+
149+
} else {
150+
req.URL.User = url.User(opt.Auth.Username)
151+
}
152+
}
153+
154+
req.Header.Set("User-Agent", opt.ClientInfo.String())
155+
156+
for k, v := range opt.HttpHeaders {
157+
req.Header.Set(k, v)
158+
}
159+
160+
return nil
161+
}
162+
122163
func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpConnect, error) {
123164
var debugf = func(format string, v ...any) {}
124165
if opt.Debug {
@@ -151,29 +192,6 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon
151192
Path: opt.HttpUrlPath,
152193
}
153194

154-
headers := make(map[string]string)
155-
for k, v := range opt.HttpHeaders {
156-
headers[k] = v
157-
}
158-
159-
if opt.TLS == nil && len(opt.Auth.Username) > 0 {
160-
if len(opt.Auth.Password) > 0 {
161-
u.User = url.UserPassword(opt.Auth.Username, opt.Auth.Password)
162-
} else {
163-
u.User = url.User(opt.Auth.Username)
164-
}
165-
} else if opt.TLS != nil && len(opt.Auth.Username) > 0 {
166-
headers["X-ClickHouse-User"] = opt.Auth.Username
167-
if len(opt.Auth.Password) > 0 {
168-
headers["X-ClickHouse-Key"] = opt.Auth.Password
169-
headers["X-ClickHouse-SSL-Certificate-Auth"] = "off"
170-
} else {
171-
headers["X-ClickHouse-SSL-Certificate-Auth"] = "on"
172-
}
173-
}
174-
175-
headers["User-Agent"] = opt.ClientInfo.String()
176-
177195
query := u.Query()
178196
if len(opt.Auth.Database) > 0 {
179197
query.Set("database", opt.Auth.Database)
@@ -225,6 +243,7 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon
225243
}
226244

227245
conn := &httpConnect{
246+
opt: opt,
228247
client: &http.Client{
229248
Transport: t,
230249
},
@@ -234,7 +253,6 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon
234253
blockCompressor: compress.NewWriter(compress.Level(opt.Compression.Level), compress.Method(opt.Compression.Method)),
235254
compressionPool: compressionPool,
236255
blockBufferSize: opt.BlockBufferSize,
237-
headers: headers,
238256
}
239257
location, err := conn.readTimeZone(ctx)
240258
if err != nil {
@@ -251,6 +269,7 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon
251269
}
252270

253271
return &httpConnect{
272+
opt: opt,
254273
client: &http.Client{
255274
Transport: t,
256275
},
@@ -261,11 +280,11 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon
261280
compressionPool: compressionPool,
262281
location: location,
263282
blockBufferSize: opt.BlockBufferSize,
264-
headers: headers,
265283
}, nil
266284
}
267285

268286
type httpConnect struct {
287+
opt *Options
269288
url *url.URL
270289
client *http.Client
271290
location *time.Location
@@ -274,7 +293,6 @@ type httpConnect struct {
274293
blockCompressor *compress.Writer
275294
compressionPool Pool[HTTPReaderWriter]
276295
blockBufferSize uint8
277-
headers map[string]string
278296
}
279297

280298
func (h *httpConnect) isBad() bool {
@@ -456,9 +474,16 @@ func (h *httpConnect) createRequest(ctx context.Context, requestUrl string, read
456474
if err != nil {
457475
return nil, err
458476
}
477+
478+
err = applyOptionsToRequest(ctx, req, h.opt)
479+
if err != nil {
480+
return nil, err
481+
}
482+
459483
for k, v := range headers {
460484
req.Header.Add(k, v)
461485
}
486+
462487
var query url.Values
463488
if options != nil {
464489
query = req.URL.Query()

conn_http_async_insert.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func (h *httpConnect) asyncInsert(ctx context.Context, query string, wait bool,
3838
}
3939
}
4040

41-
res, err := h.sendQuery(ctx, query, &options, h.headers)
41+
res, err := h.sendQuery(ctx, query, &options, nil)
4242
if res != nil {
4343
defer res.Body.Close()
4444
// we don't care about result, so just discard it to reuse connection

conn_http_batch.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,7 @@ func (b *httpBatch) Send() (err error) {
207207

208208
options.settings["query"] = b.query
209209
headers["Content-Type"] = "application/octet-stream"
210-
for k, v := range b.conn.headers {
211-
headers[k] = v
212-
}
210+
213211
res, err := b.conn.sendStreamQuery(b.ctx, r, &options, headers)
214212

215213
if res != nil {

conn_http_exec.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func (h *httpConnect) exec(ctx context.Context, query string, args ...any) error
2929
return err
3030
}
3131

32-
res, err := h.sendQuery(ctx, query, &options, h.headers)
32+
res, err := h.sendQuery(ctx, query, &options, nil)
3333
if res != nil {
3434
defer res.Body.Close()
3535
// we don't care about result, so just discard it to reuse connection

conn_http_query.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ func (h *httpConnect) query(ctx context.Context, release func(*connect, error),
4242
headers["Accept-Encoding"] = h.compression.String()
4343
}
4444

45-
for k, v := range h.headers {
46-
headers[k] = v
47-
}
48-
4945
res, err := h.sendQuery(ctx, query, &options, headers)
5046
if err != nil {
5147
return nil, err

context.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ type (
5353
async AsyncOptions
5454
queryID string
5555
quotaKey string
56+
jwt string
5657
events struct {
5758
logs func(*Log)
5859
progress func(*Progress)
@@ -95,6 +96,15 @@ func WithQuotaKey(quotaKey string) QueryOption {
9596
}
9697
}
9798

99+
// WithJWT overrides the existing authentication with the given JWT.
100+
// This only applies for clients connected with HTTPS to ClickHouse Cloud.
101+
func WithJWT(jwt string) QueryOption {
102+
return func(o *QueryOptions) error {
103+
o.jwt = jwt
104+
return nil
105+
}
106+
}
107+
98108
func WithSettings(settings Settings) QueryOption {
99109
return func(o *QueryOptions) error {
100110
o.settings = settings
@@ -211,6 +221,16 @@ func queryOptions(ctx context.Context) QueryOptions {
211221
return opt
212222
}
213223

224+
// queryOptionsJWT returns the JWT within the given context's QueryOptions.
225+
// Empty string if not present.
226+
func queryOptionsJWT(ctx context.Context) string {
227+
if opt, ok := ctx.Value(_contextOptionKey).(QueryOptions); ok {
228+
return opt.jwt
229+
}
230+
231+
return ""
232+
}
233+
214234
// queryOptionsAsync returns the AsyncOptions struct within the given context's QueryOptions.
215235
func queryOptionsAsync(ctx context.Context) AsyncOptions {
216236
if opt, ok := ctx.Value(_contextOptionKey).(QueryOptions); ok {

0 commit comments

Comments
 (0)