@@ -18,6 +18,7 @@ package pgvector
1818
1919import (
2020 "context"
21+ "fmt"
2122 "testing"
2223
2324 "github.com/cloudwego/eino/components/embedding"
@@ -27,6 +28,11 @@ import (
2728 "github.com/stretchr/testify/assert"
2829)
2930
31+ // Helper function for creating float64 pointers
32+ func float64Ptr (f float64 ) * float64 {
33+ return & f
34+ }
35+
3036// mockEmbedder is a mock implementation of embedding.Embedder for testing.
3137type mockEmbedder struct {
3238 vector []float64
@@ -311,14 +317,224 @@ func TestCalculateThresholdDistance(t *testing.T) {
311317 }
312318}
313319
320+ func TestNewRetrieverPingFailed (t * testing.T ) {
321+ ctx := context .Background ()
322+ config := & RetrieverConfig {
323+ Conn : & mockConn {pingFail : true },
324+ Embedding : & mockEmbedder {},
325+ }
326+
327+ _ , err := NewRetriever (ctx , config )
328+ assert .Error (t , err )
329+ assert .Contains (t , err .Error (), "failed to ping database" )
330+ }
331+
332+ func TestRetrieveSuccess (t * testing.T ) {
333+ ctx := context .Background ()
334+ config := & RetrieverConfig {
335+ Conn : & mockConnWithRows {},
336+ Embedding : & mockEmbedder {},
337+ DistanceFunction : DistanceCosine ,
338+ TopK : 5 ,
339+ }
340+
341+ r , err := NewRetriever (ctx , config )
342+ assert .NoError (t , err )
343+
344+ docs , err := r .Retrieve (ctx , "test query" )
345+ assert .NoError (t , err )
346+ assert .Equal (t , 2 , len (docs ))
347+ assert .Equal (t , "doc1" , docs [0 ].ID )
348+ assert .Equal (t , "doc2" , docs [1 ].ID )
349+ assert .Equal (t , 1.0 , docs [0 ].Score ())
350+ }
351+
352+ func TestRetrieveQueryFailed (t * testing.T ) {
353+ ctx := context .Background ()
354+ config := & RetrieverConfig {
355+ Conn : & mockConn {queryFail : true },
356+ Embedding : & mockEmbedder {},
357+ DistanceFunction : DistanceCosine ,
358+ TopK : 5 ,
359+ }
360+
361+ r , err := NewRetriever (ctx , config )
362+ assert .NoError (t , err )
363+
364+ _ , err = r .Retrieve (ctx , "test query" )
365+ assert .Error (t , err )
366+ assert .Contains (t , err .Error (), "query failed" )
367+ }
368+
369+ func TestRetrieveWithScoreThreshold (t * testing.T ) {
370+ ctx := context .Background ()
371+ threshold := 0.8
372+ config := & RetrieverConfig {
373+ Conn : & mockConnWithRows {},
374+ Embedding : & mockEmbedder {},
375+ DistanceFunction : DistanceCosine ,
376+ TopK : 5 ,
377+ ScoreThreshold : & threshold ,
378+ }
379+
380+ r , err := NewRetriever (ctx , config )
381+ assert .NoError (t , err )
382+
383+ docs , err := r .Retrieve (ctx , "test query" )
384+ assert .NoError (t , err )
385+ assert .Equal (t , 2 , len (docs ))
386+ }
387+
388+ func TestBuildSearchQuery (t * testing.T ) {
389+ tests := []struct {
390+ name string
391+ whereClause string
392+ scoreThreshold * float64
393+ distanceFunc DistanceFunction
394+ expectedSubstr string
395+ }{
396+ {
397+ name : "no filters" ,
398+ whereClause : "" ,
399+ scoreThreshold : nil ,
400+ distanceFunc : DistanceCosine ,
401+ expectedSubstr : "ORDER BY distance ASC LIMIT $2" ,
402+ },
403+ {
404+ name : "with where clause" ,
405+ whereClause : "metadata->>'category' = 'tech'" ,
406+ scoreThreshold : nil ,
407+ distanceFunc : DistanceCosine ,
408+ expectedSubstr : "WHERE metadata->>'category' = 'tech'" ,
409+ },
410+ {
411+ name : "with score threshold" ,
412+ whereClause : "" ,
413+ scoreThreshold : float64Ptr (0.8 ),
414+ distanceFunc : DistanceCosine ,
415+ expectedSubstr : "(embedding <=> $1) < 0.200000" ,
416+ },
417+ }
418+
419+ for _ , tt := range tests {
420+ t .Run (tt .name , func (t * testing.T ) {
421+ ctx := context .Background ()
422+ config := & RetrieverConfig {
423+ Conn : & mockConn {},
424+ Embedding : & mockEmbedder {},
425+ DistanceFunction : tt .distanceFunc ,
426+ }
427+ r , _ := NewRetriever (ctx , config )
428+
429+ query := r .buildSearchQuery (tt .whereClause , tt .scoreThreshold )
430+ assert .Contains (t , query , tt .expectedSubstr )
431+ })
432+ }
433+ }
434+
435+ // mockConnWithRows is a mock that returns actual rows
436+ type mockConnWithRows struct {}
437+
438+ func (m * mockConnWithRows ) Query (ctx context.Context , sql string , args ... any ) (pgx.Rows , error ) {
439+ return newMockRowsWithData (), nil
440+ }
441+
442+ func (m * mockConnWithRows ) Ping (ctx context.Context ) error {
443+ return nil
444+ }
445+
446+ type mockRowsWithData struct {
447+ currentRow int
448+ rows []struct {
449+ id string
450+ content string
451+ metadata map [string ]any
452+ distance float64
453+ }
454+ }
455+
456+ func newMockRowsWithData () * mockRowsWithData {
457+ return & mockRowsWithData {
458+ currentRow : 0 ,
459+ rows : []struct {
460+ id string
461+ content string
462+ metadata map [string ]any
463+ distance float64
464+ }{
465+ {
466+ id : "doc1" ,
467+ content : "test content 1" ,
468+ metadata : map [string ]any {"category" : "test" },
469+ distance : 0.0 ,
470+ },
471+ {
472+ id : "doc2" ,
473+ content : "test content 2" ,
474+ metadata : map [string ]any {"category" : "test" },
475+ distance : 0.1 ,
476+ },
477+ },
478+ }
479+ }
480+
481+ func (m * mockRowsWithData ) Close () {}
482+ func (m * mockRowsWithData ) Err () error { return nil }
483+ func (m * mockRowsWithData ) CommandTag () pgconn.CommandTag {
484+ return pgconn .NewCommandTag ("0 0 0" )
485+ }
486+ func (m * mockRowsWithData ) Next () bool {
487+ if m .currentRow < len (m .rows ) {
488+ m .currentRow ++
489+ return true
490+ }
491+ return false
492+ }
493+
494+ func (m * mockRowsWithData ) Scan (dest ... any ) error {
495+ if m .currentRow > 0 && m .currentRow <= len (m .rows ) {
496+ row := m .rows [m .currentRow - 1 ]
497+ if len (dest ) >= 4 {
498+ if str , ok := dest [0 ].(* string ); ok {
499+ * str = row .id
500+ }
501+ if str , ok := dest [1 ].(* string ); ok {
502+ * str = row .content
503+ }
504+ if meta , ok := dest [2 ].(* map [string ]any ); ok {
505+ * meta = row .metadata
506+ }
507+ if f , ok := dest [3 ].(* float64 ); ok {
508+ * f = row .distance
509+ }
510+ }
511+ }
512+ return nil
513+ }
514+
515+ func (m * mockRowsWithData ) Values () ([]any , error ) { return nil , nil }
516+ func (m * mockRowsWithData ) RawValues () [][]byte { return nil }
517+ func (m * mockRowsWithData ) Conn () * pgx.Conn { return nil }
518+ func (m * mockRowsWithData ) FieldDescriptions () []pgconn.FieldDescription { return nil }
519+
520+
314521// mockConn is a mock implementation of PgxConn for testing.
315- type mockConn struct {}
522+ type mockConn struct {
523+ pingFail bool
524+ queryFail bool
525+ }
316526
317527func (m * mockConn ) Query (ctx context.Context , sql string , args ... any ) (pgx.Rows , error ) {
528+ if m .queryFail {
529+ return nil , fmt .Errorf ("query failed" )
530+ }
318531 return & mockRows {}, nil
319532}
320533
321534func (m * mockConn ) Ping (ctx context.Context ) error {
535+ if m .pingFail {
536+ return fmt .Errorf ("ping failed" )
537+ }
322538 return nil
323539}
324540
0 commit comments