Skip to content

Commit 0691850

Browse files
mpywclaude
andcommitted
fix: factory pattern should check all conditional reassignments
The factory.go had a bug where BlockReturnsContextUsingFunc was short-circuiting when it found ANY nested func literal that uses context, even if that func wasn't actually returned. Fixed by: - Removing the short-circuit and only checking what's actually returned - Adding recursive context-checking for nested factories - Creating funcLitAssignmentsAllUseOrReturnContext to check ALL assignments from last unconditional onwards Also added FuncLitAssignment type and related methods to assignment.go to track conditional vs unconditional assignments. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 4fb020b commit 0691850

File tree

6 files changed

+236
-12
lines changed

6 files changed

+236
-12
lines changed

internal/probe/assignment.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,16 @@ import (
44
"go/ast"
55
"go/token"
66
"go/types"
7+
8+
"golang.org/x/tools/go/ast/inspector"
79
)
810

11+
// FuncLitAssignment represents a func literal assignment with its conditionality.
12+
type FuncLitAssignment struct {
13+
Lit *ast.FuncLit
14+
Conditional bool // true if inside if/for/switch/select
15+
}
16+
917
// FuncLitOfIdent is a convenience method that combines VarOf and FuncLitAssignedTo.
1018
// Returns the last func literal assignment found.
1119
func (c *Context) FuncLitOfIdent(ident *ast.Ident) *ast.FuncLit {
@@ -43,6 +51,62 @@ func (c *Context) FuncLitAssignedTo(v *types.Var, beforePos token.Pos) *ast.Func
4351
return result
4452
}
4553

54+
// FuncLitAssignmentsOfIdent returns ALL func literal assignments with conditionality info.
55+
func (c *Context) FuncLitAssignmentsOfIdent(ident *ast.Ident) []FuncLitAssignment {
56+
v := c.VarOf(ident)
57+
if v == nil {
58+
return nil
59+
}
60+
return c.FuncLitAssignmentsTo(v, token.NoPos)
61+
}
62+
63+
// FuncLitAssignmentsTo searches for ALL func literal assignments with conditionality info.
64+
func (c *Context) FuncLitAssignmentsTo(v *types.Var, beforePos token.Pos) []FuncLitAssignment {
65+
f := c.FileOf(v.Pos())
66+
if f == nil {
67+
return nil
68+
}
69+
70+
var results []FuncLitAssignment
71+
insp := inspector.New([]*ast.File{f})
72+
73+
insp.WithStack([]ast.Node{(*ast.AssignStmt)(nil)}, func(n ast.Node, push bool, stack []ast.Node) bool {
74+
if !push {
75+
return true
76+
}
77+
assign := n.(*ast.AssignStmt)
78+
if beforePos != token.NoPos && assign.Pos() >= beforePos {
79+
return true
80+
}
81+
fl := c.funcLitInAssignment(assign, v)
82+
if fl == nil {
83+
return true
84+
}
85+
86+
// Check if assignment is inside a control structure
87+
conditional := isInControlStructure(stack)
88+
89+
results = append(results, FuncLitAssignment{
90+
Lit: fl,
91+
Conditional: conditional,
92+
})
93+
return true
94+
})
95+
96+
return results
97+
}
98+
99+
// isInControlStructure checks if the stack contains a control structure.
100+
func isInControlStructure(stack []ast.Node) bool {
101+
for _, node := range stack {
102+
switch node.(type) {
103+
case *ast.IfStmt, *ast.ForStmt, *ast.RangeStmt, *ast.SwitchStmt, *ast.TypeSwitchStmt, *ast.SelectStmt:
104+
return true
105+
}
106+
}
107+
return false
108+
}
109+
46110
// funcLitInAssignment checks if the assignment assigns a func literal to v.
47111
func (c *Context) funcLitInAssignment(assign *ast.AssignStmt, v *types.Var) *ast.FuncLit {
48112
for i, lhs := range assign.Lhs {

internal/probe/factory.go

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
// BlockReturnsContextUsingFunc checks if a block's return statements
1010
// return functions that use context.
11+
// Only checks what's actually returned, not all nested func literals.
1112
func (c *Context) BlockReturnsContextUsingFunc(body *ast.BlockStmt, excludeFuncLit *ast.FuncLit) bool {
1213
if body == nil {
1314
return true
@@ -19,15 +20,8 @@ func (c *Context) BlockReturnsContextUsingFunc(body *ast.BlockStmt, excludeFuncL
1920
if usesContext {
2021
return false
2122
}
23+
// Skip descending into nested func literals (except excludeFuncLit which is the factory itself)
2224
if fl, ok := n.(*ast.FuncLit); ok && fl != excludeFuncLit {
23-
if c.FuncLitUsesContext(fl) {
24-
usesContext = true
25-
return false
26-
}
27-
if c.BlockReturnsContextUsingFunc(fl.Body, fl) {
28-
usesContext = true
29-
return false
30-
}
3125
return false
3226
}
3327

@@ -111,20 +105,52 @@ func (c *Context) IdentFactoryReturnsContextUsingFunc(ident *ast.Ident) bool {
111105
}
112106

113107
// returnedValueUsesContext checks if a returned value is a func that uses context.
108+
// For identifiers, checks ALL assignments from last unconditional onwards.
114109
func (c *Context) returnedValueUsesContext(result ast.Expr) bool {
115110
if innerFuncLit, ok := result.(*ast.FuncLit); ok {
116-
return c.FuncLitUsesContext(innerFuncLit)
111+
// Check if the func lit directly uses context OR returns a context-using func
112+
return c.FuncLitUsesContext(innerFuncLit) || c.BlockReturnsContextUsingFunc(innerFuncLit.Body, innerFuncLit)
117113
}
118114

119115
ident, ok := result.(*ast.Ident)
120116
if !ok {
121117
return false
122118
}
123119

124-
innerFuncLit := c.FuncLitOfIdent(ident)
125-
if innerFuncLit == nil {
120+
assigns := c.FuncLitAssignmentsOfIdent(ident)
121+
if len(assigns) == 0 {
126122
return false
127123
}
128124

129-
return c.FuncLitUsesContext(innerFuncLit)
125+
return c.funcLitAssignmentsAllUseOrReturnContext(assigns)
126+
}
127+
128+
// funcLitAssignmentsAllUseOrReturnContext checks if ALL func literal assignments from
129+
// last unconditional onwards use context OR return a context-using func.
130+
func (c *Context) funcLitAssignmentsAllUseOrReturnContext(assigns []FuncLitAssignment) bool {
131+
// Find the index of the last unconditional assignment
132+
lastUnconditionalIdx := -1
133+
for i := len(assigns) - 1; i >= 0; i-- {
134+
if !assigns[i].Conditional {
135+
lastUnconditionalIdx = i
136+
break
137+
}
138+
}
139+
140+
// Determine the starting point for checks
141+
startIdx := 0
142+
if lastUnconditionalIdx >= 0 {
143+
startIdx = lastUnconditionalIdx
144+
}
145+
146+
// Check all assignments from startIdx onwards
147+
// ALL must use context OR return context-using func
148+
for i := startIdx; i < len(assigns); i++ {
149+
lit := assigns[i].Lit
150+
// Check if the func lit directly uses context OR returns a context-using func
151+
if !c.FuncLitUsesContext(lit) && !c.BlockReturnsContextUsingFunc(lit.Body, lit) {
152+
return false
153+
}
154+
}
155+
return true
130156
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"title": "Factory returns conditionally reassigned variable - all use ctx",
3+
"targets": [
4+
"errgroup"
5+
],
6+
"variants": {
7+
"bad": null,
8+
"good": {
9+
"description": "Factory function returns a variable that is conditionally reassigned. All assignments use context, so all paths propagate it.",
10+
"functions": {
11+
"errgroup": "goodFactoryConditionalReassignAllUseCtx"
12+
}
13+
}
14+
},
15+
"level": "evil"
16+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"title": "Factory returns conditionally reassigned variable - first doesn't use ctx",
3+
"targets": [
4+
"errgroup"
5+
],
6+
"variants": {
7+
"bad": {
8+
"description": "Factory function returns a variable that is conditionally reassigned. First assignment ignores context, conditional reassignment uses it.",
9+
"functions": {
10+
"errgroup": "badFactoryConditionalReassignFirstNoCtx"
11+
}
12+
},
13+
"good": null
14+
},
15+
"level": "evil"
16+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"title": "Factory returns conditionally reassigned variable - first uses ctx",
3+
"targets": [
4+
"errgroup"
5+
],
6+
"variants": {
7+
"bad": {
8+
"description": "Factory function returns a variable that is conditionally reassigned. First assignment uses context, but conditional reassignment doesn't.",
9+
"functions": {
10+
"errgroup": "badFactoryConditionalReassignFirstUsesCtx"
11+
}
12+
},
13+
"good": null
14+
},
15+
"level": "evil"
16+
}

testdata/src/errgroup/evil.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,92 @@ func goodHigherOrderReturnsReassignedVariableWithCtx(ctx context.Context) {
684684
_ = g.Wait()
685685
}
686686

687+
// ===== FACTORY RETURN CONDITIONAL REASSIGNMENT =====
688+
// These patterns test conditional reassignment of variables returned by factory functions.
689+
690+
var factoryConditionFlag bool
691+
692+
// [BAD]: Factory returns conditionally reassigned variable - first uses ctx
693+
//
694+
// Factory function returns a variable that is conditionally reassigned.
695+
// First assignment uses context, but conditional reassignment doesn't.
696+
//
697+
// See also:
698+
// goroutine: badFactoryConditionalReassignFirstUsesCtx
699+
// waitgroup: badFactoryConditionalReassignFirstUsesCtx
700+
func badFactoryConditionalReassignFirstUsesCtx(ctx context.Context) {
701+
g := new(errgroup.Group)
702+
makeWorker := func() func() error {
703+
worker := func() error {
704+
_ = ctx // First uses ctx
705+
return nil
706+
}
707+
if factoryConditionFlag {
708+
worker = func() error {
709+
fmt.Println("no ctx") // Conditional doesn't use ctx
710+
return nil
711+
}
712+
}
713+
return worker
714+
}
715+
g.Go(makeWorker()) // want `errgroup.Group.Go\(\) closure should use context "ctx"`
716+
_ = g.Wait()
717+
}
718+
719+
// [BAD]: Factory returns conditionally reassigned variable - first doesn't use ctx
720+
//
721+
// Factory function returns a variable that is conditionally reassigned.
722+
// First assignment ignores context, conditional reassignment uses it.
723+
//
724+
// See also:
725+
// goroutine: badFactoryConditionalReassignFirstNoCtx
726+
// waitgroup: badFactoryConditionalReassignFirstNoCtx
727+
func badFactoryConditionalReassignFirstNoCtx(ctx context.Context) {
728+
g := new(errgroup.Group)
729+
makeWorker := func() func() error {
730+
worker := func() error {
731+
fmt.Println("no ctx") // First doesn't use ctx
732+
return nil
733+
}
734+
if factoryConditionFlag {
735+
worker = func() error {
736+
_ = ctx // Conditional uses ctx
737+
return nil
738+
}
739+
}
740+
return worker
741+
}
742+
g.Go(makeWorker()) // want `errgroup.Group.Go\(\) closure should use context "ctx"`
743+
_ = g.Wait()
744+
}
745+
746+
// [GOOD]: Factory returns conditionally reassigned variable - all use ctx
747+
//
748+
// Factory function returns a variable that is conditionally reassigned.
749+
// All assignments use context, so all paths propagate it.
750+
//
751+
// See also:
752+
// goroutine: goodFactoryConditionalReassignAllUseCtx
753+
// waitgroup: goodFactoryConditionalReassignAllUseCtx
754+
func goodFactoryConditionalReassignAllUseCtx(ctx context.Context) {
755+
g := new(errgroup.Group)
756+
makeWorker := func() func() error {
757+
worker := func() error {
758+
_ = ctx // First uses ctx
759+
return nil
760+
}
761+
if factoryConditionFlag {
762+
worker = func() error {
763+
_ = ctx // Conditional also uses ctx
764+
return nil
765+
}
766+
}
767+
return worker
768+
}
769+
g.Go(makeWorker()) // OK - all assignments use ctx
770+
_ = g.Wait()
771+
}
772+
687773
// ===== UNTRACEABLE PATTERNS (FUNCTION PARAMETERS) =====
688774

689775
// [LIMITATION]: Function from parameter - cannot trace

0 commit comments

Comments
 (0)