Skip to content

Commit 0d23024

Browse files
authored
Merge pull request #3015 from dolthub/james/server_var_tests
server var tests
2 parents 768a3ee + 3f7ef0f commit 0d23024

File tree

2 files changed

+121
-7
lines changed

2 files changed

+121
-7
lines changed

enginetest/server_engine_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,88 @@ func TestServerPreparedStatements(t *testing.T) {
375375
})
376376
}
377377
}
378+
379+
func TestServerQueries(t *testing.T) {
380+
port, perr := findEmptyPort()
381+
require.NoError(t, perr)
382+
383+
s, serr := initTestServer(port)
384+
require.NoError(t, serr)
385+
386+
go s.Start()
387+
defer s.Close()
388+
389+
tests := []serverScriptTest{
390+
{
391+
name: "test that config variables are properly set",
392+
setup: []string{},
393+
assertions: []serverScriptTestAssertion{
394+
{
395+
query: "select @@hostname, @@port",
396+
//query: "select @@hostname, @@port, @@max_connections",
397+
isExec: false,
398+
expectedRows: []any{
399+
sql.Row{"macbook.local", port},
400+
},
401+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
402+
var resHostname string
403+
var resPort int
404+
var rowNum int
405+
for rows.Next() {
406+
if err := rows.Scan(&resHostname, &resPort); err != nil {
407+
return false, err
408+
}
409+
if rowNum >= len(expectedRows) {
410+
return false, nil
411+
}
412+
expectedRow := expectedRows[rowNum].(sql.Row)
413+
require.Equal(t, expectedRow[0].(string), resHostname)
414+
require.Equal(t, expectedRow[1].(int), resPort)
415+
}
416+
return true, nil
417+
},
418+
},
419+
},
420+
},
421+
}
422+
423+
for _, test := range tests {
424+
t.Run(test.name, func(t *testing.T) {
425+
conn, cerr := dbr.Open("mysql", fmt.Sprintf(noUserFmt, address, port), nil)
426+
require.NoError(t, cerr)
427+
defer conn.Close()
428+
commonSetup := []string{
429+
"create database test_db;",
430+
"use test_db;",
431+
}
432+
commonTeardown := []string{
433+
"drop database test_db",
434+
}
435+
for _, stmt := range append(commonSetup, test.setup...) {
436+
_, err := conn.Exec(stmt)
437+
require.NoError(t, err)
438+
}
439+
for _, assertion := range test.assertions {
440+
t.Run(assertion.query, func(t *testing.T) {
441+
if assertion.skip {
442+
t.Skip()
443+
}
444+
rows, err := conn.Query(assertion.query, assertion.args...)
445+
if assertion.expectErr {
446+
require.Error(t, err)
447+
return
448+
}
449+
require.NoError(t, err)
450+
451+
ok, err := assertion.checkRows(t, rows, assertion.expectedRows)
452+
require.NoError(t, err)
453+
require.True(t, ok)
454+
})
455+
}
456+
for _, stmt := range append(commonTeardown) {
457+
_, err := conn.Exec(stmt)
458+
require.NoError(t, err)
459+
}
460+
})
461+
}
462+
}

server/server.go

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"errors"
1919
"fmt"
2020
"net"
21+
"os"
2122
"strconv"
2223
"time"
2324

@@ -119,13 +120,38 @@ func portInUse(hostPort string) bool {
119120
return false
120121
}
121122

122-
func updateSystemVariables(cfg mysql.ListenerConfig) {
123-
_, port, _ := net.SplitHostPort(cfg.Listener.Addr().String())
124-
portInt, _ := strconv.ParseInt(port, 10, 64)
125-
sql.SystemVariables.AssignValues(map[string]interface{}{
126-
"max_connections": cfg.MaxConns,
127-
"port": portInt,
123+
func getHostname() (string, error) {
124+
hostname, err := os.Hostname()
125+
if err != nil {
126+
return "", err
127+
}
128+
return hostname, nil
129+
}
130+
131+
func updateSystemVariables(cfg mysql.ListenerConfig) error {
132+
hostname, err := getHostname()
133+
if err != nil {
134+
return err
135+
}
136+
_, port, err := net.SplitHostPort(cfg.Listener.Addr().String())
137+
if err != nil {
138+
return err
139+
}
140+
portInt, err := strconv.ParseInt(port, 10, 64)
141+
if err != nil {
142+
return err
143+
}
144+
// TODO: add the rest of the config variables
145+
err = sql.SystemVariables.AssignValues(map[string]interface{}{
146+
"port": portInt,
147+
"hostname": hostname,
148+
// TODO: this causes an error because max_connections is 0?
149+
//"max_connections": cfg.MaxConns,
128150
})
151+
if err != nil {
152+
return err
153+
}
154+
return nil
129155
}
130156

131157
func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) {
@@ -182,7 +208,10 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
182208
return nil, err
183209
}
184210

185-
updateSystemVariables(listenerCfg)
211+
err = updateSystemVariables(listenerCfg)
212+
if err != nil {
213+
return nil, err
214+
}
186215

187216
return &Server{
188217
Listener: protocolListener,

0 commit comments

Comments
 (0)