Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions clickhouse_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ type Options struct {

scheme string
ReadTimeout time.Duration

// ClientTCPProtocolVersion specifies the custom protocol revision, as defined in lib/proto/const.go
// if not specified, the latest supported protocol revision, proto.DBMS_TCP_PROTOCOL_VERSION , is used.
ClientTCPProtocolVersion uint64
}

func (o *Options) fromDSN(in string) error {
Expand Down Expand Up @@ -391,5 +395,9 @@ func (o Options) setDefaults() *Options {
o.Addr = []string{"localhost:8123"}
}
}
if o.ClientTCPProtocolVersion == 0 {
o.ClientTCPProtocolVersion = ClientTCPProtocolVersion
}

return &o
}
6 changes: 5 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
compressor = compress.NewWriter(compress.LevelZero, compress.None)
}

if opt.ClientTCPProtocolVersion < proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO || opt.ClientTCPProtocolVersion > proto.DBMS_TCP_PROTOCOL_VERSION {
return nil, fmt.Errorf("unsupported protocol revision")
}

var (
connect = &connect{
id: num,
Expand All @@ -99,7 +103,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
debugf: debugf,
buffer: new(chproto.Buffer),
reader: chproto.NewReader(conn),
revision: ClientTCPProtocolVersion,
revision: opt.ClientTCPProtocolVersion,
structMap: &structMap{},
compression: compression,
connectedAt: time.Now(),
Expand Down
4 changes: 2 additions & 2 deletions conn_handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (c *connect) handshake(database, username, password string) error {
{
c.buffer.PutByte(proto.ClientHello)
handshake := &proto.ClientHandshake{
ProtocolVersion: ClientTCPProtocolVersion,
ProtocolVersion: c.revision,
ClientName: c.opt.ClientInfo.String(),
ClientVersion: proto.Version{ClientVersionMajor, ClientVersionMinor, ClientVersionPatch}, //nolint:govet
}
Expand All @@ -60,7 +60,7 @@ func (c *connect) handshake(database, username, password string) error {
case proto.ServerException:
return c.exception()
case proto.ServerHello:
if err := c.server.Decode(c.reader); err != nil {
if err := c.server.Decode(c.reader, c.revision); err != nil {
return err
}
case proto.ServerEndOfStream:
Expand Down
2 changes: 1 addition & 1 deletion conn_send_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (c *connect) sendQuery(body string, o *QueryOptions) error {
c.debugf("[send query] compression=%q %s", c.compression, body)
c.buffer.PutByte(proto.ClientQuery)
q := proto.Query{
ClientTCPProtocolVersion: ClientTCPProtocolVersion,
ClientTCPProtocolVersion: c.revision,
ClientName: c.opt.ClientInfo.String(),
ClientVersion: proto.Version{ClientVersionMajor, ClientVersionMinor, ClientVersionPatch}, //nolint:govet
ID: o.queryID,
Expand Down
9 changes: 5 additions & 4 deletions lib/proto/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func CheckMinVersion(constraint Version, version Version) bool {
return true
}

func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
func (srv *ServerHandshake) Decode(reader *chproto.Reader, clientRevision uint64) (err error) {
if srv.Name, err = reader.Str(); err != nil {
return fmt.Errorf("could not read server name: %v", err)
}
Expand All @@ -98,7 +98,8 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
if srv.Revision, err = reader.UVarInt(); err != nil {
return fmt.Errorf("could not read server revision: %v", err)
}
if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE {
rev := min(clientRevision, srv.Revision)
if rev >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE {
name, err := reader.Str()
if err != nil {
return fmt.Errorf("could not read server timezone: %v", err)
Expand All @@ -107,12 +108,12 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
return fmt.Errorf("could not load time location: %v", err)
}
}
if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME {
if rev >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME {
if srv.DisplayName, err = reader.Str(); err != nil {
return fmt.Errorf("could not read server display name: %v", err)
}
}
if srv.Revision >= DBMS_MIN_REVISION_WITH_VERSION_PATCH {
if rev >= DBMS_MIN_REVISION_WITH_VERSION_PATCH {
if srv.Version.Patch, err = reader.UVarInt(); err != nil {
return fmt.Errorf("could not read server patch: %v", err)
}
Expand Down
76 changes: 76 additions & 0 deletions tests/std/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import (
"time"

"github.com/ClickHouse/clickhouse-go/v2"
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -231,6 +233,80 @@ func TestStdConnector(t *testing.T) {
require.NoError(t, err)
}

func TestCustomProtocolRevision(t *testing.T) {
env, err := GetStdTestEnvironment()
require.NoError(t, err)
useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false"))
require.NoError(t, err)
port := env.Port
var tlsConfig *tls.Config
if useSSL {
port = env.SslPort
tlsConfig = &tls.Config{}
}
baseOpts := clickhouse.Options{
Addr: []string{fmt.Sprintf("%s:%d", env.Host, port)},
Auth: clickhouse.Auth{
Database: "default",
Username: env.Username,
Password: env.Password,
},
Compression: &clickhouse.Compression{
Method: clickhouse.CompressionLZ4,
},
TLS: tlsConfig,
}
t.Run("unsupported proto versions", func(t *testing.T) {
badOpts := baseOpts
badOpts.ClientTCPProtocolVersion = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO - 1
conn, _ := clickhouse.Open(&badOpts)
require.NotNil(t, conn)
err = conn.Ping(t.Context())
require.Error(t, err)
badOpts.ClientTCPProtocolVersion = proto.DBMS_TCP_PROTOCOL_VERSION + 1
conn, _ = clickhouse.Open(&badOpts)
require.NotNil(t, conn)
err = conn.Ping(t.Context())
require.Error(t, err)
})

t.Run("minimal proto version", func(t *testing.T) {
opts := baseOpts
opts.ClientTCPProtocolVersion = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO
conn, err := clickhouse.Open(&opts)
require.NoError(t, err)
require.NotNil(t, conn)
err = conn.Ping(t.Context())
require.NoError(t, err)

defer func() {
_ = conn.Exec(t.Context(), "DROP TABLE insert_example")
}()
err = conn.Exec(t.Context(), "DROP TABLE IF EXISTS insert_example")

err = conn.Exec(t.Context(), `
CREATE TABLE insert_example (
Col1 UInt64
) Engine = MergeTree() ORDER BY tuple()
`)
require.NoError(t, err)
var batch driver.Batch
batch, err = conn.PrepareBatch(t.Context(), "INSERT INTO insert_example (Col1)")
require.NoError(t, err)
require.NoError(t, batch.Append(10))
require.NoError(t, batch.Send())

rows, err := conn.Query(t.Context(), "SELECT Col1 FROM insert_example")
require.NoError(t, err)
count := 0
for rows.Next() {
count++
}
assert.Equal(t, 1, count)
})

}

func TestBlockBufferSize(t *testing.T) {
env, err := GetStdTestEnvironment()
require.NoError(t, err)
Expand Down
Loading