Skip to content

Commit 7a195fb

Browse files
committed
added ability to set custom ProtocolRevision in Options.
1 parent d301c86 commit 7a195fb

File tree

5 files changed

+95
-6
lines changed

5 files changed

+95
-6
lines changed

clickhouse_options.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ type Options struct {
158158

159159
scheme string
160160
ReadTimeout time.Duration
161+
162+
// ClientTCPProtocolVersion specifies the custom protocol revision, as defined in lib/proto/const.go
163+
// if not specified, the latest supported protocol revision, proto.DBMS_TCP_PROTOCOL_VERSION , is used.
164+
ClientTCPProtocolVersion uint64
161165
}
162166

163167
func (o *Options) fromDSN(in string) error {
@@ -391,5 +395,9 @@ func (o Options) setDefaults() *Options {
391395
o.Addr = []string{"localhost:8123"}
392396
}
393397
}
398+
if o.ClientTCPProtocolVersion == 0 {
399+
o.ClientTCPProtocolVersion = ClientTCPProtocolVersion
400+
}
401+
394402
return &o
395403
}

conn.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
9191
compressor = compress.NewWriter(compress.LevelZero, compress.None)
9292
}
9393

94+
if opt.ClientTCPProtocolVersion < proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO || opt.ClientTCPProtocolVersion > proto.DBMS_TCP_PROTOCOL_VERSION {
95+
return nil, fmt.Errorf("unsupported protocol revision")
96+
}
97+
9498
var (
9599
connect = &connect{
96100
id: num,
@@ -99,7 +103,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
99103
debugf: debugf,
100104
buffer: new(chproto.Buffer),
101105
reader: chproto.NewReader(conn),
102-
revision: ClientTCPProtocolVersion,
106+
revision: opt.ClientTCPProtocolVersion,
103107
structMap: &structMap{},
104108
compression: compression,
105109
connectedAt: time.Now(),

conn_handshake.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func (c *connect) handshake(database, username, password string) error {
6060
case proto.ServerException:
6161
return c.exception()
6262
case proto.ServerHello:
63-
if err := c.server.Decode(c.reader); err != nil {
63+
if err := c.server.Decode(c.reader, c.revision); err != nil {
6464
return err
6565
}
6666
case proto.ServerEndOfStream:

lib/proto/handshake.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func CheckMinVersion(constraint Version, version Version) bool {
8585
return true
8686
}
8787

88-
func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
88+
func (srv *ServerHandshake) Decode(reader *chproto.Reader, clientRevision uint64) (err error) {
8989
if srv.Name, err = reader.Str(); err != nil {
9090
return fmt.Errorf("could not read server name: %v", err)
9191
}
@@ -98,7 +98,8 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
9898
if srv.Revision, err = reader.UVarInt(); err != nil {
9999
return fmt.Errorf("could not read server revision: %v", err)
100100
}
101-
if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE {
101+
rev := min(clientRevision, srv.Revision)
102+
if rev >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE {
102103
name, err := reader.Str()
103104
if err != nil {
104105
return fmt.Errorf("could not read server timezone: %v", err)
@@ -107,12 +108,12 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
107108
return fmt.Errorf("could not load time location: %v", err)
108109
}
109110
}
110-
if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME {
111+
if rev >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME {
111112
if srv.DisplayName, err = reader.Str(); err != nil {
112113
return fmt.Errorf("could not read server display name: %v", err)
113114
}
114115
}
115-
if srv.Revision >= DBMS_MIN_REVISION_WITH_VERSION_PATCH {
116+
if rev >= DBMS_MIN_REVISION_WITH_VERSION_PATCH {
116117
if srv.Version.Patch, err = reader.UVarInt(); err != nil {
117118
return fmt.Errorf("could not read server patch: %v", err)
118119
}

tests/std/conn_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import (
3131
"time"
3232

3333
"github.com/ClickHouse/clickhouse-go/v2"
34+
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
35+
"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
3436
clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests"
3537
"github.com/stretchr/testify/assert"
3638
"github.com/stretchr/testify/require"
@@ -231,6 +233,80 @@ func TestStdConnector(t *testing.T) {
231233
require.NoError(t, err)
232234
}
233235

236+
func TestCustomProtocolRevision(t *testing.T) {
237+
env, err := GetStdTestEnvironment()
238+
require.NoError(t, err)
239+
useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false"))
240+
require.NoError(t, err)
241+
port := env.Port
242+
var tlsConfig *tls.Config
243+
if useSSL {
244+
port = env.SslPort
245+
tlsConfig = &tls.Config{}
246+
}
247+
baseOpts := clickhouse.Options{
248+
Addr: []string{fmt.Sprintf("%s:%d", env.Host, port)},
249+
Auth: clickhouse.Auth{
250+
Database: "default",
251+
Username: env.Username,
252+
Password: env.Password,
253+
},
254+
Compression: &clickhouse.Compression{
255+
Method: clickhouse.CompressionLZ4,
256+
},
257+
TLS: tlsConfig,
258+
}
259+
t.Run("unsupported proto versions", func(t *testing.T) {
260+
badOpts := baseOpts
261+
badOpts.ClientTCPProtocolVersion = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO - 1
262+
conn, _ := clickhouse.Open(&badOpts)
263+
require.NotNil(t, conn)
264+
err = conn.Ping(t.Context())
265+
require.Error(t, err)
266+
badOpts.ClientTCPProtocolVersion = proto.DBMS_TCP_PROTOCOL_VERSION + 1
267+
conn, _ = clickhouse.Open(&badOpts)
268+
require.NotNil(t, conn)
269+
err = conn.Ping(t.Context())
270+
require.Error(t, err)
271+
})
272+
273+
t.Run("minimal proto version", func(t *testing.T) {
274+
opts := baseOpts
275+
opts.ClientTCPProtocolVersion = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO
276+
conn, err := clickhouse.Open(&opts)
277+
require.NoError(t, err)
278+
require.NotNil(t, conn)
279+
err = conn.Ping(t.Context())
280+
require.NoError(t, err)
281+
282+
defer func() {
283+
_ = conn.Exec(t.Context(), "DROP TABLE insert_example")
284+
}()
285+
err = conn.Exec(t.Context(), "DROP TABLE IF EXISTS insert_example")
286+
287+
err = conn.Exec(t.Context(), `
288+
CREATE TABLE insert_example (
289+
Col1 UInt64
290+
) Engine = MergeTree() ORDER BY tuple()
291+
`)
292+
require.NoError(t, err)
293+
var batch driver.Batch
294+
batch, err = conn.PrepareBatch(t.Context(), "INSERT INTO insert_example (Col1)")
295+
require.NoError(t, err)
296+
require.NoError(t, batch.Append(10))
297+
require.NoError(t, batch.Send())
298+
299+
rows, err := conn.Query(t.Context(), "SELECT Col1 FROM insert_example")
300+
require.NoError(t, err)
301+
count := 0
302+
for rows.Next() {
303+
count++
304+
}
305+
assert.Equal(t, 1, count)
306+
})
307+
308+
}
309+
234310
func TestBlockBufferSize(t *testing.T) {
235311
env, err := GetStdTestEnvironment()
236312
require.NoError(t, err)

0 commit comments

Comments
 (0)