Skip to content

Commit 73a3264

Browse files
h9jiangdennypenta
authored andcommitted
gopls/internal/golang: add missing imports in foo_test.go
- Gopls will honor any renaming of package "testing" if any. - Gopls will collect all the package that have not been imported in foo_test.go and modify the foo_test.go imports. For golang/vscode-go#1594 Change-Id: Id6b87b6417a26f8e925582317e91fb4ebff4a0e7 Reviewed-on: https://go-review.googlesource.com/c/tools/+/620697 Reviewed-by: Robert Findley <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent 7eb3c4c commit 73a3264

File tree

2 files changed

+312
-66
lines changed

2 files changed

+312
-66
lines changed

gopls/internal/golang/addtest.go

Lines changed: 151 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,23 @@ import (
1717
"html/template"
1818
"os"
1919
"path/filepath"
20+
"sort"
2021
"strconv"
2122
"strings"
2223
"unicode"
2324

2425
"golang.org/x/tools/go/ast/astutil"
2526
"golang.org/x/tools/gopls/internal/cache"
27+
"golang.org/x/tools/gopls/internal/cache/metadata"
2628
"golang.org/x/tools/gopls/internal/cache/parsego"
2729
"golang.org/x/tools/gopls/internal/protocol"
2830
goplsastutil "golang.org/x/tools/gopls/internal/util/astutil"
31+
"golang.org/x/tools/internal/imports"
2932
"golang.org/x/tools/internal/typesinternal"
3033
)
3134

32-
const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
35+
const testTmplString = `
36+
func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
3337
{{- /* Constructor input parameters struct declaration. */}}
3438
{{- if and .Receiver .Receiver.Constructor}}
3539
{{- if gt (len .Receiver.Constructor.Args) 1}}
@@ -83,7 +87,7 @@ const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
8387
8488
{{- /* Loop over all the test cases. */}}
8589
for _, tt := range tests {
86-
t.Run(tt.name, func(t *testing.T) {
90+
t.Run(tt.name, func(t *{{.TestingPackageName}}.T) {
8791
{{- /* Constructor or empty initialization. */}}
8892
{{- if .Receiver}}
8993
{{- if .Receiver.Constructor}}
@@ -170,6 +174,10 @@ type receiver struct {
170174
}
171175

172176
type testInfo struct {
177+
// TestingPackageName is the package name should be used when referencing
178+
// package "testing"
179+
TestingPackageName string
180+
// PackageName is the package name the target function/method is delcared from.
173181
PackageName string
174182
TestFuncName string
175183
// Func holds information about the function or method being tested.
@@ -202,37 +210,79 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
202210
return nil, err
203211
}
204212

213+
if metadata.IsCommandLineArguments(pkg.Metadata().ID) {
214+
return nil, fmt.Errorf("current file in command-line-arguments package")
215+
}
216+
205217
if errors := pkg.ParseErrors(); len(errors) > 0 {
206218
return nil, fmt.Errorf("package has parse errors: %v", errors[0])
207219
}
208220
if errors := pkg.TypeErrors(); len(errors) > 0 {
209221
return nil, fmt.Errorf("package has type errors: %v", errors[0])
210222
}
211223

212-
// imports is a map from package path to local package name.
213-
var imports = make(map[string]string)
224+
type packageInfo struct {
225+
name string
226+
renamed bool
227+
}
228+
229+
var (
230+
// fileImports is a map contains all the path imported in the original
231+
// file foo.go.
232+
fileImports map[string]packageInfo
233+
// testImports is a map contains all the path already imported in test
234+
// file foo_test.go.
235+
testImports map[string]packageInfo
236+
// extraImportsis a map from package path to local package name that
237+
// need to be imported for the test function.
238+
extraImports = make(map[string]packageInfo)
239+
)
214240

215-
var collectImports = func(file *ast.File) error {
241+
var collectImports = func(file *ast.File) (map[string]packageInfo, error) {
242+
imps := make(map[string]packageInfo)
216243
for _, spec := range file.Imports {
217244
// TODO(hxjiang): support dot imports.
218245
if spec.Name != nil && spec.Name.Name == "." {
219-
return fmt.Errorf("\"add a test for FUNC\" does not support files containing dot imports")
246+
return nil, fmt.Errorf("\"add a test for func\" does not support files containing dot imports")
220247
}
221248
path, err := strconv.Unquote(spec.Path.Value)
222249
if err != nil {
223-
return err
250+
return nil, err
224251
}
225-
if spec.Name != nil && spec.Name.Name != "_" {
226-
imports[path] = spec.Name.Name
252+
if spec.Name != nil {
253+
if spec.Name.Name == "_" {
254+
continue
255+
}
256+
imps[path] = packageInfo{spec.Name.Name, true}
227257
} else {
228-
imports[path] = filepath.Base(path)
258+
// The package name might differ from the base of its import
259+
// path. For example, "/path/to/package/foo" could declare a
260+
// package named "bar". Look up the target package ensures the
261+
// accurate package name reference.
262+
//
263+
// While it's best practice to rename imported packages when
264+
// their name differs from the base path (e.g.,
265+
// "import bar \"path/to/package/foo\""), this is not mandatory.
266+
id := pkg.Metadata().DepsByImpPath[metadata.ImportPath(path)]
267+
if metadata.IsCommandLineArguments(id) {
268+
return nil, fmt.Errorf("can not import command-line-arguments package")
269+
}
270+
if id == "" { // guess upon missing.
271+
imps[path] = packageInfo{imports.ImportPathToAssumedName(path), false}
272+
} else {
273+
fromPkg, ok := snapshot.MetadataGraph().Packages[id]
274+
if !ok {
275+
return nil, fmt.Errorf("package id %v does not exist", id)
276+
}
277+
imps[path] = packageInfo{string(fromPkg.Name), false}
278+
}
229279
}
230280
}
231-
return nil
281+
return imps, nil
232282
}
233283

234284
// Collect all the imports from the x.go, keep track of the local package name.
235-
if err := collectImports(pgf.File); err != nil {
285+
if fileImports, err = collectImports(pgf.File); err != nil {
236286
return nil, err
237287
}
238288

@@ -259,7 +309,8 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
259309
xtest = true
260310
)
261311

262-
if testPGF, err := snapshot.ParseGo(ctx, testFH, parsego.Header); err != nil {
312+
testPGF, err := snapshot.ParseGo(ctx, testFH, parsego.Header)
313+
if err != nil {
263314
if !errors.Is(err, os.ErrNotExist) {
264315
return nil, err
265316
}
@@ -288,8 +339,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
288339
header.WriteString("\n\n")
289340
}
290341
}
291-
// One empty line between package decl and rest of the file.
292-
fmt.Fprintf(&header, "package %s_test\n\n", pkg.Types().Name())
342+
fmt.Fprintf(&header, "package %s_test\n", pkg.Types().Name())
293343

294344
// Write the copyright and package decl to the beginning of the file.
295345
edits = append(edits, protocol.TextEdit{
@@ -314,29 +364,41 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
314364
return nil, err
315365
}
316366

317-
// Collect all the imports from the x_test.go, overwrite the local pakcage
318-
// name collected from x.go.
319-
if err := collectImports(testPGF.File); err != nil {
367+
// Collect all the imports from the foo_test.go.
368+
if testImports, err = collectImports(testPGF.File); err != nil {
320369
return nil, err
321370
}
322371
}
323372

324-
// qf qualifier returns the local package name need to use in x_test.go by
325-
// consulting the consolidated imports map.
373+
// qf qualifier determines the correct package name to use for a type in
374+
// foo_test.go. It does this by:
375+
// - Consult imports map from test file foo_test.go.
376+
// - If not found, consult imports map from original file foo.go.
377+
// If the package is not imported in test file foo_test.go, it is added to
378+
// extraImports map.
326379
qf := func(p *types.Package) string {
327380
// When generating test in x packages, any type/function defined in the same
328381
// x package can emit package name.
329382
if !xtest && p == pkg.Types() {
330383
return ""
331384
}
332-
if local, ok := imports[p.Path()]; ok {
333-
return local
385+
// Prefer using the package name if already defined in foo_test.go
386+
if local, ok := testImports[p.Path()]; ok {
387+
return local.name
334388
}
389+
// TODO(hxjiang): we should consult the scope of the test package to
390+
// ensure these new imports do not shadow any package-level names.
391+
// If not already imported by foo_test.go, consult the foo.go import map.
392+
if local, ok := fileImports[p.Path()]; ok {
393+
// The package that contains this type need to be added to the import
394+
// list in foo_test.go.
395+
extraImports[p.Path()] = local
396+
return local.name
397+
}
398+
extraImports[p.Path()] = packageInfo{name: p.Name()}
335399
return p.Name()
336400
}
337401

338-
// TODO(hxjiang): modify existing imports or add new imports.
339-
340402
start, end, err := pgf.RangePos(loc.Range)
341403
if err != nil {
342404
return nil, err
@@ -378,8 +440,9 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
378440
}
379441

380442
data := testInfo{
381-
PackageName: qf(pkg.Types()),
382-
TestFuncName: testName,
443+
TestingPackageName: qf(types.NewPackage("testing", "testing")),
444+
PackageName: qf(pkg.Types()),
445+
TestFuncName: testName,
383446
Func: function{
384447
Name: fn.Name(),
385448
},
@@ -557,15 +620,73 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
557620
}
558621
}
559622

623+
// Compute edits to update imports.
624+
//
625+
// If we're adding to an existing test file, we need to adjust existing
626+
// imports. Otherwise, we can simply write out the imports to the new file.
627+
if testPGF != nil {
628+
var importFixes []*imports.ImportFix
629+
for path, info := range extraImports {
630+
name := ""
631+
if info.renamed {
632+
name = info.name
633+
}
634+
importFixes = append(importFixes, &imports.ImportFix{
635+
StmtInfo: imports.ImportInfo{
636+
ImportPath: path,
637+
Name: name,
638+
},
639+
FixType: imports.AddImport,
640+
})
641+
}
642+
importEdits, err := ComputeImportFixEdits(snapshot.Options().Local, testPGF.Src, importFixes...)
643+
if err != nil {
644+
return nil, fmt.Errorf("could not compute the import fix edits: %w", err)
645+
}
646+
edits = append(edits, importEdits...)
647+
} else {
648+
var importsBuffer bytes.Buffer
649+
if len(extraImports) == 1 {
650+
importsBuffer.WriteString("\nimport ")
651+
for path, info := range extraImports {
652+
if info.renamed {
653+
importsBuffer.WriteString(info.name + " ")
654+
}
655+
importsBuffer.WriteString(fmt.Sprintf("\"%s\"\n", path))
656+
}
657+
} else {
658+
importsBuffer.WriteString("\nimport(")
659+
// Loop over the map in sorted order ensures deterministic outcome.
660+
paths := make([]string, 0, len(extraImports))
661+
for key := range extraImports {
662+
paths = append(paths, key)
663+
}
664+
sort.Strings(paths)
665+
for _, path := range paths {
666+
importsBuffer.WriteString("\n\t")
667+
if extraImports[path].renamed {
668+
importsBuffer.WriteString(extraImports[path].name + " ")
669+
}
670+
importsBuffer.WriteString(fmt.Sprintf("\"%s\"", path))
671+
}
672+
importsBuffer.WriteString("\n)\n")
673+
}
674+
edits = append(edits, protocol.TextEdit{
675+
Range: protocol.Range{},
676+
NewText: importsBuffer.String(),
677+
})
678+
}
679+
560680
var test bytes.Buffer
561681
if err := testTmpl.Execute(&test, data); err != nil {
562682
return nil, err
563683
}
564684

565-
edits = append(edits, protocol.TextEdit{
566-
Range: eofRange,
567-
NewText: test.String(),
568-
})
685+
edits = append(edits,
686+
protocol.TextEdit{
687+
Range: eofRange,
688+
NewText: test.String(),
689+
})
569690

570691
return append(changes, protocol.DocumentChangeEdit(testFH, edits)), nil
571692
}

0 commit comments

Comments
 (0)