Skip to content

Commit bbbf0e0

Browse files
authored
feat: kill active connections when allow rule is deleted or denied (#3)
Track active proxy connections by the rule ID that authorized them using a new ConnTracker. When a rule is deleted or changed to deny, all connections authorized by that rule are cancelled immediately by closing both sides of the tunnel. Key changes: - ConnTracker maps rule IDs to cancellable connections - Bypass plugin writes authorizing rule ID into context - SOCKS5 and HTTP handlers register connections with ConnTracker - Pipe actively closes both connections on context cancellation - API and HTMX handlers call CancelByRule on rule delete/deny
1 parent 2037ff4 commit bbbf0e0

File tree

11 files changed

+377
-32
lines changed

11 files changed

+377
-32
lines changed

cmd/greyproxy/program.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ func (p *program) buildGreyproxyService() error {
262262
shared.Cache = tmpSvc.Cache
263263
shared.Bus = tmpSvc.Bus
264264
shared.Waiters = tmpSvc.Waiters
265+
shared.ConnTracker = greyproxy.NewConnTracker()
265266
shared.Version = version
266267

267268
// Collect listening ports for the health endpoint
@@ -288,7 +289,7 @@ func (p *program) buildGreyproxyService() error {
288289
// Create and register gost plugins
289290
autherPlugin := greyproxy_plugins.NewAuther()
290291
admissionPlugin := greyproxy_plugins.NewAdmission()
291-
bypassPlugin := greyproxy_plugins.NewBypass(shared.DB, shared.Cache, shared.Bus, shared.Waiters)
292+
bypassPlugin := greyproxy_plugins.NewBypass(shared.DB, shared.Cache, shared.Bus, shared.Waiters, shared.ConnTracker)
292293
resolverPlugin := greyproxy_plugins.NewResolver(shared.Cache)
293294

294295
registry.AutherRegistry().Register(gaCfg.Auther, autherPlugin)
@@ -302,7 +303,7 @@ func (p *program) buildGreyproxyService() error {
302303
// Build HTTP router with REST API + HTMX UI + WebSocket
303304
router, g := greyproxy_api.NewRouter(shared, gaCfg.PathPrefix)
304305
greyproxy_ui.RegisterPageRoutes(g, shared.DB, shared.Bus)
305-
greyproxy_ui.RegisterHTMXRoutes(g, shared.DB, shared.Bus, shared.Waiters)
306+
greyproxy_ui.RegisterHTMXRoutes(g, shared.DB, shared.Bus, shared.Waiters, shared.ConnTracker)
306307

307308
// Create the actual service
308309
svc := &greyproxy.Service{}

internal/gostx/ctx/value.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,30 @@ func ClientIDFromContext(ctx context.Context) ClientID {
9090
v, _ := ctx.Value(clientIDKey{}).(ClientID)
9191
return v
9292
}
93+
94+
// ConnCanceller allows registering connection cancel functions by rule ID.
95+
// Implemented by greyproxy.ConnTracker.
96+
type ConnCanceller interface {
97+
Register(ruleID int64, cancel context.CancelFunc) uint64
98+
Unregister(ruleID int64, id uint64)
99+
}
100+
101+
// BypassResult is a mutable container placed in the context before calling
102+
// bypass.Contains. The bypass plugin fills in the RuleID and Tracker when
103+
// a connection is allowed by a rule, so the handler can register the
104+
// connection for cancellation if the rule is later deleted.
105+
type BypassResult struct {
106+
RuleID int64
107+
Tracker ConnCanceller
108+
}
109+
110+
type bypassResultKey struct{}
111+
112+
func ContextWithBypassResult(ctx context.Context, result *BypassResult) context.Context {
113+
return context.WithValue(ctx, bypassResultKey{}, result)
114+
}
115+
116+
func BypassResultFromContext(ctx context.Context) *BypassResult {
117+
v, _ := ctx.Value(bypassResultKey{}).(*BypassResult)
118+
return v
119+
}

internal/gostx/handler/http/handler.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,17 +297,32 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
297297

298298
ctx = xctx.ContextWithClientID(ctx, xctx.ClientID(clientID))
299299

300-
if h.options.Bypass != nil &&
301-
h.options.Bypass.Contains(ctx, network, addr, bypass.WithService(h.options.Service)) {
302-
resp.StatusCode = http.StatusForbidden
300+
var bypassResult *xctx.BypassResult
301+
if h.options.Bypass != nil {
302+
bypassResult = &xctx.BypassResult{}
303+
checkCtx := xctx.ContextWithBypassResult(ctx, bypassResult)
304+
if h.options.Bypass.Contains(checkCtx, network, addr, bypass.WithService(h.options.Service)) {
305+
resp.StatusCode = http.StatusForbidden
306+
307+
if log.IsLevelEnabled(logger.TraceLevel) {
308+
dump, _ := httputil.DumpResponse(resp, false)
309+
log.Trace(string(dump))
310+
}
311+
log.Debug("bypass: ", addr)
312+
resp.Write(conn)
313+
return xbypass.ErrBypass
314+
}
303315

304-
if log.IsLevelEnabled(logger.TraceLevel) {
305-
dump, _ := httputil.DumpResponse(resp, false)
306-
log.Trace(string(dump))
316+
// Register this connection for cancellation if the rule is later deleted.
317+
if bypassResult.RuleID != 0 && bypassResult.Tracker != nil {
318+
var pipeCancel context.CancelFunc
319+
ctx, pipeCancel = context.WithCancel(ctx)
320+
connID := bypassResult.Tracker.Register(bypassResult.RuleID, pipeCancel)
321+
defer func() {
322+
bypassResult.Tracker.Unregister(bypassResult.RuleID, connID)
323+
pipeCancel()
324+
}()
307325
}
308-
log.Debug("bypass: ", addr)
309-
resp.Write(conn)
310-
return xbypass.ErrBypass
311326
}
312327

313328
if network == "udp" {

internal/gostx/handler/socks/v5/connect.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ
5656
}
5757

5858
if h.options.Bypass != nil {
59-
bypassCtx, bypassCancel := context.WithCancel(ctx)
59+
bypassResult := &xctx.BypassResult{}
60+
resultCtx := xctx.ContextWithBypassResult(ctx, bypassResult)
61+
bypassCtx, bypassCancel := context.WithCancel(resultCtx)
6062

6163
// Monitor the client TCP connection for close during the bypass check.
6264
// During SOCKS5 CONNECT, the client waits for the server reply before
@@ -96,6 +98,17 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ
9698
log.Debug("bypass: ", address)
9799
return resp.Write(conn)
98100
}
101+
102+
// Register this connection for cancellation if the rule is later deleted.
103+
if bypassResult.RuleID != 0 && bypassResult.Tracker != nil {
104+
var pipeCancel context.CancelFunc
105+
ctx, pipeCancel = context.WithCancel(ctx)
106+
connID := bypassResult.Tracker.Register(bypassResult.RuleID, pipeCancel)
107+
defer func() {
108+
bypassResult.Tracker.Unregister(bypassResult.RuleID, connID)
109+
pipeCancel()
110+
}()
111+
}
99112
}
100113

101114
switch h.md.hash {

internal/gostx/internal/net/pipe.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ func Pipe(ctx context.Context, rw1, rw2 io.ReadWriteCloser) error {
4343
select {
4444
case <-done:
4545
case <-ctx.Done():
46-
return nil
46+
// Close both sides to unblock the pipeBuffer goroutines.
47+
rw1.Close()
48+
rw2.Close()
49+
<-done
50+
return ctx.Err()
4751
}
4852

4953
select {

internal/greyproxy/api/router.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ import (
1010

1111
// Shared holds shared state passed to all handlers.
1212
type Shared struct {
13-
DB *greyproxy.DB
14-
Cache *greyproxy.DNSCache
15-
Bus *greyproxy.EventBus
16-
Waiters *greyproxy.WaiterTracker
17-
Version string
18-
Ports map[string]int
13+
DB *greyproxy.DB
14+
Cache *greyproxy.DNSCache
15+
Bus *greyproxy.EventBus
16+
Waiters *greyproxy.WaiterTracker
17+
ConnTracker *greyproxy.ConnTracker
18+
Version string
19+
Ports map[string]int
1920
}
2021

2122
// NewRouter creates the Gin router with all routes.

internal/greyproxy/api/rules.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package api
22

33
import (
4+
"log/slog"
45
"net/http"
56
"strconv"
67

@@ -81,6 +82,11 @@ func RulesUpdateHandler(s *Shared) gin.HandlerFunc {
8182
return
8283
}
8384

85+
// If the rule was changed to deny, cancel connections that relied on it.
86+
if rule.Action == "deny" && s.ConnTracker != nil {
87+
s.ConnTracker.CancelByRule(id)
88+
}
89+
8490
c.JSON(http.StatusOK, rule.ToJSON())
8591
}
8692
}
@@ -111,12 +117,19 @@ func RulesDeleteHandler(s *Shared) gin.HandlerFunc {
111117
return
112118
}
113119

120+
slog.Info("api: deleting rule", "rule_id", id)
121+
114122
deleted, err := greyproxy.DeleteRule(s.DB, id)
115123
if err != nil {
116124
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
117125
return
118126
}
119127

128+
slog.Info("api: rule deleted, cancelling connections", "rule_id", id, "deleted", deleted)
129+
if deleted && s.ConnTracker != nil {
130+
s.ConnTracker.CancelByRule(id)
131+
}
132+
120133
c.JSON(http.StatusOK, gin.H{"status": "ok", "deleted": deleted})
121134
}
122135
}

internal/greyproxy/conn_tracker.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package greyproxy
2+
3+
import (
4+
"context"
5+
"log/slog"
6+
"sync"
7+
"sync/atomic"
8+
)
9+
10+
var nextConnID atomic.Uint64
11+
12+
// ConnTracker tracks active proxy connections by the rule ID that authorized
13+
// them. When a rule is deleted or changed to deny, all connections that were
14+
// allowed by that rule can be cancelled immediately.
15+
type ConnTracker struct {
16+
mu sync.Mutex
17+
conns map[int64]map[uint64]context.CancelFunc
18+
}
19+
20+
func NewConnTracker() *ConnTracker {
21+
return &ConnTracker{
22+
conns: make(map[int64]map[uint64]context.CancelFunc),
23+
}
24+
}
25+
26+
// Register associates a cancel function with a rule ID and returns an ID
27+
// that can be used to unregister later.
28+
func (ct *ConnTracker) Register(ruleID int64, cancel context.CancelFunc) uint64 {
29+
id := nextConnID.Add(1)
30+
31+
ct.mu.Lock()
32+
defer ct.mu.Unlock()
33+
34+
if ct.conns[ruleID] == nil {
35+
ct.conns[ruleID] = make(map[uint64]context.CancelFunc)
36+
}
37+
ct.conns[ruleID][id] = cancel
38+
39+
slog.Info("conn_tracker: registered", "conn_id", id, "rule_id", ruleID, "total_for_rule", len(ct.conns[ruleID]))
40+
return id
41+
}
42+
43+
// Unregister removes a previously registered connection.
44+
// Called when a connection ends naturally.
45+
func (ct *ConnTracker) Unregister(ruleID int64, id uint64) {
46+
ct.mu.Lock()
47+
defer ct.mu.Unlock()
48+
49+
if m, ok := ct.conns[ruleID]; ok {
50+
delete(m, id)
51+
if len(m) == 0 {
52+
delete(ct.conns, ruleID)
53+
}
54+
slog.Info("conn_tracker: unregistered", "conn_id", id, "rule_id", ruleID)
55+
}
56+
}
57+
58+
// CancelByRule cancels all active connections that were authorized by the
59+
// given rule ID and removes them from tracking.
60+
func (ct *ConnTracker) CancelByRule(ruleID int64) {
61+
ct.mu.Lock()
62+
cancels := ct.conns[ruleID]
63+
delete(ct.conns, ruleID)
64+
ct.mu.Unlock()
65+
66+
if len(cancels) == 0 {
67+
slog.Info("conn_tracker: cancel by rule, no active connections", "rule_id", ruleID)
68+
return
69+
}
70+
71+
slog.Info("conn_tracker: cancel by rule, killing connections", "rule_id", ruleID, "count", len(cancels))
72+
for id, cancel := range cancels {
73+
slog.Info("conn_tracker: cancelling conn", "conn_id", id, "rule_id", ruleID)
74+
cancel()
75+
}
76+
}

0 commit comments

Comments
 (0)