Skip to content

Commit e81a031

Browse files
mpywclaude
andcommitted
refactor(patterns): consolidate duplicate utilities across packages
- Remove duplicate isContextType implementations (patterns, ssa) → Use single typeutil.IsContextType everywhere - Unify funcTypeHasContextParam for FuncLit and FuncDecl → Base function accepts *ast.FuncType - Consolidate findFuncLitAssignment variants → Single function with optional beforePos parameter - Unify factory returns context check logic → blockReturnsContextUsingFunc as common base Result: -95 lines of duplicated code 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent a658139 commit e81a031

File tree

7 files changed

+54
-149
lines changed

7 files changed

+54
-149
lines changed

internal/patterns/callback_receives_ctx.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package patterns
33
import (
44
"fmt"
55
"go/ast"
6+
7+
"github.com/mpyw/goroutinectx/internal/typeutil"
68
)
79

810
// CallbackReceivesCtx checks APIs where the callback receives context as its first parameter.
@@ -39,23 +41,23 @@ func (*CallbackReceivesCtx) contextArgUsesVar(cctx *CheckContext, ctxArg ast.Exp
3941
// For simple identifier, check if it's a context type from scope
4042
if ident, ok := ctxArg.(*ast.Ident); ok {
4143
obj := cctx.Pass.TypesInfo.ObjectOf(ident)
42-
if obj != nil && isContextType(obj.Type()) {
44+
if obj != nil && typeutil.IsContextType(obj.Type()) {
4345
return true
4446
}
4547
}
4648

4749
// For call expressions (e.g., context.WithCancel(ctx)), check the result type
4850
if callExpr, ok := ctxArg.(*ast.CallExpr); ok {
4951
typ := cctx.Pass.TypesInfo.TypeOf(callExpr)
50-
if typ != nil && isContextType(typ) {
52+
if typ != nil && typeutil.IsContextType(typ) {
5153
return true
5254
}
5355
}
5456

5557
// For selector expressions (e.g., req.Context())
5658
if sel, ok := ctxArg.(*ast.SelectorExpr); ok {
5759
typ := cctx.Pass.TypesInfo.TypeOf(sel)
58-
if typ != nil && isContextType(typ) {
60+
if typ != nil && typeutil.IsContextType(typ) {
5961
return true
6062
}
6163
}

internal/patterns/closure_captures_ctx.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func closureCheckFromAST(cctx *CheckContext, callbackArg ast.Expr) bool {
8383
if !ok {
8484
return false // Can't trace
8585
}
86-
funcLit := findFuncLitAssignment(cctx, v)
86+
funcLit := findFuncLitAssignment(cctx, v, token.NoPos)
8787
if funcLit == nil {
8888
return false // Can't trace (channel receive, type assertion, etc.)
8989
}
@@ -337,7 +337,7 @@ func closureCheckFactoryCall(cctx *CheckContext, call *ast.CallExpr) bool {
337337
if !ok {
338338
return false
339339
}
340-
funcLit := findFuncLitAssignment(cctx, v)
340+
funcLit := findFuncLitAssignment(cctx, v, token.NoPos)
341341
if funcLit == nil {
342342
return false
343343
}

internal/patterns/gostmt.go

Lines changed: 7 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package patterns
22

33
import (
44
"go/ast"
5+
"go/token"
56
"go/types"
67

78
"github.com/mpyw/goroutinectx/internal/directives/deriver"
@@ -161,7 +162,7 @@ func goStmtCheckIdentFactory(cctx *CheckContext, ident *ast.Ident) bool {
161162

162163
// Handle local variable pointing to a func literal
163164
if v, ok := obj.(*types.Var); ok {
164-
funcLit := findFuncLitAssignment(cctx, v)
165+
funcLit := findFuncLitAssignment(cctx, v, token.NoPos)
165166
if funcLit == nil {
166167
return true // Can't trace
167168
}
@@ -210,63 +211,12 @@ func goStmtFindFuncDecl(cctx *CheckContext, fn *types.Func) *ast.FuncDecl {
210211

211212
// goStmtFuncDeclHasContextParam checks if a function declaration has a context.Context parameter.
212213
func goStmtFuncDeclHasContextParam(cctx *CheckContext, decl *ast.FuncDecl) bool {
213-
if decl.Type == nil || decl.Type.Params == nil {
214-
return false
215-
}
216-
for _, field := range decl.Type.Params.List {
217-
typ := cctx.Pass.TypesInfo.TypeOf(field.Type)
218-
if typ == nil {
219-
continue
220-
}
221-
if isContextType(typ) {
222-
return true
223-
}
224-
}
225-
return false
214+
return funcTypeHasContextParam(cctx, decl.Type)
226215
}
227216

228217
// goStmtFactoryDeclReturnsCtxFunc checks if a function declaration returns funcs that use context.
229-
// For nested factories, this recursively checks if any deeply nested function uses context.
230218
func goStmtFactoryDeclReturnsCtxFunc(cctx *CheckContext, decl *ast.FuncDecl) bool {
231-
if decl.Body == nil {
232-
return true // No body to check
233-
}
234-
235-
usesContext := false
236-
ast.Inspect(decl.Body, func(n ast.Node) bool {
237-
if usesContext {
238-
return false
239-
}
240-
// For nested func literals, check both direct usage and returned values
241-
if fl, ok := n.(*ast.FuncLit); ok {
242-
// Check if this nested func lit uses context directly
243-
if funcLitUsesContext(cctx, fl) {
244-
usesContext = true
245-
return false
246-
}
247-
// Recursively check if it returns functions that use context
248-
if factoryReturnsContextUsingFunc(cctx, fl) {
249-
usesContext = true
250-
return false
251-
}
252-
return false // Don't descend into nested func literals (we handle them recursively)
253-
}
254-
255-
ret, ok := n.(*ast.ReturnStmt)
256-
if !ok {
257-
return true
258-
}
259-
260-
for _, result := range ret.Results {
261-
if returnedValueUsesContext(cctx, result) {
262-
usesContext = true
263-
return false
264-
}
265-
}
266-
return true
267-
})
268-
269-
return usesContext
219+
return blockReturnsContextUsingFunc(cctx, decl.Body, nil)
270220
}
271221

272222
// GoStmtCallsDeriver checks that a go statement's closure calls a deriver function.
@@ -352,7 +302,7 @@ func (p *GoStmtCallsDeriver) checkIdentDeriver(cctx *CheckContext, ident *ast.Id
352302
return true // Not a variable
353303
}
354304

355-
funcLit := findFuncLitAssignment(cctx, v)
305+
funcLit := findFuncLitAssignment(cctx, v, token.NoPos)
356306
if funcLit == nil {
357307
return true // Can't trace
358308
}
@@ -388,7 +338,7 @@ func (p *GoStmtCallsDeriver) checkHigherOrderDeriver(cctx *CheckContext, innerCa
388338
if !ok {
389339
return true // Not a variable (could be a function)
390340
}
391-
funcLit := findFuncLitAssignment(cctx, v)
341+
funcLit := findFuncLitAssignment(cctx, v, token.NoPos)
392342
if funcLit == nil {
393343
return true // Can't trace
394344
}
@@ -462,7 +412,7 @@ func (p *GoStmtCallsDeriver) returnedValueCallsDeriver(cctx *CheckContext, resul
462412
return false
463413
}
464414

465-
innerFuncLit := findFuncLitAssignment(cctx, v)
415+
innerFuncLit := findFuncLitAssignment(cctx, v, token.NoPos)
466416
if innerFuncLit == nil {
467417
return false
468418
}

internal/patterns/pattern.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package patterns
44
import (
55
"go/ast"
66
"go/token"
7-
"go/types"
87

98
"golang.org/x/tools/go/analysis"
109

@@ -29,16 +28,6 @@ func (c *CheckContext) Report(pos token.Pos, msg string) {
2928
c.Pass.Reportf(pos, "%s", msg)
3029
}
3130

32-
// isContextType checks if a type is context.Context.
33-
func isContextType(t types.Type) bool {
34-
named, ok := t.(*types.Named)
35-
if !ok {
36-
return false
37-
}
38-
obj := named.Obj()
39-
return obj.Pkg() != nil && obj.Pkg().Path() == "context" && obj.Name() == "Context"
40-
}
41-
4231
// Pattern defines the interface for context propagation patterns.
4332
type Pattern interface {
4433
// Name returns a human-readable name for the pattern.

internal/patterns/should_call_deriver.go

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func (p *ShouldCallDeriver) identCallsDeriver(cctx *CheckContext, ident *ast.Ide
6464
}
6565

6666
// Try to find a FuncLit assignment
67-
funcLit := deriverFindFuncLitAssignmentBefore(cctx, v, ident.Pos())
67+
funcLit := findFuncLitAssignment(cctx, v, ident.Pos())
6868
if funcLit != nil {
6969
return p.Matcher.SatisfiesAnyGroup(cctx.Pass, funcLit.Body)
7070
}
@@ -146,7 +146,7 @@ func (p *ShouldCallDeriver) factoryReturnCallsDeriver(cctx *CheckContext, call *
146146
return false
147147
}
148148

149-
funcLit := deriverFindFuncLitAssignmentBefore(cctx, v, call.Pos())
149+
funcLit := findFuncLitAssignment(cctx, v, call.Pos())
150150
if funcLit == nil {
151151
return false
152152
}
@@ -210,48 +210,6 @@ func (p *ShouldCallDeriver) Message(apiName string, _ string) string {
210210
return apiName + "() callback should call goroutine deriver"
211211
}
212212

213-
// deriverFindFuncLitAssignmentBefore finds the last FuncLit assigned to variable before pos.
214-
func deriverFindFuncLitAssignmentBefore(cctx *CheckContext, v *types.Var, beforePos token.Pos) *ast.FuncLit {
215-
var result *ast.FuncLit
216-
declPos := v.Pos()
217-
218-
for _, f := range cctx.Pass.Files {
219-
if f.Pos() > declPos || declPos >= f.End() {
220-
continue
221-
}
222-
223-
ast.Inspect(f, func(n ast.Node) bool {
224-
assign, ok := n.(*ast.AssignStmt)
225-
if !ok {
226-
return true
227-
}
228-
if beforePos != token.NoPos && assign.Pos() >= beforePos {
229-
return true
230-
}
231-
232-
for i, lhs := range assign.Lhs {
233-
ident, ok := lhs.(*ast.Ident)
234-
if !ok {
235-
continue
236-
}
237-
if cctx.Pass.TypesInfo.ObjectOf(ident) != v {
238-
continue
239-
}
240-
if i >= len(assign.Rhs) {
241-
continue
242-
}
243-
if fl, ok := assign.Rhs[i].(*ast.FuncLit); ok {
244-
result = fl
245-
}
246-
}
247-
return true
248-
})
249-
break
250-
}
251-
252-
return result
253-
}
254-
255213
// deriverFindCallExprAssignmentBefore finds the last CallExpr assigned to variable before pos.
256214
func deriverFindCallExprAssignmentBefore(cctx *CheckContext, v *types.Var, beforePos token.Pos) *ast.CallExpr {
257215
var result *ast.CallExpr

internal/patterns/util.go

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,36 @@ package patterns
22

33
import (
44
"go/ast"
5+
"go/token"
56
"go/types"
67

78
"golang.org/x/tools/go/analysis"
89

910
"github.com/mpyw/goroutinectx/internal/typeutil"
1011
)
1112

12-
// funcLitHasContextParam checks if a function literal has a context.Context parameter.
13-
func funcLitHasContextParam(cctx *CheckContext, lit *ast.FuncLit) bool {
14-
if lit.Type == nil || lit.Type.Params == nil {
13+
// funcTypeHasContextParam checks if a function type has a context.Context parameter.
14+
func funcTypeHasContextParam(cctx *CheckContext, fnType *ast.FuncType) bool {
15+
if fnType == nil || fnType.Params == nil {
1516
return false
1617
}
17-
for _, field := range lit.Type.Params.List {
18+
for _, field := range fnType.Params.List {
1819
typ := cctx.Pass.TypesInfo.TypeOf(field.Type)
1920
if typ == nil {
2021
continue
2122
}
22-
if isContextType(typ) {
23+
if typeutil.IsContextType(typ) {
2324
return true
2425
}
2526
}
2627
return false
2728
}
2829

30+
// funcLitHasContextParam checks if a function literal has a context.Context parameter.
31+
func funcLitHasContextParam(cctx *CheckContext, lit *ast.FuncLit) bool {
32+
return funcTypeHasContextParam(cctx, lit.Type)
33+
}
34+
2935
// funcLitUsesContext checks if a function literal references any context variable.
3036
// It does NOT descend into nested func literals - they have their own scope and
3137
// will be checked separately.
@@ -105,8 +111,9 @@ func argUsesContext(cctx *CheckContext, expr ast.Expr) bool {
105111
}
106112

107113
// findFuncLitAssignment searches for the func literal assigned to the variable.
108-
// Returns the LAST assignment found (for cases where variables are reassigned).
109-
func findFuncLitAssignment(cctx *CheckContext, v *types.Var) *ast.FuncLit {
114+
// If beforePos is token.NoPos, returns the LAST assignment found.
115+
// If beforePos is set, returns the last assignment BEFORE that position.
116+
func findFuncLitAssignment(cctx *CheckContext, v *types.Var, beforePos token.Pos) *ast.FuncLit {
110117
var result *ast.FuncLit
111118
declPos := v.Pos()
112119

@@ -120,6 +127,10 @@ func findFuncLitAssignment(cctx *CheckContext, v *types.Var) *ast.FuncLit {
120127
if !ok {
121128
return true
122129
}
130+
// Skip assignments at or after beforePos
131+
if beforePos != token.NoPos && assign.Pos() >= beforePos {
132+
return true
133+
}
123134
if fl := findFuncLitInAssignment(cctx, assign, v); fl != nil {
124135
result = fl // Keep updating - we want the LAST assignment
125136
}
@@ -151,24 +162,29 @@ func findFuncLitInAssignment(cctx *CheckContext, assign *ast.AssignStmt, v *type
151162
return nil
152163
}
153164

154-
// factoryReturnsContextUsingFunc checks if a factory function's return statements
165+
// blockReturnsContextUsingFunc checks if a block's return statements
155166
// return functions that use context. Recursively checks nested func literals.
156-
func factoryReturnsContextUsingFunc(cctx *CheckContext, factory *ast.FuncLit) bool {
167+
// excludeFuncLit can be set to exclude a specific FuncLit from being counted (e.g., the parent).
168+
func blockReturnsContextUsingFunc(cctx *CheckContext, body *ast.BlockStmt, excludeFuncLit *ast.FuncLit) bool {
169+
if body == nil {
170+
return true // No body to check
171+
}
172+
157173
usesContext := false
158174

159-
ast.Inspect(factory.Body, func(n ast.Node) bool {
175+
ast.Inspect(body, func(n ast.Node) bool {
160176
if usesContext {
161177
return false
162178
}
163179
// For nested func literals, check both direct usage and returned values
164-
if fl, ok := n.(*ast.FuncLit); ok && fl != factory {
180+
if fl, ok := n.(*ast.FuncLit); ok && fl != excludeFuncLit {
165181
// Check if this nested func lit uses context directly
166182
if funcLitUsesContext(cctx, fl) {
167183
usesContext = true
168184
return false
169185
}
170186
// Recursively check if it returns functions that use context
171-
if factoryReturnsContextUsingFunc(cctx, fl) {
187+
if blockReturnsContextUsingFunc(cctx, fl.Body, fl) {
172188
usesContext = true
173189
return false
174190
}
@@ -192,6 +208,12 @@ func factoryReturnsContextUsingFunc(cctx *CheckContext, factory *ast.FuncLit) bo
192208
return usesContext
193209
}
194210

211+
// factoryReturnsContextUsingFunc checks if a factory FuncLit's return statements
212+
// return functions that use context.
213+
func factoryReturnsContextUsingFunc(cctx *CheckContext, factory *ast.FuncLit) bool {
214+
return blockReturnsContextUsingFunc(cctx, factory.Body, factory)
215+
}
216+
195217
// returnedValueUsesContext checks if a returned value is a func that uses context.
196218
func returnedValueUsesContext(cctx *CheckContext, result ast.Expr) bool {
197219
// If it's a func literal, check directly
@@ -215,7 +237,7 @@ func returnedValueUsesContext(cctx *CheckContext, result ast.Expr) bool {
215237
return false
216238
}
217239

218-
innerFuncLit := findFuncLitAssignment(cctx, v)
240+
innerFuncLit := findFuncLitAssignment(cctx, v, token.NoPos)
219241
if innerFuncLit == nil {
220242
return false
221243
}

0 commit comments

Comments
 (0)