Skip to content

Commit e8659b1

Browse files
authored
Merge pull request #23 from mpyw/fix/errgroup-conditional-reassignment
fix: track conditional reassignments in errgroup/spawner checkers
2 parents 1069525 + 0e6eafc commit e8659b1

File tree

5 files changed

+185
-6
lines changed

5 files changed

+185
-6
lines changed

internal/checkers/spawner.go

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,11 @@ func (c *SpawnCallbackChecker) checkArgFromAST(cctx *probe.Context, arg ast.Expr
150150
}
151151

152152
if ident, ok := arg.(*ast.Ident); ok {
153-
funcLit := cctx.FuncLitOfIdent(ident)
154-
if funcLit == nil {
153+
assigns := cctx.FuncLitAssignmentsOfIdent(ident)
154+
if len(assigns) == 0 {
155155
return true
156156
}
157-
return c.checkFuncLitAST(cctx, funcLit)
157+
return c.checkFuncLitAssignments(cctx, assigns)
158158
}
159159

160160
if call, ok := arg.(*ast.CallExpr); ok {
@@ -172,6 +172,34 @@ func (c *SpawnCallbackChecker) checkArgFromAST(cctx *probe.Context, arg ast.Expr
172172
return true
173173
}
174174

175+
// checkFuncLitAssignments checks all func literal assignments from last unconditional onwards.
176+
// ALL must pass for the check to succeed.
177+
func (c *SpawnCallbackChecker) checkFuncLitAssignments(cctx *probe.Context, assigns []probe.FuncLitAssignment) bool {
178+
// Find the index of the last unconditional assignment
179+
lastUnconditionalIdx := -1
180+
for i := len(assigns) - 1; i >= 0; i-- {
181+
if !assigns[i].Conditional {
182+
lastUnconditionalIdx = i
183+
break
184+
}
185+
}
186+
187+
// Determine the starting point for checks
188+
startIdx := 0
189+
if lastUnconditionalIdx >= 0 {
190+
startIdx = lastUnconditionalIdx
191+
}
192+
193+
// Check all assignments from startIdx onwards
194+
// ALL must pass (because conditional assignments may override)
195+
for i := startIdx; i < len(assigns); i++ {
196+
if !c.checkFuncLitAST(cctx, assigns[i].Lit) {
197+
return false
198+
}
199+
}
200+
return true
201+
}
202+
175203
// checkFuncLitAST checks a func literal using AST-based analysis.
176204
func (c *SpawnCallbackChecker) checkFuncLitAST(cctx *probe.Context, lit *ast.FuncLit) bool {
177205
// Check context capture
@@ -333,11 +361,11 @@ func (c *SpawnerChecker) checkFuncArg(cctx *probe.Context, arg ast.Expr) bool {
333361
}
334362

335363
if ident, ok := arg.(*ast.Ident); ok {
336-
funcLit := cctx.FuncLitOfIdent(ident)
337-
if funcLit == nil {
364+
assigns := cctx.FuncLitAssignmentsOfIdent(ident)
365+
if len(assigns) == 0 {
338366
return true
339367
}
340-
return c.checkFuncLitAST(cctx, funcLit)
368+
return c.checkFuncLitAssignments(cctx, assigns)
341369
}
342370

343371
if call, ok := arg.(*ast.CallExpr); ok {
@@ -347,6 +375,34 @@ func (c *SpawnerChecker) checkFuncArg(cctx *probe.Context, arg ast.Expr) bool {
347375
return true
348376
}
349377

378+
// checkFuncLitAssignments checks all func literal assignments from last unconditional onwards.
379+
// ALL must pass for the check to succeed.
380+
func (c *SpawnerChecker) checkFuncLitAssignments(cctx *probe.Context, assigns []probe.FuncLitAssignment) bool {
381+
// Find the index of the last unconditional assignment
382+
lastUnconditionalIdx := -1
383+
for i := len(assigns) - 1; i >= 0; i-- {
384+
if !assigns[i].Conditional {
385+
lastUnconditionalIdx = i
386+
break
387+
}
388+
}
389+
390+
// Determine the starting point for checks
391+
startIdx := 0
392+
if lastUnconditionalIdx >= 0 {
393+
startIdx = lastUnconditionalIdx
394+
}
395+
396+
// Check all assignments from startIdx onwards
397+
// ALL must pass (because conditional assignments may override)
398+
for i := startIdx; i < len(assigns); i++ {
399+
if !c.checkFuncLitAST(cctx, assigns[i].Lit) {
400+
return false
401+
}
402+
}
403+
return true
404+
}
405+
350406
// checkFuncLitSSA checks a func literal using SSA analysis for SpawnerChecker.
351407
func (c *SpawnerChecker) checkFuncLitSSA(cctx *probe.Context, lit *ast.FuncLit) (bool, bool) {
352408
if cctx.SSAProg == nil || cctx.Tracer == nil {
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"title": "Conditional reassignment - all paths use ctx",
3+
"targets": [
4+
"errgroup"
5+
],
6+
"variants": {
7+
"bad": null,
8+
"good": {
9+
"description": "All assignments use context, so all paths propagate it.",
10+
"functions": {
11+
"errgroup": "goodConditionalReassignAllUseCtx"
12+
}
13+
}
14+
},
15+
"level": "evil"
16+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"title": "Conditional reassignment - first doesn't use ctx, conditional does",
3+
"targets": [
4+
"errgroup"
5+
],
6+
"variants": {
7+
"bad": {
8+
"description": "First assignment ignores context, conditional reassignment uses it.",
9+
"functions": {
10+
"errgroup": "badConditionalReassignFirstNoCtx"
11+
}
12+
},
13+
"good": null
14+
},
15+
"level": "evil"
16+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"title": "Conditional reassignment - first uses ctx, conditional doesn't",
3+
"targets": [
4+
"errgroup"
5+
],
6+
"variants": {
7+
"bad": {
8+
"description": "First assignment uses context, but conditional reassignment doesn't.",
9+
"functions": {
10+
"errgroup": "badConditionalReassignFirstUsesCtx"
11+
}
12+
},
13+
"good": null
14+
},
15+
"level": "evil"
16+
}

testdata/src/errgroup/evil.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,3 +843,78 @@ func badNestedFactoryWithoutCtx(ctx context.Context) {
843843
g.Go(makeFactory()()) // want `errgroup.Group.Go\(\) closure should use context "ctx"`
844844
_ = g.Wait()
845845
}
846+
847+
// ===== CONDITIONAL REASSIGNMENT PATTERNS =====
848+
// These patterns test conditional reassignment of function variables.
849+
// ALL assignments from last unconditional onwards must capture context.
850+
851+
var conditionFlag bool
852+
853+
// [BAD]: Conditional reassignment - first uses ctx, conditional doesn't
854+
//
855+
// First assignment uses context, but conditional reassignment doesn't.
856+
//
857+
// See also:
858+
// goroutine: badConditionalReassignFirstUsesCtx
859+
// waitgroup: badConditionalReassignFirstUsesCtx
860+
func badConditionalReassignFirstUsesCtx(ctx context.Context) {
861+
g := new(errgroup.Group)
862+
fn := func() error {
863+
_ = ctx
864+
return nil
865+
}
866+
if conditionFlag {
867+
fn = func() error {
868+
fmt.Println("no ctx")
869+
return nil
870+
}
871+
}
872+
g.Go(fn) // want `errgroup.Group.Go\(\) closure should use context "ctx"`
873+
_ = g.Wait()
874+
}
875+
876+
// [BAD]: Conditional reassignment - first doesn't use ctx, conditional does
877+
//
878+
// First assignment ignores context, conditional reassignment uses it.
879+
//
880+
// See also:
881+
// goroutine: badConditionalReassignFirstNoCtx
882+
// waitgroup: badConditionalReassignFirstNoCtx
883+
func badConditionalReassignFirstNoCtx(ctx context.Context) {
884+
g := new(errgroup.Group)
885+
fn := func() error {
886+
fmt.Println("no ctx")
887+
return nil
888+
}
889+
if conditionFlag {
890+
fn = func() error {
891+
_ = ctx
892+
return nil
893+
}
894+
}
895+
g.Go(fn) // want `errgroup.Group.Go\(\) closure should use context "ctx"`
896+
_ = g.Wait()
897+
}
898+
899+
// [GOOD]: Conditional reassignment - all paths use ctx
900+
//
901+
// All assignments use context, so all paths propagate it.
902+
//
903+
// See also:
904+
// goroutine: goodConditionalReassignAllUseCtx
905+
// waitgroup: goodConditionalReassignAllUseCtx
906+
func goodConditionalReassignAllUseCtx(ctx context.Context) {
907+
g := new(errgroup.Group)
908+
fn := func() error {
909+
_ = ctx
910+
return nil
911+
}
912+
if conditionFlag {
913+
fn = func() error {
914+
_ = ctx
915+
return nil
916+
}
917+
}
918+
g.Go(fn) // OK - all assignments use ctx
919+
_ = g.Wait()
920+
}

0 commit comments

Comments
 (0)