Skip to content

Commit 9420753

Browse files
committed
gopls/internal/golang: show document after test generation
Upon successful test generation, call LSP show document method pointing to the generated test function decl. The location is being calculated based on the current version of the file and the edits that will be applied. The location should only be used if the edtis got successfully applied by the LSP client. For golang/vscode-go#1594 Change-Id: I2912879d524246d6618761b930a48c883e647046 Reviewed-on: https://go-review.googlesource.com/c/tools/+/692055 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Robert Findley <[email protected]>
1 parent b7dd6b4 commit 9420753

File tree

3 files changed

+177
-29
lines changed

3 files changed

+177
-29
lines changed

gopls/internal/golang/addtest.go

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -212,21 +212,23 @@ var testTmpl = template.Must(template.New("test").Funcs(template.FuncMap{
212212

213213
// AddTestForFunc adds a test for the function enclosing the given input range.
214214
// It creates a _test.go file if one does not already exist.
215-
func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.Location) (changes []protocol.DocumentChange, _ error) {
215+
// It returns the required text edits and the predicted location of the new test
216+
// function, which is only valid after the edits have been successfully applied.
217+
func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.Location) (changes []protocol.DocumentChange, show *protocol.Location, _ error) {
216218
pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, loc.URI)
217219
if err != nil {
218-
return nil, err
220+
return nil, nil, err
219221
}
220222

221223
if metadata.IsCommandLineArguments(pkg.Metadata().ID) {
222-
return nil, fmt.Errorf("current file in command-line-arguments package")
224+
return nil, nil, fmt.Errorf("current file in command-line-arguments package")
223225
}
224226

225227
if errors := pkg.ParseErrors(); len(errors) > 0 {
226-
return nil, fmt.Errorf("package has parse errors: %v", errors[0])
228+
return nil, nil, fmt.Errorf("package has parse errors: %v", errors[0])
227229
}
228230
if errors := pkg.TypeErrors(); len(errors) > 0 {
229-
return nil, fmt.Errorf("package has type errors: %v", errors[0])
231+
return nil, nil, fmt.Errorf("package has type errors: %v", errors[0])
230232
}
231233

232234
// All three maps map the path of an imported package to
@@ -262,15 +264,15 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
262264

263265
// Collect all the imports from the x.go, keep track of the local package name.
264266
if fileImports, err = collectImports(pgf.File); err != nil {
265-
return nil, err
267+
return nil, nil, err
266268
}
267269

268270
testBase := strings.TrimSuffix(loc.URI.Base(), ".go") + "_test.go"
269271
goTestFileURI := protocol.URIFromPath(filepath.Join(loc.URI.DirPath(), testBase))
270272

271273
testFH, err := snapshot.ReadFile(ctx, goTestFileURI)
272274
if err != nil {
273-
return nil, err
275+
return nil, nil, err
274276
}
275277

276278
// TODO(hxjiang): use a fresh name if the same test function name already
@@ -289,17 +291,17 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
289291

290292
start, end, err := pgf.RangePos(loc.Range)
291293
if err != nil {
292-
return nil, err
294+
return nil, nil, err
293295
}
294296

295297
path, _ := astutil.PathEnclosingInterval(pgf.File, start, end)
296298
if len(path) < 2 {
297-
return nil, fmt.Errorf("no enclosing function")
299+
return nil, nil, fmt.Errorf("no enclosing function")
298300
}
299301

300302
decl, ok := path[len(path)-2].(*ast.FuncDecl)
301303
if !ok {
302-
return nil, fmt.Errorf("no enclosing function")
304+
return nil, nil, fmt.Errorf("no enclosing function")
303305
}
304306

305307
fn := pkg.TypesInfo().Defs[decl.Name].(*types.Func)
@@ -308,7 +310,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
308310
testPGF, err := snapshot.ParseGo(ctx, testFH, parsego.Header)
309311
if err != nil {
310312
if !errors.Is(err, os.ErrNotExist) {
311-
return nil, err
313+
return nil, nil, err
312314
}
313315
changes = append(changes, protocol.DocumentChangeCreate(goTestFileURI))
314316

@@ -322,7 +324,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
322324
if c := CopyrightComment(pgf.File); c != nil {
323325
text, err := pgf.NodeText(c)
324326
if err != nil {
325-
return nil, err
327+
return nil, nil, err
326328
}
327329
header.Write(text)
328330
// One empty line between copyright header and following.
@@ -334,7 +336,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
334336
if c := buildConstraintComment(pgf.File); c != nil {
335337
text, err := pgf.NodeText(c)
336338
if err != nil {
337-
return nil, err
339+
return nil, nil, err
338340
}
339341
header.Write(text)
340342
// One empty line between build constraint and following.
@@ -397,25 +399,25 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
397399
} else { // existing _test.go file.
398400
file := testPGF.File
399401
if !file.Name.NamePos.IsValid() {
400-
return nil, fmt.Errorf("missing package declaration")
402+
return nil, nil, fmt.Errorf("missing package declaration")
401403
}
402404
switch file.Name.Name {
403405
case pgf.File.Name.Name:
404406
xtest = false
405407
case pgf.File.Name.Name + "_test":
406408
xtest = true
407409
default:
408-
return nil, fmt.Errorf("invalid package declaration %q in test file %q", file.Name, testPGF)
410+
return nil, nil, fmt.Errorf("invalid package declaration %q in test file %q", file.Name, testPGF)
409411
}
410412

411413
eofRange, err = testPGF.PosRange(file.FileEnd, file.FileEnd)
412414
if err != nil {
413-
return nil, err
415+
return nil, nil, err
414416
}
415417

416418
// Collect all the imports from the foo_test.go.
417419
if testImports, err = collectImports(file); err != nil {
418-
return nil, err
420+
return nil, nil, err
419421
}
420422
}
421423

@@ -453,13 +455,13 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
453455
if xtest {
454456
// Reject if function/method is unexported.
455457
if !fn.Exported() {
456-
return nil, fmt.Errorf("cannot add test of unexported function %s to external test package %s_test", decl.Name, pgf.File.Name)
458+
return nil, nil, fmt.Errorf("cannot add test of unexported function %s to external test package %s_test", decl.Name, pgf.File.Name)
457459
}
458460

459461
// Reject if receiver is unexported.
460462
if sig.Recv() != nil {
461463
if _, ident, _ := goplsastutil.UnpackRecv(decl.Recv.List[0].Type); ident == nil || !ident.IsExported() {
462-
return nil, fmt.Errorf("cannot add external test for method %s.%s as receiver type is not exported", ident.Name, decl.Name)
464+
return nil, nil, fmt.Errorf("cannot add external test for method %s.%s as receiver type is not exported", ident.Name, decl.Name)
463465
}
464466
}
465467
// TODO(hxjiang): reject if the any input parameter type is unexported.
@@ -469,7 +471,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
469471

470472
testName, err := testName(fn)
471473
if err != nil {
472-
return nil, err
474+
return nil, nil, err
473475
}
474476

475477
data := testInfo{
@@ -525,7 +527,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
525527

526528
t, ok := recvType.(typesinternal.NamedOrAlias)
527529
if !ok {
528-
return nil, fmt.Errorf("the receiver type is neither named type nor alias type")
530+
return nil, nil, fmt.Errorf("the receiver type is neither named type nor alias type")
529531
}
530532

531533
var varName string
@@ -707,7 +709,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
707709
}
708710
importEdits, err := ComputeImportFixEdits(snapshot.Options().Local, testPGF.Src, importFixes...)
709711
if err != nil {
710-
return nil, fmt.Errorf("could not compute the import fix edits: %w", err)
712+
return nil, nil, fmt.Errorf("could not compute the import fix edits: %w", err)
711713
}
712714
edits = append(edits, importEdits...)
713715
} else {
@@ -740,21 +742,41 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
740742

741743
var test bytes.Buffer
742744
if err := testTmpl.Execute(&test, data); err != nil {
743-
return nil, err
745+
return nil, nil, err
744746
}
745747

746748
formatted, err := format.Source(test.Bytes())
747749
if err != nil {
748-
return nil, err
750+
return nil, nil, err
749751
}
750752

751753
edits = append(edits,
752754
protocol.TextEdit{
753755
Range: eofRange,
754756
NewText: string(formatted),
755-
})
757+
},
758+
)
759+
760+
// Show the line of generated test function.
761+
{
762+
line := eofRange.Start.Line
763+
for i := range len(edits) - 1 { // last edits is the func decl
764+
e := edits[i]
765+
oldLines := e.Range.End.Line - e.Range.Start.Line
766+
newLines := uint32(strings.Count(e.NewText, "\n"))
767+
line += (newLines - oldLines)
768+
}
769+
show = &protocol.Location{
770+
URI: testFH.URI(),
771+
Range: protocol.Range{
772+
// Test function template have a new line at beginning.
773+
Start: protocol.Position{Line: line + 1},
774+
End: protocol.Position{Line: line + 1},
775+
},
776+
}
777+
}
756778

757-
return append(changes, protocol.DocumentChangeEdit(testFH, edits)), nil
779+
return append(changes, protocol.DocumentChangeEdit(testFH, edits)), show, nil
758780
}
759781

760782
// testName returns the name of the function to use for the new function that

gopls/internal/server/command.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,13 +292,19 @@ func (c *commandHandler) AddTest(ctx context.Context, loc protocol.Location) (*p
292292
if deps.snapshot.FileKind(deps.fh) != file.Go {
293293
return fmt.Errorf("can't add test for non-Go file")
294294
}
295-
docedits, err := golang.AddTestForFunc(ctx, deps.snapshot, loc)
295+
docedits, show, err := golang.AddTestForFunc(ctx, deps.snapshot, loc)
296296
if err != nil {
297297
return err
298298
}
299-
return applyChanges(ctx, c.s.client, docedits)
299+
if err := applyChanges(ctx, c.s.client, docedits); err != nil {
300+
return err
301+
}
302+
303+
if show != nil {
304+
showDocumentImpl(ctx, c.s.client, protocol.URI(show.URI), &show.Range, c.s.options)
305+
}
306+
return nil
300307
})
301-
// TODO(hxjiang): move the cursor to the new test once edits applied.
302308
return result, err
303309
}
304310

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright 2025 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package misc
6+
7+
import (
8+
"testing"
9+
10+
"golang.org/x/tools/gopls/internal/protocol"
11+
"golang.org/x/tools/gopls/internal/settings"
12+
"golang.org/x/tools/gopls/internal/test/compare"
13+
. "golang.org/x/tools/gopls/internal/test/integration"
14+
)
15+
16+
// TestAddTest is a basic test of interaction with the "gopls.add_test" code action.
17+
func TestAddTest(t *testing.T) {
18+
const files = `
19+
-- go.mod --
20+
module example.com
21+
22+
-- a/a.go --
23+
package a
24+
25+
import(
26+
"context"
27+
)
28+
29+
func Foo(ctx context.Context, in string) string {return in}
30+
31+
-- a/a_test.go --
32+
package a_test
33+
34+
import(
35+
"testing"
36+
)
37+
38+
func TestExisting(t *testing.T) {}
39+
`
40+
const want = `package a_test
41+
42+
import (
43+
"context"
44+
"testing"
45+
46+
"example.com/a"
47+
)
48+
49+
func TestExisting(t *testing.T) {}
50+
51+
func TestFoo(t *testing.T) {
52+
tests := []struct {
53+
name string // description of this test case
54+
// Named input parameters for target function.
55+
in string
56+
want string
57+
}{
58+
// TODO: Add test cases.
59+
}
60+
for _, tt := range tests {
61+
t.Run(tt.name, func(t *testing.T) {
62+
got := a.Foo(context.Background(), tt.in)
63+
// TODO: update the condition below to compare got with tt.want.
64+
if true {
65+
t.Errorf("Foo() = %v, want %v", got, tt.want)
66+
}
67+
})
68+
}
69+
}
70+
`
71+
Run(t, files, func(t *testing.T, env *Env) {
72+
env.OpenFile("a/a.go")
73+
74+
loc := env.RegexpSearch("a/a.go", "Foo")
75+
actions, err := env.Editor.CodeAction(env.Ctx, loc, nil, protocol.CodeActionUnknownTrigger)
76+
if err != nil {
77+
t.Fatalf("CodeAction: %v", err)
78+
}
79+
action, err := codeActionByKind(actions, settings.AddTest)
80+
if err != nil {
81+
t.Fatal(err)
82+
}
83+
84+
// Execute the command.
85+
// Its side effect should be a single showDocument request.
86+
params := &protocol.ExecuteCommandParams{
87+
Command: action.Command.Command,
88+
Arguments: action.Command.Arguments,
89+
}
90+
91+
listen := env.Awaiter.ListenToShownDocuments()
92+
env.ExecuteCommand(params, nil)
93+
// Wait until we finish writing to the file.
94+
env.AfterChange()
95+
if got := env.BufferText("a/a_test.go"); got != want {
96+
t.Errorf("gopls.add_test returned unexpected diff (-want +got):\n%s", compare.Text(want, got))
97+
}
98+
99+
got := listen()
100+
if len(got) != 1 {
101+
t.Errorf("gopls.add_test: got %d showDocument requests, want 1: %v", len(got), got)
102+
} else {
103+
if want := protocol.URI(env.Sandbox.Workdir.URI("a/a_test.go")); got[0].URI != want {
104+
t.Errorf("gopls.add_test: got showDocument requests for %v, want %v", got[0].URI, want)
105+
}
106+
107+
// Pointing to the line of test function declaration.
108+
if want := (protocol.Range{
109+
Start: protocol.Position{
110+
Line: 11,
111+
},
112+
End: protocol.Position{
113+
Line: 11,
114+
},
115+
}); *got[0].Selection != want {
116+
t.Errorf("gopls.add_test: got showDocument requests selection for %v, want %v", *got[0].Selection, want)
117+
}
118+
}
119+
})
120+
}

0 commit comments

Comments
 (0)