Skip to content

Commit 4c933a5

Browse files
mpywclaude
andcommitted
feat(ssa): enhance SSA value finding and context variable detection
Improvements to SSA infrastructure for future pattern integration: 1. findSSAValue now handles multiple expression types: - FuncLit: finds MakeClosure instructions - Ident: traces variable declarations to find assigned values - CallExpr: finds SSA Call instructions - SelectorExpr: finds FieldAddr/Field instructions - IndexExpr: finds IndexAddr/Index/Lookup instructions 2. GetContextVars now includes FreeVars: - Captures context variables from enclosing scopes - Essential for nested closure analysis 3. findEnclosingFunc uses syntax range: - Uses full AST syntax range for position checking - Handles cases where GoStmt position precedes FuncLit position Note: These improvements prepare the infrastructure for future migration to SSA-based patterns. Current analyzer still uses AST-based checkers. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent ef7bebb commit 4c933a5

File tree

3 files changed

+171
-11
lines changed

3 files changed

+171
-11
lines changed

internal/patterns/pattern.go

Lines changed: 146 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,170 @@ func (c *CheckContext) Report(pos token.Pos, msg string) {
2525
}
2626

2727
// findSSAValue finds the SSA value corresponding to an AST expression.
28+
// Handles: FuncLit, Ident, CallExpr, SelectorExpr, IndexExpr
2829
func (c *CheckContext) findSSAValue(fn *ssa.Function, expr ast.Expr) ssa.Value {
2930
if fn == nil || fn.Blocks == nil {
3031
return nil
3132
}
3233

33-
pos := expr.Pos()
34+
switch e := expr.(type) {
35+
case *ast.FuncLit:
36+
return c.findFuncLitValue(fn, e)
37+
case *ast.Ident:
38+
return c.findIdentValue(fn, e)
39+
case *ast.CallExpr:
40+
return c.findCallValue(fn, e)
41+
case *ast.SelectorExpr:
42+
return c.findSelectorValue(fn, e)
43+
case *ast.IndexExpr:
44+
return c.findIndexValue(fn, e)
45+
default:
46+
return c.findValueByPos(fn, expr.Pos())
47+
}
48+
}
3449

35-
// Search through all instructions in the function
50+
// findFuncLitValue finds the MakeClosure instruction for a function literal.
51+
func (c *CheckContext) findFuncLitValue(fn *ssa.Function, lit *ast.FuncLit) ssa.Value {
52+
pos := lit.Pos()
3653
for _, block := range fn.Blocks {
3754
for _, instr := range block.Instrs {
38-
// Check if this instruction's position matches
39-
if instr.Pos() == pos {
40-
if val, ok := instr.(ssa.Value); ok {
41-
return val
55+
if mc, ok := instr.(*ssa.MakeClosure); ok {
56+
if mc.Pos() == pos {
57+
return mc
4258
}
4359
}
60+
}
61+
}
62+
return nil
63+
}
64+
65+
// findIdentValue finds the SSA value for an identifier (variable reference).
66+
// For `g.Go(fn)` where `fn := func() {}`, finds the MakeClosure assigned to fn.
67+
func (c *CheckContext) findIdentValue(fn *ssa.Function, ident *ast.Ident) ssa.Value {
68+
obj := c.Pass.TypesInfo.ObjectOf(ident)
69+
if obj == nil {
70+
return nil
71+
}
72+
73+
v, ok := obj.(*types.Var)
74+
if !ok {
75+
return nil
76+
}
77+
78+
// Search for the value assigned to this variable
79+
// In SSA, we look for instructions that define a value at the variable's declaration position
80+
declPos := v.Pos()
4481

45-
// For MakeClosure, check the Fn position
82+
for _, block := range fn.Blocks {
83+
for _, instr := range block.Instrs {
84+
// Check MakeClosure at declaration position
4685
if mc, ok := instr.(*ssa.MakeClosure); ok {
47-
if mc.Pos() == pos {
86+
if mc.Pos() == declPos {
4887
return mc
4988
}
5089
}
90+
// Check Call result at declaration position
91+
if call, ok := instr.(*ssa.Call); ok {
92+
if call.Pos() == declPos {
93+
return call
94+
}
95+
}
96+
// Check any value-producing instruction at declaration position
97+
if val, ok := instr.(ssa.Value); ok {
98+
if instr.Pos() == declPos {
99+
return val
100+
}
101+
}
102+
}
103+
}
104+
105+
return nil
106+
}
107+
108+
// findCallValue finds the SSA Call instruction for a call expression.
109+
func (c *CheckContext) findCallValue(fn *ssa.Function, call *ast.CallExpr) ssa.Value {
110+
pos := call.Pos()
111+
// Try Fun position for method calls
112+
funPos := call.Fun.Pos()
113+
114+
for _, block := range fn.Blocks {
115+
for _, instr := range block.Instrs {
116+
if ssaCall, ok := instr.(*ssa.Call); ok {
117+
if ssaCall.Pos() == pos || ssaCall.Pos() == funPos {
118+
return ssaCall
119+
}
120+
}
121+
}
122+
}
123+
return nil
124+
}
125+
126+
// findSelectorValue finds the SSA value for a selector expression (field access).
127+
func (c *CheckContext) findSelectorValue(fn *ssa.Function, sel *ast.SelectorExpr) ssa.Value {
128+
pos := sel.Pos()
129+
selPos := sel.Sel.Pos()
130+
131+
for _, block := range fn.Blocks {
132+
for _, instr := range block.Instrs {
133+
// Check FieldAddr
134+
if fa, ok := instr.(*ssa.FieldAddr); ok {
135+
if fa.Pos() == pos || fa.Pos() == selPos {
136+
return fa
137+
}
138+
}
139+
// Check Field (for value types)
140+
if f, ok := instr.(*ssa.Field); ok {
141+
if f.Pos() == pos || f.Pos() == selPos {
142+
return f
143+
}
144+
}
145+
}
146+
}
147+
return nil
148+
}
149+
150+
// findIndexValue finds the SSA value for an index expression.
151+
func (c *CheckContext) findIndexValue(fn *ssa.Function, idx *ast.IndexExpr) ssa.Value {
152+
pos := idx.Pos()
153+
154+
for _, block := range fn.Blocks {
155+
for _, instr := range block.Instrs {
156+
// Check IndexAddr
157+
if ia, ok := instr.(*ssa.IndexAddr); ok {
158+
if ia.Pos() == pos {
159+
return ia
160+
}
161+
}
162+
// Check Index (for value types)
163+
if i, ok := instr.(*ssa.Index); ok {
164+
if i.Pos() == pos {
165+
return i
166+
}
167+
}
168+
// Check Lookup (for maps)
169+
if l, ok := instr.(*ssa.Lookup); ok {
170+
if l.Pos() == pos {
171+
return l
172+
}
173+
}
174+
}
175+
}
176+
return nil
177+
}
178+
179+
// findValueByPos finds any SSA value at the given position (fallback).
180+
func (c *CheckContext) findValueByPos(fn *ssa.Function, pos token.Pos) ssa.Value {
181+
for _, block := range fn.Blocks {
182+
for _, instr := range block.Instrs {
183+
if instr.Pos() == pos {
184+
if val, ok := instr.(ssa.Value); ok {
185+
return val
186+
}
187+
}
51188
}
52189
}
53190

54-
// Also check parameters and free variables
191+
// Check parameters
55192
for _, param := range fn.Params {
56193
if param.Pos() == pos {
57194
return param

internal/ssa/builder.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,14 @@ func (p *Program) EnclosingFunc(pos ast.Node) *ssa.Function {
6060
func (p *Program) findEnclosingFunc(fn *ssa.Function, pos ast.Node) *ssa.Function {
6161
// Check anonymous functions defined within this function
6262
for _, anon := range fn.AnonFuncs {
63-
if anon.Pos() <= pos.Pos() && pos.End() <= anon.Syntax().End() {
63+
syntax := anon.Syntax()
64+
if syntax == nil {
65+
continue
66+
}
67+
// Use the syntax's full range, not just anon.Pos()
68+
// anon.Pos() is the position of 'func' keyword, but for GoStmt
69+
// the position is the 'go' keyword which comes before 'func'
70+
if syntax.Pos() <= pos.Pos() && pos.End() <= syntax.End() {
6471
// Recursively check nested functions
6572
return p.findEnclosingFunc(anon, pos)
6673
}

internal/ssa/tracer.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,15 @@ func isContextType(t types.Type) bool {
307307
return obj.Pkg() != nil && obj.Pkg().Path() == "context" && obj.Name() == "Context"
308308
}
309309

310-
// GetContextVars returns all context.Context variables from function parameters.
310+
// GetContextVars returns all context.Context variables available in a function.
311+
// This includes both parameters and captured variables (FreeVars) from enclosing scopes.
311312
func GetContextVars(fn *ssa.Function) []*types.Var {
312313
var ctxVars []*types.Var
313314
if fn == nil {
314315
return ctxVars
315316
}
316317

318+
// Check parameters
317319
sig := fn.Signature
318320
params := sig.Params()
319321
for i := 0; i < params.Len(); i++ {
@@ -323,5 +325,19 @@ func GetContextVars(fn *ssa.Function) []*types.Var {
323325
}
324326
}
325327

328+
// Check captured variables (FreeVars) from enclosing scopes
329+
for _, fv := range fn.FreeVars {
330+
if isContextType(fv.Type()) {
331+
// FreeVar doesn't have an underlying types.Var directly accessible,
332+
// but we can create a synthetic Var for matching purposes.
333+
// The package is derived from the parent function.
334+
var pkg *types.Package
335+
if fn.Pkg != nil {
336+
pkg = fn.Pkg.Pkg
337+
}
338+
ctxVars = append(ctxVars, types.NewVar(fv.Pos(), pkg, fv.Name(), fv.Type()))
339+
}
340+
}
341+
326342
return ctxVars
327343
}

0 commit comments

Comments
 (0)