Skip to content

Commit 76e6942

Browse files
committed
Go: support extracting test code
This implements support for test extraction by two mechanisms: * In autobuild mode, setting `CODEQL_EXTRACTOR_GO_EXTRACT_TESTS` to `true`. * In manual build mode, tracing a `go test` command (`go test -c` is to be recommended for efficiency). Go deals with test compilation by creating several extra packages on top of those expected from inspection of the source code (see docs of `packages.Load` for more detail): packages whose IDs include a suffix like `mydomain.com/mypackage [mydomain.com/mypackage.test]`, and packages containing generated test driver code like `mydomain.com/mypackage.test`. There are also additional packages like `mydomain.com/mypackage_tests` which are explicitly present in source code, but not compiled by a normal `go build`. So far as I can tell, the purpose of the two variants of the package is to resolve dependency cycles (because the tests variant of the package can have more dependencies than the non-tests variant, and non-test code can compile against non-test package variants). Since the test package variants seems to be a superset of the non-tests variant, I employ the simple heuristic of ignoring the variant of each package with the shortest ID. I haven't seen a case where there are three or more variants of a package, so I expect this to always identify the tests variant as the preferred one. If several variants were extracted, and we were to attempt to match Golang's linkage strategy among the different variants, we would need to extend trap-file name and most top-level symbol trap IDs with the package variant they come from; I hope this won't prove necessary. "Real" `_tests` packages, and wholly synthetic driver code packages, are extracted just like normal.
1 parent 594045b commit 76e6942

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

go/extractor/cli/go-extractor/go-extractor.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func usage() {
2121
fmt.Fprintf(os.Stderr, "--help Print this help.\n")
2222
}
2323

24-
func parseFlags(args []string, mimic bool) ([]string, []string) {
24+
func parseFlags(args []string, mimic bool, extractTests bool) ([]string, []string, bool) {
2525
i := 0
2626
buildFlags := []string{}
2727
for ; i < len(args) && strings.HasPrefix(args[i], "-"); i++ {
@@ -44,9 +44,9 @@ func parseFlags(args []string, mimic bool) ([]string, []string) {
4444
if i+1 < len(args) {
4545
i++
4646
command := args[i]
47-
if command == "build" || command == "install" || command == "run" {
48-
log.Printf("Intercepting build")
49-
return parseFlags(args[i+1:], true)
47+
if command == "build" || command == "install" || command == "run" || command == "test" {
48+
log.Printf("Intercepting build for %s command", command)
49+
return parseFlags(args[i+1:], true, command == "test")
5050
} else {
5151
log.Printf("Non-build command '%s'; skipping", strings.Join(args[1:], " "))
5252
os.Exit(0)
@@ -63,12 +63,12 @@ func parseFlags(args []string, mimic bool) ([]string, []string) {
6363

6464
// parse go build flags
6565
switch args[i] {
66-
// skip `-o output` and `-i`, if applicable
66+
// skip `-o output`, `-i` and `-c`, if applicable
6767
case "-o":
6868
if i+1 < len(args) {
6969
i++
7070
}
71-
case "-i":
71+
case "-i", "-c":
7272
case "-p", "-asmflags", "-buildmode", "-compiler", "-gccgoflags", "-gcflags", "-installsuffix",
7373
"-ldflags", "-mod", "-modfile", "-pkgdir", "-tags", "-toolexec", "-overlay":
7474
if i+1 < len(args) {
@@ -90,11 +90,12 @@ func parseFlags(args []string, mimic bool) ([]string, []string) {
9090
cpuprofile = os.Getenv("CODEQL_EXTRACTOR_GO_CPU_PROFILE")
9191
memprofile = os.Getenv("CODEQL_EXTRACTOR_GO_MEM_PROFILE")
9292

93-
return buildFlags, args[i:]
93+
return buildFlags, args[i:], extractTests
9494
}
9595

9696
func main() {
97-
buildFlags, patterns := parseFlags(os.Args[1:], false)
97+
extractTestsDefault := os.Getenv("CODEQL_EXTRACTOR_GO_EXTRACT_TESTS") == "true"
98+
buildFlags, patterns, extractTests := parseFlags(os.Args[1:], false, extractTestsDefault)
9899

99100
if cpuprofile != "" {
100101
f, err := os.Create(cpuprofile)
@@ -114,7 +115,7 @@ func main() {
114115
}
115116

116117
log.Printf("Build flags: '%s'; patterns: '%s'\n", strings.Join(buildFlags, " "), strings.Join(patterns, " "))
117-
err := extractor.ExtractWithFlags(buildFlags, patterns)
118+
err := extractor.ExtractWithFlags(buildFlags, patterns, extractTests)
118119
if err != nil {
119120
errString := err.Error()
120121
if strings.Contains(errString, "unexpected directory layout:") {

go/extractor/extractor.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ func init() {
5959

6060
// Extract extracts the packages specified by the given patterns
6161
func Extract(patterns []string) error {
62-
return ExtractWithFlags(nil, patterns)
62+
return ExtractWithFlags(nil, patterns, false)
6363
}
6464

6565
// ExtractWithFlags extracts the packages specified by the given patterns and build flags
66-
func ExtractWithFlags(buildFlags []string, patterns []string) error {
66+
func ExtractWithFlags(buildFlags []string, patterns []string, extractTests bool) error {
6767
startTime := time.Now()
6868

6969
extraction := NewExtraction(buildFlags, patterns)
@@ -89,6 +89,7 @@ func ExtractWithFlags(buildFlags []string, patterns []string) error {
8989
packages.NeedTypes | packages.NeedTypesSizes |
9090
packages.NeedTypesInfo | packages.NeedSyntax,
9191
BuildFlags: buildFlags,
92+
Tests: extractTests,
9293
}
9394
pkgs, err := packages.Load(cfg, patterns...)
9495
if err != nil {
@@ -132,10 +133,33 @@ func ExtractWithFlags(buildFlags []string, patterns []string) error {
132133

133134
pkgsNotFound := make([]string, 0, len(pkgs))
134135

136+
// Build a map from package paths to their longest IDs--
137+
// in the context of a `go test -c` compilation, we will see the same package more than
138+
// once, with IDs like "abc.com/pkgname [abc.com/pkgname.test]" to distinguish the version
139+
// that contains and is used by test code.
140+
// For our purposes it is simplest to just ignore the non-test version, since the test
141+
// version seems to be a superset of it.
142+
longestPackageIds := make(map[string]string)
143+
packages.Visit(pkgs, nil, func(pkg *packages.Package) {
144+
if shortestID, present := longestPackageIds[pkg.PkgPath]; present {
145+
if len(pkg.ID) > len(shortestID) {
146+
longestPackageIds[pkg.PkgPath] = pkg.ID
147+
}
148+
} else {
149+
longestPackageIds[pkg.PkgPath] = pkg.ID
150+
}
151+
})
152+
135153
// Do a post-order traversal and extract the package scope of each package
136154
packages.Visit(pkgs, nil, func(pkg *packages.Package) {
137155
log.Printf("Processing package %s.", pkg.PkgPath)
138156

157+
// If this is a variant of a package that also occurs with a longer ID, skip it.
158+
if pkg.ID != longestPackageIds[pkg.PkgPath] {
159+
log.Printf("Skipping variant of package %s with ID %s.", pkg.PkgPath, pkg.ID)
160+
return
161+
}
162+
139163
if _, ok := pkgInfos[pkg.PkgPath]; !ok {
140164
pkgInfos[pkg.PkgPath] = toolchain.GetPkgInfo(pkg.PkgPath, modFlags...)
141165
}
@@ -210,6 +234,13 @@ func ExtractWithFlags(buildFlags []string, patterns []string) error {
210234

211235
// extract AST information for all packages
212236
packages.Visit(pkgs, nil, func(pkg *packages.Package) {
237+
238+
// If this is a variant of a package that also occurs with a longer ID, skip it.
239+
if pkg.ID != longestPackageIds[pkg.PkgPath] {
240+
// Don't log here; we already mentioned this above.
241+
return
242+
}
243+
213244
for root := range wantedRoots {
214245
pkgInfo := pkgInfos[pkg.PkgPath]
215246
relDir, err := filepath.Rel(root, pkgInfo.PkgDir)

0 commit comments

Comments
 (0)