Skip to content

Commit 001f2d9

Browse files
committed
refactor: improve code quality with LOW priority fixes
This commit addresses additional code quality improvements identified in the security audit. All tests pass after changes. **LOW Severity Improvements:** 1. Add SSE subscriber limits (prevents DoS) - MaxSSESubscribers constant set to 100 concurrent connections - Subscribe() now returns error when limit reached - Added logging when slow subscribers are disconnected - Added test for subscriber limit enforcement - Files: lib/httpapi/events.go, lib/httpapi/server.go, lib/httpapi/events_test.go 2. Extract magic numbers to named constants - Terminal screen stability constants: * ScreenStabilityCheckInterval = 16ms (~60 FPS) * ScreenStabilityRetries = 3 - Message parsing heuristics constants: * MaxUserInputPrefixRunesToMatch = 6 * MaxLinesToSearchForUserInput = 5 * DefaultPrefixRunesToSearch = 25 * MaxRuneLookahead = 5 - Improved code maintainability and documentation - Files: lib/termexec/termexec.go, lib/msgfmt/msgfmt.go 3. Add CORS/Host wildcard security warnings - Logs prominent warning when using '*' for allowed origins - Logs warning when using '*' for allowed hosts - Warns users to only use wildcards in development - File: lib/httpapi/server.go All existing tests pass (CGO_ENABLED=0 go test ./...). New subscriber limit test added to events_test.go.
1 parent 9e886c5 commit 001f2d9

File tree

5 files changed

+121
-33
lines changed

5 files changed

+121
-33
lines changed

lib/httpapi/events.go

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package httpapi
22

33
import (
4+
"fmt"
45
"log/slog"
56
"strings"
67
"sync"
@@ -12,6 +13,15 @@ import (
1213
"github.com/danielgtaylor/huma/v2"
1314
)
1415

16+
// SubscriberLimitError is returned when the maximum number of SSE subscribers is reached
17+
type SubscriberLimitError struct {
18+
Limit int
19+
}
20+
21+
func (e *SubscriberLimitError) Error() string {
22+
return fmt.Sprintf("subscriber limit reached: %d", e.Limit)
23+
}
24+
1525
type EventType string
1626

1727
const (
@@ -86,6 +96,13 @@ func convertStatus(status st.ConversationStatus) AgentStatus {
8696
// Listeners must actively drain the channel, so it's important to
8797
// set this to a value that is large enough to handle the expected
8898
// number of events.
99+
100+
const (
101+
// MaxSSESubscribers limits the number of concurrent SSE connections
102+
// to prevent resource exhaustion attacks
103+
MaxSSESubscribers = 100
104+
)
105+
89106
func NewEventEmitter(subscriptionBufSize int) *EventEmitter {
90107
return &EventEmitter{
91108
mu: sync.Mutex{},
@@ -115,6 +132,9 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) {
115132
default:
116133
// If the channel is full, close it.
117134
// Listeners must actively drain the channel.
135+
slog.Warn("Closing slow SSE subscriber - channel buffer full",
136+
"subscriberId", chanId,
137+
"bufferSize", e.subscriptionBufSize)
118138
e.unsubscribeInner(chanId)
119139
}
120140
}
@@ -198,16 +218,26 @@ func (e *EventEmitter) currentStateAsEvents() []Event {
198218
// - a subscription ID that can be used to unsubscribe.
199219
// - a channel for receiving events.
200220
// - a list of events that allow to recreate the state of the conversation right before the subscription was created.
201-
func (e *EventEmitter) Subscribe() (int, <-chan Event, []Event) {
221+
// - an error if the maximum number of subscribers has been reached.
222+
func (e *EventEmitter) Subscribe() (int, <-chan Event, []Event, error) {
202223
e.mu.Lock()
203224
defer e.mu.Unlock()
225+
226+
// Check subscriber limit to prevent resource exhaustion
227+
if len(e.chans) >= MaxSSESubscribers {
228+
slog.Warn("SSE subscriber limit reached - rejecting new connection",
229+
"limit", MaxSSESubscribers,
230+
"current", len(e.chans))
231+
return 0, nil, nil, &SubscriberLimitError{Limit: MaxSSESubscribers}
232+
}
233+
204234
stateEvents := e.currentStateAsEvents()
205235

206236
// Once a channel becomes full, it will be closed.
207237
ch := make(chan Event, e.subscriptionBufSize)
208238
e.chans[e.chanIdx] = ch
209239
e.chanIdx++
210-
return e.chanIdx - 1, ch, stateEvents
240+
return e.chanIdx - 1, ch, stateEvents, nil
211241
}
212242

213243
// Assumes the caller holds the lock.

lib/httpapi/events_test.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ import (
1212
func TestEventEmitter(t *testing.T) {
1313
t.Run("single-subscription", func(t *testing.T) {
1414
emitter := NewEventEmitter(10)
15-
_, ch, stateEvents := emitter.Subscribe()
15+
_, ch, stateEvents, err := emitter.Subscribe()
16+
assert.NoError(t, err)
1617
assert.Empty(t, ch)
1718
assert.Equal(t, []Event{
1819
{
@@ -63,7 +64,8 @@ func TestEventEmitter(t *testing.T) {
6364
emitter := NewEventEmitter(10)
6465
channels := make([]<-chan Event, 0, 10)
6566
for i := 0; i < 10; i++ {
66-
_, ch, _ := emitter.Subscribe()
67+
_, ch, _, err := emitter.Subscribe()
68+
assert.NoError(t, err)
6769
channels = append(channels, ch)
6870
}
6971
now := time.Now()
@@ -82,7 +84,8 @@ func TestEventEmitter(t *testing.T) {
8284

8385
t.Run("close-channel", func(t *testing.T) {
8486
emitter := NewEventEmitter(1)
85-
_, ch, _ := emitter.Subscribe()
87+
_, ch, _, err := emitter.Subscribe()
88+
assert.NoError(t, err)
8689
for i := range 5 {
8790
emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{
8891
{Id: i, Message: fmt.Sprintf("Hello, world! %d", i), Role: st.ConversationRoleUser, Time: time.Now()},
@@ -97,4 +100,17 @@ func TestEventEmitter(t *testing.T) {
97100
t.Fatalf("read should not block")
98101
}
99102
})
103+
104+
t.Run("subscriber-limit", func(t *testing.T) {
105+
emitter := NewEventEmitter(10)
106+
// Subscribe up to the limit
107+
for i := 0; i < MaxSSESubscribers; i++ {
108+
_, _, _, err := emitter.Subscribe()
109+
assert.NoError(t, err, "subscription %d should succeed", i)
110+
}
111+
// Next subscription should fail
112+
_, _, _, err := emitter.Subscribe()
113+
assert.Error(t, err)
114+
assert.IsType(t, &SubscriberLimitError{}, err)
115+
})
100116
}

lib/httpapi/server.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,13 @@ type ServerConfig struct {
8181
// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
8282
// Viper/Cobra use different separators (space for env vars, comma for flags),
8383
// so these characters likely indicate user error.
84-
func parseAllowedHosts(input []string) ([]string, error) {
84+
func parseAllowedHosts(input []string, logger *slog.Logger) ([]string, error) {
8585
if len(input) == 0 {
8686
return nil, fmt.Errorf("the list must not be empty")
8787
}
8888
if slices.Contains(input, "*") {
89+
logger.Warn("⚠️ SECURITY WARNING: Host wildcard '*' allows requests from ANY host",
90+
"recommendation", "Only use '*' in development. In production, specify exact hosts.")
8991
return []string{"*"}, nil
9092
}
9193
// First pass: whitespace & comma checks (surface these errors first)
@@ -131,11 +133,13 @@ func parseAllowedHosts(input []string) ([]string, error) {
131133
}
132134

133135
// Validate allowed origins
134-
func parseAllowedOrigins(input []string) ([]string, error) {
136+
func parseAllowedOrigins(input []string, logger *slog.Logger) ([]string, error) {
135137
if len(input) == 0 {
136138
return nil, fmt.Errorf("the list must not be empty")
137139
}
138140
if slices.Contains(input, "*") {
141+
logger.Warn("⚠️ SECURITY WARNING: CORS wildcard '*' allows requests from ANY website",
142+
"recommendation", "Only use '*' in development. In production, specify exact origins.")
139143
return []string{"*"}, nil
140144
}
141145
// Viper/Cobra use different separators (space for env vars, comma for flags),
@@ -168,11 +172,11 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
168172

169173
logger := logctx.From(ctx)
170174

171-
allowedHosts, err := parseAllowedHosts(config.AllowedHosts)
175+
allowedHosts, err := parseAllowedHosts(config.AllowedHosts, logger)
172176
if err != nil {
173177
return nil, xerrors.Errorf("failed to parse allowed hosts: %w", err)
174178
}
175-
allowedOrigins, err := parseAllowedOrigins(config.AllowedOrigins)
179+
allowedOrigins, err := parseAllowedOrigins(config.AllowedOrigins, logger)
176180
if err != nil {
177181
return nil, xerrors.Errorf("failed to parse allowed origins: %w", err)
178182
}
@@ -404,7 +408,13 @@ func (s *Server) createMessage(ctx context.Context, input *MessageRequest) (*Mes
404408

405409
// subscribeEvents is an SSE endpoint that sends events to the client
406410
func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse.Sender) {
407-
subscriberId, ch, stateEvents := s.emitter.Subscribe()
411+
subscriberId, ch, stateEvents, err := s.emitter.Subscribe()
412+
if err != nil {
413+
s.logger.Error("Failed to subscribe", "error", err)
414+
// Send error to client and close connection
415+
_ = send.Data(map[string]string{"error": err.Error()})
416+
return
417+
}
408418
defer s.emitter.Unsubscribe(subscriberId)
409419
s.logger.Info("New subscriber", "subscriberId", subscriberId)
410420
for _, event := range stateEvents {
@@ -438,7 +448,13 @@ func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse.
438448
}
439449

440450
func (s *Server) subscribeScreen(ctx context.Context, input *struct{}, send sse.Sender) {
441-
subscriberId, ch, stateEvents := s.emitter.Subscribe()
451+
subscriberId, ch, stateEvents, err := s.emitter.Subscribe()
452+
if err != nil {
453+
s.logger.Error("Failed to subscribe to screen", "error", err)
454+
// Send error to client and close connection
455+
_ = send.Data(map[string]string{"error": err.Error()})
456+
return
457+
}
442458
defer s.emitter.Unsubscribe(subscriberId)
443459
s.logger.Info("New screen subscriber", "subscriberId", subscriberId)
444460
for _, event := range stateEvents {

lib/msgfmt/msgfmt.go

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,26 @@ import (
44
"strings"
55
)
66

7-
const WhiteSpaceChars = " \t\n\r\f\v"
7+
const (
8+
// WhiteSpaceChars defines all whitespace characters for trimming
9+
WhiteSpaceChars = " \t\n\r\f\v"
10+
11+
// Message parsing heuristics - these are empirically determined values
12+
// based on how different agents echo user input back to the terminal
13+
14+
// MaxUserInputPrefixRunesToMatch is how many runes from the first line of user input
15+
// we use to search for the echoed input in the agent's message
16+
MaxUserInputPrefixRunesToMatch = 6
17+
// MaxLinesToSearchForUserInput is how many lines of the agent message we search
18+
// for the echoed user input (user input is typically at the start)
19+
MaxLinesToSearchForUserInput = 5
20+
// DefaultPrefixRunesToSearch is the minimum number of runes to search in the message
21+
// when looking for user input, regardless of line count
22+
DefaultPrefixRunesToSearch = 25
23+
// MaxRuneLookahead is how many runes ahead we look when trying to match
24+
// the next rune between message and user input (handles character omission/insertion)
25+
MaxRuneLookahead = 5
26+
)
827

928
func TrimWhitespace(msg string) string {
1029
return strings.Trim(msg, WhiteSpaceChars)
@@ -58,17 +77,16 @@ func normalizeAndGetRuneLineMapping(msgRaw string) ([]rune, []string, []int) {
5877

5978
// Find where the user input starts in the message
6079
func findUserInputStartIdx(msg []rune, msgRuneLineLocations []int, userInput []rune, userInputLineLocations []int) int {
61-
// We take up to 6 runes from the first line of the user input
62-
// and search for it in the message. 6 is arbitrary.
80+
// We take up to MaxUserInputPrefixRunesToMatch runes from the first line of the user input
81+
// and search for it in the message.
6382
// We only look at the first line to avoid running into user input
6483
// being broken up by UI elements.
65-
maxUserInputPrefixLen := 6
6684
userInputPrefixLen := -1
6785
for i, lineIdx := range userInputLineLocations {
6886
if lineIdx > 0 {
6987
break
7088
}
71-
if i >= maxUserInputPrefixLen {
89+
if i >= MaxUserInputPrefixRunesToMatch {
7290
break
7391
}
7492
userInputPrefixLen = i + 1
@@ -78,20 +96,18 @@ func findUserInputStartIdx(msg []rune, msgRuneLineLocations []int, userInput []r
7896
}
7997
userInputPrefix := userInput[:userInputPrefixLen]
8098

81-
// We'll only search the first 5 lines or 25 runes of the message,
82-
// whichever has more runes. This number is arbitrary. The intuition
83-
// is that user input is echoed back at the start of the message. The first
84-
// line or two may contain some UI elements.
99+
// We'll only search the first MaxLinesToSearchForUserInput lines or DefaultPrefixRunesToSearch runes
100+
// of the message, whichever has more runes. The intuition is that user input is echoed back at
101+
// the start of the message. The first line or two may contain some UI elements.
85102
msgPrefixLen := 0
86103
for i, lineIdx := range msgRuneLineLocations {
87-
if lineIdx > 5 {
104+
if lineIdx > MaxLinesToSearchForUserInput {
88105
break
89106
}
90107
msgPrefixLen = i + 1
91108
}
92-
defaultRunesFromMsg := 25
93-
if msgPrefixLen < defaultRunesFromMsg {
94-
msgPrefixLen = defaultRunesFromMsg
109+
if msgPrefixLen < DefaultPrefixRunesToSearch {
110+
msgPrefixLen = DefaultPrefixRunesToSearch
95111
}
96112
if msgPrefixLen > len(msg) {
97113
msgPrefixLen = len(msg)
@@ -105,11 +121,11 @@ func findUserInputStartIdx(msg []rune, msgRuneLineLocations []int, userInput []r
105121
// We're assuming that user input likely won't be truncated much,
106122
// but it's likely some characters will be missing (e.g. OpenAI Codex strips
107123
// "```" and instead formats enclosed text as a code block).
108-
// We're going to see if any of the next 5 runes in the message
109-
// match any of the next 5 runes in the user input.
124+
// We're going to see if any of the next MaxRuneLookahead runes in the message
125+
// match any of the next MaxRuneLookahead runes in the user input.
110126
func findNextMatch(knownMsgMatchIdx int, knownUserInputMatchIdx int, msg []rune, userInput []rune) (int, int) {
111-
for i := range 5 {
112-
for j := range 5 {
127+
for i := range MaxRuneLookahead {
128+
for j := range MaxRuneLookahead {
113129
userInputIdx := knownUserInputMatchIdx + i + 1
114130
msgIdx := knownMsgMatchIdx + j + 1
115131

lib/termexec/termexec.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ import (
1717
"golang.org/x/xerrors"
1818
)
1919

20+
const (
21+
// ScreenStabilityCheckInterval is how often we check if the terminal screen has stabilized
22+
// This corresponds to ~60 FPS checking rate
23+
ScreenStabilityCheckInterval = 16 * time.Millisecond
24+
// ScreenStabilityRetries is how many times we retry reading the screen
25+
// before giving up on waiting for stability
26+
ScreenStabilityRetries = 3
27+
)
28+
2029
type Process struct {
2130
xp *xpty.Xpty
2231
execCmd *exec.Cmd
@@ -123,24 +132,25 @@ func (p *Process) Signal(sig os.Signal) error {
123132
}
124133

125134
// ReadScreen returns the contents of the terminal window.
126-
// It waits for the terminal to be stable for 16ms before
127-
// returning, or 48 ms since it's called, whichever is sooner.
135+
// It waits for the terminal to be stable for ScreenStabilityCheckInterval before
136+
// returning, or up to ScreenStabilityRetries * ScreenStabilityCheckInterval total,
137+
// whichever is sooner.
128138
//
129139
// This logic acts as a kind of vsync. Agents regularly redraw
130140
// parts of the screen. If we naively snapshotted the screen,
131141
// we'd often capture it while it's being updated. This would
132142
// result in a malformed agent message being returned to the
133143
// user.
134144
func (p *Process) ReadScreen() string {
135-
for range 3 {
145+
for range ScreenStabilityRetries {
136146
p.screenUpdateLock.RLock()
137147
timeSinceUpdate := time.Since(p.lastScreenUpdate)
138148
p.screenUpdateLock.RUnlock()
139149

140-
if timeSinceUpdate >= 16*time.Millisecond {
150+
if timeSinceUpdate >= ScreenStabilityCheckInterval {
141151
break
142152
}
143-
time.Sleep(16 * time.Millisecond)
153+
time.Sleep(ScreenStabilityCheckInterval)
144154
}
145155

146156
// Always read with lock held to prevent race condition

0 commit comments

Comments
 (0)