Skip to content

Commit 7cf7589

Browse files
committed
Trying to get tests to work
1 parent d306eae commit 7cf7589

File tree

5 files changed

+82
-22
lines changed

5 files changed

+82
-22
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.xcontent.XContentParserConfiguration;
2929
import org.elasticsearch.xcontent.XContentType;
3030
import org.elasticsearch.xcontent.support.MapXContentParser;
31+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3132

3233
import java.io.IOException;
3334
import java.util.ArrayList;
@@ -50,6 +51,7 @@
5051
* {@link IndexVersions#INFERENCE_METADATA_FIELDS}, null otherwise.
5152
* @param inference The inference result.
5253
* @param contentType The {@link XContentType} used to store the embeddings chunks.
54+
* @param chunkingSettings The {@link ChunkingSettings} used to override model chunking defaults
5355
*/
5456
public record SemanticTextField(
5557
boolean useLegacyFormat,
@@ -71,16 +73,7 @@ public record SemanticTextField(
7173
static final String CHUNKED_START_OFFSET_FIELD = "start_offset";
7274
static final String CHUNKED_END_OFFSET_FIELD = "end_offset";
7375
static final String MODEL_SETTINGS_FIELD = "model_settings";
74-
static final String TASK_TYPE_FIELD = "task_type";
75-
static final String DIMENSIONS_FIELD = "dimensions";
76-
static final String SIMILARITY_FIELD = "similarity";
77-
static final String ELEMENT_TYPE_FIELD = "element_type";
78-
// Chunking settings
7976
static final String CHUNKING_SETTINGS_FIELD = "chunking_settings";
80-
static final String STRATEGY_FIELD = "strategy";
81-
static final String MAX_CHUNK_SIZE_FIELD = "max_chunk_size";
82-
static final String OVERLAP_FIELD = "overlap";
83-
static final String SENTENCE_OVERLAP_FIELD = "sentence_overlap";
8477

8578
public record InferenceResult(String inferenceId, MinimalServiceSettings modelSettings, Map<String, List<Chunk>> chunks) {}
8679

@@ -194,6 +187,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
194187
private static final ConstructingObjectParser<SemanticTextField, ParserContext> SEMANTIC_TEXT_FIELD_PARSER =
195188
new ConstructingObjectParser<>(SemanticTextFieldMapper.CONTENT_TYPE, true, (args, context) -> {
196189
List<String> originalValues = (List<String>) args[0];
190+
InferenceResult inferenceResult = (InferenceResult) args[1];
191+
Map<String, Object> chunkingSettingsMap = (Map<String, Object>) args[2];
192+
ChunkingSettings chunkingSettings = chunkingSettingsMap != null ? ChunkingSettingsBuilder.fromMap(chunkingSettingsMap) : null;
197193
if (context.useLegacyFormat() == false) {
198194
if (originalValues != null && originalValues.isEmpty() == false) {
199195
throw new IllegalArgumentException("Unknown field [" + TEXT_FIELD + "]");
@@ -204,9 +200,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
204200
context.useLegacyFormat(),
205201
context.fieldName(),
206202
originalValues,
207-
(InferenceResult) args[1],
203+
inferenceResult,
208204
context.xContentType(),
209-
(ChunkingSettings) args[2]
205+
chunkingSettings
210206
);
211207
});
212208

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
5656
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
5757
import org.elasticsearch.xpack.inference.InferencePlugin;
58+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
5859
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
5960
import org.elasticsearch.xpack.inference.model.TestModel;
6061
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
@@ -543,6 +544,9 @@ private static BulkItemRequest[] randomBulkItemRequest(
543544
for (var entry : fieldInferenceMap.values()) {
544545
String field = entry.getName();
545546
var model = modelMap.get(entry.getInferenceId());
547+
ChunkingSettings chunkingSettings = entry.getChunkingSettings() != null
548+
? ChunkingSettingsBuilder.fromMap(new HashMap<>(entry.getChunkingSettings()))
549+
: null;
546550
Object inputObject = randomSemanticTextInput();
547551
String inputText = inputObject.toString();
548552
docMap.put(field, inputObject);
@@ -562,13 +566,21 @@ private static BulkItemRequest[] randomBulkItemRequest(
562566
useLegacyFormat,
563567
field,
564568
model,
569+
chunkingSettings,
565570
List.of(inputText),
566571
results,
567572
requestContentType
568573
);
569574
} else {
570575
Map<String, List<String>> inputTextMap = Map.of(field, List.of(inputText));
571-
semanticTextField = randomSemanticText(useLegacyFormat, field, model, List.of(inputText), requestContentType);
576+
semanticTextField = randomSemanticText(
577+
useLegacyFormat,
578+
field,
579+
model,
580+
chunkingSettings,
581+
List.of(inputText),
582+
requestContentType
583+
);
572584
model.putResult(inputText, toChunkedResult(useLegacyFormat, inputTextMap, semanticTextField));
573585
}
574586

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
2929
import org.elasticsearch.index.translog.Translog;
3030
import org.elasticsearch.inference.ChunkedInference;
31+
import org.elasticsearch.inference.ChunkingSettings;
3132
import org.elasticsearch.inference.Model;
3233
import org.elasticsearch.inference.SimilarityMeasure;
3334
import org.elasticsearch.inference.TaskType;
@@ -43,6 +44,7 @@
4344
import java.util.List;
4445

4546
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
47+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings;
4648
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingByte;
4749
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse;
4850
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.semanticTextFieldFromChunkedInferenceResults;
@@ -51,12 +53,14 @@
5153
public class SemanticInferenceMetadataFieldsRecoveryTests extends EngineTestCase {
5254
private final Model model1;
5355
private final Model model2;
56+
private final ChunkingSettings chunkingSettings;
5457
private final boolean useSynthetic;
5558
private final boolean useIncludesExcludes;
5659

5760
public SemanticInferenceMetadataFieldsRecoveryTests(boolean useSynthetic, boolean useIncludesExcludes) {
5861
this.model1 = randomModel(TaskType.TEXT_EMBEDDING);
5962
this.model2 = randomModel(TaskType.SPARSE_EMBEDDING);
63+
this.chunkingSettings = generateRandomChunkingSettings();
6064
this.useSynthetic = useSynthetic;
6165
this.useIncludesExcludes = useIncludesExcludes;
6266
}
@@ -105,6 +109,11 @@ protected String defaultMapping() {
105109
builder.field("similarity", model1.getServiceSettings().similarity().name());
106110
builder.field("element_type", model1.getServiceSettings().elementType().name());
107111
builder.endObject();
112+
if (chunkingSettings != null) {
113+
builder.startObject("chunking_settings");
114+
chunkingSettings.toXContent(builder, null);
115+
builder.endObject();
116+
}
108117
builder.endObject();
109118

110119
builder.startObject("semantic_2");
@@ -113,6 +122,11 @@ protected String defaultMapping() {
113122
builder.startObject("model_settings");
114123
builder.field("task_type", model2.getTaskType().name());
115124
builder.endObject();
125+
if (chunkingSettings != null) {
126+
builder.startObject("chunking_settings");
127+
chunkingSettings.toXContent(builder, null);
128+
builder.endObject();
129+
}
116130
builder.endObject();
117131

118132
builder.endObject();
@@ -244,8 +258,8 @@ private BytesReference randomSource() throws IOException {
244258
false,
245259
builder,
246260
List.of(
247-
randomSemanticText(false, "semantic_2", model2, randomInputs(), XContentType.JSON),
248-
randomSemanticText(false, "semantic_1", model1, randomInputs(), XContentType.JSON)
261+
randomSemanticText(false, "semantic_2", model2, chunkingSettings, randomInputs(), XContentType.JSON),
262+
randomSemanticText(false, "semantic_1", model1, chunkingSettings, randomInputs(), XContentType.JSON)
249263
)
250264
);
251265
builder.endObject();
@@ -256,6 +270,7 @@ private static SemanticTextField randomSemanticText(
256270
boolean useLegacyFormat,
257271
String fieldName,
258272
Model model,
273+
ChunkingSettings chunkingSettings,
259274
List<String> inputs,
260275
XContentType contentType
261276
) throws IOException {
@@ -267,7 +282,15 @@ private static SemanticTextField randomSemanticText(
267282
case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false);
268283
default -> throw new AssertionError("invalid task type: " + model.getTaskType().name());
269284
};
270-
return semanticTextFieldFromChunkedInferenceResults(useLegacyFormat, fieldName, model, inputs, results, contentType);
285+
return semanticTextFieldFromChunkedInferenceResults(
286+
useLegacyFormat,
287+
fieldName,
288+
model,
289+
chunkingSettings,
290+
inputs,
291+
results,
292+
contentType
293+
);
271294
}
272295

273296
private static List<String> randomInputs() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import org.elasticsearch.index.mapper.vectors.XFeatureField;
5555
import org.elasticsearch.index.query.SearchExecutionContext;
5656
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
57+
import org.elasticsearch.inference.ChunkingSettings;
5758
import org.elasticsearch.inference.MinimalServiceSettings;
5859
import org.elasticsearch.inference.Model;
5960
import org.elasticsearch.inference.SimilarityMeasure;
@@ -90,6 +91,7 @@
9091
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
9192
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName;
9293
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.DEFAULT_ELSER_2_INFERENCE_ID;
94+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings;
9395
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
9496
import static org.hamcrest.Matchers.containsString;
9597
import static org.hamcrest.Matchers.equalTo;
@@ -642,6 +644,7 @@ public void testSuccessfulParse() throws IOException {
642644

643645
Model model1 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
644646
Model model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
647+
ChunkingSettings chunkingSettings = generateRandomChunkingSettings();
645648
XContentBuilder mapping = mapping(b -> {
646649
addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null);
647650
addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null);
@@ -670,8 +673,15 @@ public void testSuccessfulParse() throws IOException {
670673
useLegacyFormat,
671674
b,
672675
List.of(
673-
randomSemanticText(useLegacyFormat, fieldName1, model1, List.of("a b", "c"), XContentType.JSON),
674-
randomSemanticText(useLegacyFormat, fieldName2, model2, List.of("d e f"), XContentType.JSON)
676+
randomSemanticText(
677+
useLegacyFormat,
678+
fieldName1,
679+
model1,
680+
chunkingSettings,
681+
List.of("a b", "c"),
682+
XContentType.JSON
683+
),
684+
randomSemanticText(useLegacyFormat, fieldName2, model2, chunkingSettings, List.of("d e f"), XContentType.JSON)
675685
)
676686
)
677687
)
@@ -842,7 +852,15 @@ public void testDenseVectorElementType() throws IOException {
842852
public void testModelSettingsRequiredWithChunks() throws IOException {
843853
// Create inference results where model settings are set to null and chunks are provided
844854
Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
845-
SemanticTextField randomSemanticText = randomSemanticText(useLegacyFormat, "field", model, List.of("a"), XContentType.JSON);
855+
ChunkingSettings chunkingSettings = generateRandomChunkingSettings();
856+
SemanticTextField randomSemanticText = randomSemanticText(
857+
useLegacyFormat,
858+
"field",
859+
model,
860+
chunkingSettings,
861+
List.of("a"),
862+
XContentType.JSON
863+
);
846864
SemanticTextField inferenceResults = new SemanticTextField(
847865
randomSemanticText.useLegacyFormat(),
848866
randomSemanticText.fieldName(),
@@ -853,7 +871,7 @@ public void testModelSettingsRequiredWithChunks() throws IOException {
853871
randomSemanticText.inference().chunks()
854872
),
855873
randomSemanticText.contentType(),
856-
SemanticTextFieldTests.generateRandomChunkingSettings()
874+
chunkingSettings
857875
);
858876

859877
MapperService mapperService = createMapperService(
@@ -898,7 +916,7 @@ private MapperService mapperServiceForFieldWithModelSettings(
898916
List.of(),
899917
new SemanticTextField.InferenceResult(inferenceId, modelSettings, Map.of()),
900918
XContentType.JSON,
901-
SemanticTextFieldTests.generateRandomChunkingSettings()
919+
generateRandomChunkingSettings()
902920
);
903921
XContentBuilder builder = JsonXContent.contentBuilder().startObject();
904922
if (useLegacyFormat) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ protected SemanticTextField createTestInstance() {
118118
useLegacyFormat,
119119
NAME,
120120
TestModel.createRandomInstance(),
121+
generateRandomChunkingSettings(),
121122
rawValues,
122123
randomFrom(XContentType.values())
123124
);
@@ -218,6 +219,7 @@ public static SemanticTextField randomSemanticText(
218219
boolean useLegacyFormat,
219220
String fieldName,
220221
Model model,
222+
ChunkingSettings chunkingSettings,
221223
List<String> inputs,
222224
XContentType contentType
223225
) throws IOException {
@@ -229,13 +231,22 @@ public static SemanticTextField randomSemanticText(
229231
case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs);
230232
default -> throw new AssertionError("invalid task type: " + model.getTaskType().name());
231233
};
232-
return semanticTextFieldFromChunkedInferenceResults(useLegacyFormat, fieldName, model, inputs, results, contentType);
234+
return semanticTextFieldFromChunkedInferenceResults(
235+
useLegacyFormat,
236+
fieldName,
237+
model,
238+
chunkingSettings,
239+
inputs,
240+
results,
241+
contentType
242+
);
233243
}
234244

235245
public static SemanticTextField semanticTextFieldFromChunkedInferenceResults(
236246
boolean useLegacyFormat,
237247
String fieldName,
238248
Model model,
249+
ChunkingSettings chunkingSettings,
239250
List<String> inputs,
240251
ChunkedInference results,
241252
XContentType contentType
@@ -273,7 +284,7 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults(
273284
Map.of(fieldName, chunks)
274285
),
275286
contentType,
276-
generateRandomChunkingSettings()
287+
chunkingSettings
277288
);
278289
}
279290

0 commit comments

Comments
 (0)