@@ -644,83 +644,90 @@ func (cs CumulativeTestCaseSpec) run(t *testing.T, fn func(t *testing.T)) bool {
644644 return t .Run (fmt .Sprintf ("%s_stage_%d_of_%d" , prefix , cs .StageOrdinal , cs .StagesCount ), fn )
645645}
646646
647+ // cumulativeTestForEachPostCommitStage invokes `tf` once for each stage in the
648+ // PostCommitPhase. These invocation are run in parallel.
647649func cumulativeTestForEachPostCommitStage (
648650 t * testing.T ,
649651 relTestCaseDir string ,
650652 factory TestServerFactory ,
651653 tf func (t * testing.T , spec CumulativeTestCaseSpec ),
652654) {
653- testFunc := func (t * testing.T , spec CumulativeTestSpec ) {
654- // Skip this test if any of the stmts is not fully supported.
655- if err := areStmtsFullySupportedAtClusterVersion (t , spec , factory ); err != nil {
656- skip .IgnoreLint (t , "test is skipped because" , err .Error ())
657- }
658- var postCommitCount , postCommitNonRevertibleCount int
659- var after [][]string
660- var dbName string
661- prepfn := func (db * gosql.DB , p scplan.Plan ) {
662- for _ , s := range p .Stages {
663- switch s .Phase {
664- case scop .PostCommitPhase :
665- postCommitCount ++
666- case scop .PostCommitNonRevertiblePhase :
667- postCommitNonRevertibleCount ++
655+ // Grouping the parallel subtests into a non-parallel subtest allows any defer
656+ // calls to work as expected.
657+ t .Run ("group" , func (t * testing.T ) {
658+ testFunc := func (t * testing.T , spec CumulativeTestSpec ) {
659+ // Skip this test if any of the stmts is not fully supported.
660+ if err := areStmtsFullySupportedAtClusterVersion (t , spec , factory ); err != nil {
661+ skip .IgnoreLint (t , "test is skipped because" , err .Error ())
662+ }
663+ var postCommitCount , postCommitNonRevertibleCount int
664+ var after [][]string
665+ var dbName string
666+ prepfn := func (db * gosql.DB , p scplan.Plan ) {
667+ for _ , s := range p .Stages {
668+ switch s .Phase {
669+ case scop .PostCommitPhase :
670+ postCommitCount ++
671+ case scop .PostCommitNonRevertiblePhase :
672+ postCommitNonRevertibleCount ++
673+ }
668674 }
675+ tdb := sqlutils .MakeSQLRunner (db )
676+ var ok bool
677+ dbName , ok = maybeGetDatabaseForIDs (t , tdb , screl .AllTargetStateDescIDs (p .TargetState ))
678+ if ok {
679+ tdb .Exec (t , fmt .Sprintf ("USE %q" , dbName ))
680+ }
681+ after = tdb .QueryStr (t , fetchDescriptorStateQuery )
669682 }
670- tdb := sqlutils .MakeSQLRunner (db )
671- var ok bool
672- dbName , ok = maybeGetDatabaseForIDs (t , tdb , screl .AllTargetStateDescIDs (p .TargetState ))
673- if ok {
674- tdb .Exec (t , fmt .Sprintf ("USE %q" , dbName ))
683+ withPostCommitPlanAfterSchemaChange (t , spec , factory , prepfn )
684+ if postCommitCount + postCommitNonRevertibleCount == 0 {
685+ skip .IgnoreLint (t , "test case has no post-commit stages" )
686+ return
675687 }
676- after = tdb .QueryStr (t , fetchDescriptorStateQuery )
677- }
678- withPostCommitPlanAfterSchemaChange (t , spec , factory , prepfn )
679- if postCommitCount + postCommitNonRevertibleCount == 0 {
680- skip .IgnoreLint (t , "test case has no post-commit stages" )
681- return
682- }
683- if dbName == "" {
684- skip .IgnoreLint (t , "test case has no usable database" )
685- return
686- }
687- var testCases []CumulativeTestCaseSpec
688- for stageOrdinal := 1 ; stageOrdinal <= postCommitCount ; stageOrdinal ++ {
689- testCases = append (testCases , CumulativeTestCaseSpec {
690- CumulativeTestSpec : spec ,
691- Phase : scop .PostCommitPhase ,
692- StageOrdinal : stageOrdinal ,
693- StagesCount : postCommitCount ,
694- After : after ,
695- DatabaseName : dbName ,
696- })
697- }
698- for stageOrdinal := 1 ; stageOrdinal <= postCommitNonRevertibleCount ; stageOrdinal ++ {
699- testCases = append (testCases , CumulativeTestCaseSpec {
700- CumulativeTestSpec : spec ,
701- Phase : scop .PostCommitNonRevertiblePhase ,
702- StageOrdinal : stageOrdinal ,
703- StagesCount : postCommitNonRevertibleCount ,
704- After : after ,
705- DatabaseName : dbName ,
706- })
707- }
708- var hasFailed bool
709- for _ , tc := range testCases {
710- fn := func (t * testing.T ) {
711- tf (t , tc )
688+ if dbName == "" {
689+ skip .IgnoreLint (t , "test case has no usable database" )
690+ return
712691 }
713- if hasFailed {
714- fn = func (t * testing.T ) {
715- skip .IgnoreLint (t , "skipping test cases subsequent to earlier failure" )
716- }
692+ var testCases []CumulativeTestCaseSpec
693+ for stageOrdinal := 1 ; stageOrdinal <= postCommitCount ; stageOrdinal ++ {
694+ testCases = append (testCases , CumulativeTestCaseSpec {
695+ CumulativeTestSpec : spec ,
696+ Phase : scop .PostCommitPhase ,
697+ StageOrdinal : stageOrdinal ,
698+ StagesCount : postCommitCount ,
699+ After : after ,
700+ DatabaseName : dbName ,
701+ })
717702 }
718- if ! tc .run (t , fn ) {
719- hasFailed = true
703+ for stageOrdinal := 1 ; stageOrdinal <= postCommitNonRevertibleCount ; stageOrdinal ++ {
704+ testCases = append (testCases , CumulativeTestCaseSpec {
705+ CumulativeTestSpec : spec ,
706+ Phase : scop .PostCommitNonRevertiblePhase ,
707+ StageOrdinal : stageOrdinal ,
708+ StagesCount : postCommitNonRevertibleCount ,
709+ After : after ,
710+ DatabaseName : dbName ,
711+ })
712+ }
713+ var hasFailed bool
714+ for _ , tc := range testCases {
715+ fn := func (t * testing.T ) {
716+ t .Parallel () // SAFE FOR TESTING
717+ tf (t , tc )
718+ }
719+ if hasFailed {
720+ fn = func (t * testing.T ) {
721+ skip .IgnoreLint (t , "skipping test cases subsequent to earlier failure" )
722+ }
723+ }
724+ if ! tc .run (t , fn ) {
725+ hasFailed = true
726+ }
720727 }
721728 }
722- }
723- cumulativeTest ( t , relTestCaseDir , testFunc )
729+ cumulativeTest ( t , relTestCaseDir , testFunc )
730+ } )
724731}
725732
726733// fetchDescriptorStateQuery returns the CREATE statements for all descriptors
0 commit comments