Skip to content

Commit 623f6a4

Browse files
authored
Merge pull request #4 from lazypower/fix/issue-2-extraction-correctness
Fix silent extraction lock-out (closes #2)
2 parents b82fdd8 + 78fad62 commit 623f6a4

13 files changed

Lines changed: 843 additions & 9 deletions

File tree

internal/cli/extract.go

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
package cli
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
"strings"
9+
10+
"github.com/lazypower/continuity/internal/hooks"
11+
"github.com/spf13/cobra"
12+
)
13+
14+
var (
15+
extractForce bool
16+
extractTranscript string
17+
extractBackfillEmpty bool
18+
)
19+
20+
var extractCmd = &cobra.Command{
21+
Use: "extract [session-id]",
22+
Short: "Re-run extraction for a session",
23+
Long: `Trigger memory extraction for a completed session.
24+
25+
Typical uses:
26+
continuity extract <session-id> — re-extract if not already done
27+
continuity extract <session-id> --force — re-extract even if marked done
28+
continuity extract --backfill-empty — unmark every session that was
29+
flagged as extracted but has
30+
no memories attributed to it
31+
32+
When a session-id is given, continuity auto-discovers the transcript at
33+
~/.claude/projects/*/<session-id>.jsonl. Pass --transcript to override.
34+
35+
Requires a running server (continuity serve).`,
36+
Args: cobra.MaximumNArgs(1),
37+
RunE: runExtract,
38+
}
39+
40+
func init() {
41+
extractCmd.Flags().BoolVar(&extractForce, "force", false, "Bypass the idempotency guard (re-extract already-extracted sessions)")
42+
extractCmd.Flags().StringVar(&extractTranscript, "transcript", "", "Path to transcript JSONL (overrides auto-discovery)")
43+
extractCmd.Flags().BoolVar(&extractBackfillEmpty, "backfill-empty", false, "Unmark every session marked extracted with zero attributed memories")
44+
}
45+
46+
func runExtract(cmd *cobra.Command, args []string) error {
47+
client := hooks.NewClient()
48+
if !client.Healthy() {
49+
return fmt.Errorf("continuity server is not running — start it with: continuity serve")
50+
}
51+
52+
if extractBackfillEmpty {
53+
if len(args) > 0 || extractForce || extractTranscript != "" {
54+
return fmt.Errorf("--backfill-empty cannot be combined with a session-id, --force, or --transcript")
55+
}
56+
return runBackfillEmpty(client)
57+
}
58+
59+
if len(args) != 1 {
60+
return fmt.Errorf("session-id is required (or use --backfill-empty)")
61+
}
62+
sessionID := strings.TrimSpace(args[0])
63+
if sessionID == "" {
64+
return fmt.Errorf("session-id is required")
65+
}
66+
67+
transcriptPath := extractTranscript
68+
if transcriptPath == "" {
69+
found, err := findTranscript(sessionID)
70+
if err != nil {
71+
return err
72+
}
73+
transcriptPath = found
74+
}
75+
if _, err := os.Stat(transcriptPath); err != nil {
76+
return fmt.Errorf("transcript not readable: %w", err)
77+
}
78+
79+
body, _ := json.Marshal(map[string]any{
80+
"transcript_path": transcriptPath,
81+
"force": extractForce,
82+
})
83+
84+
data, err := client.Post("/api/sessions/"+sessionID+"/extract", body)
85+
if err != nil {
86+
if len(data) > 0 {
87+
var resp struct {
88+
Error string `json:"error"`
89+
}
90+
if jsonErr := json.Unmarshal(data, &resp); jsonErr == nil && resp.Error != "" {
91+
return fmt.Errorf("%s", resp.Error)
92+
}
93+
}
94+
return fmt.Errorf("extract: %w", err)
95+
}
96+
97+
fmt.Printf("extraction queued for %s (transcript: %s, force: %v)\n", sessionID, transcriptPath, extractForce)
98+
fmt.Println("check serve.log for progress — extraction runs asynchronously")
99+
return nil
100+
}
101+
102+
func runBackfillEmpty(client *hooks.Client) error {
103+
data, err := client.Post("/api/sessions/unmark-empty-extractions", nil)
104+
if err != nil {
105+
if len(data) > 0 {
106+
var resp struct {
107+
Error string `json:"error"`
108+
}
109+
if jsonErr := json.Unmarshal(data, &resp); jsonErr == nil && resp.Error != "" {
110+
return fmt.Errorf("%s", resp.Error)
111+
}
112+
}
113+
return fmt.Errorf("backfill: %w", err)
114+
}
115+
116+
var resp struct {
117+
Status string `json:"status"`
118+
Unmarked int64 `json:"unmarked"`
119+
Error string `json:"error"`
120+
}
121+
if err := json.Unmarshal(data, &resp); err != nil {
122+
return fmt.Errorf("parse response: %w", err)
123+
}
124+
if resp.Error != "" {
125+
return fmt.Errorf("%s", resp.Error)
126+
}
127+
128+
fmt.Printf("unmarked %d session(s) that were extracted with no attributed memories\n", resp.Unmarked)
129+
if resp.Unmarked > 0 {
130+
fmt.Println("they will be re-extracted on their next Stop/SessionEnd hook,")
131+
fmt.Println("or force one now with: continuity extract <session-id> --force")
132+
}
133+
return nil
134+
}
135+
136+
// findTranscript searches ~/.claude/projects/*/<session-id>.jsonl for a
137+
// Claude Code transcript matching the given session id. The sessionID is
138+
// validated first — path separators or ".." would let a glob pattern escape
139+
// ~/.claude/projects, which is surprising for "auto-discovery". Callers who
140+
// genuinely need to point at a transcript outside that tree should pass
141+
// --transcript explicitly.
142+
func findTranscript(sessionID string) (string, error) {
143+
if err := validateSessionIDForGlob(sessionID); err != nil {
144+
return "", fmt.Errorf("%w — pass --transcript to point at a specific file", err)
145+
}
146+
home, err := os.UserHomeDir()
147+
if err != nil {
148+
return "", fmt.Errorf("resolve home dir: %w", err)
149+
}
150+
pattern := filepath.Join(home, ".claude", "projects", "*", sessionID+".jsonl")
151+
matches, err := filepath.Glob(pattern)
152+
if err != nil {
153+
return "", fmt.Errorf("glob transcripts: %w", err)
154+
}
155+
if len(matches) == 0 {
156+
return "", fmt.Errorf("no transcript found for session %s (looked in %s)", sessionID, pattern)
157+
}
158+
if len(matches) > 1 {
159+
return "", fmt.Errorf("multiple transcripts found for %s — pass --transcript to disambiguate:\n %s", sessionID, strings.Join(matches, "\n "))
160+
}
161+
return matches[0], nil
162+
}
163+
164+
// validateSessionIDForGlob rejects session IDs that would let the
165+
// auto-discovery glob escape ~/.claude/projects. Real Claude Code session
166+
// IDs are UUIDs, but continuity imports from other sources so we don't
167+
// require that — we just refuse anything that would traverse the filesystem.
168+
func validateSessionIDForGlob(sessionID string) error {
169+
if sessionID == "" {
170+
return fmt.Errorf("session-id is empty")
171+
}
172+
if strings.ContainsAny(sessionID, `/\`) {
173+
return fmt.Errorf("session-id %q contains a path separator", sessionID)
174+
}
175+
if sessionID == "." || sessionID == ".." || strings.Contains(sessionID, "..") {
176+
return fmt.Errorf("session-id %q contains path traversal", sessionID)
177+
}
178+
return nil
179+
}

internal/cli/extract_test.go

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
package cli
2+
3+
import (
4+
"encoding/json"
5+
"net/http/httptest"
6+
"os"
7+
"path/filepath"
8+
"strings"
9+
"testing"
10+
11+
"github.com/lazypower/continuity/internal/engine"
12+
"github.com/lazypower/continuity/internal/server"
13+
"github.com/lazypower/continuity/internal/store"
14+
)
15+
16+
// extractTestServer stands up an in-memory server and points the CLI client
17+
// at it via CONTINUITY_URL. Mirrors showTestServer.
18+
func extractTestServer(t *testing.T) *store.DB {
19+
t.Helper()
20+
db, err := store.OpenMemory()
21+
if err != nil {
22+
t.Fatalf("OpenMemory: %v", err)
23+
}
24+
t.Cleanup(func() { db.Close() })
25+
26+
eng := engine.New(db, nil)
27+
srv := server.New(db, eng, "test-version")
28+
ts := httptest.NewServer(srv)
29+
t.Cleanup(ts.Close)
30+
31+
prev := os.Getenv("CONTINUITY_URL")
32+
os.Setenv("CONTINUITY_URL", ts.URL)
33+
t.Cleanup(func() { os.Setenv("CONTINUITY_URL", prev) })
34+
35+
return db
36+
}
37+
38+
func resetExtractFlags() {
39+
extractForce = false
40+
extractTranscript = ""
41+
extractBackfillEmpty = false
42+
}
43+
44+
func writeDummyTranscript(t *testing.T) string {
45+
t.Helper()
46+
dir := t.TempDir()
47+
path := filepath.Join(dir, "transcript.jsonl")
48+
entries := []map[string]any{
49+
{"type": "user", "message": map[string]any{"role": "user", "content": "Hello please help me out with this task"}},
50+
{"type": "assistant", "message": map[string]any{"role": "assistant", "content": "Sure, I can help."}},
51+
{"type": "user", "message": map[string]any{"role": "user", "content": "Another message here"}},
52+
}
53+
f, err := os.Create(path)
54+
if err != nil {
55+
t.Fatalf("create transcript: %v", err)
56+
}
57+
defer f.Close()
58+
for _, e := range entries {
59+
data, _ := json.Marshal(e)
60+
if _, err := f.Write(data); err != nil {
61+
t.Fatalf("write transcript: %v", err)
62+
}
63+
if _, err := f.Write([]byte("\n")); err != nil {
64+
t.Fatalf("write transcript: %v", err)
65+
}
66+
}
67+
return path
68+
}
69+
70+
func TestExtractCLIWithExplicitTranscript(t *testing.T) {
71+
db := extractTestServer(t)
72+
db.InitSession("cli-sess", "proj")
73+
path := writeDummyTranscript(t)
74+
75+
resetExtractFlags()
76+
extractTranscript = path
77+
78+
out, err := captureStdout(t, func() error {
79+
return runExtract(extractCmd, []string{"cli-sess"})
80+
})
81+
if err != nil {
82+
t.Fatalf("runExtract: %v", err)
83+
}
84+
if !strings.Contains(out, "extraction queued") {
85+
t.Errorf("expected queued message, got: %s", out)
86+
}
87+
}
88+
89+
func TestExtractCLIBackfillEmpty(t *testing.T) {
90+
db := extractTestServer(t)
91+
db.InitSession("damaged", "proj")
92+
db.MarkExtracted("damaged")
93+
94+
resetExtractFlags()
95+
extractBackfillEmpty = true
96+
97+
out, err := captureStdout(t, func() error {
98+
return runExtract(extractCmd, nil)
99+
})
100+
if err != nil {
101+
t.Fatalf("runExtract --backfill-empty: %v", err)
102+
}
103+
if !strings.Contains(out, "unmarked 1 session") {
104+
t.Errorf("expected unmark count, got: %s", out)
105+
}
106+
107+
// Verify the damaged session was actually unmarked.
108+
s, _ := db.GetSession("damaged")
109+
if s.ExtractedAt != nil {
110+
t.Error("damaged session should have been unmarked via CLI path")
111+
}
112+
}
113+
114+
func TestExtractCLIBackfillExclusiveFlags(t *testing.T) {
115+
extractTestServer(t)
116+
117+
resetExtractFlags()
118+
extractBackfillEmpty = true
119+
extractForce = true
120+
121+
err := runExtract(extractCmd, nil)
122+
if err == nil {
123+
t.Fatal("expected error when --backfill-empty combined with --force")
124+
}
125+
if !strings.Contains(err.Error(), "cannot be combined") {
126+
t.Errorf("unexpected error: %v", err)
127+
}
128+
}
129+
130+
func TestExtractCLIRequiresSessionID(t *testing.T) {
131+
extractTestServer(t)
132+
resetExtractFlags()
133+
134+
err := runExtract(extractCmd, nil)
135+
if err == nil {
136+
t.Fatal("expected error when no session-id and no --backfill-empty")
137+
}
138+
}
139+
140+
func TestExtractCLITranscriptMissing(t *testing.T) {
141+
extractTestServer(t)
142+
resetExtractFlags()
143+
extractTranscript = "/no/such/path.jsonl"
144+
145+
err := runExtract(extractCmd, []string{"abc"})
146+
if err == nil {
147+
t.Fatal("expected error for missing transcript")
148+
}
149+
if !strings.Contains(err.Error(), "not readable") {
150+
t.Errorf("unexpected error: %v", err)
151+
}
152+
}
153+
154+
func TestValidateSessionIDForGlob(t *testing.T) {
155+
cases := []struct {
156+
name string
157+
sessionID string
158+
wantErr bool
159+
}{
160+
{"empty", "", true},
161+
{"uuid", "0f5d812c-69e1-40d0-9eef-5436f5721a80", false},
162+
{"plain slug", "my-session", false},
163+
{"forward slash", "foo/bar", true},
164+
{"backslash", `foo\bar`, true},
165+
{"parent traversal", "..", true},
166+
{"embedded traversal", "foo..bar", true},
167+
{"dot", ".", true},
168+
{"absolute escape", "/etc/passwd", true},
169+
}
170+
for _, tc := range cases {
171+
t.Run(tc.name, func(t *testing.T) {
172+
err := validateSessionIDForGlob(tc.sessionID)
173+
if (err != nil) != tc.wantErr {
174+
t.Errorf("validateSessionIDForGlob(%q) err=%v wantErr=%v", tc.sessionID, err, tc.wantErr)
175+
}
176+
})
177+
}
178+
}

internal/cli/root.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ func init() {
2929
rootCmd.AddCommand(timelineCmd)
3030
rootCmd.AddCommand(installServiceCmd)
3131
rootCmd.AddCommand(uninstallServiceCmd)
32+
rootCmd.AddCommand(extractCmd)
3233
}

0 commit comments

Comments
 (0)