Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gradle/verification-metadata.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
<trust group="org.elasticsearch.distribution.zip" name="elasticsearch"/>
<trust group="org.elasticsearch.ml"/>
<trust group="org.elasticsearch.plugin"/>
<trust group="com.jayway.jsonpath"/>
<trust file=".*-javadoc[.]jar" regex="true"/>
<trust file=".*-sources[.]jar" regex="true"/>
</trusted-artifacts>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INCLUDE_INDEX_MODE_IN_GET_DATA_STREAM = def(9_023_0_00);
public static final TransportVersion MAX_OPERATION_SIZE_REJECTIONS_ADDED = def(9_024_0_00);
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR = def(9_025_0_00);
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL = def(9_026_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ public enum TaskType implements Writeable {
SPARSE_EMBEDDING,
RERANK,
COMPLETION,
CUSTOM {
@Override
public boolean isAnyOrSame(TaskType other) {
return true;
}
},
ANY {
@Override
public boolean isAnyOrSame(TaskType other) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,14 @@ public ActionRequestValidationException validate() {
}
}

if (taskType.equals(TaskType.CUSTOM)) {
if (query == null) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Field [query] cannot be null for task type [%s]", TaskType.CUSTOM));
return e;
}
}

return null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.ml.inference.results.CustomResults;

import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

public class CustomServiceResults implements InferenceServiceResults {
public static final String NAME = "custom_service_results";
public static final String CUSTOM_TYPE = TaskType.CUSTOM.name().toLowerCase(Locale.ROOT);

Map<String, Object> data;

public CustomServiceResults(Map<String, Object> data) {
this.data = data;
}

public CustomServiceResults(StreamInput in) throws IOException {
this.data = in.readGenericMap();
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return ChunkedToXContentHelper.object(CUSTOM_TYPE, this.asMap());
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeGenericMap(data);
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return transformToLegacyFormat();
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
return List.of(new CustomResults(data));
}

@Override
public Map<String, Object> asMap() {
return data;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(NAME);
sb.append(Integer.toHexString(hashCode()));
sb.append("\n");
sb.append(this.asMap().toString());
return sb.toString();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CustomServiceResults that = (CustomServiceResults) o;
return data.equals(that.data);
}

@Override
public int hashCode() {
return Objects.hash(data);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class CustomResults implements InferenceResults {
public static final String NAME = "custom_results";
public static final String CUSTOM_TYPE = TaskType.CUSTOM.toString();

Map<String, Object> data;

public CustomResults(Map<String, Object> data) {
this.data = data;
}

public CustomResults(StreamInput in) throws IOException {
this.data = in.readGenericMap();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(CUSTOM_TYPE, this.asMap());
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeGenericMap(data);
}

@Override
public String getResultsField() {
return CUSTOM_TYPE;
}

@Override
public Map<String, Object> asMap() {
return data;
}

@Override
public Map<String, Object> asMap(String outputField) {
Map<String, Object> map = new LinkedHashMap<>();
map.put(outputField, this.asMap());
return map;
}

@Override
public Map<String, Object> predictedValue() {
return this.asMap();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CustomResults that = (CustomResults) o;
return data.equals(that.data);
}

@Override
public int hashCode() {
return Objects.hash(data);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@ public void testValidation_Rerank() {
assertNull(e);
}

public void testValidation_Custom() {
InferenceAction.Request request = new InferenceAction.Request(
TaskType.CUSTOM,
"model",
"query",
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException e = request.validate();
assertNull(e);
}

public void testValidation_TextEmbedding_Null() {
InferenceAction.Request inputNullRequest = new InferenceAction.Request(
TaskType.TEXT_EMBEDDING,
Expand Down Expand Up @@ -166,6 +181,37 @@ public void testValidation_Rerank_Empty() {
assertThat(queryEmptyError.getMessage(), is("Validation Failed: 1: Field [query] cannot be empty for task type [rerank];"));
}

public void testValidation_Custom_Null() {
InferenceAction.Request queryNullRequest = new InferenceAction.Request(
TaskType.CUSTOM,
"model",
null,
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException queryNullError = queryNullRequest.validate();
assertNotNull(queryNullError);
assertThat(queryNullError.getMessage(), is("Validation Failed: 1: Field [query] cannot be null for task type [custom];"));
}

public void testValidation_Custom_Empty() {
InferenceAction.Request queryNullRequest = new InferenceAction.Request(
TaskType.CUSTOM,
"model",
"",
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException e = queryNullRequest.validate();
assertNull(e);
}

public void testParseRequest_DefaultsInputTypeToIngest() throws IOException {
String singleInputRequest = """
{
Expand Down
6 changes: 6 additions & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies {
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')

api "com.ibm.icu:icu4j:${versions.icu4j}"
api "org.apache.commons:commons-lang3:${versions.commons_lang3}"

runtimeOnly 'com.google.guava:guava:32.0.1-jre'
implementation 'com.google.code.gson:gson:2.10'
Expand All @@ -57,6 +58,11 @@ dependencies {
implementation 'io.grpc:grpc-context:1.49.2'
implementation 'io.opencensus:opencensus-api:0.31.1'
implementation 'io.opencensus:opencensus-contrib-http-util:0.31.1'
implementation 'org.apache.commons:commons-text:1.4'
implementation 'com.jayway.jsonpath:json-path:2.9.0'
implementation 'net.minidev:json-smart:2.5.2'
implementation 'net.minidev:accessors-smart:2.5.2'


/* AWS SDK v2 */
implementation ("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
Expand Down
4 changes: 4 additions & 0 deletions x-pack/plugin/inference/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
requires org.reactivestreams;
requires org.elasticsearch.logging;
requires org.elasticsearch.sslconfig;
requires org.apache.commons.text;
requires json.path;
requires unboundid.ldapsdk;
requires json.smart;

exports org.elasticsearch.xpack.inference.action;
exports org.elasticsearch.xpack.inference.registry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
Expand Down Expand Up @@ -149,6 +151,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
addCustomWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand Down Expand Up @@ -663,4 +666,11 @@ private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> nam
)
);
}

private static void addCustomWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
);
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.custom.CustomService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
Expand Down Expand Up @@ -358,6 +359,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new CustomService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Loading