diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java index aaab14941d4bb..cbe7e7be51902 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java @@ -17,8 +17,11 @@ import org.elasticsearch.test.ESIntegTestCase; import java.util.List; +import java.util.Map; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; import static org.hamcrest.Matchers.equalTo; @@ -75,4 +78,89 @@ public void testSimpleNested() throws Exception { ); } + public void testNestedKNnnSearch() { + testNestedWithTwoSegments(false); + } + + public void testNestedKNnnSearchWithMultipleSegments() { + testNestedWithTwoSegments(true); + } + + private void testNestedWithTwoSegments(boolean flush) { + assertAcked(prepareCreate("test").setMapping(""" + { + "properties": { + "name": { + "type": "keyword" + }, + "nested": { + "type": "nested", + "properties": { + "paragraph_id": { + "type": "keyword" + }, + "vector": { + "type": "dense_vector", + "dims": 5, + "similarity": "l2_norm", + "index_options": { + "type": "hnsw" + } + } + } + } + } + } + """).setSettings(Settings.builder().put(indexSettings()).put("index.number_of_shards", 1))); + ensureGreen(); + + prepareIndex("test").setId("1") + .setSource( + "name", + "dog", + "nested", + new Object[] { + Map.of("paragraph_id", 0, "vector", new float[] { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f }), + Map.of("paragraph_id", 1, "vector", new float[] { 240.0f, 300f, -3f, 1f, -20f }) } + ) + .get(); + + prepareIndex("test").setId("2") + .setSource( + "name", + "cat", + "nested", + new Object[] { + Map.of("paragraph_id", 0, "vector", new float[] { -0.5f, 100.0f, -13f, 14.8f, -156.0f }), + Map.of("paragraph_id", 1, "vector", new float[] { 0f, 100.0f, 0f, 14.8f, -156.0f }), + Map.of("paragraph_id", 2, "vector", new float[] { 0f, 1.0f, 0f, 1.8f, -15.0f }) } + ) + .get(); + + if (flush) { + refresh("test"); + } + + prepareIndex("test").setId("3") + .setSource( + "name", + "rat", + "nested", + new Object[] { Map.of("paragraph_id", 0, "vector", new float[] { 0.5f, 111.3f, -13.0f, 14.8f, -156.0f }) } + ) + .get(); + + waitForRelocation(ClusterHealthStatus.GREEN); + refresh(); + + var knn = new KnnSearchBuilder("nested.vector", new float[] { -0.5f, 90.0f, -10f, 14.8f, -156.0f }, 2, 3, null, null); + var request = prepareSearch("test").addFetchField("name").setKnnSearch(List.of(knn)); + assertNoFailuresAndResponse(request, response -> { + assertHitCount(response, 2); + assertEquals("2", response.getHits().getHits()[0].getId()); + assertEquals("cat", response.getHits().getHits()[0].field("name").getValue()); + assertEquals("3", response.getHits().getHits()[1].getId()); + assertEquals("rat", response.getHits().getHits()[1].field("name").getValue()); + }); + } }