Skip to content

Commit fc8adf5

Browse files
mpywclaude
andcommitted
refactor: consolidate assignment finding and pointer unwrapping
- Add CheckContext.FindCallExprAssignment (mirrors FindFuncLitAssignment) - Remove duplicate deriverFindCallExprAssignmentBefore from should_call_deriver - Export typeutil.UnwrapPointer for reuse - Update funcspec, registry, spawnerlabel to use typeutil.UnwrapPointer 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 1e70e11 commit fc8adf5

File tree

6 files changed

+62
-59
lines changed

6 files changed

+62
-59
lines changed

internal/context/context.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,58 @@ func (c *CheckContext) findFuncLitInAssignment(assign *ast.AssignStmt, v *types.
164164
return nil
165165
}
166166

167+
// FindCallExprAssignment searches for the call expression assigned to the variable.
168+
// If beforePos is token.NoPos, returns the LAST assignment found.
169+
// If beforePos is set, returns the last assignment BEFORE that position.
170+
func (c *CheckContext) FindCallExprAssignment(v *types.Var, beforePos token.Pos) *ast.CallExpr {
171+
var result *ast.CallExpr
172+
declPos := v.Pos()
173+
174+
for _, f := range c.Pass.Files {
175+
if f.Pos() > declPos || declPos >= f.End() {
176+
continue
177+
}
178+
179+
ast.Inspect(f, func(n ast.Node) bool {
180+
assign, ok := n.(*ast.AssignStmt)
181+
if !ok {
182+
return true
183+
}
184+
// Skip assignments at or after beforePos
185+
if beforePos != token.NoPos && assign.Pos() >= beforePos {
186+
return true
187+
}
188+
if call := c.findCallExprInAssignment(assign, v); call != nil {
189+
result = call // Keep updating - we want the LAST assignment
190+
}
191+
return true
192+
})
193+
break
194+
}
195+
196+
return result
197+
}
198+
199+
// findCallExprInAssignment checks if the assignment assigns a call expression to v.
200+
func (c *CheckContext) findCallExprInAssignment(assign *ast.AssignStmt, v *types.Var) *ast.CallExpr {
201+
for i, lhs := range assign.Lhs {
202+
ident, ok := lhs.(*ast.Ident)
203+
if !ok {
204+
continue
205+
}
206+
if c.Pass.TypesInfo.ObjectOf(ident) != v {
207+
continue
208+
}
209+
if i >= len(assign.Rhs) {
210+
continue
211+
}
212+
if call, ok := assign.Rhs[i].(*ast.CallExpr); ok {
213+
return call
214+
}
215+
}
216+
return nil
217+
}
218+
167219
// BlockReturnsContextUsingFunc checks if a block's return statements
168220
// return functions that use context. Recursively checks nested func literals.
169221
// excludeFuncLit can be set to exclude a specific FuncLit from being counted (e.g., the parent).

internal/funcspec/funcspec.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"unicode"
99

1010
"golang.org/x/tools/go/analysis"
11+
12+
"github.com/mpyw/goroutinectx/internal/typeutil"
1113
)
1214

1315
// Spec holds parsed components of a function specification.
@@ -76,11 +78,7 @@ func (s Spec) Matches(fn *types.Func) bool {
7678
return false
7779
}
7880

79-
recvType := recv.Type()
80-
// Handle pointer receivers
81-
if ptr, ok := recvType.(*types.Pointer); ok {
82-
recvType = ptr.Elem()
83-
}
81+
recvType := typeutil.UnwrapPointer(recv.Type())
8482

8583
named, ok := recvType.(*types.Named)
8684
if !ok {

internal/patterns/should_call_deriver.go

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package patterns
22

33
import (
44
"go/ast"
5-
"go/token"
65
"go/types"
76
"strings"
87

@@ -71,7 +70,7 @@ func (p *ShouldCallDeriver) identCallsDeriver(cctx *context.CheckContext, ident
7170
}
7271

7372
// Try to find a CallExpr assignment (e.g., task := NewTask(fn))
74-
callExpr := deriverFindCallExprAssignmentBefore(cctx, v, ident.Pos())
73+
callExpr := cctx.FindCallExprAssignment(v, ident.Pos())
7574
if callExpr != nil {
7675
return p.callExprCallsDeriver(cctx, callExpr)
7776
}
@@ -210,45 +209,3 @@ func (p *ShouldCallDeriver) funcLitReturnCallsDeriver(cctx *context.CheckContext
210209
func (p *ShouldCallDeriver) Message(apiName string, _ string) string {
211210
return apiName + "() callback should call goroutine deriver"
212211
}
213-
214-
// deriverFindCallExprAssignmentBefore finds the last CallExpr assigned to variable before pos.
215-
func deriverFindCallExprAssignmentBefore(cctx *context.CheckContext, v *types.Var, beforePos token.Pos) *ast.CallExpr {
216-
var result *ast.CallExpr
217-
declPos := v.Pos()
218-
219-
for _, f := range cctx.Pass.Files {
220-
if f.Pos() > declPos || declPos >= f.End() {
221-
continue
222-
}
223-
224-
ast.Inspect(f, func(n ast.Node) bool {
225-
assign, ok := n.(*ast.AssignStmt)
226-
if !ok {
227-
return true
228-
}
229-
if beforePos != token.NoPos && assign.Pos() >= beforePos {
230-
return true
231-
}
232-
233-
for i, lhs := range assign.Lhs {
234-
ident, ok := lhs.(*ast.Ident)
235-
if !ok {
236-
continue
237-
}
238-
if cctx.Pass.TypesInfo.ObjectOf(ident) != v {
239-
continue
240-
}
241-
if i >= len(assign.Rhs) {
242-
continue
243-
}
244-
if call, ok := assign.Rhs[i].(*ast.CallExpr); ok {
245-
result = call
246-
}
247-
}
248-
return true
249-
})
250-
break
251-
}
252-
253-
return result
254-
}

internal/registry/registry.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"golang.org/x/tools/go/analysis"
99

1010
"github.com/mpyw/goroutinectx/internal/patterns"
11+
"github.com/mpyw/goroutinectx/internal/typeutil"
1112
)
1213

1314
// Entry represents a registered API with its pattern.
@@ -88,10 +89,7 @@ func (r *Registry) isMethodCall(pass *analysis.Pass, sel *ast.SelectorExpr, api
8889
return false
8990
}
9091

91-
// Unwrap pointer
92-
if ptr, ok := typ.(*types.Pointer); ok {
93-
typ = ptr.Elem()
94-
}
92+
typ = typeutil.UnwrapPointer(typ)
9593

9694
named, ok := typ.(*types.Named)
9795
if !ok {

internal/spawnerlabel/detector.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,7 @@ func isGotaskTaskType(pass *analysis.Pass, expr ast.Expr) bool {
169169
return false
170170
}
171171

172-
if ptr, ok := typ.(*types.Pointer); ok {
173-
typ = ptr.Elem()
174-
}
172+
typ = typeutil.UnwrapPointer(typ)
175173

176174
named, ok := typ.(*types.Named)
177175
if !ok {

internal/typeutil/typeutil.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func IsNamedType(pass *analysis.Pass, expr ast.Expr, pkgPath, typeName string) b
2424

2525
// isNamedTypeFromType checks if the type matches the given package path and type name.
2626
func isNamedTypeFromType(t types.Type, pkgPath, typeName string) bool {
27-
t = unwrapPointer(t)
27+
t = UnwrapPointer(t)
2828

2929
named, ok := t.(*types.Named)
3030
if !ok {
@@ -39,8 +39,8 @@ func isNamedTypeFromType(t types.Type, pkgPath, typeName string) bool {
3939
return obj.Pkg().Path() == pkgPath && obj.Name() == typeName
4040
}
4141

42-
// unwrapPointer returns the element type if t is a pointer, otherwise returns t.
43-
func unwrapPointer(t types.Type) types.Type {
42+
// UnwrapPointer returns the element type if t is a pointer, otherwise returns t.
43+
func UnwrapPointer(t types.Type) types.Type {
4444
if ptr, ok := t.(*types.Pointer); ok {
4545
return ptr.Elem()
4646
}

0 commit comments

Comments
 (0)