Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.inference;

public interface RerankingInferenceService {

/**
* The default window size for small reranking models (512 input tokens).
*/
int CONSERVATIVE_DEFAULT_WINDOW_SIZE = 250;

/**
* The reranking model's max window or an approximation of
* measured in the number of words.
* @param modelId The model ID
* @return Window size in words
*/
int rerankerWindowSize(String modelId);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;

import java.io.IOException;
import java.util.Objects;

public class GetRerankerWindowSizeAction extends ActionType<GetRerankerWindowSizeAction.Response> {

public static final GetRerankerWindowSizeAction INSTANCE = new GetRerankerWindowSizeAction();
public static final String NAME = "cluster:internal/xpack/inference/rerankwindowsize/get";

public GetRerankerWindowSizeAction() {
super(NAME);
}

public static class Request extends ActionRequest {

private final String inferenceEntityId;

public Request(String inferenceEntityId) {
this.inferenceEntityId = inferenceEntityId;
}

public Request(StreamInput in) throws IOException {
super(in);
this.inferenceEntityId = in.readString();
}

public String getInferenceEntityId() {
return inferenceEntityId;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(inferenceEntityId);
}

@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(inferenceEntityId, request.inferenceEntityId);
}

@Override
public int hashCode() {
return Objects.hashCode(inferenceEntityId);
}
}

public static class Response extends ActionResponse {

private final int windowSize;

public Response(int windowSize) {
this.windowSize = windowSize;
}

public Response(StreamInput in) throws IOException {
this.windowSize = in.readVInt();
}

public int getWindowSize() {
return windowSize;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(windowSize);
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Response response = (Response) o;
return windowSize == response.windowSize;
}

@Override
public int hashCode() {
return Objects.hashCode(windowSize);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskSettings;
Expand Down Expand Up @@ -62,7 +63,7 @@ public TestRerankingModel(String inferenceEntityId, TestServiceSettings serviceS
}
}

public static class TestInferenceService extends AbstractTestInferenceService {
public static class TestInferenceService extends AbstractTestInferenceService implements RerankingInferenceService {
public static final String NAME = "test_reranking_service";

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.RERANK);
Expand Down Expand Up @@ -200,6 +201,11 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
return TestServiceSettings.fromMap(serviceSettingsMap);
}

@Override
public int rerankerWindowSize(String modelId) {
return 333;
}

public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ public static Iterable<Object[]> parameters() {
@Before
public void setup() throws Exception {
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
Utils.storeSparseModel(modelRegistry);
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
Utils.storeDenseModel(
"dense-endpoint",
modelRegistry,
randomIntBetween(1, 100),
// dot product means that we need normalized vectors; it's not worth doing that in this test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ public void setup() throws Exception {
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
);
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
Utils.storeSparseModel(modelRegistry);
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
Utils.storeDenseModel("dense-endpoint", modelRegistry, dimensions, similarity, elementType);
}

@Override
Expand Down Expand Up @@ -122,27 +122,20 @@ public Settings indexSettings() {
}

public void testBulkOperations() throws Exception {
prepareCreate(INDEX_NAME).setMapping(
String.format(
Locale.ROOT,
"""
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "%s"
},
"dense_field": {
"type": "semantic_text",
"inference_id": "%s"
}
}
prepareCreate(INDEX_NAME).setMapping(String.format(Locale.ROOT, """
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "%s"
},
"dense_field": {
"type": "semantic_text",
"inference_id": "%s"
}
""",
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
TestDenseInferenceServiceExtension.TestInferenceService.NAME
)
).get();
}
}
""", "sparse-endpoint", "dense-endpoint")).get();
assertRandomBulkOperations(INDEX_NAME, isIndexRequest -> {
Map<String, Object> map = new HashMap<>();
map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.integration;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.junit.Before;

import java.util.Collection;
import java.util.List;

import static org.hamcrest.Matchers.containsString;

@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435
public class RerankWindowSizeIT extends ESIntegTestCase {

@Before
public void setup() throws Exception {
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
Utils.storeRerankModel("rerank-endpoint", modelRegistry);
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class);
}

public void testRerankWindowSizeAction() {
var response = client().execute(GetRerankerWindowSizeAction.INSTANCE, new GetRerankerWindowSizeAction.Request("rerank-endpoint"))
.actionGet();
assertEquals(333, response.getWindowSize());
}

public void testActionNotAReranker() {
var e = expectThrows(
ElasticsearchStatusException.class,
() -> client().execute(GetRerankerWindowSizeAction.INSTANCE, new GetRerankerWindowSizeAction.Request("sparse-endpoint"))
.actionGet()
);
assertThat(e.getMessage(), containsString("Inference endpoint [sparse-endpoint] does not have the rerank task type"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.junit.Before;
Expand Down Expand Up @@ -68,8 +66,8 @@ public void setup() throws Exception {
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
);
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
Utils.storeSparseModel(modelRegistry);
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
Utils.storeDenseModel("dense-endpoint", modelRegistry, dimensions, similarity, elementType);

Set<IndexVersion> availableVersions = IndexVersionUtils.allReleasedVersions()
.stream()
Expand Down Expand Up @@ -113,11 +111,11 @@ public void testSemanticText() throws Exception {
.startObject("properties")
.startObject(SPARSE_SEMANTIC_FIELD)
.field("type", "semantic_text")
.field("inference_id", TestSparseInferenceServiceExtension.TestInferenceService.NAME)
.field("inference_id", "sparse-endpoint")
.endObject()
.startObject(DENSE_SEMANTIC_FIELD)
.field("type", "semantic_text")
.field("inference_id", TestDenseInferenceServiceExtension.TestInferenceService.NAME)
.field("inference_id", "dense-endpoint")
.endObject()
.endObject()
.endObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction;
import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
Expand All @@ -72,6 +73,7 @@
import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction;
import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction;
import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction;
import org.elasticsearch.xpack.inference.action.TransportGetRerankerWindowSizeAction;
import org.elasticsearch.xpack.inference.action.TransportInferenceAction;
import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy;
import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction;
Expand Down Expand Up @@ -234,7 +236,8 @@ public List<ActionHandler> getActions() {
new ActionHandler(XPackUsageFeatureAction.INFERENCE, TransportInferenceUsageAction.class),
new ActionHandler(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class),
new ActionHandler(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class),
new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class)
new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class),
new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class)
);
}

Expand Down
Loading