From d301c86d2fa0e66b93919eb560769aeea10156e2 Mon Sep 17 00:00:00 2001 From: titanproger Date: Wed, 7 May 2025 18:24:58 +0300 Subject: [PATCH 1/2] use one tcp protocol revision, remove multiple usage of ClientTCPProtocolVersion constant. --- conn_handshake.go | 2 +- conn_send_query.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/conn_handshake.go b/conn_handshake.go index 265c52c2f0..ed41cc5281 100644 --- a/conn_handshake.go +++ b/conn_handshake.go @@ -37,7 +37,7 @@ func (c *connect) handshake(database, username, password string) error { { c.buffer.PutByte(proto.ClientHello) handshake := &proto.ClientHandshake{ - ProtocolVersion: ClientTCPProtocolVersion, + ProtocolVersion: c.revision, ClientName: c.opt.ClientInfo.String(), ClientVersion: proto.Version{ClientVersionMajor, ClientVersionMinor, ClientVersionPatch}, //nolint:govet } diff --git a/conn_send_query.go b/conn_send_query.go index 8897a8c768..dc05305781 100644 --- a/conn_send_query.go +++ b/conn_send_query.go @@ -27,7 +27,7 @@ func (c *connect) sendQuery(body string, o *QueryOptions) error { c.debugf("[send query] compression=%q %s", c.compression, body) c.buffer.PutByte(proto.ClientQuery) q := proto.Query{ - ClientTCPProtocolVersion: ClientTCPProtocolVersion, + ClientTCPProtocolVersion: c.revision, ClientName: c.opt.ClientInfo.String(), ClientVersion: proto.Version{ClientVersionMajor, ClientVersionMinor, ClientVersionPatch}, //nolint:govet ID: o.queryID, From 7a195fb6c98ace0f5c2480557e43359ea66fd6f0 Mon Sep 17 00:00:00 2001 From: Evgeny Pronin Date: Wed, 7 May 2025 18:43:52 +0300 Subject: [PATCH 2/2] added ability to set custom ProtocolRevision in Options. --- clickhouse_options.go | 8 +++++ conn.go | 6 +++- conn_handshake.go | 2 +- lib/proto/handshake.go | 9 ++--- tests/std/conn_test.go | 76 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 95 insertions(+), 6 deletions(-) diff --git a/clickhouse_options.go b/clickhouse_options.go index fe15e6faf1..619ccc332b 100644 --- a/clickhouse_options.go +++ b/clickhouse_options.go @@ -158,6 +158,10 @@ type Options struct { scheme string ReadTimeout time.Duration + + // ClientTCPProtocolVersion specifies the custom protocol revision, as defined in lib/proto/const.go + // if not specified, the latest supported protocol revision, proto.DBMS_TCP_PROTOCOL_VERSION , is used. + ClientTCPProtocolVersion uint64 } func (o *Options) fromDSN(in string) error { @@ -391,5 +395,9 @@ func (o Options) setDefaults() *Options { o.Addr = []string{"localhost:8123"} } } + if o.ClientTCPProtocolVersion == 0 { + o.ClientTCPProtocolVersion = ClientTCPProtocolVersion + } + return &o } diff --git a/conn.go b/conn.go index 9dee9fc302..664706b4e2 100644 --- a/conn.go +++ b/conn.go @@ -91,6 +91,10 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er compressor = compress.NewWriter(compress.LevelZero, compress.None) } + if opt.ClientTCPProtocolVersion < proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO || opt.ClientTCPProtocolVersion > proto.DBMS_TCP_PROTOCOL_VERSION { + return nil, fmt.Errorf("unsupported protocol revision") + } + var ( connect = &connect{ id: num, @@ -99,7 +103,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er debugf: debugf, buffer: new(chproto.Buffer), reader: chproto.NewReader(conn), - revision: ClientTCPProtocolVersion, + revision: opt.ClientTCPProtocolVersion, structMap: &structMap{}, compression: compression, connectedAt: time.Now(), diff --git a/conn_handshake.go b/conn_handshake.go index ed41cc5281..e8f0ff7e1b 100644 --- a/conn_handshake.go +++ b/conn_handshake.go @@ -60,7 +60,7 @@ func (c *connect) handshake(database, username, password string) error { case proto.ServerException: return c.exception() case proto.ServerHello: - if err := c.server.Decode(c.reader); err != nil { + if err := c.server.Decode(c.reader, c.revision); err != nil { return err } case proto.ServerEndOfStream: diff --git a/lib/proto/handshake.go b/lib/proto/handshake.go index 6ee620905c..c880d565a6 100644 --- a/lib/proto/handshake.go +++ b/lib/proto/handshake.go @@ -85,7 +85,7 @@ func CheckMinVersion(constraint Version, version Version) bool { return true } -func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) { +func (srv *ServerHandshake) Decode(reader *chproto.Reader, clientRevision uint64) (err error) { if srv.Name, err = reader.Str(); err != nil { return fmt.Errorf("could not read server name: %v", err) } @@ -98,7 +98,8 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) { if srv.Revision, err = reader.UVarInt(); err != nil { return fmt.Errorf("could not read server revision: %v", err) } - if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE { + rev := min(clientRevision, srv.Revision) + if rev >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE { name, err := reader.Str() if err != nil { return fmt.Errorf("could not read server timezone: %v", err) @@ -107,12 +108,12 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) { return fmt.Errorf("could not load time location: %v", err) } } - if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME { + if rev >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME { if srv.DisplayName, err = reader.Str(); err != nil { return fmt.Errorf("could not read server display name: %v", err) } } - if srv.Revision >= DBMS_MIN_REVISION_WITH_VERSION_PATCH { + if rev >= DBMS_MIN_REVISION_WITH_VERSION_PATCH { if srv.Version.Patch, err = reader.UVarInt(); err != nil { return fmt.Errorf("could not read server patch: %v", err) } diff --git a/tests/std/conn_test.go b/tests/std/conn_test.go index c8576e422d..cd0b90ad3f 100644 --- a/tests/std/conn_test.go +++ b/tests/std/conn_test.go @@ -31,6 +31,8 @@ import ( "time" "github.com/ClickHouse/clickhouse-go/v2" + "github.com/ClickHouse/clickhouse-go/v2/lib/driver" + "github.com/ClickHouse/clickhouse-go/v2/lib/proto" clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -231,6 +233,80 @@ func TestStdConnector(t *testing.T) { require.NoError(t, err) } +func TestCustomProtocolRevision(t *testing.T) { + env, err := GetStdTestEnvironment() + require.NoError(t, err) + useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false")) + require.NoError(t, err) + port := env.Port + var tlsConfig *tls.Config + if useSSL { + port = env.SslPort + tlsConfig = &tls.Config{} + } + baseOpts := clickhouse.Options{ + Addr: []string{fmt.Sprintf("%s:%d", env.Host, port)}, + Auth: clickhouse.Auth{ + Database: "default", + Username: env.Username, + Password: env.Password, + }, + Compression: &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }, + TLS: tlsConfig, + } + t.Run("unsupported proto versions", func(t *testing.T) { + badOpts := baseOpts + badOpts.ClientTCPProtocolVersion = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO - 1 + conn, _ := clickhouse.Open(&badOpts) + require.NotNil(t, conn) + err = conn.Ping(t.Context()) + require.Error(t, err) + badOpts.ClientTCPProtocolVersion = proto.DBMS_TCP_PROTOCOL_VERSION + 1 + conn, _ = clickhouse.Open(&badOpts) + require.NotNil(t, conn) + err = conn.Ping(t.Context()) + require.Error(t, err) + }) + + t.Run("minimal proto version", func(t *testing.T) { + opts := baseOpts + opts.ClientTCPProtocolVersion = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO + conn, err := clickhouse.Open(&opts) + require.NoError(t, err) + require.NotNil(t, conn) + err = conn.Ping(t.Context()) + require.NoError(t, err) + + defer func() { + _ = conn.Exec(t.Context(), "DROP TABLE insert_example") + }() + err = conn.Exec(t.Context(), "DROP TABLE IF EXISTS insert_example") + + err = conn.Exec(t.Context(), ` + CREATE TABLE insert_example ( + Col1 UInt64 + ) Engine = MergeTree() ORDER BY tuple() + `) + require.NoError(t, err) + var batch driver.Batch + batch, err = conn.PrepareBatch(t.Context(), "INSERT INTO insert_example (Col1)") + require.NoError(t, err) + require.NoError(t, batch.Append(10)) + require.NoError(t, batch.Send()) + + rows, err := conn.Query(t.Context(), "SELECT Col1 FROM insert_example") + require.NoError(t, err) + count := 0 + for rows.Next() { + count++ + } + assert.Equal(t, 1, count) + }) + +} + func TestBlockBufferSize(t *testing.T) { env, err := GetStdTestEnvironment() require.NoError(t, err)