Skip to content

Commit 6bf02c2

Browse files
authored
Merge pull request #3005 from dolthub/angela/system_variables
Added support for `@@port`, `@@hostname`, and other system variables
2 parents 628730f + 9590ef7 commit 6bf02c2

File tree

7 files changed

+145
-20
lines changed

7 files changed

+145
-20
lines changed

enginetest/queries/variable_queries.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@ import (
2222
)
2323

2424
var VariableQueries = []ScriptTest{
25-
{
26-
Name: "use string name for foreign_key checks",
27-
SetUpScript: []string{},
28-
Query: "select @@GLOBAL.unknown",
29-
ExpectedErr: sql.ErrUnknownSystemVariable,
30-
},
3125
{
3226
Name: "use string name for foreign_key checks",
3327
SetUpScript: []string{},
@@ -649,6 +643,10 @@ var VariableQueries = []ScriptTest{
649643
}
650644

651645
var VariableErrorTests = []QueryErrorTest{
646+
{
647+
Query: "select @@GLOBAL.unknown",
648+
ExpectedErr: sql.ErrUnknownSystemVariable,
649+
},
652650
{
653651
Query: "set @@does_not_exist = 100",
654652
ExpectedErr: sql.ErrUnknownSystemVariable,

enginetest/server_engine_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"math"
88
"net"
9+
"os"
910
"testing"
1011

1112
"github.com/dolthub/vitess/go/mysql"
@@ -375,3 +376,92 @@ func TestServerPreparedStatements(t *testing.T) {
375376
})
376377
}
377378
}
379+
380+
func TestServerVariables(t *testing.T) {
381+
hostname, herr := os.Hostname()
382+
require.NoError(t, herr)
383+
384+
port, perr := findEmptyPort()
385+
require.NoError(t, perr)
386+
387+
s, serr := initTestServer(port)
388+
require.NoError(t, serr)
389+
390+
go s.Start()
391+
defer s.Close()
392+
393+
tests := []serverScriptTest{
394+
{
395+
name: "test that config system variables are properly set",
396+
setup: []string{},
397+
assertions: []serverScriptTestAssertion{
398+
{
399+
query: "select @@hostname, @@port, @@net_read_timeout, @@net_write_timeout",
400+
isExec: false,
401+
expectedRows: []any{
402+
sql.Row{hostname, port, 1, 1},
403+
},
404+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
405+
var resHostname string
406+
var resPort int
407+
var resNetReadTimeout int
408+
var resNetWriteTimeout int
409+
var rowNum int
410+
for rows.Next() {
411+
if err := rows.Scan(&resHostname, &resPort, &resNetReadTimeout, &resNetWriteTimeout); err != nil {
412+
return false, err
413+
}
414+
if rowNum >= len(expectedRows) {
415+
return false, nil
416+
}
417+
expectedRow := expectedRows[rowNum].(sql.Row)
418+
require.Equal(t, expectedRow[0].(string), resHostname)
419+
require.Equal(t, expectedRow[1].(int), resPort)
420+
}
421+
return true, nil
422+
},
423+
},
424+
},
425+
},
426+
}
427+
428+
for _, test := range tests {
429+
t.Run(test.name, func(t *testing.T) {
430+
conn, cerr := dbr.Open("mysql", fmt.Sprintf(noUserFmt, address, port), nil)
431+
require.NoError(t, cerr)
432+
defer conn.Close()
433+
commonSetup := []string{
434+
"create database test_db;",
435+
"use test_db;",
436+
}
437+
commonTeardown := []string{
438+
"drop database test_db",
439+
}
440+
for _, stmt := range append(commonSetup, test.setup...) {
441+
_, err := conn.Exec(stmt)
442+
require.NoError(t, err)
443+
}
444+
for _, assertion := range test.assertions {
445+
t.Run(assertion.query, func(t *testing.T) {
446+
if assertion.skip {
447+
t.Skip()
448+
}
449+
rows, err := conn.Query(assertion.query, assertion.args...)
450+
if assertion.expectErr {
451+
require.Error(t, err)
452+
return
453+
}
454+
require.NoError(t, err)
455+
456+
ok, err := assertion.checkRows(t, rows, assertion.expectedRows)
457+
require.NoError(t, err)
458+
require.True(t, ok)
459+
})
460+
}
461+
for _, stmt := range append(commonTeardown) {
462+
_, err := conn.Exec(stmt)
463+
require.NoError(t, err)
464+
}
465+
})
466+
}
467+
}

server/server.go

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"errors"
1919
"fmt"
2020
"net"
21+
"strconv"
2122
"time"
2223

2324
"github.com/dolthub/vitess/go/mysql"
@@ -122,17 +123,42 @@ func portInUse(hostPort string) bool {
122123
return false
123124
}
124125

125-
func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) {
126-
if cfg.ConnReadTimeout < 0 {
127-
cfg.ConnReadTimeout = 0
126+
func getPort(cfg mysql.ListenerConfig) (int64, error) {
127+
_, port, err := net.SplitHostPort(cfg.Listener.Addr().String())
128+
if err != nil {
129+
return 0, err
130+
}
131+
portInt, err := strconv.ParseInt(port, 10, 64)
132+
if err != nil {
133+
return 0, err
134+
}
135+
return portInt, nil
136+
}
137+
138+
func updateSystemVariables(cfg mysql.ListenerConfig) error {
139+
sysVars := make(map[string]interface{})
140+
141+
if port, err := getPort(cfg); err == nil {
142+
sysVars["port"] = port
128143
}
129-
if cfg.ConnWriteTimeout < 0 {
130-
cfg.ConnWriteTimeout = 0
144+
145+
oneSecond := time.Duration(1) * time.Second
146+
if cfg.ConnReadTimeout >= oneSecond {
147+
sysVars["net_read_timeout"] = cfg.ConnReadTimeout.Seconds()
131148
}
132-
if cfg.MaxConnections < 0 {
133-
cfg.MaxConnections = 0
149+
if cfg.ConnWriteTimeout >= oneSecond {
150+
sysVars["net_write_timeout"] = cfg.ConnWriteTimeout.Seconds()
134151
}
135152

153+
// TODO: add the rest of the config variables
154+
err := sql.SystemVariables.AssignValues(sysVars)
155+
if err != nil {
156+
return err
157+
}
158+
return nil
159+
}
160+
161+
func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) {
136162
for _, opt := range cfg.Options {
137163
e, sm, handler = opt(e, sm, handler)
138164
}
@@ -176,6 +202,11 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
176202
return nil, err
177203
}
178204

205+
err = updateSystemVariables(listenerCfg)
206+
if err != nil {
207+
return nil, err
208+
}
209+
179210
return &Server{
180211
Listener: protocolListener,
181212
handler: handler,

server/server_config.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ func (c Config) NewConfig() (Config, error) {
116116
if !ok {
117117
return Config{}, sql.ErrUnknownSystemVariable.New("net_write_timeout")
118118
}
119-
c.ConnWriteTimeout = time.Duration(timeout) * time.Millisecond
119+
c.ConnWriteTimeout = time.Duration(timeout) * time.Second
120120
}
121121
if _, val, ok := sql.SystemVariables.GetGlobal("net_read_timeout"); ok {
122122
timeout, ok := val.(int64)
123123
if !ok {
124124
return Config{}, sql.ErrUnknownSystemVariable.New("net_read_timeout")
125125
}
126-
c.ConnReadTimeout = time.Duration(timeout) * time.Millisecond
126+
c.ConnReadTimeout = time.Duration(timeout) * time.Second
127127
}
128128
return c, nil
129129
}

server/server_config_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ func TestConfigWithDefaults(t *testing.T) {
4949
Type: types.NewSystemIntType("net_write_timeout", 1, 9223372036854775807, false),
5050
ConfigField: "ConnWriteTimeout",
5151
Default: int64(76),
52-
ExpectedCmp: int64(76000000),
52+
ExpectedCmp: int64(76000000000),
5353
}, {
5454
Name: "net_read_timeout",
5555
Scope: sql.SystemVariableScope_Both,
5656
Type: types.NewSystemIntType("net_read_timeout", 1, 9223372036854775807, false),
5757
ConfigField: "ConnReadTimeout",
5858
Default: int64(67),
59-
ExpectedCmp: int64(67000000),
59+
ExpectedCmp: int64(67000000000),
6060
},
6161
}
6262

server/server_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ import (
1818
gsql "github.com/dolthub/go-mysql-server/sql"
1919
)
2020

21-
// TestSeverCustomListener verifies a caller can provide their own net.Conn implementation for the server to use
22-
func TestSeverCustomListener(t *testing.T) {
21+
// TestServerCustomListener verifies a caller can provide their own net.Conn implementation for the server to use
22+
func TestServerCustomListener(t *testing.T) {
2323
dbName := "mydb"
2424
// create a net.Conn thats based on a golang buffer
2525
buffer := 1024

sql/variables/system_variables.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package variables
1717
import (
1818
"fmt"
1919
"math"
20+
"os"
2021
"strings"
2122
"sync"
2223
"time"
@@ -186,6 +187,11 @@ func init() {
186187
InitSystemVariables()
187188
}
188189

190+
func getHostname() string {
191+
hostname, _ := os.Hostname()
192+
return hostname
193+
}
194+
189195
// systemVars is the internal collection of all MySQL system variables according to the following pages:
190196
// https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html
191197
// https://dev.mysql.com/doc/refman/8.0/en/replication-options-gtids.html
@@ -1008,7 +1014,7 @@ var systemVars = map[string]sql.SystemVariable{
10081014
Dynamic: false,
10091015
SetVarHintApplies: false,
10101016
Type: types.NewSystemStringType("hostname"),
1011-
Default: "",
1017+
Default: getHostname(),
10121018
},
10131019
"immediate_server_version": &sql.MysqlSystemVariable{
10141020
Name: "immediate_server_version",

0 commit comments

Comments
 (0)