Skip to content

Commit 46d1522

Browse files
committed
internal/lsp: add extract to method code action
"Extract method" allows users to take a code fragment and move it to a separate method. This is available if the enclosing function is a method. Change-Id: Ib824f6b79b13ca73532223283a050946c90a47e7 Reviewed-on: https://go-review.googlesource.com/c/tools/+/330070 Trust: Suzy Mueller <[email protected]> Run-TryBot: Suzy Mueller <[email protected]> gopls-CI: kokoro <[email protected]> TryBot-Result: Go Bot <[email protected]> Reviewed-by: Rebecca Stambler <[email protected]>
1 parent c740bfd commit 46d1522

File tree

12 files changed

+964
-39
lines changed

12 files changed

+964
-39
lines changed

internal/lsp/cmd/test/cmdtest.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span
100100
//TODO: function extraction not supported on command line
101101
}
102102

103+
func (r *runner) MethodExtraction(t *testing.T, start span.Span, end span.Span) {
104+
//TODO: function extraction not supported on command line
105+
}
106+
103107
func (r *runner) AddImport(t *testing.T, uri span.URI, expectedImport string) {
104108
//TODO: import addition not supported on command line
105109
}

internal/lsp/code_action.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,8 @@ func extractionFixes(ctx context.Context, snapshot source.Snapshot, pkg source.P
289289
}
290290
puri := protocol.URIFromSpanURI(uri)
291291
var commands []protocol.Command
292-
if _, ok, _ := source.CanExtractFunction(snapshot.FileSet(), srng, pgf.Src, pgf.File); ok {
293-
cmd, err := command.NewApplyFixCommand("Extract to function", command.ApplyFixArgs{
292+
if _, ok, methodOk, _ := source.CanExtractFunction(snapshot.FileSet(), srng, pgf.Src, pgf.File); ok {
293+
cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
294294
URI: puri,
295295
Fix: source.ExtractFunction,
296296
Range: rng,
@@ -299,6 +299,17 @@ func extractionFixes(ctx context.Context, snapshot source.Snapshot, pkg source.P
299299
return nil, err
300300
}
301301
commands = append(commands, cmd)
302+
if methodOk {
303+
cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
304+
URI: puri,
305+
Fix: source.ExtractMethod,
306+
Range: rng,
307+
})
308+
if err != nil {
309+
return nil, err
310+
}
311+
commands = append(commands, cmd)
312+
}
302313
}
303314
if _, _, ok, _ := source.CanExtractVariable(srng, pgf.File); ok {
304315
cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{

internal/lsp/lsp_test.go

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span
583583
if err != nil {
584584
t.Fatal(err)
585585
}
586-
actions, err := r.server.CodeAction(r.ctx, &protocol.CodeActionParams{
586+
actionsRaw, err := r.server.CodeAction(r.ctx, &protocol.CodeActionParams{
587587
TextDocument: protocol.TextDocumentIdentifier{
588588
URI: protocol.URIFromSpanURI(uri),
589589
},
@@ -595,6 +595,12 @@ func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span
595595
if err != nil {
596596
t.Fatal(err)
597597
}
598+
var actions []protocol.CodeAction
599+
for _, action := range actionsRaw {
600+
if action.Command.Title == "Extract function" {
601+
actions = append(actions, action)
602+
}
603+
}
598604
// Hack: We assume that we only get one code action per range.
599605
// TODO(rstambler): Support multiple code actions per test.
600606
if len(actions) == 0 || len(actions) > 1 {
@@ -618,6 +624,58 @@ func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span
618624
}
619625
}
620626

627+
func (r *runner) MethodExtraction(t *testing.T, start span.Span, end span.Span) {
628+
uri := start.URI()
629+
m, err := r.data.Mapper(uri)
630+
if err != nil {
631+
t.Fatal(err)
632+
}
633+
spn := span.New(start.URI(), start.Start(), end.End())
634+
rng, err := m.Range(spn)
635+
if err != nil {
636+
t.Fatal(err)
637+
}
638+
actionsRaw, err := r.server.CodeAction(r.ctx, &protocol.CodeActionParams{
639+
TextDocument: protocol.TextDocumentIdentifier{
640+
URI: protocol.URIFromSpanURI(uri),
641+
},
642+
Range: rng,
643+
Context: protocol.CodeActionContext{
644+
Only: []protocol.CodeActionKind{"refactor.extract"},
645+
},
646+
})
647+
if err != nil {
648+
t.Fatal(err)
649+
}
650+
var actions []protocol.CodeAction
651+
for _, action := range actionsRaw {
652+
if action.Command.Title == "Extract method" {
653+
actions = append(actions, action)
654+
}
655+
}
656+
// Hack: We assume that we only get one matching code action per range.
657+
// TODO(rstambler): Support multiple code actions per test.
658+
if len(actions) == 0 || len(actions) > 1 {
659+
t.Fatalf("unexpected number of code actions, want 1, got %v", len(actions))
660+
}
661+
_, err = r.server.ExecuteCommand(r.ctx, &protocol.ExecuteCommandParams{
662+
Command: actions[0].Command.Command,
663+
Arguments: actions[0].Command.Arguments,
664+
})
665+
if err != nil {
666+
t.Fatal(err)
667+
}
668+
res := <-r.editRecv
669+
for u, got := range res {
670+
want := string(r.data.Golden("methodextraction_"+tests.SpanName(spn), u.Filename(), func() ([]byte, error) {
671+
return []byte(got), nil
672+
}))
673+
if want != got {
674+
t.Errorf("method extraction failed for %s:\n%s", u.Filename(), tests.Diff(t, want, got))
675+
}
676+
}
677+
}
678+
621679
func (r *runner) Definition(t *testing.T, spn span.Span, d tests.Definition) {
622680
sm, err := r.data.Mapper(d.Src.URI())
623681
if err != nil {

internal/lsp/source/extract.go

Lines changed: 97 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,17 @@ func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.
139139
// Possible collisions include other function and variable names. Returns the next index to check for prefix.
140140
func generateAvailableIdentifier(pos token.Pos, file *ast.File, path []ast.Node, info *types.Info, prefix string, idx int) (string, int) {
141141
scopes := CollectScopes(info, path, pos)
142+
return generateIdentifier(idx, prefix, func(name string) bool {
143+
return file.Scope.Lookup(name) != nil || !isValidName(name, scopes)
144+
})
145+
}
146+
147+
func generateIdentifier(idx int, prefix string, hasCollision func(string) bool) (string, int) {
142148
name := prefix
143149
if idx != 0 {
144150
name += fmt.Sprintf("%d", idx)
145151
}
146-
for file.Scope.Lookup(name) != nil || !isValidName(name, scopes) {
152+
for hasCollision(name) {
147153
idx++
148154
name = fmt.Sprintf("%v%d", prefix, idx)
149155
}
@@ -177,28 +183,42 @@ type returnVariable struct {
177183
zeroVal ast.Expr
178184
}
179185

186+
// extractMethod refactors the selected block of code into a new method.
187+
func extractMethod(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
188+
return extractFunctionMethod(fset, rng, src, file, pkg, info, true)
189+
}
190+
180191
// extractFunction refactors the selected block of code into a new function.
192+
func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
193+
return extractFunctionMethod(fset, rng, src, file, pkg, info, false)
194+
}
195+
196+
// extractFunctionMethod refactors the selected block of code into a new function/method.
181197
// It also replaces the selected block of code with a call to the extracted
182198
// function. First, we manually adjust the selection range. We remove trailing
183199
// and leading whitespace characters to ensure the range is precisely bounded
184200
// by AST nodes. Next, we determine the variables that will be the parameters
185-
// and return values of the extracted function. Lastly, we construct the call
186-
// of the function and insert this call as well as the extracted function into
201+
// and return values of the extracted function/method. Lastly, we construct the call
202+
// of the function/method and insert this call as well as the extracted function/method into
187203
// their proper locations.
188-
func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
189-
p, ok, err := CanExtractFunction(fset, rng, src, file)
190-
if !ok {
191-
return nil, fmt.Errorf("extractFunction: cannot extract %s: %v",
204+
func extractFunctionMethod(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info, isMethod bool) (*analysis.SuggestedFix, error) {
205+
errorPrefix := "extractFunction"
206+
if isMethod {
207+
errorPrefix = "extractMethod"
208+
}
209+
p, ok, methodOk, err := CanExtractFunction(fset, rng, src, file)
210+
if (!ok && !isMethod) || (!methodOk && isMethod) {
211+
return nil, fmt.Errorf("%s: cannot extract %s: %v", errorPrefix,
192212
fset.Position(rng.Start), err)
193213
}
194214
tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start
195215
fileScope := info.Scopes[file]
196216
if fileScope == nil {
197-
return nil, fmt.Errorf("extractFunction: file scope is empty")
217+
return nil, fmt.Errorf("%s: file scope is empty", errorPrefix)
198218
}
199219
pkgScope := fileScope.Parent()
200220
if pkgScope == nil {
201-
return nil, fmt.Errorf("extractFunction: package scope is empty")
221+
return nil, fmt.Errorf("%s: package scope is empty", errorPrefix)
202222
}
203223

204224
// A return statement is non-nested if its parent node is equal to the parent node
@@ -235,6 +255,25 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.
235255
return nil, err
236256
}
237257

258+
var (
259+
receiverUsed bool
260+
receiver *ast.Field
261+
receiverName string
262+
receiverObj types.Object
263+
)
264+
if isMethod {
265+
if outer == nil || outer.Recv == nil || len(outer.Recv.List) == 0 {
266+
return nil, fmt.Errorf("%s: cannot extract need method receiver", errorPrefix)
267+
}
268+
receiver = outer.Recv.List[0]
269+
if len(receiver.Names) == 0 || receiver.Names[0] == nil {
270+
return nil, fmt.Errorf("%s: cannot extract need method receiver name", errorPrefix)
271+
}
272+
recvName := receiver.Names[0]
273+
receiverName = recvName.Name
274+
receiverObj = info.ObjectOf(recvName)
275+
}
276+
238277
var (
239278
params, returns []ast.Expr // used when calling the extracted function
240279
paramTypes, returnTypes []*ast.Field // used in the signature of the extracted function
@@ -308,6 +347,11 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.
308347
// extracted function. (1) it must be free (isFree), and (2) its first
309348
// use within the selection cannot be its own definition (isDefined).
310349
if v.free && !v.defined {
350+
// Skip the selector for a method.
351+
if isMethod && v.obj == receiverObj {
352+
receiverUsed = true
353+
continue
354+
}
311355
params = append(params, identifier)
312356
paramTypes = append(paramTypes, &ast.Field{
313357
Names: []*ast.Ident{identifier},
@@ -471,9 +515,17 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.
471515
if canDefine {
472516
sym = token.DEFINE
473517
}
474-
funName, _ := generateAvailableIdentifier(rng.Start, file, path, info, "newFunction", 0)
518+
var name, funName string
519+
if isMethod {
520+
name = "newMethod"
521+
// TODO(suzmue): generate a name that does not conflict for "newMethod".
522+
funName = name
523+
} else {
524+
name = "newFunction"
525+
funName, _ = generateAvailableIdentifier(rng.Start, file, path, info, name, 0)
526+
}
475527
extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params,
476-
append(returns, getNames(retVars)...), funName, sym)
528+
append(returns, getNames(retVars)...), funName, sym, receiverName)
477529

478530
// Build the extracted function.
479531
newFunc := &ast.FuncDecl{
@@ -484,6 +536,18 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.
484536
},
485537
Body: extractedBlock,
486538
}
539+
if isMethod {
540+
var names []*ast.Ident
541+
if receiverUsed {
542+
names = append(names, ast.NewIdent(receiverName))
543+
}
544+
newFunc.Recv = &ast.FieldList{
545+
List: []*ast.Field{{
546+
Names: names,
547+
Type: receiver.Type,
548+
}},
549+
}
550+
}
487551

488552
// Create variable declarations for any identifiers that need to be initialized prior to
489553
// calling the extracted function. We do not manually initialize variables if every return
@@ -844,24 +908,24 @@ type fnExtractParams struct {
844908

845909
// CanExtractFunction reports whether the code in the given range can be
846910
// extracted to a function.
847-
func CanExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File) (*fnExtractParams, bool, error) {
911+
func CanExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File) (*fnExtractParams, bool, bool, error) {
848912
if rng.Start == rng.End {
849-
return nil, false, fmt.Errorf("start and end are equal")
913+
return nil, false, false, fmt.Errorf("start and end are equal")
850914
}
851915
tok := fset.File(file.Pos())
852916
if tok == nil {
853-
return nil, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
917+
return nil, false, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
854918
}
855919
rng = adjustRangeForWhitespace(rng, tok, src)
856920
path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
857921
if len(path) == 0 {
858-
return nil, false, fmt.Errorf("no path enclosing interval")
922+
return nil, false, false, fmt.Errorf("no path enclosing interval")
859923
}
860924
// Node that encloses the selection must be a statement.
861925
// TODO: Support function extraction for an expression.
862926
_, ok := path[0].(ast.Stmt)
863927
if !ok {
864-
return nil, false, fmt.Errorf("node is not a statement")
928+
return nil, false, false, fmt.Errorf("node is not a statement")
865929
}
866930

867931
// Find the function declaration that encloses the selection.
@@ -873,7 +937,7 @@ func CanExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *a
873937
}
874938
}
875939
if outer == nil {
876-
return nil, false, fmt.Errorf("no enclosing function")
940+
return nil, false, false, fmt.Errorf("no enclosing function")
877941
}
878942

879943
// Find the nodes at the start and end of the selection.
@@ -893,15 +957,15 @@ func CanExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *a
893957
return n.Pos() <= rng.End
894958
})
895959
if start == nil || end == nil {
896-
return nil, false, fmt.Errorf("range does not map to AST nodes")
960+
return nil, false, false, fmt.Errorf("range does not map to AST nodes")
897961
}
898962
return &fnExtractParams{
899963
tok: tok,
900964
path: path,
901965
rng: rng,
902966
outer: outer,
903967
start: start,
904-
}, true, nil
968+
}, true, outer.Recv != nil, nil
905969
}
906970

907971
// objUsed checks if the object is used within the range. It returns the first
@@ -1089,13 +1153,22 @@ func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]
10891153

10901154
// generateFuncCall constructs a call expression for the extracted function, described by the
10911155
// given parameters and return variables.
1092-
func generateFuncCall(hasNonNestedReturn, hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token) ast.Node {
1156+
func generateFuncCall(hasNonNestedReturn, hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token, selector string) ast.Node {
10931157
var replace ast.Node
1094-
if hasReturnVals {
1095-
callExpr := &ast.CallExpr{
1096-
Fun: ast.NewIdent(name),
1158+
callExpr := &ast.CallExpr{
1159+
Fun: ast.NewIdent(name),
1160+
Args: params,
1161+
}
1162+
if selector != "" {
1163+
callExpr = &ast.CallExpr{
1164+
Fun: &ast.SelectorExpr{
1165+
X: ast.NewIdent(selector),
1166+
Sel: ast.NewIdent(name),
1167+
},
10971168
Args: params,
10981169
}
1170+
}
1171+
if hasReturnVals {
10991172
if hasNonNestedReturn {
11001173
// Create a return statement that returns the result of the function call.
11011174
replace = &ast.ReturnStmt{
@@ -1111,10 +1184,7 @@ func generateFuncCall(hasNonNestedReturn, hasReturnVals bool, params, returns []
11111184
}
11121185
}
11131186
} else {
1114-
replace = &ast.CallExpr{
1115-
Fun: ast.NewIdent(name),
1116-
Args: params,
1117-
}
1187+
replace = callExpr
11181188
}
11191189
return replace
11201190
}

internal/lsp/source/fix.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ const (
3232
UndeclaredName = "undeclared_name"
3333
ExtractVariable = "extract_variable"
3434
ExtractFunction = "extract_function"
35+
ExtractMethod = "extract_method"
3536
)
3637

3738
// suggestedFixes maps a suggested fix command id to its handler.
@@ -40,6 +41,7 @@ var suggestedFixes = map[string]SuggestedFixFunc{
4041
UndeclaredName: undeclaredname.SuggestedFix,
4142
ExtractVariable: extractVariable,
4243
ExtractFunction: extractFunction,
44+
ExtractMethod: extractMethod,
4345
}
4446

4547
func SuggestedFixFromCommand(cmd protocol.Command, kind protocol.CodeActionKind) SuggestedFix {

internal/lsp/source/source_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,7 @@ func (r *runner) Link(t *testing.T, uri span.URI, wantLinks []tests.Link) {}
935935
func (r *runner) SuggestedFix(t *testing.T, spn span.Span, actionKinds []string, expectedActions int) {
936936
}
937937
func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span) {}
938+
func (r *runner) MethodExtraction(t *testing.T, start span.Span, end span.Span) {}
938939
func (r *runner) CodeLens(t *testing.T, uri span.URI, want []protocol.CodeLens) {}
939940
func (r *runner) AddImport(t *testing.T, uri span.URI, expectedImport string) {}
940941

0 commit comments

Comments
 (0)