Skip to content

Commit c787e7b

Browse files
committed
refactor(memory): introduce LLM interface for testability
Replaces concrete *provider.Router with LLM interface in Extract and Reconcile. Adds RouterLLM adapter. Enables mock-based testing.
1 parent c54c275 commit c787e7b

File tree

4 files changed

+36
-14
lines changed

4 files changed

+36
-14
lines changed

internal/cli/memory/extract.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,16 @@ var extractCmd = &cobra.Command{
7070
return fmt.Errorf("migrate memory: %w", err)
7171
}
7272

73-
// Create provider router.
74-
router, err := buildRouter(cfg)
73+
// Create LLM client.
74+
llm, err := buildLLM(cfg)
7575
if err != nil {
76-
return fmt.Errorf("build router: %w", err)
76+
return fmt.Errorf("build LLM: %w", err)
7777
}
7878

7979
ctx := context.Background()
8080

8181
// Extract candidate facts.
82-
candidates, err := memory.Extract(ctx, router, messages)
82+
candidates, err := memory.Extract(ctx, llm, messages)
8383
if err != nil {
8484
return fmt.Errorf("extract: %w", err)
8585
}
@@ -90,7 +90,7 @@ var extractCmd = &cobra.Command{
9090
}
9191

9292
// Reconcile with existing memories.
93-
result, err := memory.Reconcile(ctx, memDB, router, candidates)
93+
result, err := memory.Reconcile(ctx, memDB, llm, candidates)
9494
if err != nil {
9595
return fmt.Errorf("reconcile: %w", err)
9696
}
@@ -130,12 +130,13 @@ func loadRecentMessages(db *store.DB, lastN int) ([]string, error) {
130130
return messages, rows.Err()
131131
}
132132

133-
func buildRouter(cfg *config.Config) (*provider.Router, error) {
133+
func buildLLM(cfg *config.Config) (memory.LLM, error) {
134134
registry, err := provider.NewRegistry(cfg.Models)
135135
if err != nil {
136136
return nil, fmt.Errorf("create provider registry: %w", err)
137137
}
138-
return provider.NewRouter(registry, cfg.Models), nil
138+
router := provider.NewRouter(registry, cfg.Models)
139+
return &memory.RouterLLM{Router: router, Tier: provider.TierFast}, nil
139140
}
140141

141142
func init() {

memory/extract.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type ExtractResult struct {
3232
Skipped int
3333
}
3434

35-
func Extract(ctx context.Context, router *provider.Router, messages []string) ([]CandidateFact, error) {
35+
func Extract(ctx context.Context, llm LLM, messages []string) ([]CandidateFact, error) {
3636
filtered := preFilter(messages)
3737
if len(filtered) == 0 {
3838
return nil, nil
@@ -48,7 +48,7 @@ func Extract(ctx context.Context, router *provider.Router, messages []string) ([
4848
MaxTokens: 2048,
4949
}
5050

51-
resp, err := router.Chat(ctx, provider.TierFast, req)
51+
resp, err := llm.Chat(ctx, req)
5252
if err != nil {
5353
return nil, fmt.Errorf("extraction LLM call: %w", err)
5454
}

memory/reconcile.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ type reconcileDecision struct {
2929
Content string `json:"content"`
3030
}
3131

32-
func Reconcile(ctx context.Context, db *store.DB, router *provider.Router, candidates []CandidateFact) (*ExtractResult, error) {
32+
func Reconcile(ctx context.Context, db *store.DB, llm LLM, candidates []CandidateFact) (*ExtractResult, error) {
3333
result := &ExtractResult{}
3434

3535
for _, candidate := range candidates {
@@ -52,7 +52,7 @@ func Reconcile(ctx context.Context, db *store.DB, router *provider.Router, candi
5252
continue
5353
}
5454

55-
decision, err := reconcileWithLLM(ctx, router, existing, candidate)
55+
decision, err := reconcileWithLLM(ctx, llm, existing, candidate)
5656
if err != nil {
5757
result.Skipped++
5858
continue
@@ -87,7 +87,7 @@ func Reconcile(ctx context.Context, db *store.DB, router *provider.Router, candi
8787
return result, nil
8888
}
8989

90-
func reconcileWithLLM(ctx context.Context, router *provider.Router, existing []Memory, candidate CandidateFact) (*reconcileDecision, error) {
90+
func reconcileWithLLM(ctx context.Context, llm LLM, existing []Memory, candidate CandidateFact) (*reconcileDecision, error) {
9191
var existingLines []string
9292
for _, m := range existing {
9393
existingLines = append(existingLines, fmt.Sprintf("[ID=%d] %s", m.ID, m.Content))
@@ -103,7 +103,7 @@ func reconcileWithLLM(ctx context.Context, router *provider.Router, existing []M
103103
MaxTokens: 256,
104104
}
105105

106-
resp, err := router.Chat(ctx, provider.TierFast, req)
106+
resp, err := llm.Chat(ctx, req)
107107
if err != nil {
108108
return nil, fmt.Errorf("reconciliation LLM call: %w", err)
109109
}

memory/types.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,27 @@
11
package memory
22

3-
import "time"
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/priyanshujain/openbotkit/provider"
8+
)
9+
10+
// LLM is the interface used by Extract and Reconcile for LLM calls.
11+
// Satisfied by RouterLLM adapter or any mock in tests.
12+
type LLM interface {
13+
Chat(ctx context.Context, req provider.ChatRequest) (*provider.ChatResponse, error)
14+
}
15+
16+
// RouterLLM adapts a provider.Router to the LLM interface using a fixed tier.
17+
type RouterLLM struct {
18+
Router *provider.Router
19+
Tier provider.ModelTier
20+
}
21+
22+
func (r *RouterLLM) Chat(ctx context.Context, req provider.ChatRequest) (*provider.ChatResponse, error) {
23+
return r.Router.Chat(ctx, r.Tier, req)
24+
}
425

526
type Category string
627

0 commit comments

Comments
 (0)