Skip to content

Commit be4a5b5

Browse files
committed
test(memory): add mock LLM tests and integration tests
Unit tests: Extract with mock LLM (verifies prompt construction, JSON parsing, filtering), Reconcile with mock (NOOP, UPDATE, DELETE, ADD decisions). Integration tests: real LLM extraction and end-to-end extract→reconcile flow, skipped when no API keys set.
1 parent c787e7b commit be4a5b5

File tree

3 files changed

+395
-5
lines changed

3 files changed

+395
-5
lines changed

memory/extract_test.go

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,30 @@
11
package memory
22

3-
import "testing"
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/priyanshujain/openbotkit/provider"
8+
)
9+
10+
type mockLLM struct {
11+
response string
12+
err error
13+
lastReq *provider.ChatRequest
14+
}
15+
16+
func (m *mockLLM) Chat(_ context.Context, req provider.ChatRequest) (*provider.ChatResponse, error) {
17+
m.lastReq = &req
18+
if m.err != nil {
19+
return nil, m.err
20+
}
21+
return &provider.ChatResponse{
22+
Content: []provider.ContentBlock{
23+
{Type: provider.ContentText, Text: m.response},
24+
},
25+
StopReason: provider.StopEndTurn,
26+
}, nil
27+
}
428

529
func TestPreFilter(t *testing.T) {
630
messages := []string{
@@ -102,3 +126,70 @@ func TestIsAck(t *testing.T) {
102126
}
103127
}
104128
}
129+
130+
func TestExtractWithMockLLM(t *testing.T) {
131+
llm := &mockLLM{
132+
response: `[{"content": "User prefers dark mode", "category": "preference"}, {"content": "User's name is Priyanshu", "category": "identity"}]`,
133+
}
134+
135+
messages := []string{
136+
"My name is Priyanshu and I prefer dark mode in all my editors",
137+
"I've been working on this project for a while now",
138+
}
139+
140+
facts, err := Extract(context.Background(), llm, messages)
141+
if err != nil {
142+
t.Fatalf("Extract: %v", err)
143+
}
144+
if len(facts) != 2 {
145+
t.Fatalf("expected 2 facts, got %d", len(facts))
146+
}
147+
if facts[0].Content != "User prefers dark mode" {
148+
t.Errorf("fact[0].Content = %q", facts[0].Content)
149+
}
150+
if facts[1].Category != "identity" {
151+
t.Errorf("fact[1].Category = %q", facts[1].Category)
152+
}
153+
154+
// Verify the prompt was constructed correctly.
155+
if llm.lastReq == nil {
156+
t.Fatal("expected LLM to be called")
157+
}
158+
if llm.lastReq.System != extractionPrompt {
159+
t.Error("expected system prompt to be the extraction prompt")
160+
}
161+
if len(llm.lastReq.Messages) != 1 {
162+
t.Fatalf("expected 1 message, got %d", len(llm.lastReq.Messages))
163+
}
164+
}
165+
166+
func TestExtractAllFiltered(t *testing.T) {
167+
llm := &mockLLM{response: "should not be called"}
168+
169+
messages := []string{"ok", "yes", "thanks"}
170+
171+
facts, err := Extract(context.Background(), llm, messages)
172+
if err != nil {
173+
t.Fatalf("Extract: %v", err)
174+
}
175+
if len(facts) != 0 {
176+
t.Fatalf("expected 0 facts, got %d", len(facts))
177+
}
178+
if llm.lastReq != nil {
179+
t.Error("LLM should not have been called when all messages filtered")
180+
}
181+
}
182+
183+
func TestExtractEmptyResponse(t *testing.T) {
184+
llm := &mockLLM{response: `[]`}
185+
186+
messages := []string{"I've been thinking about this problem for a while"}
187+
188+
facts, err := Extract(context.Background(), llm, messages)
189+
if err != nil {
190+
t.Fatalf("Extract: %v", err)
191+
}
192+
if len(facts) != 0 {
193+
t.Fatalf("expected 0 facts, got %d", len(facts))
194+
}
195+
}

memory/integration_test.go

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
package memory
2+
3+
import (
4+
"context"
5+
"os"
6+
"testing"
7+
8+
"github.com/priyanshujain/openbotkit/provider"
9+
"github.com/priyanshujain/openbotkit/provider/anthropic"
10+
"github.com/priyanshujain/openbotkit/provider/gemini"
11+
"github.com/priyanshujain/openbotkit/provider/openai"
12+
)
13+
14+
type providerTestCase struct {
15+
name string
16+
provider provider.Provider
17+
model string
18+
}
19+
20+
func availableProviders(t *testing.T) []providerTestCase {
21+
t.Helper()
22+
var providers []providerTestCase
23+
24+
if key := os.Getenv("ANTHROPIC_API_KEY"); key != "" {
25+
providers = append(providers, providerTestCase{
26+
name: "anthropic",
27+
provider: anthropic.New(key),
28+
model: "claude-sonnet-4-6",
29+
})
30+
}
31+
if key := os.Getenv("OPENAI_API_KEY"); key != "" {
32+
providers = append(providers, providerTestCase{
33+
name: "openai",
34+
provider: openai.New(key),
35+
model: "gpt-4o-mini",
36+
})
37+
}
38+
if key := os.Getenv("GEMINI_API_KEY"); key != "" {
39+
providers = append(providers, providerTestCase{
40+
name: "gemini",
41+
provider: gemini.New(key),
42+
model: "gemini-2.0-flash",
43+
})
44+
}
45+
46+
if len(providers) == 0 {
47+
t.Skip("no API keys set — skipping integration tests (set ANTHROPIC_API_KEY, OPENAI_API_KEY, or GEMINI_API_KEY)")
48+
}
49+
return providers
50+
}
51+
52+
type providerLLM struct {
53+
p provider.Provider
54+
model string
55+
}
56+
57+
func (pl *providerLLM) Chat(ctx context.Context, req provider.ChatRequest) (*provider.ChatResponse, error) {
58+
req.Model = pl.model
59+
return pl.p.Chat(ctx, req)
60+
}
61+
62+
func TestIntegration_Extract(t *testing.T) {
63+
for _, tc := range availableProviders(t) {
64+
t.Run(tc.name, func(t *testing.T) {
65+
llm := &providerLLM{p: tc.provider, model: tc.model}
66+
67+
messages := []string{
68+
"My name is Alice and I'm a software engineer at TechCorp",
69+
"I really prefer using Go for backend development over Python",
70+
"I'm currently building a personal assistant called BotKit",
71+
}
72+
73+
facts, err := Extract(context.Background(), llm, messages)
74+
if err != nil {
75+
t.Fatalf("Extract: %v", err)
76+
}
77+
78+
if len(facts) == 0 {
79+
t.Fatal("expected at least 1 fact extracted")
80+
}
81+
82+
// Verify facts have valid categories.
83+
validCategories := map[string]bool{
84+
"identity": true, "preference": true,
85+
"relationship": true, "project": true,
86+
}
87+
for _, f := range facts {
88+
if f.Content == "" {
89+
t.Error("fact has empty content")
90+
}
91+
if !validCategories[f.Category] {
92+
t.Errorf("fact has invalid category %q: %q", f.Category, f.Content)
93+
}
94+
}
95+
})
96+
}
97+
}
98+
99+
func TestIntegration_ExtractAndReconcile(t *testing.T) {
100+
for _, tc := range availableProviders(t) {
101+
t.Run(tc.name, func(t *testing.T) {
102+
db := testDB(t)
103+
if err := Migrate(db); err != nil {
104+
t.Fatalf("migrate: %v", err)
105+
}
106+
107+
llm := &providerLLM{p: tc.provider, model: tc.model}
108+
109+
messages := []string{
110+
"My name is Bob and I live in San Francisco",
111+
"I prefer dark mode in all my code editors",
112+
"I'm working on an open source project called DataFlow",
113+
}
114+
115+
// Extract facts.
116+
facts, err := Extract(context.Background(), llm, messages)
117+
if err != nil {
118+
t.Fatalf("Extract: %v", err)
119+
}
120+
if len(facts) == 0 {
121+
t.Fatal("expected at least 1 fact")
122+
}
123+
124+
// Reconcile into empty DB (should all ADD).
125+
result, err := Reconcile(context.Background(), db, llm, facts)
126+
if err != nil {
127+
t.Fatalf("Reconcile: %v", err)
128+
}
129+
130+
if result.Added == 0 {
131+
t.Error("expected at least 1 add")
132+
}
133+
134+
count, _ := Count(db)
135+
if count == 0 {
136+
t.Fatal("expected memories in DB after reconciliation")
137+
}
138+
139+
// Verify memories are retrievable.
140+
memories, err := List(db)
141+
if err != nil {
142+
t.Fatalf("List: %v", err)
143+
}
144+
for _, m := range memories {
145+
if m.Content == "" {
146+
t.Error("memory has empty content")
147+
}
148+
if m.Source != "history" {
149+
t.Errorf("memory source = %q, want 'history'", m.Source)
150+
}
151+
}
152+
153+
// Second extraction with same facts — should mostly NOOP/skip.
154+
result2, err := Reconcile(context.Background(), db, llm, facts)
155+
if err != nil {
156+
t.Fatalf("second Reconcile: %v", err)
157+
}
158+
159+
// Count should not have grown much (some ADD is OK if LLM decides differently).
160+
count2, _ := Count(db)
161+
t.Logf("first run: added=%d, count=%d; second run: added=%d, updated=%d, skipped=%d, count=%d",
162+
result.Added, count, result2.Added, result2.Updated, result2.Skipped, count2)
163+
})
164+
}
165+
}

0 commit comments

Comments
 (0)