diff --git a/internal/cli/init.go b/internal/cli/init.go index 661e9c9..b6c56cf 100644 --- a/internal/cli/init.go +++ b/internal/cli/init.go @@ -93,10 +93,10 @@ func runInit(cmd *cobra.Command, args []string) error { autostartPath := filepath.Join(homeDir, ".continuity", "autostart") if initAutostart { - if err := os.MkdirAll(filepath.Dir(autostartPath), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(autostartPath), 0700); err != nil { return fmt.Errorf("create .continuity dir: %w", err) } - if err := os.WriteFile(autostartPath, []byte("enabled\n"), 0644); err != nil { + if err := os.WriteFile(autostartPath, []byte("enabled\n"), 0600); err != nil { return fmt.Errorf("write autostart marker: %w", err) } fmt.Println("Autostart enabled: continuity serve will launch automatically when needed.") diff --git a/internal/cli/serve.go b/internal/cli/serve.go index 9620db4..dca3b7f 100644 --- a/internal/cli/serve.go +++ b/internal/cli/serve.go @@ -107,8 +107,12 @@ func runServe(cmd *cobra.Command, args []string) error { addr := cfg.ListenAddr() httpServer := &http.Server{ - Addr: addr, - Handler: srv, + Addr: addr, + Handler: srv, + ReadTimeout: 10 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + MaxHeaderBytes: 1 << 20, // 1MB } // Graceful shutdown diff --git a/internal/hooks/autostart.go b/internal/hooks/autostart.go index 68055c0..1616779 100644 --- a/internal/hooks/autostart.go +++ b/internal/hooks/autostart.go @@ -56,11 +56,15 @@ func TryAutostart() bool { return false } logPath := filepath.Join(home, ".continuity", "serve.log") - logFile, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + logFile, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) if err != nil { fmt.Fprintf(os.Stderr, "continuity: autostart: open log: %v\n", err) return false } + // Tighten existing log files from previous installs (0644 → 0600) + if info, err := logFile.Stat(); err == nil && info.Mode().Perm()&0077 != 0 { + os.Chmod(logPath, 0600) + } devNull, err := os.Open(os.DevNull) if err != nil { diff --git a/internal/hooks/handler.go b/internal/hooks/handler.go index f7fa09e..7d19a31 100644 --- a/internal/hooks/handler.go +++ b/internal/hooks/handler.go @@ -6,11 +6,13 @@ import ( "io" ) +const maxHookInputSize = 10 << 20 // 10MB + // Handle reads HookInput from the given reader, dispatches to the appropriate // handler based on the event argument, and writes output to stdout. func Handle(event string, stdin io.Reader) { var input HookInput - if err := json.NewDecoder(stdin).Decode(&input); err != nil { + if err := json.NewDecoder(io.LimitReader(stdin, maxHookInputSize)).Decode(&input); err != nil { // Stdin may be empty for some events — degrade gracefully if event == "start" { WriteSessionStartOutput("") diff --git a/internal/server/middleware.go b/internal/server/middleware.go new file mode 100644 index 0000000..19b437c --- /dev/null +++ b/internal/server/middleware.go @@ -0,0 +1,53 @@ +package server + +import ( + "net" + "net/http" + "strings" +) + +const maxRequestBody = 1 << 20 // 1MB + +// normalizeHost extracts and normalizes the hostname from a Host header. +// Handles ports, bracketed IPv6, case folding, and trailing dots. +func normalizeHost(host string) string { + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } else if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + host = host[1 : len(host)-1] + } + host = strings.ToLower(host) + host = strings.TrimSuffix(host, ".") + return host +} + +// localhostOnly rejects requests where the Host header is not localhost. +// Prevents DNS rebinding attacks against the local API server. +func localhostOnly(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host := normalizeHost(r.Host) + if host != "localhost" && host != "127.0.0.1" && host != "::1" { + jsonError(w, "forbidden", http.StatusForbidden) + return + } + next.ServeHTTP(w, r) + }) +} + +// securityHeaders adds standard security headers to all responses. +func securityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("Referrer-Policy", "no-referrer") + next.ServeHTTP(w, r) + }) +} + +// limitRequestBody caps the size of incoming request bodies to prevent OOM. +func limitRequestBody(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestBody) + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/routes.go b/internal/server/routes.go index c78e17e..aba185c 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -3,7 +3,6 @@ package server import ( "context" "encoding/json" - "fmt" "io" "log" "net/http" @@ -14,23 +13,32 @@ import ( "github.com/lazypower/continuity/internal/engine" ) +// jsonError writes a JSON error response with proper Content-Type and encoding. +// Prefer this over http.Error for consistent JSON responses. +func jsonError(w http.ResponseWriter, msg string, code int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + json.NewEncoder(w).Encode(map[string]string{"error": msg}) +} + func (s *Server) handleSessionInit(w http.ResponseWriter, r *http.Request) { var req struct { SessionID string `json:"session_id"` Project string `json:"project"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest) + jsonError(w, "invalid json", http.StatusBadRequest) return } if req.SessionID == "" { - http.Error(w, `{"error":"session_id required"}`, http.StatusBadRequest) + jsonError(w, "session_id required", http.StatusBadRequest) return } sess, err := s.db.InitSession(req.SessionID, req.Project) if err != nil { - http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusInternalServerError) + log.Printf("init session: %v", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } @@ -52,16 +60,17 @@ func (s *Server) handleAddObservation(w http.ResponseWriter, r *http.Request) { } body, err := io.ReadAll(r.Body) if err != nil { - http.Error(w, `{"error":"read body failed"}`, http.StatusBadRequest) + jsonError(w, "read body failed", http.StatusBadRequest) return } if err := json.Unmarshal(body, &req); err != nil { - http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest) + jsonError(w, "invalid json", http.StatusBadRequest) return } if err := s.db.AddObservation(sessionID, req.ToolName, req.ToolInput, req.ToolResponse); err != nil { - http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusInternalServerError) + log.Printf("add observation: %v", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } @@ -79,8 +88,9 @@ func (s *Server) handleCompleteSession(w http.ResponseWriter, r *http.Request) { if err := s.db.CompleteSession(sessionID); err != nil { // Not finding an active session is not a server error — the session // may have already been completed or never existed. Log but return OK. + log.Printf("complete session: %v", err) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]string{"status": "ok", "note": err.Error()}) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) return } @@ -92,7 +102,8 @@ func (s *Server) handleEndSession(w http.ResponseWriter, r *http.Request) { sessionID := chi.URLParam(r, "sessionID") if err := s.db.EndSession(sessionID); err != nil { - http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusInternalServerError) + log.Printf("end session: %v", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } @@ -108,7 +119,7 @@ func (s *Server) handleExtractSession(w http.ResponseWriter, r *http.Request) { Force bool `json:"force"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest) + jsonError(w, "invalid json", http.StatusBadRequest) return } @@ -144,11 +155,11 @@ func (s *Server) handleSignal(w http.ResponseWriter, r *http.Request) { Prompt string `json:"prompt"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest) + jsonError(w, "invalid json", http.StatusBadRequest) return } if req.Prompt == "" { - http.Error(w, `{"error":"prompt required"}`, http.StatusBadRequest) + jsonError(w, "prompt required", http.StatusBadRequest) return } @@ -180,9 +191,8 @@ func (s *Server) handleSignal(w http.ResponseWriter, r *http.Request) { func (s *Server) handleUnmarkEmptyExtractions(w http.ResponseWriter, r *http.Request) { n, err := s.db.UnmarkEmptyExtractions() if err != nil { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + log.Printf("unmark empty extractions: %v", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } @@ -204,15 +214,12 @@ func (s *Server) handleGetMemory(w http.ResponseWriter, r *http.Request) { node, err := s.db.GetNodeByURI(uri) if err != nil { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + log.Printf("get memory: %v", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if node == nil { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]string{"error": "memory not found: " + uri}) + jsonError(w, "memory not found", http.StatusNotFound) return } @@ -241,11 +248,11 @@ func (s *Server) handleRemember(w http.ResponseWriter, r *http.Request) { SessionID string `json:"session_id"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest) + jsonError(w, "invalid json", http.StatusBadRequest) return } if req.Category == "" || req.Name == "" || req.Summary == "" || req.Body == "" { - http.Error(w, `{"error":"category, name, summary, and body are required"}`, http.StatusBadRequest) + jsonError(w, "category, name, summary, and body are required", http.StatusBadRequest) return } @@ -268,9 +275,8 @@ func (s *Server) handleRemember(w http.ResponseWriter, r *http.Request) { SessionID: req.SessionID, }) if err != nil { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + log.Printf("remember: %v", err) + jsonError(w, "failed to store memory", http.StatusBadRequest) return } @@ -289,7 +295,7 @@ func (s *Server) handleRemember(w http.ResponseWriter, r *http.Request) { func (s *Server) handleSearch(w http.ResponseWriter, r *http.Request) { query := r.URL.Query().Get("q") if query == "" { - http.Error(w, `{"error":"q parameter required"}`, http.StatusBadRequest) + jsonError(w, "q parameter required", http.StatusBadRequest) return } @@ -304,6 +310,9 @@ func (s *Server) handleSearch(w http.ResponseWriter, r *http.Request) { limit = n } } + if limit > 100 { + limit = 100 + } category := r.URL.Query().Get("category") @@ -333,7 +342,8 @@ func (s *Server) handleSearch(w http.ResponseWriter, r *http.Request) { } if err != nil { - http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err.Error()), http.StatusInternalServerError) + log.Printf("search: %v", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } @@ -384,7 +394,8 @@ func (s *Server) handleTimeline(w http.ResponseWriter, r *http.Request) { sessions, err := s.db.GetSessionsSince(sinceMs) if err != nil { - http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err.Error()), http.StatusInternalServerError) + log.Printf("timeline: %v", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } @@ -415,7 +426,8 @@ func (s *Server) handleTimeline(w http.ResponseWriter, r *http.Request) { func (s *Server) handleProfile(w http.ResponseWriter, r *http.Request) { relProfile, err := s.db.GetNodeByURI("mem://user/profile/communication") if err != nil { - http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err.Error()), http.StatusInternalServerError) + log.Printf("profile: %v", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } @@ -476,7 +488,8 @@ func (s *Server) handleTree(w http.ResponseWriter, r *http.Request) { // List roots roots, err := s.db.ListRoots() if err != nil { - http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err.Error()), http.StatusInternalServerError) + log.Printf("tree roots: %v", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } for _, r := range roots { @@ -492,7 +505,8 @@ func (s *Server) handleTree(w http.ResponseWriter, r *http.Request) { // List children children, err := s.db.GetChildren(uri) if err != nil { - http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err.Error()), http.StatusInternalServerError) + log.Printf("tree children: %v", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } for _, c := range children { diff --git a/internal/server/routes_test.go b/internal/server/routes_test.go index 5fe09df..26fe5b8 100644 --- a/internal/server/routes_test.go +++ b/internal/server/routes_test.go @@ -15,7 +15,7 @@ func TestSessionInit(t *testing.T) { srv := testServer(t) body := `{"session_id":"test-001","project":"/tmp/myproject"}` - req := httptest.NewRequest("POST", "/api/sessions/init", strings.NewReader(body)) + req := newTestRequest("POST", "/api/sessions/init", strings.NewReader(body)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -37,7 +37,7 @@ func TestSessionInitMissingID(t *testing.T) { srv := testServer(t) body := `{"project":"/tmp/myproject"}` - req := httptest.NewRequest("POST", "/api/sessions/init", strings.NewReader(body)) + req := newTestRequest("POST", "/api/sessions/init", strings.NewReader(body)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -51,13 +51,13 @@ func TestAddObservation(t *testing.T) { // Init session first initBody := `{"session_id":"test-001","project":"/tmp/myproject"}` - req := httptest.NewRequest("POST", "/api/sessions/init", strings.NewReader(initBody)) + req := newTestRequest("POST", "/api/sessions/init", strings.NewReader(initBody)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) // Add observation obsBody := `{"tool_name":"Bash","tool_input":"{\"command\":\"ls\"}","tool_response":"file1 file2"}` - req = httptest.NewRequest("POST", "/api/sessions/test-001/observations", strings.NewReader(obsBody)) + req = newTestRequest("POST", "/api/sessions/test-001/observations", strings.NewReader(obsBody)) w = httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -71,12 +71,12 @@ func TestCompleteSession(t *testing.T) { // Init session initBody := `{"session_id":"test-001","project":"/tmp/myproject"}` - req := httptest.NewRequest("POST", "/api/sessions/init", strings.NewReader(initBody)) + req := newTestRequest("POST", "/api/sessions/init", strings.NewReader(initBody)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) // Complete session - req = httptest.NewRequest("POST", "/api/sessions/test-001/complete", nil) + req = newTestRequest("POST", "/api/sessions/test-001/complete", nil) w = httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -96,12 +96,12 @@ func TestEndSession(t *testing.T) { // Init session initBody := `{"session_id":"test-001","project":"/tmp/myproject"}` - req := httptest.NewRequest("POST", "/api/sessions/init", strings.NewReader(initBody)) + req := newTestRequest("POST", "/api/sessions/init", strings.NewReader(initBody)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) // End session - req = httptest.NewRequest("POST", "/api/sessions/test-001/end", nil) + req = newTestRequest("POST", "/api/sessions/test-001/end", nil) w = httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -120,7 +120,7 @@ func TestSignalRouteNoEngine(t *testing.T) { srv := testServer(t) // engine is nil body := `{"prompt":"remember this: always use WAL mode"}` - req := httptest.NewRequest("POST", "/api/sessions/test-001/signal", strings.NewReader(body)) + req := newTestRequest("POST", "/api/sessions/test-001/signal", strings.NewReader(body)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -133,7 +133,7 @@ func TestSignalRouteMissingPrompt(t *testing.T) { srv := testServer(t) body := `{}` - req := httptest.NewRequest("POST", "/api/sessions/test-001/signal", strings.NewReader(body)) + req := newTestRequest("POST", "/api/sessions/test-001/signal", strings.NewReader(body)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -145,7 +145,7 @@ func TestSignalRouteMissingPrompt(t *testing.T) { func TestSignalRouteInvalidJSON(t *testing.T) { srv := testServer(t) - req := httptest.NewRequest("POST", "/api/sessions/test-001/signal", strings.NewReader("not json")) + req := newTestRequest("POST", "/api/sessions/test-001/signal", strings.NewReader("not json")) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -158,7 +158,7 @@ func TestGetContext(t *testing.T) { srv := testServer(t) // Empty context - req := httptest.NewRequest("GET", "/api/context", nil) + req := newTestRequest("GET", "/api/context", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -191,7 +191,7 @@ func TestRememberRoute(t *testing.T) { srv := testServerWithEngine(t) body := `{"category":"preferences","name":"devbox","summary":"Always use devbox","body":"The project uses devbox shell to provide Go and SQLite tools."}` - req := httptest.NewRequest("POST", "/api/memories", strings.NewReader(body)) + req := newTestRequest("POST", "/api/memories", strings.NewReader(body)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -215,7 +215,7 @@ func TestRememberRouteUpdate(t *testing.T) { body := `{"category":"preferences","name":"devbox","summary":"Always use devbox","body":"The project uses devbox shell to provide Go and SQLite tools."}` // First call → created - req := httptest.NewRequest("POST", "/api/memories", strings.NewReader(body)) + req := newTestRequest("POST", "/api/memories", strings.NewReader(body)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusCreated { @@ -224,7 +224,7 @@ func TestRememberRouteUpdate(t *testing.T) { // Second call with different content → updated body2 := `{"category":"preferences","name":"devbox","summary":"Updated devbox preference","body":"Updated: devbox shell provides Go, SQLite, and additional tooling."}` - req = httptest.NewRequest("POST", "/api/memories", strings.NewReader(body2)) + req = newTestRequest("POST", "/api/memories", strings.NewReader(body2)) w = httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { @@ -243,7 +243,7 @@ func TestGetMemoryRoute(t *testing.T) { // Seed a memory via POST body := `{"category":"patterns","name":"test-journal","summary":"tiny test","body":"section A\n- entry 1\n\nsection B\n- entry 2\n"}` - req := httptest.NewRequest("POST", "/api/memories", strings.NewReader(body)) + req := newTestRequest("POST", "/api/memories", strings.NewReader(body)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusCreated { @@ -251,7 +251,7 @@ func TestGetMemoryRoute(t *testing.T) { } // Read it back - req = httptest.NewRequest("GET", "/api/memories?uri=mem://agent/patterns/test-journal", nil) + req = newTestRequest("GET", "/api/memories?uri=mem://agent/patterns/test-journal", nil) w = httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -281,7 +281,7 @@ func TestGetMemoryRoute(t *testing.T) { func TestGetMemoryRouteNotFound(t *testing.T) { srv := testServerWithEngine(t) - req := httptest.NewRequest("GET", "/api/memories?uri=mem://agent/patterns/does-not-exist", nil) + req := newTestRequest("GET", "/api/memories?uri=mem://agent/patterns/does-not-exist", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -293,7 +293,7 @@ func TestGetMemoryRouteNotFound(t *testing.T) { func TestGetMemoryRouteMissingURI(t *testing.T) { srv := testServerWithEngine(t) - req := httptest.NewRequest("GET", "/api/memories", nil) + req := newTestRequest("GET", "/api/memories", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -306,7 +306,7 @@ func TestRememberRouteNoEngine(t *testing.T) { srv := testServer(t) // engine is nil body := `{"category":"preferences","name":"test","summary":"test","body":"test body with enough content for validation."}` - req := httptest.NewRequest("POST", "/api/memories", strings.NewReader(body)) + req := newTestRequest("POST", "/api/memories", strings.NewReader(body)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -330,7 +330,7 @@ func TestRememberRouteMissingFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("POST", "/api/memories", strings.NewReader(tt.body)) + req := newTestRequest("POST", "/api/memories", strings.NewReader(tt.body)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -344,7 +344,7 @@ func TestRememberRouteMissingFields(t *testing.T) { func TestRememberRouteInvalidJSON(t *testing.T) { srv := testServerWithEngine(t) - req := httptest.NewRequest("POST", "/api/memories", strings.NewReader("not json")) + req := newTestRequest("POST", "/api/memories", strings.NewReader("not json")) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -358,21 +358,21 @@ func TestGetContextWithSessions(t *testing.T) { // Create a completed session with observations initBody := `{"session_id":"old-001","project":"/tmp/myproject"}` - req := httptest.NewRequest("POST", "/api/sessions/init", strings.NewReader(initBody)) + req := newTestRequest("POST", "/api/sessions/init", strings.NewReader(initBody)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) obsBody := `{"tool_name":"Bash","tool_input":"{}","tool_response":"ok"}` - req = httptest.NewRequest("POST", "/api/sessions/old-001/observations", strings.NewReader(obsBody)) + req = newTestRequest("POST", "/api/sessions/old-001/observations", strings.NewReader(obsBody)) w = httptest.NewRecorder() srv.ServeHTTP(w, req) - req = httptest.NewRequest("POST", "/api/sessions/old-001/complete", nil) + req = newTestRequest("POST", "/api/sessions/old-001/complete", nil) w = httptest.NewRecorder() srv.ServeHTTP(w, req) // Get context for a new session - req = httptest.NewRequest("GET", "/api/context?session_id=new-001", nil) + req = newTestRequest("GET", "/api/context?session_id=new-001", nil) w = httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -405,7 +405,7 @@ func TestUnmarkEmptyExtractionsRoute(t *testing.T) { t.Fatalf("UpsertNode: %v", err) } - req := httptest.NewRequest("POST", "/api/sessions/unmark-empty-extractions", nil) + req := newTestRequest("POST", "/api/sessions/unmark-empty-extractions", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -440,7 +440,7 @@ func TestExtractSessionRouteAcceptsForce(t *testing.T) { srv.db.InitSession("extract-001", "proj") body := `{"transcript_path":"/nonexistent/transcript.jsonl","force":true}` - req := httptest.NewRequest("POST", "/api/sessions/extract-001/extract", strings.NewReader(body)) + req := newTestRequest("POST", "/api/sessions/extract-001/extract", strings.NewReader(body)) w := httptest.NewRecorder() srv.ServeHTTP(w, req) diff --git a/internal/server/server.go b/internal/server/server.go index e157b4e..47f1bdb 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -41,7 +41,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) routes() { r := chi.NewRouter() r.Use(middleware.Recoverer) - r.Use(middleware.RealIP) + r.Use(securityHeaders) + r.Use(localhostOnly) + r.Use(limitRequestBody) r.Route("/api", func(r chi.Router) { r.Get("/health", s.handleHealth) @@ -90,8 +92,7 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { "status": "ok", "version": s.version, "uptime": time.Since(s.started).Seconds(), - "db": dbOK, - "db_path": s.db.Path, + "db": dbOK, }) } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 8e089bf..386e02e 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -3,6 +3,7 @@ package server import ( "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -11,6 +12,14 @@ import ( "github.com/lazypower/continuity/internal/store" ) +// newTestRequest wraps httptest.NewRequest with Host: localhost so the +// localhostOnly middleware doesn't reject test traffic. +func newTestRequest(method, url string, body io.Reader) *http.Request { + req := httptest.NewRequest(method, url, body) + req.Host = "localhost" + return req +} + func testServer(t *testing.T) *Server { t.Helper() db, err := store.OpenMemory() @@ -293,7 +302,7 @@ func TestTruncateAtSentence(t *testing.T) { func TestHealthEndpoint(t *testing.T) { srv := testServer(t) - req := httptest.NewRequest("GET", "/api/health", nil) + req := newTestRequest("GET", "/api/health", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -330,7 +339,7 @@ func TestStubRoutes(t *testing.T) { } for _, s := range stubs { - req := httptest.NewRequest(s.method, s.path, nil) + req := newTestRequest(s.method, s.path, nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -353,7 +362,7 @@ func TestSearchRoute(t *testing.T) { srv := testServer(t) // Search without embedder returns 503 - req := httptest.NewRequest("GET", "/api/search?q=test", nil) + req := newTestRequest("GET", "/api/search?q=test", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -362,7 +371,7 @@ func TestSearchRoute(t *testing.T) { } // Search without q param returns 400 - req = httptest.NewRequest("GET", "/api/search", nil) + req = newTestRequest("GET", "/api/search", nil) w = httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -374,7 +383,7 @@ func TestSearchRoute(t *testing.T) { func TestProfileRoute(t *testing.T) { srv := testServer(t) - req := httptest.NewRequest("GET", "/api/profile", nil) + req := newTestRequest("GET", "/api/profile", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) @@ -394,7 +403,7 @@ func TestProfileRoute(t *testing.T) { func TestTreeRoute(t *testing.T) { srv := testServer(t) - req := httptest.NewRequest("GET", "/api/tree", nil) + req := newTestRequest("GET", "/api/tree", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) diff --git a/internal/store/db.go b/internal/store/db.go index 299ee54..9a4a7eb 100644 --- a/internal/store/db.go +++ b/internal/store/db.go @@ -28,10 +28,14 @@ func DefaultDBPath() (string, error) { // configures pragmas, and runs migrations. func Open(path string) (*DB, error) { dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0755); err != nil { + if err := os.MkdirAll(dir, 0700); err != nil { return nil, fmt.Errorf("create db dir: %w", err) } + // Tighten permissions on existing installs — MkdirAll/Open only set + // permissions on creation, so pre-existing dirs/files need explicit chmod. + hardenPermissions(dir, path) + sqlDB, err := sql.Open("sqlite", path) if err != nil { return nil, fmt.Errorf("open sqlite: %w", err) @@ -68,6 +72,19 @@ func OpenMemory() (*DB, error) { return db, nil } +// hardenPermissions tightens file/directory permissions for existing installs. +// MkdirAll/OpenFile only set permissions on creation — this fixes pre-existing files. +func hardenPermissions(dir, dbPath string) { + if info, err := os.Stat(dir); err == nil && info.Mode().Perm()&0077 != 0 { + _ = os.Chmod(dir, 0700) + } + for _, f := range []string{dbPath, dbPath + "-wal", dbPath + "-shm"} { + if info, err := os.Stat(f); err == nil && info.Mode().Perm()&0077 != 0 { + _ = os.Chmod(f, 0600) + } + } +} + func (db *DB) configurePragmas() error { pragmas := []string{ "PRAGMA journal_mode=WAL", diff --git a/internal/store/observations.go b/internal/store/observations.go index 9416f3a..6145735 100644 --- a/internal/store/observations.go +++ b/internal/store/observations.go @@ -2,12 +2,13 @@ package store import ( "fmt" + "log" "time" ) -// maxToolResponseSize is the maximum size of tool_response stored in the DB. +// maxToolFieldSize is the maximum size of tool_input and tool_response stored in the DB. // Prevents bloat — Phase 2 extraction processes full transcript anyway. -const maxToolResponseSize = 10 * 1024 // 10KB +const maxToolFieldSize = 10 * 1024 // 10KB // Observation represents a single tool use recorded during a session. type Observation struct { @@ -19,10 +20,15 @@ type Observation struct { CreatedAt int64 } -// AddObservation stores a tool use observation. Truncates tool_response to 10KB. +// AddObservation stores a tool use observation. Truncates large fields to prevent DB bloat. func (db *DB) AddObservation(sessionID, toolName, toolInput, toolResponse string) error { - if len(toolResponse) > maxToolResponseSize { - toolResponse = toolResponse[:maxToolResponseSize] + if len(toolInput) > maxToolFieldSize { + log.Printf("observation: tool_input truncated for session %s: %d → %d bytes", sessionID, len(toolInput), maxToolFieldSize) + toolInput = toolInput[:maxToolFieldSize] + } + if len(toolResponse) > maxToolFieldSize { + log.Printf("observation: tool_response truncated for session %s: %d → %d bytes", sessionID, len(toolResponse), maxToolFieldSize) + toolResponse = toolResponse[:maxToolFieldSize] } now := time.Now().UnixMilli() diff --git a/internal/store/observations_test.go b/internal/store/observations_test.go index 5a33c77..04bc1ea 100644 --- a/internal/store/observations_test.go +++ b/internal/store/observations_test.go @@ -42,15 +42,19 @@ func TestAddObservationTruncation(t *testing.T) { } defer db.Close() - bigResponse := strings.Repeat("x", 20*1024) // 20KB - err = db.AddObservation("sess-001", "Bash", "{}", bigResponse) + bigInput := strings.Repeat("i", 20*1024) // 20KB + bigResponse := strings.Repeat("r", 20*1024) // 20KB + err = db.AddObservation("sess-001", "Bash", bigInput, bigResponse) if err != nil { t.Fatalf("AddObservation: %v", err) } obs, _ := db.GetObservations("sess-001") - if len(obs[0].ToolResponse) != maxToolResponseSize { - t.Errorf("ToolResponse length = %d, want %d", len(obs[0].ToolResponse), maxToolResponseSize) + if len(obs[0].ToolInput) != maxToolFieldSize { + t.Errorf("ToolInput length = %d, want %d", len(obs[0].ToolInput), maxToolFieldSize) + } + if len(obs[0].ToolResponse) != maxToolFieldSize { + t.Errorf("ToolResponse length = %d, want %d", len(obs[0].ToolResponse), maxToolFieldSize) } } diff --git a/ui/src/components/Header.svelte b/ui/src/components/Header.svelte index 0ccd9b6..d33cb29 100644 --- a/ui/src/components/Header.svelte +++ b/ui/src/components/Header.svelte @@ -3,9 +3,9 @@ import ThemeToggle from './ThemeToggle.svelte'; const tabs: { id: Tab; label: string; icon: string }[] = [ - { id: 'tree', label: 'Tree', icon: '◈' }, - { id: 'search', label: 'Search', icon: '⚲' }, - { id: 'profile', label: 'Profile', icon: '♦' }, + { id: 'tree', label: 'Tree', icon: '\u25C8' }, + { id: 'search', label: 'Search', icon: '\u26B2' }, + { id: 'profile', label: 'Profile', icon: '\u2666' }, ]; function setTab(tab: Tab) { @@ -36,7 +36,7 @@ : 'text-[var(--text-secondary)] hover:text-[var(--text-primary)] hover:bg-[var(--bg-card)]'}" > - {@html tab.icon} + {tab.icon} {tab.label} diff --git a/ui/src/components/ProfilePanel.svelte b/ui/src/components/ProfilePanel.svelte index ba67c0d..f6339c6 100644 --- a/ui/src/components/ProfilePanel.svelte +++ b/ui/src/components/ProfilePanel.svelte @@ -22,14 +22,27 @@ } } - function formatProfile(text: string): string { - // Simple markdown-ish rendering for section headers - return text - .replace(/^## (\d+)\. (.+)$/gm, '