Skip to content

Commit d535464

Browse files
committed
Test code extraction for exercises with separate file for test cases
This commit also support arbitrary test data variable name when test function includes multiple assignments.
1 parent 35f10f2 commit d535464

File tree

7 files changed

+327
-30
lines changed

7 files changed

+327
-30
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ require (
77
github.com/pmezard/go-difflib v1.0.0 // indirect
88
github.com/stretchr/testify v1.8.2 // indirect
99
gopkg.in/yaml.v3 v3.0.1 // indirect
10-
)
10+
)

go.sum

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
1313
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
1414
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
1515
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
16-
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
16+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

testrunner/ast.go

Lines changed: 90 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ package testrunner
33
import (
44
"bytes"
55
"errors"
6+
"fmt"
67
"go/ast"
78
"go/format"
89
"go/parser"
910
"go/printer"
1011
"go/token"
1112
"log"
13+
"path/filepath"
1214
"regexp"
1315
"strconv"
1416
"strings"
@@ -34,6 +36,7 @@ type rootLevelTest struct {
3436
fileName string
3537
code string
3638
taskID uint64
39+
pkgName string
3740
}
3841

3942
// FindAllRootLevelTests parses the test file and extracts the name,
@@ -60,6 +63,7 @@ func FindAllRootLevelTests(fileName string) []rootLevelTest {
6063
fileName: fileName,
6164
code: buf.String(),
6265
taskID: taskID,
66+
pkgName: file.Name.Name,
6367
})
6468
}
6569
}
@@ -95,16 +99,19 @@ func findTaskID(doc *ast.CommentGroup) uint64 {
9599
}
96100

97101
// generate simplified test code corresponding to a subtest
98-
func getSubCode(test string, sub string, code string, file string) string {
102+
func getSubCode(test string, sub string, code string, file string, pkgName string) string {
103+
pkgLine := fmt.Sprintf("package %s\n", pkgName)
99104
fset := token.NewFileSet()
100105
f, err := parser.ParseFile(
101-
fset, file, "package main\n"+code, parser.ParseComments,
106+
fset, file, pkgLine+code, parser.ParseComments,
102107
)
103108
if err != nil {
104109
log.Printf("warning: '%s' not parsed from '%s': %s", test, file, err)
105110
return ""
106111
}
107112

113+
resolveTestData(fset, f, file)
114+
108115
fAST, ok := f.Decls[0].(*ast.FuncDecl)
109116
if !ok {
110117
log.Println("warning: first subtest declaration must be a function")
@@ -113,7 +120,7 @@ func getSubCode(test string, sub string, code string, file string) string {
113120

114121
fbAST := fAST.Body.List // f.Decls[0].Body.List
115122

116-
astInfo, err := findTestDataAndRange(fbAST)
123+
astInfo, err := findTestDataAndRange(fbAST, fset)
117124
if err != nil {
118125
log.Printf("warning: could not find test table and/or range: %v\n", err)
119126
return ""
@@ -146,36 +153,33 @@ func getSubCode(test string, sub string, code string, file string) string {
146153
log.Println("warning: failed to format extracted AST for subtest")
147154
return ""
148155
}
149-
return strings.TrimSpace(strings.TrimPrefix(buf.String(), "package main"))
156+
if astInfo.testDataAstIdx != -1 { // testDataAst is already in the test function
157+
return strings.TrimSpace(strings.TrimPrefix(buf.String(), pkgLine))
158+
}
159+
return insertTestDataASTIntoFunc(fset, astInfo.testDataAst, fAST.Body, buf.Bytes(), pkgLine)
150160
}
151161

152-
func findTestDataAndRange(stmtList []ast.Stmt) (subTestAstInfo, error) {
162+
func findTestDataAndRange(stmtList []ast.Stmt, fset *token.FileSet) (subTestAstInfo, error) {
153163
result := subTestAstInfo{}
154-
164+
posToIndex := make(map[token.Position]int)
155165
for i := range stmtList {
156-
assignCandidate, ok := stmtList[i].(*ast.AssignStmt)
157-
if ok && result.testDataAst == nil {
158-
result.testDataAst = assignCandidate
159-
result.testDataAstIdx = i
160-
} else if ok {
161-
identifier, isIdentifier := assignCandidate.Lhs[0].(*ast.Ident)
162-
if !isIdentifier {
163-
continue
164-
}
165-
// Overwrite the assignment we already found in case there is an
166-
// assignment to a "tests" variable.
167-
if identifier.Name == "tests" {
166+
posToIndex[fset.Position(stmtList[i].Pos())] = i
167+
if rangeCandidate, ok := stmtList[i].(*ast.RangeStmt); ok {
168+
assignCandidate := getTestDataAssignFromRange(rangeCandidate)
169+
if assignCandidate != nil {
170+
// check if assignCandidate is in the same function with rangeCandidate
171+
if idx, ok := posToIndex[fset.Position(assignCandidate.Pos())]; ok &&
172+
fset.File(assignCandidate.Pos()).Name() == fset.File(rangeCandidate.Pos()).Name() {
173+
result.testDataAstIdx = idx
174+
} else {
175+
result.testDataAstIdx = -1
176+
}
168177
result.testDataAst = assignCandidate
169-
result.testDataAstIdx = i
178+
result.rangeAst = rangeCandidate
179+
result.rangeAstIdx = i
180+
return result, nil
170181
}
171-
}
172-
173-
rangeCandidate, ok := stmtList[i].(*ast.RangeStmt)
174-
// If we found a range after we already found an assignment, we are good to go.
175-
if ok && result.testDataAst != nil {
176-
result.rangeAst = rangeCandidate
177-
result.rangeAstIdx = i
178-
return result, nil
182+
return subTestAstInfo{}, errors.New("failed to find assignment in sub-test")
179183
}
180184
}
181185

@@ -185,6 +189,24 @@ func findTestDataAndRange(stmtList []ast.Stmt) (subTestAstInfo, error) {
185189

186190
return subTestAstInfo{}, errors.New("failed to find range statement in sub-test")
187191
}
192+
func getTestDataAssignFromRange(rangeAst *ast.RangeStmt) *ast.AssignStmt {
193+
spec := rangeAst.X.(*ast.Ident).Obj.Decl
194+
if assignStmt, ok := spec.(*ast.AssignStmt); ok {
195+
return assignStmt
196+
}
197+
if valueSpec, ok := spec.(*ast.ValueSpec); ok {
198+
lhs := make([]ast.Expr, len(valueSpec.Names))
199+
for i, name := range valueSpec.Names {
200+
lhs[i] = name
201+
}
202+
return &ast.AssignStmt{
203+
Lhs: lhs,
204+
Tok: token.DEFINE,
205+
Rhs: valueSpec.Values,
206+
}
207+
}
208+
return nil
209+
}
188210

189211
// validate the test data assignment and return the associated metadata
190212
func processTestDataAssgn(sub string, assgn *ast.AssignStmt) (*subTData, bool) {
@@ -290,3 +312,44 @@ func processRange(metadata *subTData, rastmt *ast.RangeStmt) bool {
290312
metadata.subTest = body
291313
return true
292314
}
315+
316+
// resolveTestData resolves test data variable declared in cases_test.go (if exists)
317+
func resolveTestData(fset *token.FileSet, f *ast.File, file string) {
318+
filedata := filepath.Join(filepath.Dir(file), "cases_test.go")
319+
fdata, _ := parser.ParseFile(fset, filedata, nil, parser.ParseComments)
320+
321+
// NewPackage func always return errors because f files's missing import part
322+
// so ignore checking the returned errors
323+
if fdata != nil {
324+
_, _ = ast.NewPackage(fset, map[string]*ast.File{file: f, filedata: fdata}, nil, nil)
325+
} else {
326+
_, _ = ast.NewPackage(fset, map[string]*ast.File{file: f}, nil, nil)
327+
}
328+
}
329+
330+
// insertTestDataASTIntoFunc inserts testDataAst into the first line of fbAST function's body
331+
func insertTestDataASTIntoFunc(fset *token.FileSet, testDataAst *ast.AssignStmt, fbAST *ast.BlockStmt, fileText []byte, pkgLine string) string {
332+
buf := bytes.Buffer{}
333+
334+
p := fset.Position(fbAST.Lbrace).Offset + 1
335+
336+
// write the beginning of fileText to func (...) {
337+
buf.Write(fileText[:p+1])
338+
339+
// write test data assign stmt
340+
if err := format.Node(&buf, fset, testDataAst); err != nil {
341+
log.Println("warning: failed to format extracted AST for subtest")
342+
return ""
343+
}
344+
// write the rest of fileText
345+
buf.Write(fileText[p+1:])
346+
347+
// because assign stmt is extracted from different file, its indentation is different from fileText
348+
// so need to reformat
349+
src, err := format.Source((buf.Bytes()))
350+
if err != nil {
351+
log.Println("warning: failed to format extracted AST for subtest")
352+
return ""
353+
}
354+
return strings.TrimSpace(strings.TrimPrefix(string(src), pkgLine))
355+
}

testrunner/extract.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func ExtractTestCodeAndTaskID(rootLevelTests map[string]rootLevelTest, testName
5353
return rootLevelTest.code, rootLevelTest.taskID
5454
}
5555
defer handleASTPanic()
56-
subtc := getSubCode(test, subtest, rootLevelTest.code, rootLevelTest.fileName)
56+
subtc := getSubCode(test, subtest, rootLevelTest.code, rootLevelTest.fileName, rootLevelTest.pkgName)
5757
if len(subtc) == 0 {
5858
return rootLevelTest.code, rootLevelTest.taskID
5959
}

testrunner/extract_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ func TestExtractTestCode(t *testing.T) {
191191
}`,
192192
},
193193
}
194+
tests = append(tests, testsDataSeparate...)
195+
tests = append(tests, testsMultiAssignStmt...)
194196
for _, tt := range tests {
195197
t.Run(tt.name, func(t *testing.T) {
196198
code, _ := ExtractTestCodeAndTaskID(rootLevelTestsMap, tt.testName)
@@ -221,3 +223,108 @@ func TestExtractTestCode(t *testing.T) {
221223
})
222224
}
223225
}
226+
227+
var testsDataSeparate = []struct {
228+
name string
229+
testName string
230+
testFile string
231+
code string
232+
}{
233+
{
234+
name: "working subtest with separate test data",
235+
testName: "TestParseCard_Separate/parse_jack",
236+
testFile: filepath.Join("testdata", "concept", "conditionals", "conditionals_test.go"),
237+
code: `func TestParseCard_Separate(t *testing.T) {
238+
tt := struct {
239+
name string
240+
card string
241+
want int
242+
}{
243+
name: "parse jack",
244+
card: "jack",
245+
want: 10,
246+
}
247+
248+
if got := ParseCard(tt.card); got != tt.want {
249+
t.Errorf("ParseCard(%s) = %d, want %d", tt.card, got, tt.want)
250+
}
251+
252+
}`,
253+
}, {
254+
name: "missing / not found subtest with separate test data",
255+
testName: "TestParseCard_Separate/parse_missing_subtests",
256+
testFile: filepath.Join("testdata", "concept", "conditionals", "conditionals_test.go"),
257+
code: `func TestParseCard_Separate(t *testing.T) {
258+
for _, tt := range testcases {
259+
t.Run(tt.name, func(t *testing.T) {
260+
if got := ParseCard(tt.card); got != tt.want {
261+
t.Errorf("ParseCard(%s) = %d, want %d", tt.card, got, tt.want)
262+
}
263+
})
264+
}
265+
}`,
266+
}, {
267+
name: "multiple statements with separate test data",
268+
testName: "TestBlackjack_Separate/blackjack_with_ten_(ace_first)",
269+
testFile: filepath.Join("testdata", "concept", "conditionals", "conditionals_test.go"),
270+
code: `func TestBlackjack_Separate(t *testing.T) {
271+
tt := struct {
272+
name string
273+
hand hand
274+
want bool
275+
}{
276+
name: "blackjack with ten (ace first)",
277+
hand: hand{card1: "ace", card2: "ten"},
278+
want: true,
279+
}
280+
someAssignment := "test"
281+
fmt.Println(someAssignment)
282+
283+
_ = "literally anything"
284+
285+
got := IsBlackjack(tt.hand.card1, tt.hand.card2)
286+
if got != tt.want {
287+
t.Errorf("IsBlackjack(%s, %s) = %t, want %t", tt.hand.card1, tt.hand.card2, got, tt.want)
288+
}
289+
290+
// Additional statements should be included
291+
fmt.Println("the whole block")
292+
fmt.Println("should be returned")
293+
}`,
294+
},
295+
}
296+
var testsMultiAssignStmt = []struct {
297+
name string
298+
testName string
299+
testFile string
300+
code string
301+
}{
302+
{
303+
name: "subtest with arbitrary test data variable name, additional assign statements above and below test data",
304+
testName: "TestSubtest_MultiAssignStmt/parse_king",
305+
testFile: filepath.Join("testdata", "concept", "conditionals", "conditionals_test.go"),
306+
code: `func TestSubtest_MultiAssignStmt(t *testing.T) {
307+
someAssignment := "test"
308+
309+
tt := struct {
310+
name string
311+
card string
312+
want int
313+
}{
314+
name: "parse king",
315+
card: "king",
316+
want: 10,
317+
}
318+
319+
someAssignment2 := "test2"
320+
321+
if got := ParseCard(tt.card); got != tt.want {
322+
t.Errorf("ParseCard(%s) = %d, want %d", tt.card, got, tt.want)
323+
}
324+
325+
// Additional statements should be included
326+
fmt.Println("the whole block")
327+
fmt.Println("should be returned")
328+
}`,
329+
},
330+
}

testrunner/testdata/concept/conditionals/cases_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,62 @@ var allergicToTests = []struct {
1919
expected: false,
2020
},
2121
}
22+
var testcases = []struct {
23+
name string
24+
card string
25+
want int
26+
}{
27+
{
28+
name: "parse two",
29+
card: "two",
30+
want: 2,
31+
},
32+
{
33+
name: "parse jack",
34+
card: "jack",
35+
want: 10,
36+
},
37+
{
38+
name: "parse king",
39+
card: "king",
40+
want: 10,
41+
},
42+
}
43+
44+
type hand struct {
45+
card1, card2 string
46+
}
47+
48+
var testcases2 = []struct {
49+
name string
50+
hand hand
51+
want bool
52+
}{
53+
{
54+
name: "blackjack with ten (ace first)",
55+
hand: hand{card1: "ace", card2: "ten"},
56+
want: true,
57+
},
58+
{
59+
name: "blackjack with jack (ace first)",
60+
hand: hand{card1: "ace", card2: "jack"},
61+
want: true,
62+
},
63+
{
64+
name: "blackjack with queen (ace first)",
65+
hand: hand{
66+
card1: "ace", card2: "queen"
67+
},
68+
want: true,
69+
},
70+
{
71+
name: "blackjack with king (ace first)",
72+
hand: hand{card1: "ace", card2: "king"},
73+
want: true,
74+
},
75+
{
76+
name: "no blackjack with eight and five",
77+
hand: hand{card2: "eight", card1: "five"},
78+
want: false,
79+
},
80+
}

0 commit comments

Comments
 (0)