Skip to content

Commit f00076e

Browse files
mpywclaude
andcommitted
feat(ssa): add SSA-based deriver detection with IIFE/defer distinction
This commit adds SSA-based deriver function detection that: - Traverses into immediately-invoked function expressions (IIFE) - Distinguishes between regular calls and defer statements - Reports a special error message when deriver is only in defer Key changes: - Add CheckDeriverCalls to SSA tracer with DeriverResult struct - Update GoStmtPattern interface to return GoStmtResult with DeferOnly flag - Add DeferMessage method for defer-specific error messages - Former PARTIAL test cases now pass as GOOD (SSA detects IIFE correctly) Behavioral changes: - Deriver calls in IIFE are now correctly detected - Deriver calls only in defer will report: "goroutine calls X in defer, but it should be called at goroutine start" TODO: Add test case for defer-only scenario 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 05163d9 commit f00076e

File tree

9 files changed

+323
-53
lines changed

9 files changed

+323
-53
lines changed

internal/checker/checker.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,20 @@ func (c *Checker) checkGoStmt(cctx *patterns.CheckContext, stmt *ast.GoStmt, sco
106106
continue
107107
}
108108

109-
if !pattern.CheckGoStmt(cctx, stmt) {
110-
msg := pattern.Message(scope.ctxName())
109+
result := pattern.CheckGoStmt(cctx, stmt)
110+
if result.OK {
111+
continue
112+
}
113+
114+
var msg string
115+
if result.DeferOnly {
116+
// Deriver found but only in defer - use special message
117+
msg = pattern.DeferMessage(scope.ctxName())
118+
} else {
119+
msg = pattern.Message(scope.ctxName())
120+
}
121+
122+
if msg != "" {
111123
cctx.Pass.Reportf(stmt.Pos(), "%s", msg)
112124
}
113125
}

internal/patterns/gostmt.go

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,29 @@ import (
77
"github.com/mpyw/goroutinectx/internal/directives/deriver"
88
)
99

10+
// GoStmtResult represents the result of a go statement pattern check.
11+
type GoStmtResult struct {
12+
// OK indicates the pattern is satisfied (no error).
13+
OK bool
14+
// DeferOnly indicates the deriver was found but only in defer statements.
15+
// This is only relevant for deriver patterns.
16+
DeferOnly bool
17+
}
18+
1019
// GoStmtPattern defines the interface for go statement patterns.
1120
type GoStmtPattern interface {
1221
// Name returns a human-readable name for the pattern.
1322
Name() string
1423

1524
// CheckGoStmt checks if the pattern is satisfied for the given go statement.
16-
// Returns true if the pattern is satisfied (no error).
17-
CheckGoStmt(cctx *CheckContext, stmt *ast.GoStmt) bool
25+
CheckGoStmt(cctx *CheckContext, stmt *ast.GoStmt) GoStmtResult
1826

1927
// Message returns the diagnostic message when the pattern is violated.
2028
Message(ctxName string) string
29+
30+
// DeferMessage returns the diagnostic message when deriver is only in defer.
31+
// Returns empty string if not applicable.
32+
DeferMessage(ctxName string) string
2133
}
2234

2335
// GoStmtCapturesCtx checks that a go statement's closure captures the outer context.
@@ -27,19 +39,19 @@ func (*GoStmtCapturesCtx) Name() string {
2739
return "GoStmtCapturesCtx"
2840
}
2941

30-
func (*GoStmtCapturesCtx) CheckGoStmt(cctx *CheckContext, stmt *ast.GoStmt) bool {
42+
func (*GoStmtCapturesCtx) CheckGoStmt(cctx *CheckContext, stmt *ast.GoStmt) GoStmtResult {
3143
// If no context names in scope (from AST), nothing to check
3244
if len(cctx.CtxNames) == 0 {
33-
return true
45+
return GoStmtResult{OK: true}
3446
}
3547

3648
// Try SSA-based check first (more accurate, includes nested closures)
3749
if result, ok := checkGoStmtFromSSA(cctx, stmt); ok {
38-
return result
50+
return GoStmtResult{OK: result}
3951
}
4052

4153
// Fall back to AST-based check when SSA fails
42-
return checkGoStmtFromAST(cctx, stmt)
54+
return GoStmtResult{OK: checkGoStmtFromAST(cctx, stmt)}
4355
}
4456

4557
// checkGoStmtFromSSA uses SSA analysis to check if a goroutine captures context.
@@ -74,6 +86,10 @@ func (*GoStmtCapturesCtx) Message(ctxName string) string {
7486
return "goroutine does not propagate context \"" + ctxName + "\""
7587
}
7688

89+
func (*GoStmtCapturesCtx) DeferMessage(_ string) string {
90+
return "" // Not applicable for context capture pattern
91+
}
92+
7793
// checkGoStmtFromAST falls back to AST-based analysis for go statements.
7894
func checkGoStmtFromAST(cctx *CheckContext, stmt *ast.GoStmt) bool {
7995
call := stmt.Call
@@ -263,9 +279,9 @@ func (*GoStmtCallsDeriver) Name() string {
263279
return "GoStmtCallsDeriver"
264280
}
265281

266-
func (p *GoStmtCallsDeriver) CheckGoStmt(cctx *CheckContext, stmt *ast.GoStmt) bool {
282+
func (p *GoStmtCallsDeriver) CheckGoStmt(cctx *CheckContext, stmt *ast.GoStmt) GoStmtResult {
267283
if p.Matcher == nil || p.Matcher.IsEmpty() {
268-
return true // No deriver configured
284+
return GoStmtResult{OK: true} // No deriver configured
269285
}
270286

271287
call := stmt.Call
@@ -274,22 +290,54 @@ func (p *GoStmtCallsDeriver) CheckGoStmt(cctx *CheckContext, stmt *ast.GoStmt) b
274290
if lit, ok := call.Fun.(*ast.FuncLit); ok {
275291
// Skip if closure has its own context parameter
276292
if funcLitHasContextParam(cctx, lit) {
277-
return true
293+
return GoStmtResult{OK: true}
278294
}
279-
return p.Matcher.SatisfiesAnyGroup(cctx.Pass, lit.Body)
295+
296+
// Try SSA-based check first (detects IIFE, distinguishes defer)
297+
if result, ok := p.checkDeriverFromSSA(cctx, lit); ok {
298+
return result
299+
}
300+
301+
// Fall back to AST-based check
302+
return GoStmtResult{OK: p.Matcher.SatisfiesAnyGroup(cctx.Pass, lit.Body)}
280303
}
281304

282305
// For go fn()() (higher-order), check the factory function
283306
if innerCall, ok := call.Fun.(*ast.CallExpr); ok {
284-
return p.checkHigherOrderDeriver(cctx, innerCall)
307+
return GoStmtResult{OK: p.checkHigherOrderDeriver(cctx, innerCall)}
285308
}
286309

287310
// For go fn() where fn is an identifier, trace the variable
288311
if ident, ok := call.Fun.(*ast.Ident); ok {
289-
return p.checkIdentDeriver(cctx, ident)
312+
return GoStmtResult{OK: p.checkIdentDeriver(cctx, ident)}
290313
}
291314

292-
return true // Can't analyze, assume OK
315+
return GoStmtResult{OK: true} // Can't analyze, assume OK
316+
}
317+
318+
// checkDeriverFromSSA uses SSA analysis to check deriver calls.
319+
// Returns (result, true) if SSA analysis succeeded, or (GoStmtResult{}, false) if it failed.
320+
func (p *GoStmtCallsDeriver) checkDeriverFromSSA(cctx *CheckContext, lit *ast.FuncLit) (GoStmtResult, bool) {
321+
if cctx.SSAProg == nil || cctx.Tracer == nil {
322+
return GoStmtResult{}, false
323+
}
324+
325+
ssaFn := cctx.SSAProg.FindFuncLit(lit)
326+
if ssaFn == nil {
327+
return GoStmtResult{}, false
328+
}
329+
330+
result := cctx.Tracer.CheckDeriverCalls(ssaFn, p.Matcher)
331+
332+
if result.FoundAtStart {
333+
return GoStmtResult{OK: true}, true
334+
}
335+
336+
if result.FoundOnlyInDefer {
337+
return GoStmtResult{OK: false, DeferOnly: true}, true
338+
}
339+
340+
return GoStmtResult{OK: false}, true
293341
}
294342

295343
// checkIdentDeriver checks go fn() patterns where fn is a variable.
@@ -425,3 +473,7 @@ func (p *GoStmtCallsDeriver) returnedValueCallsDeriver(cctx *CheckContext, resul
425473
func (p *GoStmtCallsDeriver) Message(_ string) string {
426474
return "goroutine should call " + p.Matcher.Original + " to derive context"
427475
}
476+
477+
func (p *GoStmtCallsDeriver) DeferMessage(_ string) string {
478+
return "goroutine calls " + p.Matcher.Original + " in defer, but it should be called at goroutine start"
479+
}

internal/ssa/tracer.go

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"golang.org/x/tools/go/ssa"
99

1010
"github.com/mpyw/goroutinectx/internal/directives/carrier"
11+
"github.com/mpyw/goroutinectx/internal/directives/deriver"
1112
"github.com/mpyw/goroutinectx/internal/typeutil"
1213
)
1314

@@ -369,3 +370,211 @@ func GetContextVars(fn *ssa.Function) []*types.Var {
369370

370371
return ctxVars
371372
}
373+
374+
// =============================================================================
375+
// Deriver Function Detection
376+
// =============================================================================
377+
378+
// DeriverResult represents the result of deriver function detection.
379+
type DeriverResult struct {
380+
// FoundAtStart indicates the deriver is called at goroutine start (not in defer)
381+
FoundAtStart bool
382+
// FoundOnlyInDefer indicates the deriver is called, but only in defer statements
383+
FoundOnlyInDefer bool
384+
}
385+
386+
// CheckDeriverCalls checks if a closure calls any of the required deriver functions.
387+
// It traverses into immediately-invoked function expressions (IIFE) but tracks
388+
// whether calls are made in defer statements.
389+
func (t *Tracer) CheckDeriverCalls(closure *ssa.Function, matcher *deriver.Matcher) DeriverResult {
390+
if closure == nil || matcher == nil || matcher.IsEmpty() {
391+
return DeriverResult{FoundAtStart: true} // No check needed
392+
}
393+
394+
// Collect all function calls with their defer status
395+
calls := t.collectDeriverCalls(closure, false, make(map[*ssa.Function]bool))
396+
397+
// Check if any OR group is satisfied
398+
for _, andGroup := range matcher.OrGroups {
399+
startResult := t.checkAndGroup(calls, andGroup, false)
400+
if startResult {
401+
return DeriverResult{FoundAtStart: true}
402+
}
403+
}
404+
405+
// Check if deriver is only in defer
406+
for _, andGroup := range matcher.OrGroups {
407+
deferResult := t.checkAndGroup(calls, andGroup, true)
408+
if deferResult {
409+
return DeriverResult{FoundOnlyInDefer: true}
410+
}
411+
}
412+
413+
return DeriverResult{}
414+
}
415+
416+
// deriverCall represents a function call with its context (defer or not).
417+
type deriverCall struct {
418+
fn *types.Func
419+
inDefer bool
420+
}
421+
422+
// collectDeriverCalls collects all function calls in a closure, including IIFE.
423+
func (t *Tracer) collectDeriverCalls(fn *ssa.Function, inDefer bool, visited map[*ssa.Function]bool) []deriverCall {
424+
if fn == nil || visited[fn] {
425+
return nil
426+
}
427+
visited[fn] = true
428+
429+
var calls []deriverCall
430+
431+
for _, block := range fn.Blocks {
432+
for _, instr := range block.Instrs {
433+
switch v := instr.(type) {
434+
case *ssa.Call:
435+
// Regular function call
436+
if calledFn := t.extractCalledFunc(v); calledFn != nil {
437+
calls = append(calls, deriverCall{fn: calledFn, inDefer: inDefer})
438+
}
439+
// Check for IIFE: call where the callee is a MakeClosure
440+
if iifeFn := t.extractIIFE(v); iifeFn != nil {
441+
// Traverse into the IIFE with the same defer status
442+
calls = append(calls, t.collectDeriverCalls(iifeFn, inDefer, visited)...)
443+
}
444+
445+
case *ssa.Defer:
446+
// Deferred function call - mark as inDefer
447+
if calledFn := t.extractCalledFuncFromCallCommon(&v.Call); calledFn != nil {
448+
calls = append(calls, deriverCall{fn: calledFn, inDefer: true})
449+
}
450+
// Check for deferred IIFE
451+
if iifeFn := t.extractIIFEFromCallCommon(&v.Call); iifeFn != nil {
452+
// Traverse into the deferred IIFE with inDefer=true
453+
calls = append(calls, t.collectDeriverCalls(iifeFn, true, visited)...)
454+
}
455+
}
456+
}
457+
}
458+
459+
return calls
460+
}
461+
462+
// extractCalledFunc extracts the types.Func from a Call instruction.
463+
func (t *Tracer) extractCalledFunc(call *ssa.Call) *types.Func {
464+
return t.extractCalledFuncFromCallCommon(&call.Call)
465+
}
466+
467+
// extractCalledFuncFromCallCommon extracts the types.Func from a CallCommon.
468+
func (t *Tracer) extractCalledFuncFromCallCommon(call *ssa.CallCommon) *types.Func {
469+
if call.IsInvoke() {
470+
// Interface method call
471+
return call.Method
472+
}
473+
474+
// Static call
475+
if fn := call.StaticCallee(); fn != nil {
476+
if obj, ok := fn.Object().(*types.Func); ok {
477+
return obj
478+
}
479+
}
480+
481+
return nil
482+
}
483+
484+
// extractIIFE checks if a Call instruction is an IIFE (immediately invoked function expression).
485+
// Returns the called function if it's an IIFE, nil otherwise.
486+
func (t *Tracer) extractIIFE(call *ssa.Call) *ssa.Function {
487+
return t.extractIIFEFromCallCommon(&call.Call)
488+
}
489+
490+
// extractIIFEFromCallCommon checks if a CallCommon is an IIFE.
491+
func (t *Tracer) extractIIFEFromCallCommon(call *ssa.CallCommon) *ssa.Function {
492+
if call.IsInvoke() {
493+
return nil
494+
}
495+
496+
// Check if the callee is a MakeClosure
497+
if mc, ok := call.Value.(*ssa.MakeClosure); ok {
498+
if fn, ok := mc.Fn.(*ssa.Function); ok {
499+
return fn
500+
}
501+
}
502+
503+
// Check if the callee is a direct function reference
504+
if fn, ok := call.Value.(*ssa.Function); ok {
505+
// Only count as IIFE if it's an anonymous function (has no name in package scope)
506+
if fn.Parent() != nil {
507+
return fn
508+
}
509+
}
510+
511+
return nil
512+
}
513+
514+
// checkAndGroup checks if all specs in an AND group are satisfied.
515+
// If includeDefer is false, only non-defer calls are considered.
516+
// If includeDefer is true, all calls (including defer) are considered.
517+
func (t *Tracer) checkAndGroup(calls []deriverCall, andGroup []deriver.FuncSpec, includeDefer bool) bool {
518+
for _, spec := range andGroup {
519+
found := false
520+
for _, call := range calls {
521+
if !includeDefer && call.inDefer {
522+
continue
523+
}
524+
if t.matchesSpec(call.fn, spec) {
525+
found = true
526+
break
527+
}
528+
}
529+
if !found {
530+
return false
531+
}
532+
}
533+
return true
534+
}
535+
536+
// matchesSpec checks if a types.Func matches the given deriver function spec.
537+
func (t *Tracer) matchesSpec(fn *types.Func, spec deriver.FuncSpec) bool {
538+
if fn == nil {
539+
return false
540+
}
541+
542+
if fn.Name() != spec.FuncName {
543+
return false
544+
}
545+
546+
if fn.Pkg() == nil || fn.Pkg().Path() != spec.PkgPath {
547+
return false
548+
}
549+
550+
if spec.TypeName != "" {
551+
return t.matchesMethod(fn, spec.TypeName)
552+
}
553+
554+
return true
555+
}
556+
557+
// matchesMethod checks if a types.Func is a method on the expected type.
558+
func (t *Tracer) matchesMethod(fn *types.Func, typeName string) bool {
559+
sig, ok := fn.Type().(*types.Signature)
560+
if !ok {
561+
return false
562+
}
563+
564+
recv := sig.Recv()
565+
if recv == nil {
566+
return false
567+
}
568+
569+
recvType := recv.Type()
570+
if ptr, ok := recvType.(*types.Pointer); ok {
571+
recvType = ptr.Elem()
572+
}
573+
574+
named, ok := recvType.(*types.Named)
575+
if !ok {
576+
return false
577+
}
578+
579+
return named.Obj().Name() == typeName
580+
}

0 commit comments

Comments
 (0)