1
+ package com .baeldung .jvector ;
2
+
3
+ import static com .baeldung .jvector .VectorSearch .persistIndex ;
4
+ import static org .junit .jupiter .api .Assertions .assertEquals ;
5
+ import static org .junit .jupiter .api .Assertions .assertInstanceOf ;
6
+ import static org .junit .jupiter .api .Assertions .assertNotNull ;
7
+
8
+ import java .io .BufferedReader ;
9
+ import java .io .FileReader ;
10
+ import java .io .IOException ;
11
+ import java .net .URL ;
12
+ import java .nio .file .Files ;
13
+ import java .nio .file .Path ;
14
+ import java .util .ArrayList ;
15
+ import java .util .HashMap ;
16
+ import java .util .Map ;
17
+
18
+ import org .junit .jupiter .api .BeforeAll ;
19
+ import org .junit .jupiter .api .Test ;
20
+
21
+ import io .github .jbellis .jvector .disk .ReaderSupplier ;
22
+ import io .github .jbellis .jvector .disk .ReaderSupplierFactory ;
23
+ import io .github .jbellis .jvector .graph .GraphIndex ;
24
+ import io .github .jbellis .jvector .graph .GraphSearcher ;
25
+ import io .github .jbellis .jvector .graph .ListRandomAccessVectorValues ;
26
+ import io .github .jbellis .jvector .graph .SearchResult ;
27
+ import io .github .jbellis .jvector .graph .disk .OnDiskGraphIndex ;
28
+ import io .github .jbellis .jvector .util .Bits ;
29
+ import io .github .jbellis .jvector .vector .VectorSimilarityFunction ;
30
+ import io .github .jbellis .jvector .vector .VectorizationProvider ;
31
+ import io .github .jbellis .jvector .vector .types .VectorFloat ;
32
+ import io .github .jbellis .jvector .vector .types .VectorTypeSupport ;
33
+
34
+ class VectorSearchTest {
35
+
36
+ private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider .getInstance ()
37
+ .getVectorTypeSupport ();
38
+ private static Path indexPath ;
39
+ private static Map <String , VectorFloat <?>> datasetVectors ;
40
+
41
+ @ BeforeAll
42
+ static void setup () throws IOException {
43
+ datasetVectors = new VectorSearchTest ().loadGlove6B50dDataSet (1000 );
44
+ indexPath = Files .createTempFile ("sample" , ".inline" );
45
+ persistIndex (new ArrayList <>(datasetVectors .values ()), indexPath );
46
+ }
47
+
48
+ @ Test
49
+ void givenLoadedDataset_whenPersistingIndex_thenPersistIndexInDisk () throws IOException {
50
+ try (ReaderSupplier readerSupplier = ReaderSupplierFactory .open (indexPath )) {
51
+ GraphIndex index = OnDiskGraphIndex .load (readerSupplier );
52
+ assertInstanceOf (OnDiskGraphIndex .class , index );
53
+ }
54
+ }
55
+
56
+ @ Test
57
+ void givenLoadedDataset_whenSearchingSimilarVectors_thenReturnValidSearchResult () throws IOException {
58
+ VectorFloat <?> queryVector = datasetVectors .get ("said" );
59
+ ArrayList <VectorFloat <?>> vectorsList = new ArrayList <>(datasetVectors .values ());
60
+
61
+ try (ReaderSupplier readerSupplier = ReaderSupplierFactory .open (indexPath )) {
62
+ GraphIndex index = OnDiskGraphIndex .load (readerSupplier );
63
+
64
+ SearchResult result = GraphSearcher .search (queryVector , 10 ,
65
+ new ListRandomAccessVectorValues (vectorsList , vectorsList .get (0 ).length ()),
66
+ VectorSimilarityFunction .EUCLIDEAN , index , Bits .ALL );
67
+
68
+ assertNotNull (result .getNodes ());
69
+ assertEquals (10 , result .getNodes ().length );
70
+ }
71
+ }
72
+
73
+ private Map <String , VectorFloat <?>> loadGlove6B50dDataSet (int limit ) throws IOException {
74
+ URL datasetResource = getClass ().getClassLoader ()
75
+ .getResource ("jvector/glove.6B.50d.txt" );
76
+ assertNotNull (datasetResource );
77
+
78
+ Map <String , VectorFloat <?>> vectors = new HashMap <>();
79
+
80
+ try (BufferedReader reader = new BufferedReader (new FileReader (datasetResource .getFile ()))) {
81
+ String line ;
82
+ int count = 0 ;
83
+ while ((line = reader .readLine ()) != null && count < limit ) {
84
+ String [] values = line .split (" " );
85
+ String word = values [0 ];
86
+ VectorFloat <?> vector = VECTOR_TYPE_SUPPORT .createFloatVector (50 );
87
+ for (int i = 0 ; i < 50 ; i ++) {
88
+ vector .set (i , Float .parseFloat (values [i + 1 ]));
89
+ }
90
+ vectors .put (word , vector );
91
+ count ++;
92
+ }
93
+ }
94
+ assertEquals (1000 , vectors .size ());
95
+ return vectors ;
96
+ }
97
+ }
0 commit comments