Skip to content

Commit b9f0c04

Browse files
committed
test(retriever): improve test coverage to 88.6%
1 parent e484ea5 commit b9f0c04

File tree

1 file changed

+217
-1
lines changed

1 file changed

+217
-1
lines changed

components/retriever/pgvector/retriever_test.go

Lines changed: 217 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package pgvector
1818

1919
import (
2020
"context"
21+
"fmt"
2122
"testing"
2223

2324
"github.com/cloudwego/eino/components/embedding"
@@ -27,6 +28,11 @@ import (
2728
"github.com/stretchr/testify/assert"
2829
)
2930

31+
// Helper function for creating float64 pointers
32+
func float64Ptr(f float64) *float64 {
33+
return &f
34+
}
35+
3036
// mockEmbedder is a mock implementation of embedding.Embedder for testing.
3137
type mockEmbedder struct {
3238
vector []float64
@@ -311,14 +317,224 @@ func TestCalculateThresholdDistance(t *testing.T) {
311317
}
312318
}
313319

320+
func TestNewRetrieverPingFailed(t *testing.T) {
321+
ctx := context.Background()
322+
config := &RetrieverConfig{
323+
Conn: &mockConn{pingFail: true},
324+
Embedding: &mockEmbedder{},
325+
}
326+
327+
_, err := NewRetriever(ctx, config)
328+
assert.Error(t, err)
329+
assert.Contains(t, err.Error(), "failed to ping database")
330+
}
331+
332+
func TestRetrieveSuccess(t *testing.T) {
333+
ctx := context.Background()
334+
config := &RetrieverConfig{
335+
Conn: &mockConnWithRows{},
336+
Embedding: &mockEmbedder{},
337+
DistanceFunction: DistanceCosine,
338+
TopK: 5,
339+
}
340+
341+
r, err := NewRetriever(ctx, config)
342+
assert.NoError(t, err)
343+
344+
docs, err := r.Retrieve(ctx, "test query")
345+
assert.NoError(t, err)
346+
assert.Equal(t, 2, len(docs))
347+
assert.Equal(t, "doc1", docs[0].ID)
348+
assert.Equal(t, "doc2", docs[1].ID)
349+
assert.Equal(t, 1.0, docs[0].Score())
350+
}
351+
352+
func TestRetrieveQueryFailed(t *testing.T) {
353+
ctx := context.Background()
354+
config := &RetrieverConfig{
355+
Conn: &mockConn{queryFail: true},
356+
Embedding: &mockEmbedder{},
357+
DistanceFunction: DistanceCosine,
358+
TopK: 5,
359+
}
360+
361+
r, err := NewRetriever(ctx, config)
362+
assert.NoError(t, err)
363+
364+
_, err = r.Retrieve(ctx, "test query")
365+
assert.Error(t, err)
366+
assert.Contains(t, err.Error(), "query failed")
367+
}
368+
369+
func TestRetrieveWithScoreThreshold(t *testing.T) {
370+
ctx := context.Background()
371+
threshold := 0.8
372+
config := &RetrieverConfig{
373+
Conn: &mockConnWithRows{},
374+
Embedding: &mockEmbedder{},
375+
DistanceFunction: DistanceCosine,
376+
TopK: 5,
377+
ScoreThreshold: &threshold,
378+
}
379+
380+
r, err := NewRetriever(ctx, config)
381+
assert.NoError(t, err)
382+
383+
docs, err := r.Retrieve(ctx, "test query")
384+
assert.NoError(t, err)
385+
assert.Equal(t, 2, len(docs))
386+
}
387+
388+
func TestBuildSearchQuery(t *testing.T) {
389+
tests := []struct {
390+
name string
391+
whereClause string
392+
scoreThreshold *float64
393+
distanceFunc DistanceFunction
394+
expectedSubstr string
395+
}{
396+
{
397+
name: "no filters",
398+
whereClause: "",
399+
scoreThreshold: nil,
400+
distanceFunc: DistanceCosine,
401+
expectedSubstr: "ORDER BY distance ASC LIMIT $2",
402+
},
403+
{
404+
name: "with where clause",
405+
whereClause: "metadata->>'category' = 'tech'",
406+
scoreThreshold: nil,
407+
distanceFunc: DistanceCosine,
408+
expectedSubstr: "WHERE metadata->>'category' = 'tech'",
409+
},
410+
{
411+
name: "with score threshold",
412+
whereClause: "",
413+
scoreThreshold: float64Ptr(0.8),
414+
distanceFunc: DistanceCosine,
415+
expectedSubstr: "(embedding <=> $1) < 0.200000",
416+
},
417+
}
418+
419+
for _, tt := range tests {
420+
t.Run(tt.name, func(t *testing.T) {
421+
ctx := context.Background()
422+
config := &RetrieverConfig{
423+
Conn: &mockConn{},
424+
Embedding: &mockEmbedder{},
425+
DistanceFunction: tt.distanceFunc,
426+
}
427+
r, _ := NewRetriever(ctx, config)
428+
429+
query := r.buildSearchQuery(tt.whereClause, tt.scoreThreshold)
430+
assert.Contains(t, query, tt.expectedSubstr)
431+
})
432+
}
433+
}
434+
435+
// mockConnWithRows is a mock that returns actual rows
436+
type mockConnWithRows struct{}
437+
438+
func (m *mockConnWithRows) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
439+
return newMockRowsWithData(), nil
440+
}
441+
442+
func (m *mockConnWithRows) Ping(ctx context.Context) error {
443+
return nil
444+
}
445+
446+
type mockRowsWithData struct {
447+
currentRow int
448+
rows []struct {
449+
id string
450+
content string
451+
metadata map[string]any
452+
distance float64
453+
}
454+
}
455+
456+
func newMockRowsWithData() *mockRowsWithData {
457+
return &mockRowsWithData{
458+
currentRow: 0,
459+
rows: []struct {
460+
id string
461+
content string
462+
metadata map[string]any
463+
distance float64
464+
}{
465+
{
466+
id: "doc1",
467+
content: "test content 1",
468+
metadata: map[string]any{"category": "test"},
469+
distance: 0.0,
470+
},
471+
{
472+
id: "doc2",
473+
content: "test content 2",
474+
metadata: map[string]any{"category": "test"},
475+
distance: 0.1,
476+
},
477+
},
478+
}
479+
}
480+
481+
func (m *mockRowsWithData) Close() {}
482+
func (m *mockRowsWithData) Err() error { return nil }
483+
func (m *mockRowsWithData) CommandTag() pgconn.CommandTag {
484+
return pgconn.NewCommandTag("0 0 0")
485+
}
486+
func (m *mockRowsWithData) Next() bool {
487+
if m.currentRow < len(m.rows) {
488+
m.currentRow++
489+
return true
490+
}
491+
return false
492+
}
493+
494+
func (m *mockRowsWithData) Scan(dest ...any) error {
495+
if m.currentRow > 0 && m.currentRow <= len(m.rows) {
496+
row := m.rows[m.currentRow-1]
497+
if len(dest) >= 4 {
498+
if str, ok := dest[0].(*string); ok {
499+
*str = row.id
500+
}
501+
if str, ok := dest[1].(*string); ok {
502+
*str = row.content
503+
}
504+
if meta, ok := dest[2].(*map[string]any); ok {
505+
*meta = row.metadata
506+
}
507+
if f, ok := dest[3].(*float64); ok {
508+
*f = row.distance
509+
}
510+
}
511+
}
512+
return nil
513+
}
514+
515+
func (m *mockRowsWithData) Values() ([]any, error) { return nil, nil }
516+
func (m *mockRowsWithData) RawValues() [][]byte { return nil }
517+
func (m *mockRowsWithData) Conn() *pgx.Conn { return nil }
518+
func (m *mockRowsWithData) FieldDescriptions() []pgconn.FieldDescription { return nil }
519+
520+
314521
// mockConn is a mock implementation of PgxConn for testing.
315-
type mockConn struct{}
522+
type mockConn struct {
523+
pingFail bool
524+
queryFail bool
525+
}
316526

317527
func (m *mockConn) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
528+
if m.queryFail {
529+
return nil, fmt.Errorf("query failed")
530+
}
318531
return &mockRows{}, nil
319532
}
320533

321534
func (m *mockConn) Ping(ctx context.Context) error {
535+
if m.pingFail {
536+
return fmt.Errorf("ping failed")
537+
}
322538
return nil
323539
}
324540

0 commit comments

Comments
 (0)