Skip to content

Commit 8039a98

Browse files
authored
Merge pull request #444 from dolthub/aaron/connectionauthenticated-callback
go/mysql: server.go: Add a callback on Handler, ConnectionAuthenticated, which is called immediately after the connection is authenticated.
2 parents 843d10a + 6b5f6cc commit 8039a98

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

go/mysql/server.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,17 @@ type Handler interface {
9393
// ConnectionClosed is called when a connection is closed.
9494
ConnectionClosed(c *Conn)
9595

96+
// ConnectionAuthenticated is called when a connection is authenticated.
97+
// Always called after NewConnection and before ConnectionClosed.
98+
ConnectionAuthenticated(*Conn) error
99+
96100
// ConnectionAborted is called when a new connection cannot be fully established. For
97101
// example, if a client connects to the server, but fails authentication, or can't
98102
// negotiate an authentication handshake, this method will be called to let integrators
99103
// know about the failed connection attempt.
104+
//
105+
// ConnectionClosed will still be called for the connection after ConnectionAborted is
106+
// called.
100107
ConnectionAborted(c *Conn, reason string) error
101108

102109
// ComInitDB is called once at the beginning to set db name,
@@ -560,6 +567,13 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3
560567
defer connCountPerUser.Add(c.User, -1)
561568
}
562569

570+
if err = l.handler.ConnectionAuthenticated(c); err != nil {
571+
log.Errorf("failed to register the connection as authenticated %s: %v", c, err)
572+
573+
c.writeErrorPacketFromError(err)
574+
return
575+
}
576+
563577
// Set initial db name.
564578
if c.schemaName != "" {
565579
if err = l.handler.ComInitDB(c, c.schemaName); err != nil {

go/mysql/server_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ func (th *testHandler) ConnectionAborted(c *Conn, reason string) error {
119119
return nil
120120
}
121121

122+
func (th *testHandler) ConnectionAuthenticated(c *Conn) error {
123+
return nil
124+
}
125+
122126
func (th *testHandler) ParserOptionsForConnection(c *Conn) (sqlparser.ParserOptions, error) {
123127
return sqlparser.ParserOptions{}, nil
124128
}

go/sqltypes/arithmetic_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"fmt"
2222
"math"
2323
"reflect"
24+
"runtime"
2425
"strconv"
2526
"testing"
2627

@@ -486,6 +487,7 @@ func TestNullsafeAdd(t *testing.T) {
486487
v1, v2 Value
487488
out Value
488489
err error
490+
skip bool
489491
}{{
490492
// All nulls.
491493
v1: NULL,
@@ -521,13 +523,24 @@ func TestNullsafeAdd(t *testing.T) {
521523
v1: NewInt64(-1),
522524
v2: NewUint64(2),
523525
out: NewInt64(-9223372036854775808),
526+
// This test relies on float64 -> int64 conversion, where the float64 is larger than the maximum int64.
527+
// According to the Golang spec:
528+
//
529+
// > In all non-constant conversions involving floating-point or complex values, if the result
530+
// type cannot represent the value the conversion succeeds but the result value is implementation-dependent.
531+
//
532+
// And indeed, the test fails on arm64 but passes on amd64 because the computed conversion is different.
533+
skip: runtime.GOARCH != "amd64",
524534
}, {
525535
// Make sure underlying error is returned while converting.
526536
v1: NewFloat64(1),
527537
v2: NewFloat64(2),
528538
out: NewInt64(3),
529539
}}
530540
for _, tcase := range tcases {
541+
if tcase.skip {
542+
t.SkipNow()
543+
}
531544
got := NullsafeAdd(tcase.v1, tcase.v2, Int64)
532545

533546
if !reflect.DeepEqual(got, tcase.out) {

0 commit comments

Comments
 (0)