@@ -3,12 +3,14 @@ package testrunner
3
3
import (
4
4
"bytes"
5
5
"errors"
6
+ "fmt"
6
7
"go/ast"
7
8
"go/format"
8
9
"go/parser"
9
10
"go/printer"
10
11
"go/token"
11
12
"log"
13
+ "path/filepath"
12
14
"regexp"
13
15
"strconv"
14
16
"strings"
@@ -34,6 +36,7 @@ type rootLevelTest struct {
34
36
fileName string
35
37
code string
36
38
taskID uint64
39
+ pkgName string
37
40
}
38
41
39
42
// FindAllRootLevelTests parses the test file and extracts the name,
@@ -60,6 +63,7 @@ func FindAllRootLevelTests(fileName string) []rootLevelTest {
60
63
fileName : fileName ,
61
64
code : buf .String (),
62
65
taskID : taskID ,
66
+ pkgName : file .Name .Name ,
63
67
})
64
68
}
65
69
}
@@ -95,16 +99,19 @@ func findTaskID(doc *ast.CommentGroup) uint64 {
95
99
}
96
100
97
101
// generate simplified test code corresponding to a subtest
98
- func getSubCode (test string , sub string , code string , file string ) string {
102
+ func getSubCode (test string , sub string , code string , file string , pkgName string ) string {
103
+ pkgLine := fmt .Sprintf ("package %s\n " , pkgName )
99
104
fset := token .NewFileSet ()
100
105
f , err := parser .ParseFile (
101
- fset , file , "package main \n " + code , parser .ParseComments ,
106
+ fset , file , pkgLine + code , parser .ParseComments ,
102
107
)
103
108
if err != nil {
104
109
log .Printf ("warning: '%s' not parsed from '%s': %s" , test , file , err )
105
110
return ""
106
111
}
107
112
113
+ resolveTestData (fset , f , file )
114
+
108
115
fAST , ok := f .Decls [0 ].(* ast.FuncDecl )
109
116
if ! ok {
110
117
log .Println ("warning: first subtest declaration must be a function" )
@@ -113,7 +120,7 @@ func getSubCode(test string, sub string, code string, file string) string {
113
120
114
121
fbAST := fAST .Body .List // f.Decls[0].Body.List
115
122
116
- astInfo , err := findTestDataAndRange (fbAST )
123
+ astInfo , err := findTestDataAndRange (fbAST , fset )
117
124
if err != nil {
118
125
log .Printf ("warning: could not find test table and/or range: %v\n " , err )
119
126
return ""
@@ -146,36 +153,33 @@ func getSubCode(test string, sub string, code string, file string) string {
146
153
log .Println ("warning: failed to format extracted AST for subtest" )
147
154
return ""
148
155
}
149
- return strings .TrimSpace (strings .TrimPrefix (buf .String (), "package main" ))
156
+ if astInfo .testDataAstIdx != - 1 { // testDataAst is already in the test function
157
+ return strings .TrimSpace (strings .TrimPrefix (buf .String (), pkgLine ))
158
+ }
159
+ return insertTestDataASTIntoFunc (fset , astInfo .testDataAst , fAST .Body , buf .Bytes (), pkgLine )
150
160
}
151
161
152
- func findTestDataAndRange (stmtList []ast.Stmt ) (subTestAstInfo , error ) {
162
+ func findTestDataAndRange (stmtList []ast.Stmt , fset * token. FileSet ) (subTestAstInfo , error ) {
153
163
result := subTestAstInfo {}
154
-
164
+ posToIndex := make ( map [token. Position ] int )
155
165
for i := range stmtList {
156
- assignCandidate , ok := stmtList [i ].(* ast.AssignStmt )
157
- if ok && result .testDataAst == nil {
158
- result .testDataAst = assignCandidate
159
- result .testDataAstIdx = i
160
- } else if ok {
161
- identifier , isIdentifier := assignCandidate .Lhs [0 ].(* ast.Ident )
162
- if ! isIdentifier {
163
- continue
164
- }
165
- // Overwrite the assignment we already found in case there is an
166
- // assignment to a "tests" variable.
167
- if identifier .Name == "tests" {
166
+ posToIndex [fset .Position (stmtList [i ].Pos ())] = i
167
+ if rangeCandidate , ok := stmtList [i ].(* ast.RangeStmt ); ok {
168
+ assignCandidate := getTestDataAssignFromRange (rangeCandidate )
169
+ if assignCandidate != nil {
170
+ // check if assignCandidate is in the same function with rangeCandidate
171
+ if idx , ok := posToIndex [fset .Position (assignCandidate .Pos ())]; ok &&
172
+ fset .File (assignCandidate .Pos ()).Name () == fset .File (rangeCandidate .Pos ()).Name () {
173
+ result .testDataAstIdx = idx
174
+ } else {
175
+ result .testDataAstIdx = - 1
176
+ }
168
177
result .testDataAst = assignCandidate
169
- result .testDataAstIdx = i
178
+ result .rangeAst = rangeCandidate
179
+ result .rangeAstIdx = i
180
+ return result , nil
170
181
}
171
- }
172
-
173
- rangeCandidate , ok := stmtList [i ].(* ast.RangeStmt )
174
- // If we found a range after we already found an assignment, we are good to go.
175
- if ok && result .testDataAst != nil {
176
- result .rangeAst = rangeCandidate
177
- result .rangeAstIdx = i
178
- return result , nil
182
+ return subTestAstInfo {}, errors .New ("failed to find assignment in sub-test" )
179
183
}
180
184
}
181
185
@@ -185,6 +189,24 @@ func findTestDataAndRange(stmtList []ast.Stmt) (subTestAstInfo, error) {
185
189
186
190
return subTestAstInfo {}, errors .New ("failed to find range statement in sub-test" )
187
191
}
192
+ func getTestDataAssignFromRange (rangeAst * ast.RangeStmt ) * ast.AssignStmt {
193
+ spec := rangeAst .X .(* ast.Ident ).Obj .Decl
194
+ if assignStmt , ok := spec .(* ast.AssignStmt ); ok {
195
+ return assignStmt
196
+ }
197
+ if valueSpec , ok := spec .(* ast.ValueSpec ); ok {
198
+ lhs := make ([]ast.Expr , len (valueSpec .Names ))
199
+ for i , name := range valueSpec .Names {
200
+ lhs [i ] = name
201
+ }
202
+ return & ast.AssignStmt {
203
+ Lhs : lhs ,
204
+ Tok : token .DEFINE ,
205
+ Rhs : valueSpec .Values ,
206
+ }
207
+ }
208
+ return nil
209
+ }
188
210
189
211
// validate the test data assignment and return the associated metadata
190
212
func processTestDataAssgn (sub string , assgn * ast.AssignStmt ) (* subTData , bool ) {
@@ -290,3 +312,44 @@ func processRange(metadata *subTData, rastmt *ast.RangeStmt) bool {
290
312
metadata .subTest = body
291
313
return true
292
314
}
315
+
316
+ // resolveTestData resolves test data variable declared in cases_test.go (if exists)
317
+ func resolveTestData (fset * token.FileSet , f * ast.File , file string ) {
318
+ filedata := filepath .Join (filepath .Dir (file ), "cases_test.go" )
319
+ fdata , _ := parser .ParseFile (fset , filedata , nil , parser .ParseComments )
320
+
321
+ // NewPackage func always return errors because f files's missing import part
322
+ // so ignore checking the returned errors
323
+ if fdata != nil {
324
+ _ , _ = ast .NewPackage (fset , map [string ]* ast.File {file : f , filedata : fdata }, nil , nil )
325
+ } else {
326
+ _ , _ = ast .NewPackage (fset , map [string ]* ast.File {file : f }, nil , nil )
327
+ }
328
+ }
329
+
330
+ // insertTestDataASTIntoFunc inserts testDataAst into the first line of fbAST function's body
331
+ func insertTestDataASTIntoFunc (fset * token.FileSet , testDataAst * ast.AssignStmt , fbAST * ast.BlockStmt , fileText []byte , pkgLine string ) string {
332
+ buf := bytes.Buffer {}
333
+
334
+ p := fset .Position (fbAST .Lbrace ).Offset + 1
335
+
336
+ // write the beginning of fileText to func (...) {
337
+ buf .Write (fileText [:p + 1 ])
338
+
339
+ // write test data assign stmt
340
+ if err := format .Node (& buf , fset , testDataAst ); err != nil {
341
+ log .Println ("warning: failed to format extracted AST for subtest" )
342
+ return ""
343
+ }
344
+ // write the rest of fileText
345
+ buf .Write (fileText [p + 1 :])
346
+
347
+ // because assign stmt is extracted from different file, its indentation is different from fileText
348
+ // so need to reformat
349
+ src , err := format .Source ((buf .Bytes ()))
350
+ if err != nil {
351
+ log .Println ("warning: failed to format extracted AST for subtest" )
352
+ return ""
353
+ }
354
+ return strings .TrimSpace (strings .TrimPrefix (string (src ), pkgLine ))
355
+ }
0 commit comments