1919import org .elasticsearch .index .mapper .InferenceMetadataFieldsMapper ;
2020import org .elasticsearch .index .query .NestedQueryBuilder ;
2121import org .elasticsearch .inference .Model ;
22+ import org .elasticsearch .inference .SimilarityMeasure ;
2223import org .elasticsearch .inference .TaskType ;
2324import org .elasticsearch .search .fetch .subphase .highlight .HighlightBuilder ;
2425import org .elasticsearch .test .rest .ObjectPath ;
4647
4748public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase {
4849 private static final String INDEX_BASE_NAME = "semantic_text_test_index" ;
49- private static final String SEMANTIC_TEXT_FIELD = "semantic_field" ;
50+ private static final String SPARSE_FIELD = "sparse_field" ;
51+ private static final String DENSE_FIELD = "dense_field" ;
5052
5153 private static Model SPARSE_MODEL ;
54+ private static Model DENSE_MODEL ;
5255
5356 private final boolean useLegacyFormat ;
5457
5558 @ BeforeClass
5659 public static void beforeClass () {
5760 SPARSE_MODEL = TestModel .createRandomInstance (TaskType .SPARSE_EMBEDDING );
61+ // Exclude dot product because we are not producing unit length vectors
62+ DENSE_MODEL = TestModel .createRandomInstance (TaskType .TEXT_EMBEDDING , List .of (SimilarityMeasure .DOT_PRODUCT ));
5863 }
5964
6065 public SemanticTextUpgradeIT (boolean useLegacyFormat ) {
@@ -79,13 +84,17 @@ private void createAndPopulateIndex() throws IOException {
7984 final String mapping = Strings .format ("""
8085 {
8186 "properties": {
87+ "%s": {
88+ "type": "semantic_text",
89+ "inference_id": "%s"
90+ },
8291 "%s": {
8392 "type": "semantic_text",
8493 "inference_id": "%s"
8594 }
8695 }
8796 }
88- """ , SEMANTIC_TEXT_FIELD , SPARSE_MODEL .getInferenceEntityId ());
97+ """ , SPARSE_FIELD , SPARSE_MODEL . getInferenceEntityId (), DENSE_FIELD , DENSE_MODEL .getInferenceEntityId ());
8998
9099 CreateIndexResponse response = createIndex (
91100 indexName ,
@@ -99,8 +108,8 @@ private void createAndPopulateIndex() throws IOException {
99108
100109 private void performIndexQueryHighlightOps () throws IOException {
101110 indexDoc ("doc_2" , List .of ("another test value" ));
102- ObjectPath queryObjectPath = semanticQuery ("test value" , 3 );
103- assertQueryResponse (queryObjectPath );
111+ ObjectPath queryObjectPath = semanticQuery (SPARSE_FIELD , "test value" , 3 );
112+ assertQueryResponse (queryObjectPath , SPARSE_FIELD );
104113 }
105114
106115 private String getIndexName () {
@@ -109,21 +118,30 @@ private String getIndexName() {
109118
110119 private void indexDoc (String id , List <String > semanticTextFieldValue ) throws IOException {
111120 final String indexName = getIndexName ();
112- final SemanticTextField semanticTextField = randomSemanticText (
121+ final SemanticTextField sparseFieldValue = randomSemanticText (
113122 useLegacyFormat ,
114- SEMANTIC_TEXT_FIELD ,
123+ SPARSE_FIELD ,
115124 SPARSE_MODEL ,
116125 null ,
117126 semanticTextFieldValue ,
118127 XContentType .JSON
119128 );
129+ final SemanticTextField denseFieldValue = randomSemanticText (
130+ useLegacyFormat ,
131+ DENSE_FIELD ,
132+ DENSE_MODEL ,
133+ null ,
134+ semanticTextFieldValue ,
135+ XContentType .JSON
136+ );
120137
121138 XContentBuilder builder = XContentFactory .jsonBuilder ();
122139 builder .startObject ();
123140 if (useLegacyFormat == false ) {
124- builder .field (semanticTextField .fieldName (), semanticTextFieldValue );
141+ builder .field (sparseFieldValue .fieldName (), semanticTextFieldValue );
142+ builder .field (denseFieldValue .fieldName (), semanticTextFieldValue );
125143 }
126- addSemanticTextInferenceResults (useLegacyFormat , builder , List .of (semanticTextField ));
144+ addSemanticTextInferenceResults (useLegacyFormat , builder , List .of (sparseFieldValue , denseFieldValue ));
127145 builder .endObject ();
128146
129147 RequestOptions requestOptions = RequestOptions .DEFAULT .toBuilder ().addParameter ("refresh" , "true" ).build ();
@@ -135,20 +153,20 @@ private void indexDoc(String id, List<String> semanticTextFieldValue) throws IOE
135153 assertOK (response );
136154 }
137155
138- private ObjectPath semanticQuery (String query , Integer numOfHighlightFragments ) throws IOException {
156+ private ObjectPath semanticQuery (String field , String query , Integer numOfHighlightFragments ) throws IOException {
139157 // We can't perform a real semantic query because that requires performing inference, so instead we perform an equivalent nested
140158 // query
141159 List <WeightedToken > weightedTokens = Arrays .stream (query .split ("\\ s" )).map (t -> new WeightedToken (t , 1.0f )).toList ();
142160 SparseVectorQueryBuilder sparseVectorQueryBuilder = new SparseVectorQueryBuilder (
143- SemanticTextField .getEmbeddingsFieldName (SEMANTIC_TEXT_FIELD ),
161+ SemanticTextField .getEmbeddingsFieldName (field ),
144162 weightedTokens ,
145163 null ,
146164 null ,
147165 null ,
148166 null
149167 );
150168 NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder (
151- SemanticTextField .getChunksFieldName (SEMANTIC_TEXT_FIELD ),
169+ SemanticTextField .getChunksFieldName (field ),
152170 sparseVectorQueryBuilder ,
153171 ScoreMode .Max
154172 );
@@ -157,7 +175,7 @@ private ObjectPath semanticQuery(String query, Integer numOfHighlightFragments)
157175 builder .startObject ();
158176 builder .field ("query" , nestedQueryBuilder );
159177 if (numOfHighlightFragments != null ) {
160- HighlightBuilder .Field highlightField = new HighlightBuilder .Field (SEMANTIC_TEXT_FIELD );
178+ HighlightBuilder .Field highlightField = new HighlightBuilder .Field (field );
161179 highlightField .numOfFragments (numOfHighlightFragments );
162180
163181 HighlightBuilder highlightBuilder = new HighlightBuilder ();
@@ -175,7 +193,7 @@ private ObjectPath semanticQuery(String query, Integer numOfHighlightFragments)
175193 }
176194
177195 @ SuppressWarnings ("unchecked" )
178- private static void assertQueryResponse (ObjectPath queryObjectPath ) throws IOException {
196+ private static void assertQueryResponse (ObjectPath queryObjectPath , String field ) throws IOException {
179197 final Map <String , List <String >> expectedHighlights = Map .of (
180198 "doc_1" ,
181199 List .of ("a test value" , "with multiple test values" ),
@@ -198,7 +216,7 @@ private static void assertQueryResponse(ObjectPath queryObjectPath) throws IOExc
198216
199217 List <String > expectedHighlight = expectedHighlights .get (id );
200218 assertThat (expectedHighlight , notNullValue ());
201- assertThat (((Map <String , Object >) hitMap .get ("highlight" )).get (SEMANTIC_TEXT_FIELD ), equalTo (expectedHighlight ));
219+ assertThat (((Map <String , Object >) hitMap .get ("highlight" )).get (field ), equalTo (expectedHighlight ));
202220 }
203221
204222 assertThat (docIds , equalTo (Set .of ("doc_1" , "doc_2" )));
0 commit comments