Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/cli/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ func runInit(cmd *cobra.Command, args []string) error {
autostartPath := filepath.Join(homeDir, ".continuity", "autostart")

if initAutostart {
if err := os.MkdirAll(filepath.Dir(autostartPath), 0755); err != nil {
if err := os.MkdirAll(filepath.Dir(autostartPath), 0700); err != nil {
return fmt.Errorf("create .continuity dir: %w", err)
}
if err := os.WriteFile(autostartPath, []byte("enabled\n"), 0644); err != nil {
if err := os.WriteFile(autostartPath, []byte("enabled\n"), 0600); err != nil {
return fmt.Errorf("write autostart marker: %w", err)
}
fmt.Println("Autostart enabled: continuity serve will launch automatically when needed.")
Expand Down
8 changes: 6 additions & 2 deletions internal/cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,12 @@ func runServe(cmd *cobra.Command, args []string) error {
addr := cfg.ListenAddr()

httpServer := &http.Server{
Addr: addr,
Handler: srv,
Addr: addr,
Handler: srv,
ReadTimeout: 10 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
MaxHeaderBytes: 1 << 20, // 1MB
}

// Graceful shutdown
Expand Down
6 changes: 5 additions & 1 deletion internal/hooks/autostart.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,15 @@ func TryAutostart() bool {
return false
}
logPath := filepath.Join(home, ".continuity", "serve.log")
logFile, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
logFile, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
fmt.Fprintf(os.Stderr, "continuity: autostart: open log: %v\n", err)
return false
}
// Tighten existing log files from previous installs (0644 → 0600)
if info, err := logFile.Stat(); err == nil && info.Mode().Perm()&0077 != 0 {
os.Chmod(logPath, 0600)
}

devNull, err := os.Open(os.DevNull)
if err != nil {
Expand Down
4 changes: 3 additions & 1 deletion internal/hooks/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import (
"io"
)

const maxHookInputSize = 10 << 20 // 10MB

// Handle reads HookInput from the given reader, dispatches to the appropriate
// handler based on the event argument, and writes output to stdout.
func Handle(event string, stdin io.Reader) {
var input HookInput
if err := json.NewDecoder(stdin).Decode(&input); err != nil {
if err := json.NewDecoder(io.LimitReader(stdin, maxHookInputSize)).Decode(&input); err != nil {
// Stdin may be empty for some events — degrade gracefully
if event == "start" {
WriteSessionStartOutput("")
Expand Down
53 changes: 53 additions & 0 deletions internal/server/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package server

import (
"net"
"net/http"
"strings"
)

const maxRequestBody = 1 << 20 // 1MB

// normalizeHost extracts and normalizes the hostname from a Host header.
// Handles ports, bracketed IPv6, case folding, and trailing dots.
func normalizeHost(host string) string {
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
} else if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = host[1 : len(host)-1]
}
host = strings.ToLower(host)
host = strings.TrimSuffix(host, ".")
return host
}

// localhostOnly rejects requests where the Host header is not localhost.
// Prevents DNS rebinding attacks against the local API server.
func localhostOnly(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host := normalizeHost(r.Host)
if host != "localhost" && host != "127.0.0.1" && host != "::1" {
jsonError(w, "forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}

// securityHeaders adds standard security headers to all responses.
func securityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("Referrer-Policy", "no-referrer")
next.ServeHTTP(w, r)
})
}

// limitRequestBody caps the size of incoming request bodies to prevent OOM.
func limitRequestBody(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBody)
next.ServeHTTP(w, r)
})
}
78 changes: 46 additions & 32 deletions internal/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package server
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
Expand All @@ -14,23 +13,32 @@ import (
"github.com/lazypower/continuity/internal/engine"
)

// jsonError writes a JSON error response with proper Content-Type and encoding.
// Prefer this over http.Error for consistent JSON responses.
func jsonError(w http.ResponseWriter, msg string, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(map[string]string{"error": msg})
}

func (s *Server) handleSessionInit(w http.ResponseWriter, r *http.Request) {
var req struct {
SessionID string `json:"session_id"`
Project string `json:"project"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest)
jsonError(w, "invalid json", http.StatusBadRequest)
return
}
if req.SessionID == "" {
http.Error(w, `{"error":"session_id required"}`, http.StatusBadRequest)
jsonError(w, "session_id required", http.StatusBadRequest)
return
}

sess, err := s.db.InitSession(req.SessionID, req.Project)
if err != nil {
http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusInternalServerError)
log.Printf("init session: %v", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return
}

Expand All @@ -52,16 +60,17 @@ func (s *Server) handleAddObservation(w http.ResponseWriter, r *http.Request) {
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, `{"error":"read body failed"}`, http.StatusBadRequest)
jsonError(w, "read body failed", http.StatusBadRequest)
return
}
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest)
jsonError(w, "invalid json", http.StatusBadRequest)
return
}

if err := s.db.AddObservation(sessionID, req.ToolName, req.ToolInput, req.ToolResponse); err != nil {
http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusInternalServerError)
log.Printf("add observation: %v", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return
}

Expand All @@ -79,8 +88,9 @@ func (s *Server) handleCompleteSession(w http.ResponseWriter, r *http.Request) {
if err := s.db.CompleteSession(sessionID); err != nil {
// Not finding an active session is not a server error — the session
// may have already been completed or never existed. Log but return OK.
log.Printf("complete session: %v", err)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "ok", "note": err.Error()})
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
return
}

Expand All @@ -92,7 +102,8 @@ func (s *Server) handleEndSession(w http.ResponseWriter, r *http.Request) {
sessionID := chi.URLParam(r, "sessionID")

if err := s.db.EndSession(sessionID); err != nil {
http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusInternalServerError)
log.Printf("end session: %v", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return
}

Expand All @@ -108,7 +119,7 @@ func (s *Server) handleExtractSession(w http.ResponseWriter, r *http.Request) {
Force bool `json:"force"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest)
jsonError(w, "invalid json", http.StatusBadRequest)
return
}

Expand Down Expand Up @@ -144,11 +155,11 @@ func (s *Server) handleSignal(w http.ResponseWriter, r *http.Request) {
Prompt string `json:"prompt"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest)
jsonError(w, "invalid json", http.StatusBadRequest)
return
}
if req.Prompt == "" {
http.Error(w, `{"error":"prompt required"}`, http.StatusBadRequest)
jsonError(w, "prompt required", http.StatusBadRequest)
return
}

Expand Down Expand Up @@ -180,9 +191,8 @@ func (s *Server) handleSignal(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleUnmarkEmptyExtractions(w http.ResponseWriter, r *http.Request) {
n, err := s.db.UnmarkEmptyExtractions()
if err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
log.Printf("unmark empty extractions: %v", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return
}

Expand All @@ -204,15 +214,12 @@ func (s *Server) handleGetMemory(w http.ResponseWriter, r *http.Request) {

node, err := s.db.GetNodeByURI(uri)
if err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
log.Printf("get memory: %v", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return
}
if node == nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]string{"error": "memory not found: " + uri})
jsonError(w, "memory not found", http.StatusNotFound)
return
}

Expand Down Expand Up @@ -241,11 +248,11 @@ func (s *Server) handleRemember(w http.ResponseWriter, r *http.Request) {
SessionID string `json:"session_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest)
jsonError(w, "invalid json", http.StatusBadRequest)
return
}
if req.Category == "" || req.Name == "" || req.Summary == "" || req.Body == "" {
http.Error(w, `{"error":"category, name, summary, and body are required"}`, http.StatusBadRequest)
jsonError(w, "category, name, summary, and body are required", http.StatusBadRequest)
return
}

Expand All @@ -268,9 +275,8 @@ func (s *Server) handleRemember(w http.ResponseWriter, r *http.Request) {
SessionID: req.SessionID,
})
if err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
log.Printf("remember: %v", err)
jsonError(w, "failed to store memory", http.StatusBadRequest)
return
}

Expand All @@ -289,7 +295,7 @@ func (s *Server) handleRemember(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleSearch(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query().Get("q")
if query == "" {
http.Error(w, `{"error":"q parameter required"}`, http.StatusBadRequest)
jsonError(w, "q parameter required", http.StatusBadRequest)
return
}

Expand All @@ -304,6 +310,9 @@ func (s *Server) handleSearch(w http.ResponseWriter, r *http.Request) {
limit = n
}
}
if limit > 100 {
limit = 100
}

category := r.URL.Query().Get("category")

Expand Down Expand Up @@ -333,7 +342,8 @@ func (s *Server) handleSearch(w http.ResponseWriter, r *http.Request) {
}

if err != nil {
http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err.Error()), http.StatusInternalServerError)
log.Printf("search: %v", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return
}

Expand Down Expand Up @@ -384,7 +394,8 @@ func (s *Server) handleTimeline(w http.ResponseWriter, r *http.Request) {

sessions, err := s.db.GetSessionsSince(sinceMs)
if err != nil {
http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err.Error()), http.StatusInternalServerError)
log.Printf("timeline: %v", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return
}

Expand Down Expand Up @@ -415,7 +426,8 @@ func (s *Server) handleTimeline(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleProfile(w http.ResponseWriter, r *http.Request) {
relProfile, err := s.db.GetNodeByURI("mem://user/profile/communication")
if err != nil {
http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err.Error()), http.StatusInternalServerError)
log.Printf("profile: %v", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return
}

Expand Down Expand Up @@ -476,7 +488,8 @@ func (s *Server) handleTree(w http.ResponseWriter, r *http.Request) {
// List roots
roots, err := s.db.ListRoots()
if err != nil {
http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err.Error()), http.StatusInternalServerError)
log.Printf("tree roots: %v", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return
}
for _, r := range roots {
Expand All @@ -492,7 +505,8 @@ func (s *Server) handleTree(w http.ResponseWriter, r *http.Request) {
// List children
children, err := s.db.GetChildren(uri)
if err != nil {
http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err.Error()), http.StatusInternalServerError)
log.Printf("tree children: %v", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return
}
for _, c := range children {
Expand Down
Loading
Loading