@@ -212,21 +212,23 @@ var testTmpl = template.Must(template.New("test").Funcs(template.FuncMap{
212
212
213
213
// AddTestForFunc adds a test for the function enclosing the given input range.
214
214
// It creates a _test.go file if one does not already exist.
215
- func AddTestForFunc (ctx context.Context , snapshot * cache.Snapshot , loc protocol.Location ) (changes []protocol.DocumentChange , _ error ) {
215
+ // It returns the required text edits and the predicted location of the new test
216
+ // function, which is only valid after the edits have been successfully applied.
217
+ func AddTestForFunc (ctx context.Context , snapshot * cache.Snapshot , loc protocol.Location ) (changes []protocol.DocumentChange , show * protocol.Location , _ error ) {
216
218
pkg , pgf , err := NarrowestPackageForFile (ctx , snapshot , loc .URI )
217
219
if err != nil {
218
- return nil , err
220
+ return nil , nil , err
219
221
}
220
222
221
223
if metadata .IsCommandLineArguments (pkg .Metadata ().ID ) {
222
- return nil , fmt .Errorf ("current file in command-line-arguments package" )
224
+ return nil , nil , fmt .Errorf ("current file in command-line-arguments package" )
223
225
}
224
226
225
227
if errors := pkg .ParseErrors (); len (errors ) > 0 {
226
- return nil , fmt .Errorf ("package has parse errors: %v" , errors [0 ])
228
+ return nil , nil , fmt .Errorf ("package has parse errors: %v" , errors [0 ])
227
229
}
228
230
if errors := pkg .TypeErrors (); len (errors ) > 0 {
229
- return nil , fmt .Errorf ("package has type errors: %v" , errors [0 ])
231
+ return nil , nil , fmt .Errorf ("package has type errors: %v" , errors [0 ])
230
232
}
231
233
232
234
// All three maps map the path of an imported package to
@@ -262,15 +264,15 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
262
264
263
265
// Collect all the imports from the x.go, keep track of the local package name.
264
266
if fileImports , err = collectImports (pgf .File ); err != nil {
265
- return nil , err
267
+ return nil , nil , err
266
268
}
267
269
268
270
testBase := strings .TrimSuffix (loc .URI .Base (), ".go" ) + "_test.go"
269
271
goTestFileURI := protocol .URIFromPath (filepath .Join (loc .URI .DirPath (), testBase ))
270
272
271
273
testFH , err := snapshot .ReadFile (ctx , goTestFileURI )
272
274
if err != nil {
273
- return nil , err
275
+ return nil , nil , err
274
276
}
275
277
276
278
// TODO(hxjiang): use a fresh name if the same test function name already
@@ -289,17 +291,17 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
289
291
290
292
start , end , err := pgf .RangePos (loc .Range )
291
293
if err != nil {
292
- return nil , err
294
+ return nil , nil , err
293
295
}
294
296
295
297
path , _ := astutil .PathEnclosingInterval (pgf .File , start , end )
296
298
if len (path ) < 2 {
297
- return nil , fmt .Errorf ("no enclosing function" )
299
+ return nil , nil , fmt .Errorf ("no enclosing function" )
298
300
}
299
301
300
302
decl , ok := path [len (path )- 2 ].(* ast.FuncDecl )
301
303
if ! ok {
302
- return nil , fmt .Errorf ("no enclosing function" )
304
+ return nil , nil , fmt .Errorf ("no enclosing function" )
303
305
}
304
306
305
307
fn := pkg .TypesInfo ().Defs [decl .Name ].(* types.Func )
@@ -308,7 +310,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
308
310
testPGF , err := snapshot .ParseGo (ctx , testFH , parsego .Header )
309
311
if err != nil {
310
312
if ! errors .Is (err , os .ErrNotExist ) {
311
- return nil , err
313
+ return nil , nil , err
312
314
}
313
315
changes = append (changes , protocol .DocumentChangeCreate (goTestFileURI ))
314
316
@@ -322,7 +324,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
322
324
if c := CopyrightComment (pgf .File ); c != nil {
323
325
text , err := pgf .NodeText (c )
324
326
if err != nil {
325
- return nil , err
327
+ return nil , nil , err
326
328
}
327
329
header .Write (text )
328
330
// One empty line between copyright header and following.
@@ -334,7 +336,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
334
336
if c := buildConstraintComment (pgf .File ); c != nil {
335
337
text , err := pgf .NodeText (c )
336
338
if err != nil {
337
- return nil , err
339
+ return nil , nil , err
338
340
}
339
341
header .Write (text )
340
342
// One empty line between build constraint and following.
@@ -397,25 +399,25 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
397
399
} else { // existing _test.go file.
398
400
file := testPGF .File
399
401
if ! file .Name .NamePos .IsValid () {
400
- return nil , fmt .Errorf ("missing package declaration" )
402
+ return nil , nil , fmt .Errorf ("missing package declaration" )
401
403
}
402
404
switch file .Name .Name {
403
405
case pgf .File .Name .Name :
404
406
xtest = false
405
407
case pgf .File .Name .Name + "_test" :
406
408
xtest = true
407
409
default :
408
- return nil , fmt .Errorf ("invalid package declaration %q in test file %q" , file .Name , testPGF )
410
+ return nil , nil , fmt .Errorf ("invalid package declaration %q in test file %q" , file .Name , testPGF )
409
411
}
410
412
411
413
eofRange , err = testPGF .PosRange (file .FileEnd , file .FileEnd )
412
414
if err != nil {
413
- return nil , err
415
+ return nil , nil , err
414
416
}
415
417
416
418
// Collect all the imports from the foo_test.go.
417
419
if testImports , err = collectImports (file ); err != nil {
418
- return nil , err
420
+ return nil , nil , err
419
421
}
420
422
}
421
423
@@ -453,13 +455,13 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
453
455
if xtest {
454
456
// Reject if function/method is unexported.
455
457
if ! fn .Exported () {
456
- return nil , fmt .Errorf ("cannot add test of unexported function %s to external test package %s_test" , decl .Name , pgf .File .Name )
458
+ return nil , nil , fmt .Errorf ("cannot add test of unexported function %s to external test package %s_test" , decl .Name , pgf .File .Name )
457
459
}
458
460
459
461
// Reject if receiver is unexported.
460
462
if sig .Recv () != nil {
461
463
if _ , ident , _ := goplsastutil .UnpackRecv (decl .Recv .List [0 ].Type ); ident == nil || ! ident .IsExported () {
462
- return nil , fmt .Errorf ("cannot add external test for method %s.%s as receiver type is not exported" , ident .Name , decl .Name )
464
+ return nil , nil , fmt .Errorf ("cannot add external test for method %s.%s as receiver type is not exported" , ident .Name , decl .Name )
463
465
}
464
466
}
465
467
// TODO(hxjiang): reject if the any input parameter type is unexported.
@@ -469,7 +471,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
469
471
470
472
testName , err := testName (fn )
471
473
if err != nil {
472
- return nil , err
474
+ return nil , nil , err
473
475
}
474
476
475
477
data := testInfo {
@@ -525,7 +527,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
525
527
526
528
t , ok := recvType .(typesinternal.NamedOrAlias )
527
529
if ! ok {
528
- return nil , fmt .Errorf ("the receiver type is neither named type nor alias type" )
530
+ return nil , nil , fmt .Errorf ("the receiver type is neither named type nor alias type" )
529
531
}
530
532
531
533
var varName string
@@ -707,7 +709,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
707
709
}
708
710
importEdits , err := ComputeImportFixEdits (snapshot .Options ().Local , testPGF .Src , importFixes ... )
709
711
if err != nil {
710
- return nil , fmt .Errorf ("could not compute the import fix edits: %w" , err )
712
+ return nil , nil , fmt .Errorf ("could not compute the import fix edits: %w" , err )
711
713
}
712
714
edits = append (edits , importEdits ... )
713
715
} else {
@@ -740,21 +742,41 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
740
742
741
743
var test bytes.Buffer
742
744
if err := testTmpl .Execute (& test , data ); err != nil {
743
- return nil , err
745
+ return nil , nil , err
744
746
}
745
747
746
748
formatted , err := format .Source (test .Bytes ())
747
749
if err != nil {
748
- return nil , err
750
+ return nil , nil , err
749
751
}
750
752
751
753
edits = append (edits ,
752
754
protocol.TextEdit {
753
755
Range : eofRange ,
754
756
NewText : string (formatted ),
755
- })
757
+ },
758
+ )
759
+
760
+ // Show the line of generated test function.
761
+ {
762
+ line := eofRange .Start .Line
763
+ for i := range len (edits ) - 1 { // last edits is the func decl
764
+ e := edits [i ]
765
+ oldLines := e .Range .End .Line - e .Range .Start .Line
766
+ newLines := uint32 (strings .Count (e .NewText , "\n " ))
767
+ line += (newLines - oldLines )
768
+ }
769
+ show = & protocol.Location {
770
+ URI : testFH .URI (),
771
+ Range : protocol.Range {
772
+ // Test function template have a new line at beginning.
773
+ Start : protocol.Position {Line : line + 1 },
774
+ End : protocol.Position {Line : line + 1 },
775
+ },
776
+ }
777
+ }
756
778
757
- return append (changes , protocol .DocumentChangeEdit (testFH , edits )), nil
779
+ return append (changes , protocol .DocumentChangeEdit (testFH , edits )), show , nil
758
780
}
759
781
760
782
// testName returns the name of the function to use for the new function that
0 commit comments