Skip to content

Commit e6186a3

Browse files
committed
added ability to set custom ProtocolRevision in Options.
1 parent cd06326 commit e6186a3

File tree

5 files changed

+95
-6
lines changed

5 files changed

+95
-6
lines changed

clickhouse_options.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ type Options struct {
164164

165165
scheme string
166166
ReadTimeout time.Duration
167+
168+
// ProtocolRevision specifies custom protocol revision, from lib/proto/const.go
169+
// if nil then used latest supported protocol revision ClientTCPProtocolVersion
170+
ProtocolRevision uint64
167171
}
168172

169173
func (o *Options) fromDSN(in string) error {

conn.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
9191
compressor = compress.NewWriter(compress.LevelZero, compress.None)
9292
}
9393

94+
var revision uint64 = ClientTCPProtocolVersion
95+
if opt.ProtocolRevision != 0 {
96+
revision = opt.ProtocolRevision
97+
if revision < proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO || revision > ClientTCPProtocolVersion {
98+
return nil, fmt.Errorf("unsupported protocol revision")
99+
}
100+
}
101+
94102
var (
95103
connect = &connect{
96104
id: num,
@@ -99,7 +107,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
99107
debugf: debugf,
100108
buffer: new(chproto.Buffer),
101109
reader: chproto.NewReader(conn),
102-
revision: ClientTCPProtocolVersion,
110+
revision: revision,
103111
structMap: &structMap{},
104112
compression: compression,
105113
connectedAt: time.Now(),

conn_handshake.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func (c *connect) handshake(auth Auth) error {
6060
case proto.ServerException:
6161
return c.exception()
6262
case proto.ServerHello:
63-
if err := c.server.Decode(c.reader); err != nil {
63+
if err := c.server.Decode(c.reader, c.revision); err != nil {
6464
return err
6565
}
6666
case proto.ServerEndOfStream:

lib/proto/handshake.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func CheckMinVersion(constraint Version, version Version) bool {
8585
return true
8686
}
8787

88-
func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
88+
func (srv *ServerHandshake) Decode(reader *chproto.Reader, maxRevision uint64) (err error) {
8989
if srv.Name, err = reader.Str(); err != nil {
9090
return fmt.Errorf("could not read server name: %v", err)
9191
}
@@ -98,7 +98,8 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
9898
if srv.Revision, err = reader.UVarInt(); err != nil {
9999
return fmt.Errorf("could not read server revision: %v", err)
100100
}
101-
if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE {
101+
rev := min(maxRevision, srv.Revision)
102+
if rev >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE {
102103
name, err := reader.Str()
103104
if err != nil {
104105
return fmt.Errorf("could not read server timezone: %v", err)
@@ -107,12 +108,12 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
107108
return fmt.Errorf("could not load time location: %v", err)
108109
}
109110
}
110-
if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME {
111+
if rev >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME {
111112
if srv.DisplayName, err = reader.Str(); err != nil {
112113
return fmt.Errorf("could not read server display name: %v", err)
113114
}
114115
}
115-
if srv.Revision >= DBMS_MIN_REVISION_WITH_VERSION_PATCH {
116+
if rev >= DBMS_MIN_REVISION_WITH_VERSION_PATCH {
116117
if srv.Version.Patch, err = reader.UVarInt(); err != nil {
117118
return fmt.Errorf("could not read server patch: %v", err)
118119
}

tests/std/conn_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import (
3131
"time"
3232

3333
"github.com/ClickHouse/clickhouse-go/v2"
34+
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
35+
"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
3436
clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests"
3537
"github.com/stretchr/testify/assert"
3638
"github.com/stretchr/testify/require"
@@ -231,6 +233,80 @@ func TestStdConnector(t *testing.T) {
231233
require.NoError(t, err)
232234
}
233235

236+
func TestCustomProtocolRevision(t *testing.T) {
237+
env, err := GetStdTestEnvironment()
238+
require.NoError(t, err)
239+
useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false"))
240+
require.NoError(t, err)
241+
port := env.Port
242+
var tlsConfig *tls.Config
243+
if useSSL {
244+
port = env.SslPort
245+
tlsConfig = &tls.Config{}
246+
}
247+
baseOpts := clickhouse.Options{
248+
Addr: []string{fmt.Sprintf("%s:%d", env.Host, port)},
249+
Auth: clickhouse.Auth{
250+
Database: "default",
251+
Username: env.Username,
252+
Password: env.Password,
253+
},
254+
Compression: &clickhouse.Compression{
255+
Method: clickhouse.CompressionLZ4,
256+
},
257+
TLS: tlsConfig,
258+
}
259+
t.Run("unsupported proto versions", func(t *testing.T) {
260+
badOpts := baseOpts
261+
badOpts.ProtocolRevision = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO - 1
262+
conn, _ := clickhouse.Open(&badOpts)
263+
require.NotNil(t, conn)
264+
err = conn.Ping(t.Context())
265+
require.Error(t, err)
266+
badOpts.ProtocolRevision = proto.DBMS_TCP_PROTOCOL_VERSION + 1
267+
conn, _ = clickhouse.Open(&badOpts)
268+
require.NotNil(t, conn)
269+
err = conn.Ping(t.Context())
270+
require.Error(t, err)
271+
})
272+
273+
t.Run("minimal proto version", func(t *testing.T) {
274+
opts := baseOpts
275+
opts.ProtocolRevision = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO
276+
conn, err := clickhouse.Open(&opts)
277+
require.NoError(t, err)
278+
require.NotNil(t, conn)
279+
err = conn.Ping(t.Context())
280+
require.NoError(t, err)
281+
282+
defer func() {
283+
_ = conn.Exec(t.Context(), "DROP TABLE insert_example")
284+
}()
285+
err = conn.Exec(t.Context(), "DROP TABLE IF EXISTS insert_example")
286+
287+
err = conn.Exec(t.Context(), `
288+
CREATE TABLE insert_example (
289+
Col1 UInt64
290+
) Engine = MergeTree() ORDER BY tuple()
291+
`)
292+
require.NoError(t, err)
293+
var batch driver.Batch
294+
batch, err = conn.PrepareBatch(t.Context(), "INSERT INTO insert_example (Col1)")
295+
require.NoError(t, err)
296+
require.NoError(t, batch.Append(10))
297+
require.NoError(t, batch.Send())
298+
299+
rows, err := conn.Query(t.Context(), "SELECT Col1 FROM insert_example")
300+
require.NoError(t, err)
301+
count := 0
302+
for rows.Next() {
303+
count++
304+
}
305+
assert.Equal(t, 1, count)
306+
})
307+
308+
}
309+
234310
func TestBlockBufferSize(t *testing.T) {
235311
env, err := GetStdTestEnvironment()
236312
require.NoError(t, err)

0 commit comments

Comments
 (0)