Skip to content

Commit fa3d081

Browse files
committed
Merge branch 'main' into fix/factory-conditional-reassignment
2 parents 0691850 + e8659b1 commit fa3d081

12 files changed

+432
-29
lines changed

internal/checkers/goroutine.go

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ func (*Goroutine) checkFromAST(cctx *probe.Context, stmt *ast.GoStmt) bool {
6161
}
6262

6363
if ident, ok := call.Fun.(*ast.Ident); ok {
64-
funcLit := cctx.FuncLitOfIdent(ident)
65-
if funcLit == nil {
64+
assigns := cctx.FuncLitAssignmentsOfIdent(ident)
65+
if len(assigns) == 0 {
6666
return true
6767
}
68-
return cctx.FuncLitCapturesContext(funcLit)
68+
return cctx.FuncLitsAllCaptureContext(assigns)
6969
}
7070

7171
if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
@@ -166,16 +166,35 @@ func (c *GoroutineDerive) checkFromSSA(cctx *probe.Context, lit *ast.FuncLit) (*
166166
}
167167

168168
func (c *GoroutineDerive) checkIdent(cctx *probe.Context, ident *ast.Ident) bool {
169-
funcLit := cctx.FuncLitOfIdent(ident)
170-
if funcLit == nil {
169+
assigns := cctx.FuncLitAssignmentsOfIdent(ident)
170+
if len(assigns) == 0 {
171171
return true
172172
}
173173

174-
if cctx.FuncLitHasContextParam(funcLit) {
175-
return true
174+
// Find the index of the last unconditional assignment
175+
lastUnconditionalIdx := -1
176+
for i := len(assigns) - 1; i >= 0; i-- {
177+
if !assigns[i].Conditional {
178+
lastUnconditionalIdx = i
179+
break
180+
}
181+
}
182+
183+
// Determine the starting point for checks
184+
startIdx := 0
185+
if lastUnconditionalIdx >= 0 {
186+
startIdx = lastUnconditionalIdx
187+
}
188+
189+
// Check all assignments from startIdx onwards
190+
for i := startIdx; i < len(assigns); i++ {
191+
lit := assigns[i].Lit
192+
if !cctx.FuncLitHasContextParam(lit) && !c.derivers.SatisfiesAnyGroup(cctx.Pass, lit.Body) {
193+
return false
194+
}
176195
}
177196

178-
return c.derivers.SatisfiesAnyGroup(cctx.Pass, funcLit.Body)
197+
return true
179198
}
180199

181200
func (c *GoroutineDerive) checkHigherOrder(cctx *probe.Context, innerCall *ast.CallExpr) bool {
@@ -192,14 +211,34 @@ func (c *GoroutineDerive) checkHigherOrder(cctx *probe.Context, innerCall *ast.C
192211
}
193212

194213
if ident, ok := fun.(*ast.Ident); ok {
195-
funcLit := cctx.FuncLitOfIdent(ident)
196-
if funcLit == nil {
214+
assigns := cctx.FuncLitAssignmentsOfIdent(ident)
215+
if len(assigns) == 0 {
197216
return true
198217
}
199-
if cctx.FuncLitHasContextParam(funcLit) {
200-
return true
218+
219+
// Find the index of the last unconditional assignment
220+
lastUnconditionalIdx := -1
221+
for i := len(assigns) - 1; i >= 0; i-- {
222+
if !assigns[i].Conditional {
223+
lastUnconditionalIdx = i
224+
break
225+
}
226+
}
227+
228+
// Determine the starting point for checks
229+
startIdx := 0
230+
if lastUnconditionalIdx >= 0 {
231+
startIdx = lastUnconditionalIdx
201232
}
202-
return c.factoryReturnsCallingFunc(cctx, funcLit)
233+
234+
// Check all assignments from startIdx onwards
235+
for i := startIdx; i < len(assigns); i++ {
236+
lit := assigns[i].Lit
237+
if !cctx.FuncLitHasContextParam(lit) && !c.factoryReturnsCallingFunc(cctx, lit) {
238+
return false
239+
}
240+
}
241+
return true
203242
}
204243

205244
return true
@@ -254,13 +293,32 @@ func (c *GoroutineDerive) returnedValueCalls(cctx *probe.Context, result ast.Exp
254293
return false
255294
}
256295

257-
innerFuncLit := cctx.FuncLitOfIdent(ident)
258-
if innerFuncLit == nil {
296+
assigns := cctx.FuncLitAssignmentsOfIdent(ident)
297+
if len(assigns) == 0 {
259298
return false
260299
}
261300

262-
if cctx.FuncLitHasContextParam(innerFuncLit) {
263-
return true
301+
// Find the index of the last unconditional assignment
302+
lastUnconditionalIdx := -1
303+
for i := len(assigns) - 1; i >= 0; i-- {
304+
if !assigns[i].Conditional {
305+
lastUnconditionalIdx = i
306+
break
307+
}
264308
}
265-
return c.derivers.SatisfiesAnyGroup(cctx.Pass, innerFuncLit.Body)
309+
310+
// Determine the starting point for checks
311+
startIdx := 0
312+
if lastUnconditionalIdx >= 0 {
313+
startIdx = lastUnconditionalIdx
314+
}
315+
316+
// Check all assignments from startIdx onwards
317+
for i := startIdx; i < len(assigns); i++ {
318+
lit := assigns[i].Lit
319+
if !cctx.FuncLitHasContextParam(lit) && !c.derivers.SatisfiesAnyGroup(cctx.Pass, lit.Body) {
320+
return false
321+
}
322+
}
323+
return true
266324
}

internal/checkers/spawner.go

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,11 @@ func (c *SpawnCallbackChecker) checkArgFromAST(cctx *probe.Context, arg ast.Expr
150150
}
151151

152152
if ident, ok := arg.(*ast.Ident); ok {
153-
funcLit := cctx.FuncLitOfIdent(ident)
154-
if funcLit == nil {
153+
assigns := cctx.FuncLitAssignmentsOfIdent(ident)
154+
if len(assigns) == 0 {
155155
return true
156156
}
157-
return c.checkFuncLitAST(cctx, funcLit)
157+
return c.checkFuncLitAssignments(cctx, assigns)
158158
}
159159

160160
if call, ok := arg.(*ast.CallExpr); ok {
@@ -172,6 +172,34 @@ func (c *SpawnCallbackChecker) checkArgFromAST(cctx *probe.Context, arg ast.Expr
172172
return true
173173
}
174174

175+
// checkFuncLitAssignments checks all func literal assignments from last unconditional onwards.
176+
// ALL must pass for the check to succeed.
177+
func (c *SpawnCallbackChecker) checkFuncLitAssignments(cctx *probe.Context, assigns []probe.FuncLitAssignment) bool {
178+
// Find the index of the last unconditional assignment
179+
lastUnconditionalIdx := -1
180+
for i := len(assigns) - 1; i >= 0; i-- {
181+
if !assigns[i].Conditional {
182+
lastUnconditionalIdx = i
183+
break
184+
}
185+
}
186+
187+
// Determine the starting point for checks
188+
startIdx := 0
189+
if lastUnconditionalIdx >= 0 {
190+
startIdx = lastUnconditionalIdx
191+
}
192+
193+
// Check all assignments from startIdx onwards
194+
// ALL must pass (because conditional assignments may override)
195+
for i := startIdx; i < len(assigns); i++ {
196+
if !c.checkFuncLitAST(cctx, assigns[i].Lit) {
197+
return false
198+
}
199+
}
200+
return true
201+
}
202+
175203
// checkFuncLitAST checks a func literal using AST-based analysis.
176204
func (c *SpawnCallbackChecker) checkFuncLitAST(cctx *probe.Context, lit *ast.FuncLit) bool {
177205
// Check context capture
@@ -333,11 +361,11 @@ func (c *SpawnerChecker) checkFuncArg(cctx *probe.Context, arg ast.Expr) bool {
333361
}
334362

335363
if ident, ok := arg.(*ast.Ident); ok {
336-
funcLit := cctx.FuncLitOfIdent(ident)
337-
if funcLit == nil {
364+
assigns := cctx.FuncLitAssignmentsOfIdent(ident)
365+
if len(assigns) == 0 {
338366
return true
339367
}
340-
return c.checkFuncLitAST(cctx, funcLit)
368+
return c.checkFuncLitAssignments(cctx, assigns)
341369
}
342370

343371
if call, ok := arg.(*ast.CallExpr); ok {
@@ -347,6 +375,34 @@ func (c *SpawnerChecker) checkFuncArg(cctx *probe.Context, arg ast.Expr) bool {
347375
return true
348376
}
349377

378+
// checkFuncLitAssignments checks all func literal assignments from last unconditional onwards.
379+
// ALL must pass for the check to succeed.
380+
func (c *SpawnerChecker) checkFuncLitAssignments(cctx *probe.Context, assigns []probe.FuncLitAssignment) bool {
381+
// Find the index of the last unconditional assignment
382+
lastUnconditionalIdx := -1
383+
for i := len(assigns) - 1; i >= 0; i-- {
384+
if !assigns[i].Conditional {
385+
lastUnconditionalIdx = i
386+
break
387+
}
388+
}
389+
390+
// Determine the starting point for checks
391+
startIdx := 0
392+
if lastUnconditionalIdx >= 0 {
393+
startIdx = lastUnconditionalIdx
394+
}
395+
396+
// Check all assignments from startIdx onwards
397+
// ALL must pass (because conditional assignments may override)
398+
for i := startIdx; i < len(assigns); i++ {
399+
if !c.checkFuncLitAST(cctx, assigns[i].Lit) {
400+
return false
401+
}
402+
}
403+
return true
404+
}
405+
350406
// checkFuncLitSSA checks a func literal using SSA analysis for SpawnerChecker.
351407
func (c *SpawnerChecker) checkFuncLitSSA(cctx *probe.Context, lit *ast.FuncLit) (bool, bool) {
352408
if cctx.SSAProg == nil || cctx.Tracer == nil {

internal/probe/assignment.go

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,26 @@ func (c *Context) FuncLitOfIdent(ident *ast.Ident) *ast.FuncLit {
2424
return c.FuncLitAssignedTo(v, token.NoPos)
2525
}
2626

27+
// FuncLitsOfIdent returns ALL func literals assigned to the identifier's variable.
28+
// This is needed for conditional reassignment patterns where different branches
29+
// assign different closures to the same variable.
30+
func (c *Context) FuncLitsOfIdent(ident *ast.Ident) []*ast.FuncLit {
31+
v := c.VarOf(ident)
32+
if v == nil {
33+
return nil
34+
}
35+
return c.FuncLitsAssignedTo(v, token.NoPos)
36+
}
37+
38+
// FuncLitAssignmentsOfIdent returns ALL func literal assignments with conditionality info.
39+
func (c *Context) FuncLitAssignmentsOfIdent(ident *ast.Ident) []FuncLitAssignment {
40+
v := c.VarOf(ident)
41+
if v == nil {
42+
return nil
43+
}
44+
return c.FuncLitAssignmentsTo(v, token.NoPos)
45+
}
46+
2747
// FuncLitAssignedTo searches for the func literal assigned to the variable.
2848
// If beforePos is token.NoPos, returns the LAST assignment found.
2949
// If beforePos is set, returns the last assignment BEFORE that position.
@@ -51,13 +71,32 @@ func (c *Context) FuncLitAssignedTo(v *types.Var, beforePos token.Pos) *ast.Func
5171
return result
5272
}
5373

54-
// FuncLitAssignmentsOfIdent returns ALL func literal assignments with conditionality info.
55-
func (c *Context) FuncLitAssignmentsOfIdent(ident *ast.Ident) []FuncLitAssignment {
56-
v := c.VarOf(ident)
57-
if v == nil {
74+
// FuncLitsAssignedTo searches for ALL func literals assigned to the variable.
75+
// If beforePos is token.NoPos, returns ALL assignments found.
76+
// If beforePos is set, returns all assignments BEFORE that position.
77+
// This is needed for conditional reassignment patterns.
78+
func (c *Context) FuncLitsAssignedTo(v *types.Var, beforePos token.Pos) []*ast.FuncLit {
79+
f := c.FileOf(v.Pos())
80+
if f == nil {
5881
return nil
5982
}
60-
return c.FuncLitAssignmentsTo(v, token.NoPos)
83+
84+
var results []*ast.FuncLit
85+
ast.Inspect(f, func(n ast.Node) bool {
86+
assign, ok := n.(*ast.AssignStmt)
87+
if !ok {
88+
return true
89+
}
90+
if beforePos != token.NoPos && assign.Pos() >= beforePos {
91+
return true
92+
}
93+
if fl := c.funcLitInAssignment(assign, v); fl != nil {
94+
results = append(results, fl)
95+
}
96+
return true
97+
})
98+
99+
return results
61100
}
62101

63102
// FuncLitAssignmentsTo searches for ALL func literal assignments with conditionality info.

internal/probe/capture.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,41 @@ func (c *Context) FuncLitCapturesContext(lit *ast.FuncLit) bool {
5353
return c.FuncLitHasContextParam(lit) || c.FuncLitUsesContext(lit)
5454
}
5555

56+
// FuncLitsAllCaptureContext checks if func literals properly capture context.
57+
// Uses conditionality information to determine the correct check:
58+
// - Find the last unconditional assignment
59+
// - Check all assignments from that point onwards (including conditional ones)
60+
// - ALL must capture context for the check to pass
61+
func (c *Context) FuncLitsAllCaptureContext(assigns []FuncLitAssignment) bool {
62+
if len(assigns) == 0 {
63+
return true
64+
}
65+
66+
// Find the index of the last unconditional assignment
67+
lastUnconditionalIdx := -1
68+
for i := len(assigns) - 1; i >= 0; i-- {
69+
if !assigns[i].Conditional {
70+
lastUnconditionalIdx = i
71+
break
72+
}
73+
}
74+
75+
// Determine the starting point for checks
76+
startIdx := 0
77+
if lastUnconditionalIdx >= 0 {
78+
startIdx = lastUnconditionalIdx
79+
}
80+
81+
// Check all assignments from startIdx onwards
82+
// ALL must capture context (because conditional assignments may override)
83+
for i := startIdx; i < len(assigns); i++ {
84+
if !c.FuncLitCapturesContext(assigns[i].Lit) {
85+
return false
86+
}
87+
}
88+
return true
89+
}
90+
5691
// FuncLitUsesContext checks if a function literal references any context variable.
5792
// Does NOT descend into nested func literals.
5893
func (c *Context) FuncLitUsesContext(lit *ast.FuncLit) bool {
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"title": "Conditional reassignment - all paths use ctx",
3+
"targets": [
4+
"goroutine"
5+
],
6+
"variants": {
7+
"bad": null,
8+
"good": {
9+
"description": "All assignments use context, so all paths propagate it.",
10+
"functions": {
11+
"goroutine": "goodConditionalReassignAllUseCtx"
12+
}
13+
}
14+
},
15+
"level": "evil"
16+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"title": "Conditional reassignment - first doesn't use ctx, conditional does",
3+
"targets": [
4+
"goroutine"
5+
],
6+
"variants": {
7+
"bad": {
8+
"description": "First assignment ignores context, conditional reassignment uses it.",
9+
"functions": {
10+
"goroutine": "badConditionalReassignFirstNoCtx"
11+
}
12+
},
13+
"good": null
14+
},
15+
"level": "evil"
16+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"title": "Conditional reassignment - first uses ctx, conditional doesn't",
3+
"targets": [
4+
"goroutine"
5+
],
6+
"variants": {
7+
"bad": {
8+
"description": "First assignment uses context, but conditional reassignment doesn't.",
9+
"functions": {
10+
"goroutine": "badConditionalReassignFirstUsesCtx"
11+
}
12+
},
13+
"good": null
14+
},
15+
"level": "evil"
16+
}

0 commit comments

Comments
 (0)