diff --git a/enginetest/server_engine_test.go b/enginetest/server_engine_test.go index c95f60b1c3..b6bd55ddf0 100644 --- a/enginetest/server_engine_test.go +++ b/enginetest/server_engine_test.go @@ -375,3 +375,88 @@ func TestServerPreparedStatements(t *testing.T) { }) } } + +func TestServerQueries(t *testing.T) { + port, perr := findEmptyPort() + require.NoError(t, perr) + + s, serr := initTestServer(port) + require.NoError(t, serr) + + go s.Start() + defer s.Close() + + tests := []serverScriptTest{ + { + name: "test that config variables are properly set", + setup: []string{}, + assertions: []serverScriptTestAssertion{ + { + query: "select @@hostname, @@port", + //query: "select @@hostname, @@port, @@max_connections", + isExec: false, + expectedRows: []any{ + sql.Row{"macbook.local", port}, + }, + checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) { + var resHostname string + var resPort int + var rowNum int + for rows.Next() { + if err := rows.Scan(&resHostname, &resPort); err != nil { + return false, err + } + if rowNum >= len(expectedRows) { + return false, nil + } + expectedRow := expectedRows[rowNum].(sql.Row) + require.Equal(t, expectedRow[0].(string), resHostname) + require.Equal(t, expectedRow[1].(int), resPort) + } + return true, nil + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + conn, cerr := dbr.Open("mysql", fmt.Sprintf(noUserFmt, address, port), nil) + require.NoError(t, cerr) + defer conn.Close() + commonSetup := []string{ + "create database test_db;", + "use test_db;", + } + commonTeardown := []string{ + "drop database test_db", + } + for _, stmt := range append(commonSetup, test.setup...) { + _, err := conn.Exec(stmt) + require.NoError(t, err) + } + for _, assertion := range test.assertions { + t.Run(assertion.query, func(t *testing.T) { + if assertion.skip { + t.Skip() + } + rows, err := conn.Query(assertion.query, assertion.args...) + if assertion.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + ok, err := assertion.checkRows(t, rows, assertion.expectedRows) + require.NoError(t, err) + require.True(t, ok) + }) + } + for _, stmt := range append(commonTeardown) { + _, err := conn.Exec(stmt) + require.NoError(t, err) + } + }) + } +} diff --git a/server/server.go b/server/server.go index f1de217aba..061e7dbc6b 100644 --- a/server/server.go +++ b/server/server.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "net" + "os" "strconv" "time" @@ -119,13 +120,38 @@ func portInUse(hostPort string) bool { return false } -func updateSystemVariables(cfg mysql.ListenerConfig) { - _, port, _ := net.SplitHostPort(cfg.Listener.Addr().String()) - portInt, _ := strconv.ParseInt(port, 10, 64) - sql.SystemVariables.AssignValues(map[string]interface{}{ - "max_connections": cfg.MaxConns, - "port": portInt, +func getHostname() (string, error) { + hostname, err := os.Hostname() + if err != nil { + return "", err + } + return hostname, nil +} + +func updateSystemVariables(cfg mysql.ListenerConfig) error { + hostname, err := getHostname() + if err != nil { + return err + } + _, port, err := net.SplitHostPort(cfg.Listener.Addr().String()) + if err != nil { + return err + } + portInt, err := strconv.ParseInt(port, 10, 64) + if err != nil { + return err + } + // TODO: add the rest of the config variables + err = sql.SystemVariables.AssignValues(map[string]interface{}{ + "port": portInt, + "hostname": hostname, + // TODO: this causes an error because max_connections is 0? + //"max_connections": cfg.MaxConns, }) + if err != nil { + return err + } + return nil } func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) { @@ -182,7 +208,10 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle return nil, err } - updateSystemVariables(listenerCfg) + err = updateSystemVariables(listenerCfg) + if err != nil { + return nil, err + } return &Server{ Listener: protocolListener,