@@ -19,15 +19,19 @@ package vstreamer
19
19
import (
20
20
"context"
21
21
"fmt"
22
+ "os"
22
23
"regexp"
23
24
"testing"
25
+ "time"
24
26
27
+ "github.com/spf13/pflag"
25
28
"github.com/stretchr/testify/require"
26
29
27
30
"vitess.io/vitess/go/mysql"
28
31
"vitess.io/vitess/go/mysql/collations"
29
32
"vitess.io/vitess/go/sqltypes"
30
33
"vitess.io/vitess/go/vt/log"
34
+ "vitess.io/vitess/go/vt/servenv"
31
35
32
36
binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
33
37
)
@@ -432,6 +436,100 @@ func TestStreamRowsCancel(t *testing.T) {
432
436
}
433
437
}
434
438
439
+ // setFlag() sets a flag for a test in a non-racy way:
440
+ // - it registers the flag using a different flagset scope
441
+ // - clears other flags by passing a dummy os.Args() while parsing this flagset
442
+ // - sets the specific flag, if it has not already been defined
443
+ // - resets the os.Args() so that the remaining flagsets can be parsed correctly
444
+ func setFlag (flagName , flagValue string ) {
445
+ flagSetName := "vttablet"
446
+ var tmp []string
447
+ tmp , os .Args = os .Args [:], []string {flagSetName }
448
+ defer func () { os .Args = tmp }()
449
+
450
+ servenv .OnParseFor (flagSetName , func (fs * pflag.FlagSet ) {
451
+ if fs .Lookup (flagName ) != nil {
452
+ fmt .Printf ("found %s: %+v" , flagName , fs .Lookup (flagName ).Value )
453
+ return
454
+ }
455
+ })
456
+ servenv .ParseFlags (flagSetName )
457
+
458
+ if err := pflag .Set (flagName , flagValue ); err != nil {
459
+ msg := "failed to set flag %q to %q: %v"
460
+ log .Errorf (msg , flagName , flagValue , err )
461
+ }
462
+ }
463
+
464
+ func TestStreamRowsHeartbeat (t * testing.T ) {
465
+ if testing .Short () {
466
+ t .Skip ()
467
+ }
468
+ setFlag ("vstream_packet_size" , "10" )
469
+ defer setFlag ("vstream_packet_size" , "10000" )
470
+
471
+ // Save original heartbeat interval and restore it after test
472
+ originalInterval := rowStreamertHeartbeatInterval
473
+ defer func () {
474
+ rowStreamertHeartbeatInterval = originalInterval
475
+ }()
476
+
477
+ // Set a very short heartbeat interval for testing (100ms)
478
+ rowStreamertHeartbeatInterval = 10 * time .Millisecond
479
+
480
+ execStatements (t , []string {
481
+ "create table t1(id int, val varchar(128), primary key(id))" ,
482
+ "insert into t1 values (1, 'test1')" ,
483
+ "insert into t1 values (2, 'test2')" ,
484
+ "insert into t1 values (3, 'test3')" ,
485
+ "insert into t1 values (4, 'test4')" ,
486
+ "insert into t1 values (5, 'test5')" ,
487
+ })
488
+
489
+ defer execStatements (t , []string {
490
+ "drop table t1" ,
491
+ })
492
+
493
+ ctx , cancel := context .WithTimeout (context .Background (), 1 * time .Second )
494
+ defer cancel ()
495
+
496
+ heartbeatCount := 0
497
+ dataReceived := false
498
+
499
+ err := engine .StreamRows (ctx , "select * from t1" , nil , func (rows * binlogdatapb.VStreamRowsResponse ) error {
500
+ if rows .Heartbeat {
501
+ heartbeatCount ++
502
+ // After receiving at least 3 heartbeats, we can be confident the fix is working
503
+ if heartbeatCount >= 3 {
504
+ cancel ()
505
+ return nil
506
+ }
507
+ } else if len (rows .Rows ) > 0 {
508
+ dataReceived = true
509
+ }
510
+ // Add a small delay to allow heartbeats to be sent
511
+ time .Sleep (50 * time .Millisecond )
512
+ return nil
513
+ })
514
+
515
+ // We expect context canceled error since we cancel after receiving heartbeats
516
+ if err != nil && err .Error () != "stream ended: context canceled" {
517
+ t .Errorf ("unexpected error: %v" , err )
518
+ }
519
+
520
+ // Verify we received data
521
+ if ! dataReceived {
522
+ t .Error ("expected to receive data rows" )
523
+ }
524
+
525
+ // This is the critical test: we should receive multiple heartbeats
526
+ // Without the fix (missing for loop), we would only get 1 heartbeat
527
+ // With the fix, we should get at least 3 heartbeats
528
+ if heartbeatCount < 3 {
529
+ t .Errorf ("expected at least 3 heartbeats, got %d. This indicates the heartbeat goroutine is not running continuously" , heartbeatCount )
530
+ }
531
+ }
532
+
435
533
func checkStream (t * testing.T , query string , lastpk []sqltypes.Value , wantQuery string , wantStream []string ) {
436
534
t .Helper ()
437
535
0 commit comments