Skip to content

Commit 1342a73

Browse files
lazypowerclaude
andcommitted
Address Copilot review on #8
- Fix middleware ordering: securityHeaders before localhostOnly so 403s get headers - Normalize Host header: case-insensitive, handle bracketed IPv6 [::1], trim trailing dots - Return JSON from localhostOnly rejection (was plaintext) - Migrate all remaining http.Error calls to jsonError for consistent JSON responses - Rename maxToolResponseSize → maxToolFieldSize (now covers both fields) - Add tool_input truncation assertion to TestAddObservationTruncation - Remove stderr output from hardenPermissions (library code, silent chmod) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 811901d commit 1342a73

6 files changed

Lines changed: 48 additions & 36 deletions

File tree

internal/server/middleware.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,31 @@ package server
33
import (
44
"net"
55
"net/http"
6+
"strings"
67
)
78

89
const maxRequestBody = 1 << 20 // 1MB
910

11+
// normalizeHost extracts and normalizes the hostname from a Host header.
12+
// Handles ports, bracketed IPv6, case folding, and trailing dots.
13+
func normalizeHost(host string) string {
14+
if h, _, err := net.SplitHostPort(host); err == nil {
15+
host = h
16+
} else if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
17+
host = host[1 : len(host)-1]
18+
}
19+
host = strings.ToLower(host)
20+
host = strings.TrimSuffix(host, ".")
21+
return host
22+
}
23+
1024
// localhostOnly rejects requests where the Host header is not localhost.
1125
// Prevents DNS rebinding attacks against the local API server.
1226
func localhostOnly(next http.Handler) http.Handler {
1327
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
14-
host := r.Host
15-
if h, _, err := net.SplitHostPort(host); err == nil {
16-
host = h
17-
}
28+
host := normalizeHost(r.Host)
1829
if host != "localhost" && host != "127.0.0.1" && host != "::1" {
19-
http.Error(w, "Forbidden", http.StatusForbidden)
30+
jsonError(w, "forbidden", http.StatusForbidden)
2031
return
2132
}
2233
next.ServeHTTP(w, r)

internal/server/routes.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ import (
1313
"github.com/lazypower/continuity/internal/engine"
1414
)
1515

16-
// jsonError writes a JSON error response. All error responses should use this
17-
// to avoid JSON injection via string concatenation.
16+
// jsonError writes a JSON error response with proper Content-Type and encoding.
17+
// Prefer this over http.Error for consistent JSON responses.
1818
func jsonError(w http.ResponseWriter, msg string, code int) {
1919
w.Header().Set("Content-Type", "application/json")
2020
w.WriteHeader(code)
@@ -27,11 +27,11 @@ func (s *Server) handleSessionInit(w http.ResponseWriter, r *http.Request) {
2727
Project string `json:"project"`
2828
}
2929
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
30-
http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest)
30+
jsonError(w, "invalid json", http.StatusBadRequest)
3131
return
3232
}
3333
if req.SessionID == "" {
34-
http.Error(w, `{"error":"session_id required"}`, http.StatusBadRequest)
34+
jsonError(w, "session_id required", http.StatusBadRequest)
3535
return
3636
}
3737

@@ -60,11 +60,11 @@ func (s *Server) handleAddObservation(w http.ResponseWriter, r *http.Request) {
6060
}
6161
body, err := io.ReadAll(r.Body)
6262
if err != nil {
63-
http.Error(w, `{"error":"read body failed"}`, http.StatusBadRequest)
63+
jsonError(w, "read body failed", http.StatusBadRequest)
6464
return
6565
}
6666
if err := json.Unmarshal(body, &req); err != nil {
67-
http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest)
67+
jsonError(w, "invalid json", http.StatusBadRequest)
6868
return
6969
}
7070

@@ -119,7 +119,7 @@ func (s *Server) handleExtractSession(w http.ResponseWriter, r *http.Request) {
119119
Force bool `json:"force"`
120120
}
121121
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
122-
http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest)
122+
jsonError(w, "invalid json", http.StatusBadRequest)
123123
return
124124
}
125125

@@ -155,11 +155,11 @@ func (s *Server) handleSignal(w http.ResponseWriter, r *http.Request) {
155155
Prompt string `json:"prompt"`
156156
}
157157
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
158-
http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest)
158+
jsonError(w, "invalid json", http.StatusBadRequest)
159159
return
160160
}
161161
if req.Prompt == "" {
162-
http.Error(w, `{"error":"prompt required"}`, http.StatusBadRequest)
162+
jsonError(w, "prompt required", http.StatusBadRequest)
163163
return
164164
}
165165

@@ -248,11 +248,11 @@ func (s *Server) handleRemember(w http.ResponseWriter, r *http.Request) {
248248
SessionID string `json:"session_id"`
249249
}
250250
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
251-
http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest)
251+
jsonError(w, "invalid json", http.StatusBadRequest)
252252
return
253253
}
254254
if req.Category == "" || req.Name == "" || req.Summary == "" || req.Body == "" {
255-
http.Error(w, `{"error":"category, name, summary, and body are required"}`, http.StatusBadRequest)
255+
jsonError(w, "category, name, summary, and body are required", http.StatusBadRequest)
256256
return
257257
}
258258

@@ -295,7 +295,7 @@ func (s *Server) handleRemember(w http.ResponseWriter, r *http.Request) {
295295
func (s *Server) handleSearch(w http.ResponseWriter, r *http.Request) {
296296
query := r.URL.Query().Get("q")
297297
if query == "" {
298-
http.Error(w, `{"error":"q parameter required"}`, http.StatusBadRequest)
298+
jsonError(w, "q parameter required", http.StatusBadRequest)
299299
return
300300
}
301301

internal/server/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4141
func (s *Server) routes() {
4242
r := chi.NewRouter()
4343
r.Use(middleware.Recoverer)
44-
r.Use(localhostOnly)
4544
r.Use(securityHeaders)
45+
r.Use(localhostOnly)
4646
r.Use(limitRequestBody)
4747

4848
r.Route("/api", func(r chi.Router) {

internal/store/db.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,14 @@ func OpenMemory() (*DB, error) {
7373
}
7474

7575
// hardenPermissions tightens file/directory permissions for existing installs.
76+
// MkdirAll/OpenFile only set permissions on creation — this fixes pre-existing files.
7677
func hardenPermissions(dir, dbPath string) {
7778
if info, err := os.Stat(dir); err == nil && info.Mode().Perm()&0077 != 0 {
78-
if err := os.Chmod(dir, 0700); err == nil {
79-
fmt.Fprintf(os.Stderr, " security: tightened %s to 0700\n", dir)
80-
}
79+
_ = os.Chmod(dir, 0700)
8180
}
8281
for _, f := range []string{dbPath, dbPath + "-wal", dbPath + "-shm"} {
8382
if info, err := os.Stat(f); err == nil && info.Mode().Perm()&0077 != 0 {
84-
if err := os.Chmod(f, 0600); err == nil {
85-
fmt.Fprintf(os.Stderr, " security: tightened %s to 0600\n", f)
86-
}
83+
_ = os.Chmod(f, 0600)
8784
}
8885
}
8986
}

internal/store/observations.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ import (
66
"time"
77
)
88

9-
// maxToolResponseSize is the maximum size of tool_response stored in the DB.
9+
// maxToolFieldSize is the maximum size of tool_input and tool_response stored in the DB.
1010
// Prevents bloat — Phase 2 extraction processes full transcript anyway.
11-
const maxToolResponseSize = 10 * 1024 // 10KB
11+
const maxToolFieldSize = 10 * 1024 // 10KB
1212

1313
// Observation represents a single tool use recorded during a session.
1414
type Observation struct {
@@ -22,13 +22,13 @@ type Observation struct {
2222

2323
// AddObservation stores a tool use observation. Truncates large fields to prevent DB bloat.
2424
func (db *DB) AddObservation(sessionID, toolName, toolInput, toolResponse string) error {
25-
if len(toolInput) > maxToolResponseSize {
26-
log.Printf("observation: tool_input truncated for session %s: %d → %d bytes", sessionID, len(toolInput), maxToolResponseSize)
27-
toolInput = toolInput[:maxToolResponseSize]
25+
if len(toolInput) > maxToolFieldSize {
26+
log.Printf("observation: tool_input truncated for session %s: %d → %d bytes", sessionID, len(toolInput), maxToolFieldSize)
27+
toolInput = toolInput[:maxToolFieldSize]
2828
}
29-
if len(toolResponse) > maxToolResponseSize {
30-
log.Printf("observation: tool_response truncated for session %s: %d → %d bytes", sessionID, len(toolResponse), maxToolResponseSize)
31-
toolResponse = toolResponse[:maxToolResponseSize]
29+
if len(toolResponse) > maxToolFieldSize {
30+
log.Printf("observation: tool_response truncated for session %s: %d → %d bytes", sessionID, len(toolResponse), maxToolFieldSize)
31+
toolResponse = toolResponse[:maxToolFieldSize]
3232
}
3333

3434
now := time.Now().UnixMilli()

internal/store/observations_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,19 @@ func TestAddObservationTruncation(t *testing.T) {
4242
}
4343
defer db.Close()
4444

45-
bigResponse := strings.Repeat("x", 20*1024) // 20KB
46-
err = db.AddObservation("sess-001", "Bash", "{}", bigResponse)
45+
bigInput := strings.Repeat("i", 20*1024) // 20KB
46+
bigResponse := strings.Repeat("r", 20*1024) // 20KB
47+
err = db.AddObservation("sess-001", "Bash", bigInput, bigResponse)
4748
if err != nil {
4849
t.Fatalf("AddObservation: %v", err)
4950
}
5051

5152
obs, _ := db.GetObservations("sess-001")
52-
if len(obs[0].ToolResponse) != maxToolResponseSize {
53-
t.Errorf("ToolResponse length = %d, want %d", len(obs[0].ToolResponse), maxToolResponseSize)
53+
if len(obs[0].ToolInput) != maxToolFieldSize {
54+
t.Errorf("ToolInput length = %d, want %d", len(obs[0].ToolInput), maxToolFieldSize)
55+
}
56+
if len(obs[0].ToolResponse) != maxToolFieldSize {
57+
t.Errorf("ToolResponse length = %d, want %d", len(obs[0].ToolResponse), maxToolFieldSize)
5458
}
5559
}
5660

0 commit comments

Comments
 (0)