Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/121041.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 121041
summary: Support configurable chunking in `semantic_text` fields
area: Relevance
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG_8_19 = def(8_841_0_17);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.Diff;
import org.elasticsearch.cluster.SimpleDiffable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.ToXContentFragment;
Expand All @@ -22,8 +23,11 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_CHUNKING_CONFIG_8_19;

/**
* Contains inference field data for fields.
* As inference is done in the coordinator node to avoid re-doing it at shard / replica level, the coordinator needs to check for the need
Expand All @@ -35,21 +39,30 @@ public final class InferenceFieldMetadata implements SimpleDiffable<InferenceFie
private static final String INFERENCE_ID_FIELD = "inference_id";
private static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id";
private static final String SOURCE_FIELDS_FIELD = "source_fields";
static final String CHUNKING_SETTINGS_FIELD = "chunking_settings";

private final String name;
private final String inferenceId;
private final String searchInferenceId;
private final String[] sourceFields;
private final Map<String, Object> chunkingSettings;

public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) {
this(name, inferenceId, inferenceId, sourceFields);
public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields, Map<String, Object> chunkingSettings) {
this(name, inferenceId, inferenceId, sourceFields, chunkingSettings);
}

public InferenceFieldMetadata(String name, String inferenceId, String searchInferenceId, String[] sourceFields) {
public InferenceFieldMetadata(
String name,
String inferenceId,
String searchInferenceId,
String[] sourceFields,
Map<String, Object> chunkingSettings
) {
this.name = Objects.requireNonNull(name);
this.inferenceId = Objects.requireNonNull(inferenceId);
this.searchInferenceId = Objects.requireNonNull(searchInferenceId);
this.sourceFields = Objects.requireNonNull(sourceFields);
this.chunkingSettings = chunkingSettings != null ? Map.copyOf(chunkingSettings) : null;
}

public InferenceFieldMetadata(StreamInput input) throws IOException {
Expand All @@ -61,6 +74,11 @@ public InferenceFieldMetadata(StreamInput input) throws IOException {
this.searchInferenceId = this.inferenceId;
}
this.sourceFields = input.readStringArray();
if (input.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG_8_19)) {
this.chunkingSettings = input.readGenericMap();
} else {
this.chunkingSettings = null;
}
}

@Override
Expand All @@ -71,6 +89,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(searchInferenceId);
}
out.writeStringArray(sourceFields);
if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG_8_19)) {
out.writeGenericMap(chunkingSettings);
}
}

@Override
Expand All @@ -81,16 +102,22 @@ public boolean equals(Object o) {
return Objects.equals(name, that.name)
&& Objects.equals(inferenceId, that.inferenceId)
&& Objects.equals(searchInferenceId, that.searchInferenceId)
&& Arrays.equals(sourceFields, that.sourceFields);
&& Arrays.equals(sourceFields, that.sourceFields)
&& Objects.equals(chunkingSettings, that.chunkingSettings);
}

@Override
public int hashCode() {
int result = Objects.hash(name, inferenceId, searchInferenceId);
int result = Objects.hash(name, inferenceId, searchInferenceId, chunkingSettings);
result = 31 * result + Arrays.hashCode(sourceFields);
return result;
}

@Override
public String toString() {
return Strings.toString(this);
}

public String getName() {
return name;
}
Expand All @@ -107,6 +134,10 @@ public String[] getSourceFields() {
return sourceFields;
}

public Map<String, Object> getChunkingSettings() {
return chunkingSettings;
}

public static Diff<InferenceFieldMetadata> readDiffFrom(StreamInput in) throws IOException {
return SimpleDiffable.readDiffFrom(InferenceFieldMetadata::new, in);
}
Expand All @@ -119,6 +150,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId);
}
builder.array(SOURCE_FIELDS_FIELD, sourceFields);
if (chunkingSettings != null) {
builder.startObject(CHUNKING_SETTINGS_FIELD);
builder.mapContents(chunkingSettings);
builder.endObject();
}
return builder.endObject();
}

Expand All @@ -131,6 +167,7 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws
String currentFieldName = null;
String inferenceId = null;
String searchInferenceId = null;
Map<String, Object> chunkingSettings = null;
List<String> inputFields = new ArrayList<>();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
Expand All @@ -151,6 +188,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws
}
}
}
} else if (CHUNKING_SETTINGS_FIELD.equals(currentFieldName)) {
chunkingSettings = parser.map();
} else {
parser.skipChildren();
}
Expand All @@ -159,7 +198,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws
name,
inferenceId,
searchInferenceId == null ? inferenceId : searchInferenceId,
inputFields.toArray(String[]::new)
inputFields.toArray(String[]::new),
chunkingSettings
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.inference;

import org.elasticsearch.core.Nullable;

import java.util.List;

public record ChunkInferenceInput(String input, @Nullable ChunkingSettings chunkingSettings) {

public ChunkInferenceInput(String input) {
this(input, null);
}

public static List<String> inputs(List<ChunkInferenceInput> chunkInferenceInputs) {
return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ToXContentObject;

import java.util.Map;

public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable {
ChunkingStrategy getChunkingStrategy();

Map<String, Object> asMap();
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,18 @@ void unifiedCompletionInfer(
/**
* Chunk long text.
*
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param input Inference input
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
* @param listener Chunked Inference result listener
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param input Inference input
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
* @param listener Chunked Inference result listener
*/
void chunkedInfer(
Model model,
@Nullable String query,
List<String> input,
List<ChunkInferenceInput> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,8 @@ private static InferenceFieldMetadata randomInferenceFieldMetadata(String name)
name,
randomIdentifier(),
randomIdentifier(),
randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new)
randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new),
InferenceFieldMetadataTests.generateRandomChunkingSettings()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Map;
import java.util.function.Predicate;

import static org.elasticsearch.cluster.metadata.InferenceFieldMetadata.CHUNKING_SETTINGS_FIELD;
import static org.hamcrest.Matchers.equalTo;

public class InferenceFieldMetadataTests extends AbstractXContentTestCase<InferenceFieldMetadata> {
Expand All @@ -37,11 +39,6 @@ protected InferenceFieldMetadata createTestInstance() {
return createTestItem();
}

@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return p -> p.equals(""); // do not add elements at the top-level as any element at this level is parsed as a new inference field
}

@Override
protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws IOException {
if (parser.nextToken() == XContentParser.Token.START_OBJECT) {
Expand All @@ -58,18 +55,57 @@ protected boolean supportsUnknownFields() {
return true;
}

@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
// do not add elements at the top-level as any element at this level is parsed as a new inference field,
// and do not add additional elements to chunking maps as they will fail parsing with extra data
return field -> field.equals("") || field.contains(CHUNKING_SETTINGS_FIELD);
}

private static InferenceFieldMetadata createTestItem() {
String name = randomAlphaOfLengthBetween(3, 10);
String inferenceId = randomIdentifier();
String searchInferenceId = randomIdentifier();
String[] inputFields = generateRandomStringArray(5, 10, false, false);
return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields);
Map<String, Object> chunkingSettings = generateRandomChunkingSettings();
return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields, chunkingSettings);
}

public static Map<String, Object> generateRandomChunkingSettings() {
if (randomBoolean()) {
return null; // Defaults to model chunking settings
}
return randomBoolean() ? generateRandomWordBoundaryChunkingSettings() : generateRandomSentenceBoundaryChunkingSettings();
}

private static Map<String, Object> generateRandomWordBoundaryChunkingSettings() {
return Map.of("strategy", "word_boundary", "max_chunk_size", randomIntBetween(20, 100), "overlap", randomIntBetween(1, 50));
}

private static Map<String, Object> generateRandomSentenceBoundaryChunkingSettings() {
return Map.of(
"strategy",
"sentence_boundary",
"max_chunk_size",
randomIntBetween(20, 100),
"sentence_overlap",
randomIntBetween(0, 1)
);
}

public void testNullCtorArgsThrowException() {
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0]));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0]));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0]));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null));
assertThrows(
NullPointerException.class,
() -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0], Map.of())
);
assertThrows(
NullPointerException.class,
() -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0], Map.of())
);
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0], Map.of()));
assertThrows(
NullPointerException.class,
() -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null, Map.of())
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.apache.lucene.search.Query;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadataTests;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.plugins.MapperPlugin;
import org.elasticsearch.plugins.Plugin;
Expand Down Expand Up @@ -102,7 +103,13 @@ private static class TestInferenceFieldMapper extends FieldMapper implements Inf

@Override
public InferenceFieldMetadata getMetadata(Set<String> sourcePaths) {
return new InferenceFieldMetadata(fullPath(), INFERENCE_ID, SEARCH_INFERENCE_ID, sourcePaths.toArray(new String[0]));
return new InferenceFieldMetadata(
fullPath(),
INFERENCE_ID,
SEARCH_INFERENCE_ID,
sourcePaths.toArray(new String[0]),
InferenceFieldMetadataTests.generateRandomChunkingSettings()
);
}

@Override
Expand Down
Loading