Skip to content

Commit 5f3906d

Browse files
Adding dimensions to jina preconfigured endpoint
1 parent 630b65b commit 5f3906d

File tree

3 files changed

+121
-3
lines changed

3 files changed

+121
-3
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
222222
new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
223223
DEFAULT_MULTILINGUAL_EMBED_MODEL_ID,
224224
defaultDenseTextEmbeddingsSimilarity(),
225-
null,
225+
DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
226226
null
227227
),
228228
EmptyTaskSettings.INSTANCE,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,20 @@ public int rateLimitGroupingHash() {
4646
return Objects.hash(this.getServiceSettings().modelId());
4747
}
4848

49+
@Override
50+
public boolean equals(Object o) {
51+
if (o == null || getClass() != o.getClass()) return false;
52+
if (super.equals(o) == false) return false;
53+
ElasticInferenceServiceModel that = (ElasticInferenceServiceModel) o;
54+
return Objects.equals(rateLimitServiceSettings, that.rateLimitServiceSettings)
55+
&& Objects.equals(elasticInferenceServiceComponents, that.elasticInferenceServiceComponents);
56+
}
57+
58+
@Override
59+
public int hashCode() {
60+
return Objects.hash(super.hashCode(), rateLimitServiceSettings, elasticInferenceServiceComponents);
61+
}
62+
4963
public RateLimitSettings rateLimitSettings() {
5064
return rateLimitServiceSettings.rateLimitSettings();
5165
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.ElasticsearchStatusException;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.action.support.PlainActionFuture;
14+
import org.elasticsearch.action.support.TestPlainActionFuture;
1415
import org.elasticsearch.common.ValidationException;
1516
import org.elasticsearch.common.bytes.BytesArray;
1617
import org.elasticsearch.common.bytes.BytesReference;
@@ -45,6 +46,7 @@
4546
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
4647
import org.elasticsearch.xpack.inference.InferencePlugin;
4748
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
49+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
4850
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
4951
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
5052
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -57,11 +59,15 @@
5759
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
5860
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
5961
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
62+
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
6063
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests;
64+
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings;
6165
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
6266
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
67+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
6368
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
6469
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
70+
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
6571
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
6672
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
6773
import org.hamcrest.MatcherAssert;
@@ -91,9 +97,20 @@
9197
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
9298
import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender;
9399
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
100+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1;
101+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1;
102+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_ELSER_2_MODEL_ID;
103+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_ELSER_ENDPOINT_ID_V2;
104+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID;
105+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_MULTILINGUAL_EMBED_MODEL_ID;
106+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_RERANK_ENDPOINT_ID_V1;
107+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_RERANK_MODEL_ID_V1;
108+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS;
109+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity;
94110
import static org.hamcrest.CoreMatchers.instanceOf;
95111
import static org.hamcrest.CoreMatchers.is;
96112
import static org.hamcrest.Matchers.contains;
113+
import static org.hamcrest.Matchers.containsInAnyOrder;
97114
import static org.hamcrest.Matchers.empty;
98115
import static org.hamcrest.Matchers.equalTo;
99116
import static org.hamcrest.Matchers.hasSize;
@@ -1298,8 +1315,8 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
12981315
".jina-embeddings-v3",
12991316
MinimalServiceSettings.textEmbedding(
13001317
ElasticInferenceService.NAME,
1301-
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
1302-
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
1318+
DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
1319+
defaultDenseTextEmbeddingsSimilarity(),
13031320
DenseVectorFieldMapper.ElementType.FLOAT
13041321
),
13051322
service
@@ -1328,6 +1345,93 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
13281345
}
13291346
}
13301347

1348+
public void testDefaultConfigs_Returns_DefaultEndpointsModels() throws Exception {
1349+
String responseJson = """
1350+
{
1351+
"models": [
1352+
{
1353+
"model_name": "rainbow-sprinkles",
1354+
"task_types": ["chat"]
1355+
},
1356+
{
1357+
"model_name": "elser_model_2",
1358+
"task_types": ["embed/text/sparse"]
1359+
},
1360+
{
1361+
"model_name": "jina-embeddings-v3",
1362+
"task_types": ["embed/text/dense"]
1363+
},
1364+
{
1365+
"model_name": "elastic-rerank-v1",
1366+
"task_types": ["rerank/text/text-similarity"]
1367+
}
1368+
]
1369+
}
1370+
""";
1371+
1372+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
1373+
1374+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
1375+
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
1376+
ensureAuthorizationCallFinished(service);
1377+
var listener = new TestPlainActionFuture<List<Model>>();
1378+
1379+
service.defaultConfigs(listener);
1380+
var models = listener.actionGet(TIMEOUT);
1381+
1382+
var elasticInferenceServiceComponents = new ElasticInferenceServiceComponents(getUrl(webServer));
1383+
1384+
assertThat(
1385+
models,
1386+
containsInAnyOrder(
1387+
new ElasticInferenceServiceCompletionModel(
1388+
DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1,
1389+
TaskType.CHAT_COMPLETION,
1390+
ElasticInferenceService.NAME,
1391+
new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1),
1392+
EmptyTaskSettings.INSTANCE,
1393+
EmptySecretSettings.INSTANCE,
1394+
elasticInferenceServiceComponents
1395+
),
1396+
new ElasticInferenceServiceSparseEmbeddingsModel(
1397+
DEFAULT_ELSER_ENDPOINT_ID_V2,
1398+
TaskType.SPARSE_EMBEDDING,
1399+
ElasticInferenceService.NAME,
1400+
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null),
1401+
EmptyTaskSettings.INSTANCE,
1402+
EmptySecretSettings.INSTANCE,
1403+
elasticInferenceServiceComponents,
1404+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
1405+
),
1406+
new ElasticInferenceServiceDenseTextEmbeddingsModel(
1407+
DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID,
1408+
TaskType.TEXT_EMBEDDING,
1409+
ElasticInferenceService.NAME,
1410+
new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
1411+
DEFAULT_MULTILINGUAL_EMBED_MODEL_ID,
1412+
defaultDenseTextEmbeddingsSimilarity(),
1413+
DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
1414+
null
1415+
),
1416+
EmptyTaskSettings.INSTANCE,
1417+
EmptySecretSettings.INSTANCE,
1418+
elasticInferenceServiceComponents,
1419+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
1420+
),
1421+
new ElasticInferenceServiceRerankModel(
1422+
DEFAULT_RERANK_ENDPOINT_ID_V1,
1423+
TaskType.RERANK,
1424+
ElasticInferenceService.NAME,
1425+
new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1),
1426+
EmptyTaskSettings.INSTANCE,
1427+
EmptySecretSettings.INSTANCE,
1428+
elasticInferenceServiceComponents
1429+
)
1430+
)
1431+
);
1432+
}
1433+
}
1434+
13311435
public void testUnifiedCompletionError() {
13321436
var e = assertThrows(UnifiedChatCompletionException.class, () -> testUnifiedStream(404, """
13331437
{

0 commit comments

Comments
 (0)