@@ -2,30 +2,36 @@ package patterns
22
33import (
44 "go/ast"
5+ "go/token"
56 "go/types"
67
78 "golang.org/x/tools/go/analysis"
89
910 "github.com/mpyw/goroutinectx/internal/typeutil"
1011)
1112
12- // funcLitHasContextParam checks if a function literal has a context.Context parameter.
13- func funcLitHasContextParam (cctx * CheckContext , lit * ast.FuncLit ) bool {
14- if lit . Type == nil || lit . Type .Params == nil {
13+ // funcTypeHasContextParam checks if a function type has a context.Context parameter.
14+ func funcTypeHasContextParam (cctx * CheckContext , fnType * ast.FuncType ) bool {
15+ if fnType == nil || fnType .Params == nil {
1516 return false
1617 }
17- for _ , field := range lit . Type .Params .List {
18+ for _ , field := range fnType .Params .List {
1819 typ := cctx .Pass .TypesInfo .TypeOf (field .Type )
1920 if typ == nil {
2021 continue
2122 }
22- if isContextType (typ ) {
23+ if typeutil . IsContextType (typ ) {
2324 return true
2425 }
2526 }
2627 return false
2728}
2829
30+ // funcLitHasContextParam checks if a function literal has a context.Context parameter.
31+ func funcLitHasContextParam (cctx * CheckContext , lit * ast.FuncLit ) bool {
32+ return funcTypeHasContextParam (cctx , lit .Type )
33+ }
34+
2935// funcLitUsesContext checks if a function literal references any context variable.
3036// It does NOT descend into nested func literals - they have their own scope and
3137// will be checked separately.
@@ -105,8 +111,9 @@ func argUsesContext(cctx *CheckContext, expr ast.Expr) bool {
105111}
106112
107113// findFuncLitAssignment searches for the func literal assigned to the variable.
108- // Returns the LAST assignment found (for cases where variables are reassigned).
109- func findFuncLitAssignment (cctx * CheckContext , v * types.Var ) * ast.FuncLit {
114+ // If beforePos is token.NoPos, returns the LAST assignment found.
115+ // If beforePos is set, returns the last assignment BEFORE that position.
116+ func findFuncLitAssignment (cctx * CheckContext , v * types.Var , beforePos token.Pos ) * ast.FuncLit {
110117 var result * ast.FuncLit
111118 declPos := v .Pos ()
112119
@@ -120,6 +127,10 @@ func findFuncLitAssignment(cctx *CheckContext, v *types.Var) *ast.FuncLit {
120127 if ! ok {
121128 return true
122129 }
130+ // Skip assignments at or after beforePos
131+ if beforePos != token .NoPos && assign .Pos () >= beforePos {
132+ return true
133+ }
123134 if fl := findFuncLitInAssignment (cctx , assign , v ); fl != nil {
124135 result = fl // Keep updating - we want the LAST assignment
125136 }
@@ -151,24 +162,29 @@ func findFuncLitInAssignment(cctx *CheckContext, assign *ast.AssignStmt, v *type
151162 return nil
152163}
153164
154- // factoryReturnsContextUsingFunc checks if a factory function 's return statements
165+ // blockReturnsContextUsingFunc checks if a block 's return statements
155166// return functions that use context. Recursively checks nested func literals.
156- func factoryReturnsContextUsingFunc (cctx * CheckContext , factory * ast.FuncLit ) bool {
167+ // excludeFuncLit can be set to exclude a specific FuncLit from being counted (e.g., the parent).
168+ func blockReturnsContextUsingFunc (cctx * CheckContext , body * ast.BlockStmt , excludeFuncLit * ast.FuncLit ) bool {
169+ if body == nil {
170+ return true // No body to check
171+ }
172+
157173 usesContext := false
158174
159- ast .Inspect (factory . Body , func (n ast.Node ) bool {
175+ ast .Inspect (body , func (n ast.Node ) bool {
160176 if usesContext {
161177 return false
162178 }
163179 // For nested func literals, check both direct usage and returned values
164- if fl , ok := n .(* ast.FuncLit ); ok && fl != factory {
180+ if fl , ok := n .(* ast.FuncLit ); ok && fl != excludeFuncLit {
165181 // Check if this nested func lit uses context directly
166182 if funcLitUsesContext (cctx , fl ) {
167183 usesContext = true
168184 return false
169185 }
170186 // Recursively check if it returns functions that use context
171- if factoryReturnsContextUsingFunc (cctx , fl ) {
187+ if blockReturnsContextUsingFunc (cctx , fl . Body , fl ) {
172188 usesContext = true
173189 return false
174190 }
@@ -192,6 +208,12 @@ func factoryReturnsContextUsingFunc(cctx *CheckContext, factory *ast.FuncLit) bo
192208 return usesContext
193209}
194210
211+ // factoryReturnsContextUsingFunc checks if a factory FuncLit's return statements
212+ // return functions that use context.
213+ func factoryReturnsContextUsingFunc (cctx * CheckContext , factory * ast.FuncLit ) bool {
214+ return blockReturnsContextUsingFunc (cctx , factory .Body , factory )
215+ }
216+
195217// returnedValueUsesContext checks if a returned value is a func that uses context.
196218func returnedValueUsesContext (cctx * CheckContext , result ast.Expr ) bool {
197219 // If it's a func literal, check directly
@@ -215,7 +237,7 @@ func returnedValueUsesContext(cctx *CheckContext, result ast.Expr) bool {
215237 return false
216238 }
217239
218- innerFuncLit := findFuncLitAssignment (cctx , v )
240+ innerFuncLit := findFuncLitAssignment (cctx , v , token . NoPos )
219241 if innerFuncLit == nil {
220242 return false
221243 }
0 commit comments