@@ -17,19 +17,23 @@ import (
17
17
"html/template"
18
18
"os"
19
19
"path/filepath"
20
+ "sort"
20
21
"strconv"
21
22
"strings"
22
23
"unicode"
23
24
24
25
"golang.org/x/tools/go/ast/astutil"
25
26
"golang.org/x/tools/gopls/internal/cache"
27
+ "golang.org/x/tools/gopls/internal/cache/metadata"
26
28
"golang.org/x/tools/gopls/internal/cache/parsego"
27
29
"golang.org/x/tools/gopls/internal/protocol"
28
30
goplsastutil "golang.org/x/tools/gopls/internal/util/astutil"
31
+ "golang.org/x/tools/internal/imports"
29
32
"golang.org/x/tools/internal/typesinternal"
30
33
)
31
34
32
- const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
35
+ const testTmplString = `
36
+ func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
33
37
{{- /* Constructor input parameters struct declaration. */}}
34
38
{{- if and .Receiver .Receiver.Constructor}}
35
39
{{- if gt (len .Receiver.Constructor.Args) 1}}
@@ -83,7 +87,7 @@ const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
83
87
84
88
{{- /* Loop over all the test cases. */}}
85
89
for _, tt := range tests {
86
- t.Run(tt.name, func(t *testing .T) {
90
+ t.Run(tt.name, func(t *{{.TestingPackageName}} .T) {
87
91
{{- /* Constructor or empty initialization. */}}
88
92
{{- if .Receiver}}
89
93
{{- if .Receiver.Constructor}}
@@ -170,6 +174,10 @@ type receiver struct {
170
174
}
171
175
172
176
type testInfo struct {
177
+ // TestingPackageName is the package name should be used when referencing
178
+ // package "testing"
179
+ TestingPackageName string
180
+ // PackageName is the package name the target function/method is delcared from.
173
181
PackageName string
174
182
TestFuncName string
175
183
// Func holds information about the function or method being tested.
@@ -202,37 +210,79 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
202
210
return nil , err
203
211
}
204
212
213
+ if metadata .IsCommandLineArguments (pkg .Metadata ().ID ) {
214
+ return nil , fmt .Errorf ("current file in command-line-arguments package" )
215
+ }
216
+
205
217
if errors := pkg .ParseErrors (); len (errors ) > 0 {
206
218
return nil , fmt .Errorf ("package has parse errors: %v" , errors [0 ])
207
219
}
208
220
if errors := pkg .TypeErrors (); len (errors ) > 0 {
209
221
return nil , fmt .Errorf ("package has type errors: %v" , errors [0 ])
210
222
}
211
223
212
- // imports is a map from package path to local package name.
213
- var imports = make (map [string ]string )
224
+ type packageInfo struct {
225
+ name string
226
+ renamed bool
227
+ }
228
+
229
+ var (
230
+ // fileImports is a map contains all the path imported in the original
231
+ // file foo.go.
232
+ fileImports map [string ]packageInfo
233
+ // testImports is a map contains all the path already imported in test
234
+ // file foo_test.go.
235
+ testImports map [string ]packageInfo
236
+ // extraImportsis a map from package path to local package name that
237
+ // need to be imported for the test function.
238
+ extraImports = make (map [string ]packageInfo )
239
+ )
214
240
215
- var collectImports = func (file * ast.File ) error {
241
+ var collectImports = func (file * ast.File ) (map [string ]packageInfo , error ) {
242
+ imps := make (map [string ]packageInfo )
216
243
for _ , spec := range file .Imports {
217
244
// TODO(hxjiang): support dot imports.
218
245
if spec .Name != nil && spec .Name .Name == "." {
219
- return fmt .Errorf ("\" add a test for FUNC \" does not support files containing dot imports" )
246
+ return nil , fmt .Errorf ("\" add a test for func \" does not support files containing dot imports" )
220
247
}
221
248
path , err := strconv .Unquote (spec .Path .Value )
222
249
if err != nil {
223
- return err
250
+ return nil , err
224
251
}
225
- if spec .Name != nil && spec .Name .Name != "_" {
226
- imports [path ] = spec .Name .Name
252
+ if spec .Name != nil {
253
+ if spec .Name .Name == "_" {
254
+ continue
255
+ }
256
+ imps [path ] = packageInfo {spec .Name .Name , true }
227
257
} else {
228
- imports [path ] = filepath .Base (path )
258
+ // The package name might differ from the base of its import
259
+ // path. For example, "/path/to/package/foo" could declare a
260
+ // package named "bar". Look up the target package ensures the
261
+ // accurate package name reference.
262
+ //
263
+ // While it's best practice to rename imported packages when
264
+ // their name differs from the base path (e.g.,
265
+ // "import bar \"path/to/package/foo\""), this is not mandatory.
266
+ id := pkg .Metadata ().DepsByImpPath [metadata .ImportPath (path )]
267
+ if metadata .IsCommandLineArguments (id ) {
268
+ return nil , fmt .Errorf ("can not import command-line-arguments package" )
269
+ }
270
+ if id == "" { // guess upon missing.
271
+ imps [path ] = packageInfo {imports .ImportPathToAssumedName (path ), false }
272
+ } else {
273
+ fromPkg , ok := snapshot .MetadataGraph ().Packages [id ]
274
+ if ! ok {
275
+ return nil , fmt .Errorf ("package id %v does not exist" , id )
276
+ }
277
+ imps [path ] = packageInfo {string (fromPkg .Name ), false }
278
+ }
229
279
}
230
280
}
231
- return nil
281
+ return imps , nil
232
282
}
233
283
234
284
// Collect all the imports from the x.go, keep track of the local package name.
235
- if err : = collectImports (pgf .File ); err != nil {
285
+ if fileImports , err = collectImports (pgf .File ); err != nil {
236
286
return nil , err
237
287
}
238
288
@@ -259,7 +309,8 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
259
309
xtest = true
260
310
)
261
311
262
- if testPGF , err := snapshot .ParseGo (ctx , testFH , parsego .Header ); err != nil {
312
+ testPGF , err := snapshot .ParseGo (ctx , testFH , parsego .Header )
313
+ if err != nil {
263
314
if ! errors .Is (err , os .ErrNotExist ) {
264
315
return nil , err
265
316
}
@@ -288,8 +339,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
288
339
header .WriteString ("\n \n " )
289
340
}
290
341
}
291
- // One empty line between package decl and rest of the file.
292
- fmt .Fprintf (& header , "package %s_test\n \n " , pkg .Types ().Name ())
342
+ fmt .Fprintf (& header , "package %s_test\n " , pkg .Types ().Name ())
293
343
294
344
// Write the copyright and package decl to the beginning of the file.
295
345
edits = append (edits , protocol.TextEdit {
@@ -314,29 +364,41 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
314
364
return nil , err
315
365
}
316
366
317
- // Collect all the imports from the x_test.go, overwrite the local pakcage
318
- // name collected from x.go.
319
- if err := collectImports (testPGF .File ); err != nil {
367
+ // Collect all the imports from the foo_test.go.
368
+ if testImports , err = collectImports (testPGF .File ); err != nil {
320
369
return nil , err
321
370
}
322
371
}
323
372
324
- // qf qualifier returns the local package name need to use in x_test.go by
325
- // consulting the consolidated imports map.
373
+ // qf qualifier determines the correct package name to use for a type in
374
+ // foo_test.go. It does this by:
375
+ // - Consult imports map from test file foo_test.go.
376
+ // - If not found, consult imports map from original file foo.go.
377
+ // If the package is not imported in test file foo_test.go, it is added to
378
+ // extraImports map.
326
379
qf := func (p * types.Package ) string {
327
380
// When generating test in x packages, any type/function defined in the same
328
381
// x package can emit package name.
329
382
if ! xtest && p == pkg .Types () {
330
383
return ""
331
384
}
332
- if local , ok := imports [p .Path ()]; ok {
333
- return local
385
+ // Prefer using the package name if already defined in foo_test.go
386
+ if local , ok := testImports [p .Path ()]; ok {
387
+ return local .name
334
388
}
389
+ // TODO(hxjiang): we should consult the scope of the test package to
390
+ // ensure these new imports do not shadow any package-level names.
391
+ // If not already imported by foo_test.go, consult the foo.go import map.
392
+ if local , ok := fileImports [p .Path ()]; ok {
393
+ // The package that contains this type need to be added to the import
394
+ // list in foo_test.go.
395
+ extraImports [p .Path ()] = local
396
+ return local .name
397
+ }
398
+ extraImports [p .Path ()] = packageInfo {name : p .Name ()}
335
399
return p .Name ()
336
400
}
337
401
338
- // TODO(hxjiang): modify existing imports or add new imports.
339
-
340
402
start , end , err := pgf .RangePos (loc .Range )
341
403
if err != nil {
342
404
return nil , err
@@ -378,8 +440,9 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
378
440
}
379
441
380
442
data := testInfo {
381
- PackageName : qf (pkg .Types ()),
382
- TestFuncName : testName ,
443
+ TestingPackageName : qf (types .NewPackage ("testing" , "testing" )),
444
+ PackageName : qf (pkg .Types ()),
445
+ TestFuncName : testName ,
383
446
Func : function {
384
447
Name : fn .Name (),
385
448
},
@@ -557,15 +620,73 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
557
620
}
558
621
}
559
622
623
+ // Compute edits to update imports.
624
+ //
625
+ // If we're adding to an existing test file, we need to adjust existing
626
+ // imports. Otherwise, we can simply write out the imports to the new file.
627
+ if testPGF != nil {
628
+ var importFixes []* imports.ImportFix
629
+ for path , info := range extraImports {
630
+ name := ""
631
+ if info .renamed {
632
+ name = info .name
633
+ }
634
+ importFixes = append (importFixes , & imports.ImportFix {
635
+ StmtInfo : imports.ImportInfo {
636
+ ImportPath : path ,
637
+ Name : name ,
638
+ },
639
+ FixType : imports .AddImport ,
640
+ })
641
+ }
642
+ importEdits , err := ComputeImportFixEdits (snapshot .Options ().Local , testPGF .Src , importFixes ... )
643
+ if err != nil {
644
+ return nil , fmt .Errorf ("could not compute the import fix edits: %w" , err )
645
+ }
646
+ edits = append (edits , importEdits ... )
647
+ } else {
648
+ var importsBuffer bytes.Buffer
649
+ if len (extraImports ) == 1 {
650
+ importsBuffer .WriteString ("\n import " )
651
+ for path , info := range extraImports {
652
+ if info .renamed {
653
+ importsBuffer .WriteString (info .name + " " )
654
+ }
655
+ importsBuffer .WriteString (fmt .Sprintf ("\" %s\" \n " , path ))
656
+ }
657
+ } else {
658
+ importsBuffer .WriteString ("\n import(" )
659
+ // Loop over the map in sorted order ensures deterministic outcome.
660
+ paths := make ([]string , 0 , len (extraImports ))
661
+ for key := range extraImports {
662
+ paths = append (paths , key )
663
+ }
664
+ sort .Strings (paths )
665
+ for _ , path := range paths {
666
+ importsBuffer .WriteString ("\n \t " )
667
+ if extraImports [path ].renamed {
668
+ importsBuffer .WriteString (extraImports [path ].name + " " )
669
+ }
670
+ importsBuffer .WriteString (fmt .Sprintf ("\" %s\" " , path ))
671
+ }
672
+ importsBuffer .WriteString ("\n )\n " )
673
+ }
674
+ edits = append (edits , protocol.TextEdit {
675
+ Range : protocol.Range {},
676
+ NewText : importsBuffer .String (),
677
+ })
678
+ }
679
+
560
680
var test bytes.Buffer
561
681
if err := testTmpl .Execute (& test , data ); err != nil {
562
682
return nil , err
563
683
}
564
684
565
- edits = append (edits , protocol.TextEdit {
566
- Range : eofRange ,
567
- NewText : test .String (),
568
- })
685
+ edits = append (edits ,
686
+ protocol.TextEdit {
687
+ Range : eofRange ,
688
+ NewText : test .String (),
689
+ })
569
690
570
691
return append (changes , protocol .DocumentChangeEdit (testFH , edits )), nil
571
692
}
0 commit comments