@@ -25,6 +25,7 @@ import (
2525 "time"
2626
2727 "github.com/dolthub/vitess/go/mysql"
28+ "github.com/dolthub/vitess/go/race"
2829 "github.com/dolthub/vitess/go/sqltypes"
2930 "github.com/dolthub/vitess/go/vt/proto/query"
3031 "github.com/stretchr/testify/assert"
@@ -742,6 +743,113 @@ func TestHandlerKill(t *testing.T) {
742743 require .Len (handler .sm .sessions , 1 )
743744}
744745
746+ func TestHandlerKillQuery (t * testing.T ) {
747+ if race .Enabled {
748+ t .Skip ("this test is inherently racey" )
749+ }
750+ require := require .New (t )
751+ e , pro := setupMemDB (require )
752+ dbFunc := pro .Database
753+
754+ handler := & Handler {
755+ e : e ,
756+ sm : NewSessionManager (
757+ func (ctx context.Context , conn * mysql.Conn , addr string ) (sql.Session , error ) {
758+ return sql .NewBaseSessionWithClientServer (addr , sql.Client {Capabilities : conn .Capabilities }, conn .ConnectionID ), nil
759+ },
760+ sql .NoopTracer ,
761+ dbFunc ,
762+ e .MemoryManager ,
763+ e .ProcessList ,
764+ "foo" ,
765+ ),
766+ }
767+
768+ var err error
769+ conn1 := newConn (1 )
770+ handler .NewConnection (conn1 )
771+
772+ conn2 := newConn (2 )
773+ handler .NewConnection (conn2 )
774+
775+ require .Len (handler .sm .connections , 2 )
776+ require .Len (handler .sm .sessions , 0 )
777+
778+ handler .ComInitDB (conn1 , "test" )
779+ err = handler .sm .SetDB (conn1 , "test" )
780+ require .NoError (err )
781+
782+ err = handler .sm .SetDB (conn2 , "test" )
783+ require .NoError (err )
784+
785+ require .False (conn1 .Conn .(* mockConn ).closed )
786+ require .False (conn2 .Conn .(* mockConn ).closed )
787+ require .Len (handler .sm .connections , 2 )
788+ require .Len (handler .sm .sessions , 2 )
789+
790+ var wg sync.WaitGroup
791+ wg .Add (1 )
792+ sleepQuery := "SELECT SLEEP(1)"
793+ go func () {
794+ defer wg .Done ()
795+ err = handler .ComQuery (context .Background (), conn1 , sleepQuery , func (res * sqltypes.Result , more bool ) error {
796+ return nil
797+ })
798+ require .Error (err )
799+ }()
800+
801+ time .Sleep (100 * time .Millisecond )
802+ var sleepQueryID string
803+ err = handler .ComQuery (context .Background (), conn2 , "SHOW PROCESSLIST" , func (res * sqltypes.Result , more bool ) error {
804+ // 1, , , test, Query, 0, ... , SELECT SLEEP(1000)
805+ // 2, , , test, Query, 0, running, SHOW PROCESSLIST
806+ require .Equal (2 , len (res .Rows ))
807+ hasSleepQuery := false
808+ for _ , row := range res .Rows {
809+ if row [7 ].ToString () != sleepQuery {
810+ continue
811+ }
812+ hasSleepQuery = true
813+ sleepQueryID = row [0 ].ToString ()
814+ require .Equal ("Query" , row [4 ].ToString ())
815+ }
816+ require .True (hasSleepQuery )
817+ return nil
818+ })
819+ require .NoError (err )
820+
821+ time .Sleep (100 * time .Millisecond )
822+ err = handler .ComQuery (context .Background (), conn2 , "KILL QUERY " + sleepQueryID , func (res * sqltypes.Result , more bool ) error {
823+ return nil
824+ })
825+ require .NoError (err )
826+ wg .Wait ()
827+
828+ time .Sleep (100 * time .Millisecond )
829+ err = handler .ComQuery (context .Background (), conn2 , "SHOW PROCESSLIST" , func (res * sqltypes.Result , more bool ) error {
830+ // 1, , , test, Sleep, 0, ,
831+ // 2, , , test, Query, 0, running, SHOW PROCESSLIST
832+ require .Equal (2 , len (res .Rows ))
833+ hasSleepQueryID := false
834+ for _ , row := range res .Rows {
835+ if row [0 ].ToString () != sleepQueryID {
836+ continue
837+ }
838+ hasSleepQueryID = true
839+ require .Equal ("Sleep" , row [4 ].ToString ())
840+ require .Equal ("" , row [7 ].ToString ())
841+ }
842+ require .True (hasSleepQueryID )
843+ return nil
844+ })
845+ require .NoError (err )
846+
847+ require .False (conn1 .Conn .(* mockConn ).closed )
848+ require .False (conn2 .Conn .(* mockConn ).closed )
849+ require .Len (handler .sm .connections , 2 )
850+ require .Len (handler .sm .sessions , 2 )
851+ }
852+
745853func TestSchemaToFields (t * testing.T ) {
746854 require := require .New (t )
747855
0 commit comments