@@ -8,8 +8,10 @@ package sctest
88import (
99 "context"
1010 gosql "database/sql"
11+ "flag"
1112 "fmt"
1213 "math"
14+ "math/rand"
1315 "os"
1416 "path/filepath"
1517 "regexp"
@@ -634,6 +636,31 @@ type CumulativeTestCaseSpec struct {
634636 CreateDatabaseStmt string
635637}
636638
639+ // sampleAllPostCommitRevertible samples all post commit revertible stages, and
640+ // limits testing to maxStagesToTest if *runAllCumulative is not set.
641+ func sampleAllPostCommitRevertible (testCases []CumulativeTestCaseSpec ) []CumulativeTestCaseSpec {
642+ newTestCases := make ([]CumulativeTestCaseSpec , 0 , len (testCases ))
643+ for _ , tc := range testCases {
644+ if tc .Phase != scop .PostCommitNonRevertiblePhase {
645+ newTestCases = append (newTestCases , tc )
646+ }
647+ }
648+ return sampleAllPostCommitStages (newTestCases )
649+ }
650+
651+ // sampleAllPostCommitStages samples all post-commit stages, and limits these to
652+ // maxStagesToTest if *runAllCumulative is not set.
653+ func sampleAllPostCommitStages (testCases []CumulativeTestCaseSpec ) []CumulativeTestCaseSpec {
654+ if len (testCases ) > maxStagesToTest && ! (* runAllCumulative ) {
655+ // Shuffle and pick up to maxStagesToTest.
656+ rand .Shuffle (len (testCases ), func (i , j int ) {
657+ testCases [i ], testCases [j ] = testCases [j ], testCases [i ]
658+ })
659+ testCases = testCases [:maxStagesToTest ]
660+ }
661+ return testCases
662+ }
663+
637664func (cs CumulativeTestCaseSpec ) run (t * testing.T , fn func (t * testing.T )) bool {
638665 var prefix string
639666 switch cs .Phase {
@@ -647,6 +674,12 @@ func (cs CumulativeTestCaseSpec) run(t *testing.T, fn func(t *testing.T)) bool {
647674 return t .Run (fmt .Sprintf ("%s_stage_%d_of_%d" , prefix , cs .StageOrdinal , cs .StagesCount ), fn )
648675}
649676
677+ // / runAllCumulative used to disable sampling for cumulative tests.
678+ var runAllCumulative = flag .Bool (
679+ "run-all-cumulative" , false ,
680+ "if true, run all cumulative instead of a random subset" ,
681+ )
682+
650683// cumulativeTestForEachPostCommitStage invokes `tf` once for each stage in the
651684// PostCommitPhase.
652685func cumulativeTestForEachPostCommitStage (
@@ -655,6 +688,7 @@ func cumulativeTestForEachPostCommitStage(
655688 factory TestServerFactory ,
656689 prepFn func (t * testing.T , spec CumulativeTestSpec , dbName string ),
657690 tf func (t * testing.T , spec CumulativeTestCaseSpec ),
691+ samplingFn func ([]CumulativeTestCaseSpec ) []CumulativeTestCaseSpec ,
658692) {
659693 testFunc := func (t * testing.T , spec CumulativeTestSpec ) {
660694 // Skip this test if any of the stmts is not fully supported.
@@ -720,6 +754,10 @@ func cumulativeTestForEachPostCommitStage(
720754 if prepFn != nil {
721755 prepFn (t , spec , dbName )
722756 }
757+ // If sampling is enabled limit the number of stages executed.
758+ if samplingFn != nil {
759+ testCases = samplingFn (testCases )
760+ }
723761 for _ , tc := range testCases {
724762 fn := func (t * testing.T ) {
725763 tf (t , tc )
0 commit comments