Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ public Builder elementType(ElementType elementType) {
return this;
}

public Builder indexOptions(IndexOptions indexOptions) {
this.indexOptions.setValue(indexOptions);
return this;
}

@Override
public DenseVectorFieldMapper build(MapperBuilderContext context) {
// Validate again here because the dimensions or element type could have been set programmatically,
Expand Down Expand Up @@ -1177,7 +1182,7 @@ public final String toString() {
public abstract VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersion, ElementType elementType);
}

abstract static class IndexOptions implements ToXContent {
public abstract static class IndexOptions implements ToXContent {
final VectorIndexType type;

IndexOptions(VectorIndexType type) {
Expand All @@ -1186,7 +1191,7 @@ abstract static class IndexOptions implements ToXContent {

abstract KnnVectorsFormat getVectorsFormat(ElementType elementType);

final void validateElementType(ElementType elementType) {
public final void validateElementType(ElementType elementType) {
if (type.supportsElementType(elementType) == false) {
throw new IllegalArgumentException(
"[element_type] cannot be [" + elementType.toString() + "] when using index type [" + type + "]"
Expand Down Expand Up @@ -2319,7 +2324,11 @@ public FieldMapper.Builder getMergeBuilder() {
return new Builder(leafName(), indexCreatedVersion).init(this);
}

private static IndexOptions parseIndexOptions(String fieldName, Object propNode) {
public IndexOptions indexOptions() {
return indexOptions;
}

public static IndexOptions parseIndexOptions(String fieldName, Object propNode) {
@SuppressWarnings("unchecked")
Map<String, ?> indexOptionsMap = (Map<String, ?>) propNode;
Object typeNode = indexOptionsMap.remove("type");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import java.util.function.BiConsumer;
import java.util.function.Function;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.parseIndexOptions;
import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
Expand Down Expand Up @@ -136,6 +137,10 @@ public static BiConsumer<String, MappingParserContext> validateParserContext(Str
};
}

private static Builder builder(FieldMapper in) {
return ((SemanticTextFieldMapper) in).builder;
}

public static class Builder extends FieldMapper.Builder {
private final boolean useLegacyFormat;

Expand Down Expand Up @@ -175,6 +180,20 @@ public static class Builder extends FieldMapper.Builder {
Objects::toString
).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings);

private final Parameter<DenseVectorFieldMapper.IndexOptions> indexOptions = new Parameter<>(
"index_options",
true,
() -> null,
(n, c, o) -> o == null ? null : parseIndexOptions(n, o),
m -> builder(m).indexOptions.get(),
(b, n, v) -> {
if (v != null) {
b.field(n, v);
}
},
Objects::toString
);

private final Parameter<Map<String, String>> meta = Parameter.metaParam();

private Function<MapperBuilderContext, ObjectMapper> inferenceFieldBuilder;
Expand All @@ -197,6 +216,7 @@ public Builder(String name, Function<Query, BitSetProducer> bitSetProducer, Inde
indexSettings.getIndexVersionCreated(),
useLegacyFormat,
modelSettings.get(),
indexOptions.get(),
bitSetProducer,
indexSettings
);
Expand Down Expand Up @@ -265,7 +285,8 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
useLegacyFormat,
meta.getValue()
),
builderParams(this, context)
builderParams(this, context),
this
);
}

Expand Down Expand Up @@ -306,9 +327,12 @@ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, Map
}
}

private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams) {
private final Builder builder;

private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams, Builder builder) {
super(simpleName, mappedFieldType, builderParams);
ensureMultiFields(builderParams.multiFields().iterator());
this.builder = builder;
}

private void ensureMultiFields(Iterator<FieldMapper> mappers) {
Expand Down Expand Up @@ -910,18 +934,20 @@ private static ObjectMapper createInferenceField(
IndexVersion indexVersionCreated,
boolean useLegacyFormat,
@Nullable MinimalServiceSettings modelSettings,
@Nullable DenseVectorFieldMapper.IndexOptions indexOptions,
Function<Query, BitSetProducer> bitSetProducer,
IndexSettings indexSettings
) {
return new ObjectMapper.Builder(INFERENCE_FIELD, Optional.of(ObjectMapper.Subobjects.ENABLED)).dynamic(ObjectMapper.Dynamic.FALSE)
.add(createChunksField(indexVersionCreated, useLegacyFormat, modelSettings, bitSetProducer, indexSettings))
.add(createChunksField(indexVersionCreated, useLegacyFormat, modelSettings, indexOptions, bitSetProducer, indexSettings))
.build(context);
}

private static NestedObjectMapper.Builder createChunksField(
IndexVersion indexVersionCreated,
boolean useLegacyFormat,
@Nullable MinimalServiceSettings modelSettings,
@Nullable DenseVectorFieldMapper.IndexOptions indexOptions,
Function<Query, BitSetProducer> bitSetProducer,
IndexSettings indexSettings
) {
Expand All @@ -933,7 +959,7 @@ private static NestedObjectMapper.Builder createChunksField(
);
chunksField.dynamic(ObjectMapper.Dynamic.FALSE);
if (modelSettings != null) {
chunksField.add(createEmbeddingsField(indexSettings.getIndexVersionCreated(), modelSettings, useLegacyFormat));
chunksField.add(createEmbeddingsField(indexSettings.getIndexVersionCreated(), modelSettings, indexOptions, useLegacyFormat));
}
if (useLegacyFormat) {
var chunkTextField = new KeywordFieldMapper.Builder(TEXT_FIELD, indexVersionCreated).indexed(false).docValues(false);
Expand All @@ -947,6 +973,7 @@ private static NestedObjectMapper.Builder createChunksField(
private static Mapper.Builder createEmbeddingsField(
IndexVersion indexVersionCreated,
MinimalServiceSettings modelSettings,
DenseVectorFieldMapper.IndexOptions indexOptions,
boolean useLegacyFormat
) {
return switch (modelSettings.taskType()) {
Expand All @@ -970,6 +997,11 @@ private static Mapper.Builder createEmbeddingsField(
}
denseVectorMapperBuilder.dimensions(modelSettings.dimensions());
denseVectorMapperBuilder.elementType(modelSettings.elementType());
if (indexOptions != null) {
indexOptions.validateDimension(modelSettings.dimensions());
indexOptions.validateElementType(modelSettings.elementType());
denseVectorMapperBuilder.indexOptions(indexOptions);
}

yield denseVectorMapperBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.QueryBitSetProducer;
import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.CheckedBiConsumer;
import org.elasticsearch.common.CheckedBiFunction;
Expand Down Expand Up @@ -73,6 +72,7 @@

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
Expand Down Expand Up @@ -879,17 +879,29 @@ private MapperService mapperServiceForFieldWithModelSettings(
String searchInferenceId,
MinimalServiceSettings modelSettings
) throws IOException {
String mappingParams = "type=semantic_text,inference_id=" + inferenceId;
return mapperServiceForFieldWithModelSettingsAndIndexOptions(fieldName, inferenceId, searchInferenceId, modelSettings, null);
}

private MapperService mapperServiceForFieldWithModelSettingsAndIndexOptions(
String fieldName,
String inferenceId,
String searchInferenceId,
MinimalServiceSettings modelSettings,
DenseVectorFieldMapper.IndexOptions indexOptions
) throws IOException {
XContentBuilder mappingBuilder = JsonXContent.contentBuilder().startObject();
mappingBuilder.startObject("properties").startObject(fieldName).field("type", "semantic_text").field("inference_id", inferenceId);
if (searchInferenceId != null) {
mappingParams += ",search_inference_id=" + searchInferenceId;
mappingBuilder.field("search_inference_id", searchInferenceId);
}
if (indexOptions != null) {
mappingBuilder.field("index_options", indexOptions);
}

mappingBuilder.endObject().endObject().endObject();

MapperService mapperService = createMapperService(mapping(b -> {}), useLegacyFormat);
mapperService.merge(
"_doc",
new CompressedXContent(Strings.toString(PutMappingRequest.simpleMapping(fieldName, mappingParams))),
MapperService.MergeReason.MAPPING_UPDATE
);
mapperService.merge("_doc", new CompressedXContent(Strings.toString(mappingBuilder)), MapperService.MergeReason.MAPPING_UPDATE);

SemanticTextField semanticTextField = new SemanticTextField(
useLegacyFormat,
Expand Down Expand Up @@ -951,6 +963,105 @@ public void testExistsQueryDenseVector() throws IOException {
assertThat(existsQuery, instanceOf(ESToParentBlockJoinQuery.class));
}

public void testDenseVectorIndexOptions() throws IOException {
final String fieldName = "field";
final String inferenceId = "test_service";

List<DenseVectorFieldMapper.IndexOptions> indexOptionsList = List.of(
DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "hnsw"))),
DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "int8_hnsw"))),
DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "int4_hnsw"))),
DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "bbq_hnsw"))),
DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "flat"))),
DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "int8_flat"))),
DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "int4_flat"))),
DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "bbq_flat"))),
DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "hnsw", "m", 32, "ef_construction", 200)))
);

for (DenseVectorFieldMapper.IndexOptions indexOptions : indexOptionsList) {
BiConsumer<MapperService, DenseVectorFieldMapper.IndexOptions> assertMapperService = (m, e) -> {
Mapper mapper = m.mappingLookup().getMapper(fieldName);
assertThat(mapper, instanceOf(SemanticTextFieldMapper.class));
SemanticTextFieldMapper semanticTextFieldMapper = (SemanticTextFieldMapper) mapper;

FieldMapper fieldMapper = semanticTextFieldMapper.fieldType().getEmbeddingsField();
assertThat(fieldMapper, instanceOf(DenseVectorFieldMapper.class));
DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) fieldMapper;

assertThat(denseVectorFieldMapper.indexOptions(), equalTo(e));
};

MapperService floatMapperService = mapperServiceForFieldWithModelSettingsAndIndexOptions(
fieldName,
inferenceId,
inferenceId,
new MinimalServiceSettings(
TaskType.TEXT_EMBEDDING,
1024,
SimilarityMeasure.COSINE,
DenseVectorFieldMapper.ElementType.FLOAT
),
indexOptions
);
assertMapperService.accept(floatMapperService, indexOptions);
}
}

public void testDenseVectorIndexOptionsVaild() {
final String fieldName = "field";
final String inferenceId = "test_service";

{
DenseVectorFieldMapper.IndexOptions indexOptions = DenseVectorFieldMapper.parseIndexOptions(
fieldName,
new HashMap<>(Map.of("type", "int8_hnsw"))
);
MinimalServiceSettings invalidSettings = new MinimalServiceSettings(
TaskType.TEXT_EMBEDDING,
1024,
SimilarityMeasure.L2_NORM,
DenseVectorFieldMapper.ElementType.BYTE
);

Exception e = expectThrows(
DocumentParsingException.class,
() -> mapperServiceForFieldWithModelSettingsAndIndexOptions(
fieldName,
inferenceId,
inferenceId,
invalidSettings,
indexOptions
)
);
assertThat(e.getCause().getMessage(), containsString("cannot be [byte] when using index type [int8_hnsw]"));
}

{
DenseVectorFieldMapper.IndexOptions indexOptions = DenseVectorFieldMapper.parseIndexOptions(
fieldName,
new HashMap<>(Map.of("type", "bbq_hnsw"))
);
MinimalServiceSettings invalidSettings = new MinimalServiceSettings(
TaskType.TEXT_EMBEDDING,
10,
SimilarityMeasure.COSINE,
DenseVectorFieldMapper.ElementType.BYTE
);
Exception e = expectThrows(
DocumentParsingException.class,
() -> mapperServiceForFieldWithModelSettingsAndIndexOptions(
fieldName,
inferenceId,
inferenceId,
invalidSettings,
indexOptions
)
);
assertThat(e.getCause().getMessage(), containsString("bbq_hnsw does not support dimensions fewer than 64"));
}
}

@Override
protected void assertExistsQuery(MappedFieldType fieldType, Query query, LuceneDocument fields) {
// Until a doc is indexed, the query is rewritten as match no docs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,61 @@ setup:
- match: { "test-index.mappings.properties.dense_field.model_settings.task_type": text_embedding }
- length: { "test-index.mappings.properties.dense_field": 3 }

---
"Indexes dense vector document with index_options":

- do:
indices.create:
index: test-index-options
body:
mappings:
properties:
dense_field:
type: semantic_text
inference_id: dense-inference-id
index_options:
type: "hnsw"
m: 24
ef_construction: 200

- do:
index:
index: test-index-options
id: doc_2
body:
dense_field:
text: "these are not the droids you're looking for. He's free to go around"
inference:
inference_id: "dense-inference-id"
model_settings:
task_type: "text_embedding"
dimensions: 4
similarity: "cosine"
element_type: "float"
index_options:
type: "int8_hnsw"
m: 24
ef_construction: 100
confidence_interval: 0.9
chunks:
- text: "these are not the droids you're looking for"
embeddings: [0.04673296958208084, -0.03237321600317955, -0.02543032355606556, 0.056035321205854416]
- text: "He's free to go around"
embeddings: [0.00641461368650198, -0.0016253676731139421, -0.05126338079571724, 0.053438711911439896]

# Checks mapping is updated when first doc arrives
- do:
indices.get_mapping:
index: test-index-options

- match: { "test-index-options.mappings.properties.dense_field.type": "semantic_text" }
- match: { "test-index-options.mappings.properties.dense_field.inference_id": "dense-inference-id" }
- match: { "test-index-options.mappings.properties.dense_field.model_settings.task_type": "text_embedding" }
- match: { "test-index-options.mappings.properties.dense_field.index_options.type": "hnsw" }
- match: { "test-index-options.mappings.properties.dense_field.index_options.m": 24 }
- match: { "test-index-options.mappings.properties.dense_field.index_options.ef_construction": 200 }
- length: { "test-index-options.mappings.properties.dense_field": 4 }

---
"Field caps with text embedding":
- requires:
Expand Down