1010import com .carrotsearch .randomizedtesting .annotations .Name ;
1111import com .carrotsearch .randomizedtesting .annotations .ParametersFactory ;
1212
13- import org .elasticsearch .action .bulk .BulkRequestBuilder ;
14- import org .elasticsearch .action .index .IndexRequest ;
15- import org .elasticsearch .action .support .WriteRequest ;
13+ import org .elasticsearch .action .index .IndexRequestBuilder ;
14+ import org .elasticsearch .cluster .metadata .IndexMetadata ;
1615import org .elasticsearch .common .settings .Settings ;
1716import org .elasticsearch .xpack .esql .action .AbstractEsqlIntegTestCase ;
1817import org .junit .Before ;
1918
19+ import java .util .ArrayList ;
2020import java .util .HashMap ;
2121import java .util .List ;
2222import java .util .Locale ;
@@ -39,6 +39,8 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
3939 );
4040
4141 private final String indexType ;
42+ private int numDims ;
43+ private int numDocs ;
4244
4345 @ ParametersFactory
4446 public static Iterable <Object []> parameters () throws Exception {
@@ -49,15 +51,7 @@ public DenseVectorFieldTypeIT(@Name("indexType") String indexType) {
4951 this .indexType = indexType ;
5052 }
5153
52- private static Map <Integer , List <Float >> DOC_VALUES = new HashMap <>();
53- static {
54- DOC_VALUES .put (1 , List .of (1.0f , 2.0f , 3.0f ));
55- DOC_VALUES .put (2 , List .of (4.0f , 5.0f , 6.0f ));
56- DOC_VALUES .put (3 , List .of (7.0f , 8.0f , 9.0f ));
57- DOC_VALUES .put (4 , List .of (10.0f , 11.0f , 12.0f ));
58- DOC_VALUES .put (5 , List .of (13.0f , 14.0f , 15.0f ));
59- DOC_VALUES .put (6 , List .of (16.0f , 17.0f , 18.0f ));
60- }
54+ private Map <Integer , List <Float >> indexedDocs = new HashMap <>();
6155
6256 public void testRetrieveFieldType () {
6357 var query = """
@@ -66,57 +60,96 @@ public void testRetrieveFieldType() {
6660
6761 try (var resp = run (query )) {
6862 assertColumnNames (resp .columns (), List .of ("id" , "vector" ));
69- assertColumnTypes (resp .columns (), List .of ("long " , "double " ));
63+ assertColumnTypes (resp .columns (), List .of ("integer " , "dense_vector " ));
7064 }
7165 }
7266
7367 @ SuppressWarnings ("unchecked" )
74- public void testRetrieveDenseVectorFieldData () {
68+ public void testRetrieveOrderedDenseVectorFieldData () {
7569 var query = """
7670 FROM test
7771 | SORT id ASC
7872 """ ;
7973
8074 try (var resp = run (query )) {
8175 List <List <Object >> valuesList = EsqlTestUtils .getValuesList (resp );
82- DOC_VALUES .forEach ((id , vector ) -> {
83- var values = valuesList .get (id - 1 );
84- assertEquals (id . intValue (), (( Long ) values .get (0 )). intValue ( ));
76+ indexedDocs .forEach ((id , vector ) -> {
77+ var values = valuesList .get (id );
78+ assertEquals (id , values .get (0 ));
8579 List <Double > vectors = (List <Double >) values .get (1 );
8680 assertEquals (vector .size (), vectors .size ());
8781 for (int i = 0 ; i < vector .size (); i ++) {
88- assertEquals ((float ) vector .get (i ), vectors .get (i ).floatValue (), 0F );
82+ assertEquals (vector .get (i ), vectors .get (i ).floatValue (), 0F );
83+ }
84+ });
85+ }
86+ }
87+
88+ @ SuppressWarnings ("unchecked" )
89+ public void testRetrieveUnOrderedDenseVectorFieldData () {
90+ var query = "FROM test" ;
91+
92+ try (var resp = run (query )) {
93+ List <List <Object >> valuesList = EsqlTestUtils .getValuesList (resp );
94+ assertEquals (valuesList .size (), indexedDocs .size ());
95+ valuesList .forEach (value -> {;
96+ assertEquals (2 , value .size ());
97+ Integer id = (Integer ) value .get (0 );
98+ List <Double > vector = (List <Double >) value .get (1 );
99+
100+ List <Float > expectedVector = indexedDocs .get (id );
101+ for (int i = 0 ; i < vector .size (); i ++) {
102+ assertEquals (expectedVector .get (i ), vector .get (i ).floatValue (), 0F );
89103 }
90104 });
91105 }
92106 }
93107
94108 @ Before
95109 public void setup () {
110+ numDims = randomIntBetween (64 , 256 );
111+ numDocs = randomIntBetween (10 , 100 );
112+ for (int i = 0 ; i < numDocs ; i ++) {
113+ List <Float > vector = new ArrayList <>(numDims );
114+ for (int j = 0 ; j < numDims ; j ++) {
115+ // vector.add(randomFloat());
116+ vector .add (1.0f );
117+ }
118+ indexedDocs .put (i , vector );
119+ }
120+
96121 var indexName = "test" ;
97122 var client = client ().admin ().indices ();
98123 var mapping = String .format (Locale .ROOT , """
99- "id": integer,
124+ {
125+ "properties": {
126+ "id": {
127+ "type": "integer"
128+ },
100129 "vector": {
101- "type": "dense_vector",
102- "index_options": {
103- "type": "%s"
104- }
130+ "type": "dense_vector",
131+ "similarity": "l2_norm",
132+ "index_options": {
133+ "type": "%s"
134+ }
105135 }
136+ }
137+ }
106138 """ , indexType );
139+ Settings settings = Settings .builder ()
140+ .put (IndexMetadata .SETTING_NUMBER_OF_REPLICAS , 0 )
141+ .put (IndexMetadata .SETTING_NUMBER_OF_SHARDS , randomIntBetween (1 , 5 ))
142+ .build ();
107143 var CreateRequest = client .prepareCreate (indexName )
108144 .setSettings (Settings .builder ().put ("index.number_of_shards" , 1 ))
109- .setMapping (mapping );
145+ .setMapping (mapping )
146+ .setSettings (settings );
110147 assertAcked (CreateRequest );
111148
112- BulkRequestBuilder bulkRequestBuilder = client ().prepareBulk ();
113- for (var entry : DOC_VALUES .entrySet ()) {
114- bulkRequestBuilder .add (
115- new IndexRequest (indexName ).id (entry .getKey ().toString ()).source ("id" , entry .getKey (), "vector" , entry .getValue ())
116- );
149+ IndexRequestBuilder [] docs = new IndexRequestBuilder [numDocs ];
150+ for (int i = 0 ; i < numDocs ; i ++) {
151+ docs [i ] = prepareIndex ("test" ).setId ("" + i ).setSource ("id" , i , "vector" , indexedDocs .get (i ));
117152 }
118-
119- bulkRequestBuilder .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE ).get ();
120- ensureYellow (indexName );
153+ indexRandom (true , docs );
121154 }
122155}
0 commit comments