@@ -10,6 +10,7 @@ import (
1010 "fmt"
1111 "go/ast"
1212 "go/format"
13+ "go/token"
1314 "go/types"
1415 "regexp"
1516 "slices"
@@ -112,7 +113,7 @@ outer:
112113 }
113114
114115 file , _ := cursorutil.FirstEnclosing [* ast.File ](curRet )
115- matches := analysisinternal . MatchingIdents (retTyps , file , ret .Pos (), info , pass .Pkg )
116+ matches := MatchingIdents (retTyps , file , ret .Pos (), info , pass .Pkg )
116117 qual := typesinternal .FileQualifier (file , pass .Pkg )
117118 for i , retTyp := range retTyps {
118119 var match ast.Expr
@@ -242,3 +243,95 @@ func isZeroExpr(expr ast.Expr) bool {
242243 return false
243244 }
244245}
246+
247+ // MatchingIdents finds the names of all identifiers in 'node' that match any of the given types.
248+ // 'pos' represents the position at which the identifiers may be inserted. 'pos' must be within
249+ // the scope of each of identifier we select. Otherwise, we will insert a variable at 'pos' that
250+ // is unrecognized.
251+ //
252+ // This function is shared with the 'fillstruct' analyzer.
253+ func MatchingIdents (typs []types.Type , node ast.Node , pos token.Pos , info * types.Info , pkg * types.Package ) map [types.Type ][]string {
254+
255+ // Initialize matches to contain the variable types we are searching for.
256+ matches := make (map [types.Type ][]string )
257+ for _ , typ := range typs {
258+ if typ == nil {
259+ continue // TODO(adonovan): is this reachable?
260+ }
261+ matches [typ ] = nil // create entry
262+ }
263+
264+ seen := map [types.Object ]struct {}{}
265+ ast .Inspect (node , func (n ast.Node ) bool {
266+ if n == nil {
267+ return false
268+ }
269+ // Prevent circular definitions. If 'pos' is within an assignment statement, do not
270+ // allow any identifiers in that assignment statement to be selected. Otherwise,
271+ // we could do the following, where 'x' satisfies the type of 'f0':
272+ //
273+ // x := fakeStruct{f0: x}
274+ //
275+ if assign , ok := n .(* ast.AssignStmt ); ok && pos > assign .Pos () && pos <= assign .End () {
276+ return false
277+ }
278+ if n .End () > pos {
279+ return n .Pos () <= pos
280+ }
281+ ident , ok := n .(* ast.Ident )
282+ if ! ok || ident .Name == "_" {
283+ return true
284+ }
285+ obj := info .Defs [ident ]
286+ if obj == nil || obj .Type () == nil {
287+ return true
288+ }
289+ if _ , ok := obj .(* types.TypeName ); ok {
290+ return true
291+ }
292+ // Prevent duplicates in matches' values.
293+ if _ , ok = seen [obj ]; ok {
294+ return true
295+ }
296+ seen [obj ] = struct {}{}
297+ // Find the scope for the given position. Then, check whether the object
298+ // exists within the scope.
299+ innerScope := pkg .Scope ().Innermost (pos )
300+ if innerScope == nil {
301+ return true
302+ }
303+ _ , foundObj := innerScope .LookupParent (ident .Name , pos )
304+ if foundObj != obj {
305+ return true
306+ }
307+ // The object must match one of the types that we are searching for.
308+ // TODO(adonovan): opt: use typeutil.Map?
309+ if names , ok := matches [obj .Type ()]; ok {
310+ matches [obj .Type ()] = append (names , ident .Name )
311+ } else {
312+ // If the object type does not exactly match
313+ // any of the target types, greedily find the first
314+ // target type that the object type can satisfy.
315+ for typ := range matches {
316+ if equivalentTypes (obj .Type (), typ ) {
317+ matches [typ ] = append (matches [typ ], ident .Name )
318+ }
319+ }
320+ }
321+ return true
322+ })
323+ return matches
324+ }
325+
326+ func equivalentTypes (want , got types.Type ) bool {
327+ if types .Identical (want , got ) {
328+ return true
329+ }
330+ // Code segment to help check for untyped equality from (golang/go#32146).
331+ if rhs , ok := want .(* types.Basic ); ok && rhs .Info ()& types .IsUntyped > 0 {
332+ if lhs , ok := got .Underlying ().(* types.Basic ); ok {
333+ return rhs .Info ()& types .IsConstType == lhs .Info ()& types .IsConstType
334+ }
335+ }
336+ return types .AssignableTo (want , got )
337+ }
0 commit comments