Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
4f4c603
add inference custom model
Huaixinww Mar 7, 2025
e53b2e4
add unit test
Huaixinww Mar 7, 2025
c593b1a
spotless apply
Huaixinww Mar 7, 2025
0a851c1
add custom validation
Huaixinww Mar 7, 2025
e240f06
xpack core spotless apply
Huaixinww Mar 7, 2025
3ea3053
update commons-lang3's version
Huaixinww Mar 7, 2025
83daf69
Fix compilation after rebase
davidkyle Mar 25, 2025
3cb0cfb
Add missing licences and fix build checks
davidkyle Mar 25, 2025
a3c862c
Remove some unused code
davidkyle Mar 25, 2025
2b7e6fe
Update docs/changelog/125679.yaml
davidkyle Mar 26, 2025
6cc593d
Fix services it
davidkyle Mar 27, 2025
a4630e3
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 7, 2025
95f23f0
Contuing refactor of service settings
jonathan-buttner Apr 9, 2025
014f95b
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 9, 2025
189edba
Moving classes to reflect new structure
jonathan-buttner Apr 9, 2025
4fe3a1f
Refactoring service settings
jonathan-buttner Apr 9, 2025
4ef37f5
Refactoring the request
jonathan-buttner Apr 10, 2025
6bac18b
Adding files to handle generic error response
jonathan-buttner Apr 11, 2025
f644471
Making progress on tests
jonathan-buttner Apr 15, 2025
11cf7cc
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 15, 2025
f962d74
Adding more tests
jonathan-buttner Apr 16, 2025
eb63e8b
Adding more tests
jonathan-buttner Apr 18, 2025
adc3210
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 18, 2025
c9ff298
Adding tests for remaining parsers
jonathan-buttner Apr 21, 2025
de83271
More tests
jonathan-buttner Apr 22, 2025
34df922
Need to address quoted strings
jonathan-buttner Apr 24, 2025
b496732
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 28, 2025
097246b
Adding query parameter handling and tests
jonathan-buttner Apr 28, 2025
e7f6ac5
Adding encoding tests
jonathan-buttner Apr 29, 2025
a8c5241
Fixing embedding dimensions issue and test field names
jonathan-buttner Apr 29, 2025
3df0f70
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 29, 2025
ad55337
Fixing tests
jonathan-buttner Apr 29, 2025
4714fd3
[CI] Auto commit changes from spotless
Apr 29, 2025
d13191c
Removing licenses
jonathan-buttner Apr 29, 2025
12d46d7
Adding custom service tests
jonathan-buttner May 2, 2025
0134346
Merge branch 'custom-inference-service' of github.com:davidkyle/elast…
jonathan-buttner May 2, 2025
e6fefc4
[CI] Auto commit changes from spotless
May 2, 2025
eef7188
Correcting tranport version number
jonathan-buttner May 2, 2025
83837c8
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 2, 2025
dc02425
Merge branch 'custom-inference-service' of github.com:davidkyle/elast…
jonathan-buttner May 2, 2025
59f75b9
Cleaning up
jonathan-buttner May 2, 2025
8a82163
Fixing counts
jonathan-buttner May 5, 2025
5f13d28
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 5, 2025
c211d83
Fixing rerank and chat completions
jonathan-buttner May 7, 2025
133ef4e
Missing a few changes
jonathan-buttner May 7, 2025
5c28ee8
Passing request to the error response handler
jonathan-buttner May 7, 2025
be84291
Merge remote-tracking branch 'origin/ml-expose-request-in-error-parse…
jonathan-buttner May 7, 2025
8d1bd22
Adding inference id to error parser log message
jonathan-buttner May 8, 2025
a0984c7
Reverting exposing request to error parsing logic
jonathan-buttner May 8, 2025
4242a37
Refactoring the error parsing logic
jonathan-buttner May 8, 2025
6492cd7
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125679.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125679
summary: Custom Inference Service
area: Machine Learning
type: enhancement
issues: []
5 changes: 5 additions & 0 deletions gradle/verification-metadata.xml
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,11 @@
<sha256 value="681e53c4ffd59fa12068803b259e3a83d43f07a47c112e748a187dee179eb31f" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="com.jayway.jsonpath" name="json-path" version="2.9.0">
<artifact name="json-path-2.9.0.jar">
<sha256 value="11a9ee6f88bb31f1450108d1cf6441377dec84aca075eb6bb2343be157575bea" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="com.jcraft" name="jsch" version="0.1.54">
<artifact name="jsch-0.1.54.jar">
<sha256 value="92eb273a3316762478fdd4fe03a0ce1842c56f496c9c12fe1235db80450e1fdb" origin="Generated by Gradle"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INDEX_METADATA_INCLUDES_RECENT_WRITE_LOAD = def(9_036_0_00);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED = def(9_037_0_00);
public static final TransportVersion ESQL_REPORT_ORIGINAL_TYPES = def(9_038_00_0);
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL = def(9_039_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ public enum TaskType implements Writeable {
SPARSE_EMBEDDING,
RERANK,
COMPLETION,
CUSTOM {
@Override
public boolean isAnyOrSame(TaskType other) {
return true;
}
},
ANY {
@Override
public boolean isAnyOrSame(TaskType other) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,14 @@ public ActionRequestValidationException validate() {
return e;
}

if (taskType.equals(TaskType.CUSTOM)) {
if (query == null) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Field [query] cannot be null for task type [%s]", TaskType.CUSTOM));
return e;
}
}

return null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.ml.inference.results.CustomResults;

import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

public class CustomServiceResults implements InferenceServiceResults {
public static final String NAME = "custom_service_results";
public static final String CUSTOM_TYPE = TaskType.CUSTOM.name().toLowerCase(Locale.ROOT);

Map<String, Object> data;

public CustomServiceResults(Map<String, Object> data) {
this.data = data;
}

public CustomServiceResults(StreamInput in) throws IOException {
this.data = in.readGenericMap();
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return ChunkedToXContentHelper.object(CUSTOM_TYPE, this.asMap());
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeGenericMap(data);
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return transformToLegacyFormat();
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
return List.of(new CustomResults(data));
}

@Override
public Map<String, Object> asMap() {
return data;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(NAME);
sb.append(Integer.toHexString(hashCode()));
sb.append("\n");
sb.append(this.asMap().toString());
return sb.toString();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CustomServiceResults that = (CustomServiceResults) o;
return data.equals(that.data);
}

@Override
public int hashCode() {
return Objects.hash(data);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class CustomResults implements InferenceResults {
public static final String NAME = "custom_results";
public static final String CUSTOM_TYPE = TaskType.CUSTOM.toString();

Map<String, Object> data;

public CustomResults(Map<String, Object> data) {
this.data = data;
}

public CustomResults(StreamInput in) throws IOException {
this.data = in.readGenericMap();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(CUSTOM_TYPE, this.asMap());
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeGenericMap(data);
}

@Override
public String getResultsField() {
return CUSTOM_TYPE;
}

@Override
public Map<String, Object> asMap() {
return data;
}

@Override
public Map<String, Object> asMap(String outputField) {
Map<String, Object> map = new LinkedHashMap<>();
map.put(outputField, this.asMap());
return map;
}

@Override
public Map<String, Object> predictedValue() {
return this.asMap();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CustomResults that = (CustomResults) o;
return data.equals(that.data);
}

@Override
public int hashCode() {
return Objects.hash(data);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,23 @@ public void testValidation_Rerank() {
assertNull(e);
}

public void testValidation_Custom() {
InferenceAction.Request request = new InferenceAction.Request(
TaskType.CUSTOM,
"model",
"query",
null,
null,
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException e = request.validate();
assertNull(e);
}

public void testValidation_TextEmbedding_Null() {
InferenceAction.Request inputNullRequest = new InferenceAction.Request(
TaskType.TEXT_EMBEDDING,
Expand Down Expand Up @@ -309,6 +326,23 @@ public void testValidation_SparseEmbedding_WithTopN() {
);
}

public void testValidation_Custom_Empty() {
InferenceAction.Request queryNullRequest = new InferenceAction.Request(
TaskType.CUSTOM,
"model",
"",
null,
null,
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException e = queryNullRequest.validate();
assertNull(e);
}

public void testValidation_Completion_WithInputType() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.COMPLETION,
Expand Down
48 changes: 48 additions & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,15 @@ dependencies {
implementation 'com.google.http-client:google-http-client-jackson2:1.42.3'
implementation "com.fasterxml.jackson.core:jackson-core:${versions.jackson}"
implementation 'com.google.api:gax-httpjson:0.105.1'
implementation 'com.jayway.jsonpath:json-path:2.9.0'
implementation 'io.grpc:grpc-context:1.49.2'
implementation 'io.opencensus:opencensus-api:0.31.1'
implementation 'io.opencensus:opencensus-contrib-http-util:0.31.1'
implementation 'net.minidev:json-smart:2.5.2'
implementation 'net.minidev:accessors-smart:2.5.2'
implementation "org.apache.commons:commons-lang3:${versions.commons_lang3}"
implementation 'org.apache.commons:commons-text:1.4'


/* AWS SDK v2 */
implementation ("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
Expand Down Expand Up @@ -212,6 +218,14 @@ tasks.named("thirdPartyAudit").configure {
)

ignoreMissingClasses(
'com.fasterxml.jackson.databind.JsonNode',
'com.fasterxml.jackson.databind.ObjectMapper',
'com.fasterxml.jackson.databind.ObjectReader',
'com.fasterxml.jackson.databind.node.ArrayNode',
'com.fasterxml.jackson.databind.node.JsonNodeFactory',
'com.fasterxml.jackson.databind.node.ObjectNode',
'com.fasterxml.jackson.databind.node.TextNode',
'com.fasterxml.jackson.databind.type.TypeFactory',
'com.google.api.AnnotationsProto',
'com.google.api.ClientProto',
'com.google.api.FieldBehaviorProto',
Expand Down Expand Up @@ -394,9 +408,43 @@ tasks.named("thirdPartyAudit").configure {
'software.amazon.awssdk.crt.http.HttpHeader',
'software.amazon.awssdk.crt.http.HttpRequest',
'software.amazon.awssdk.crt.http.HttpRequestBodyStream',
'jakarta.json.JsonArray',
'jakarta.json.JsonArrayBuilder',
'jakarta.json.JsonBuilderFactory',
'jakarta.json.JsonNumber',
'jakarta.json.JsonObject',
'jakarta.json.JsonObjectBuilder',
'jakarta.json.JsonReader',
'jakarta.json.JsonString',
'jakarta.json.JsonStructure',
'jakarta.json.JsonValue',
'jakarta.json.JsonValue$ValueType',
'jakarta.json.bind.Jsonb',
'jakarta.json.bind.JsonbBuilder',
'jakarta.json.bind.JsonbConfig',
'jakarta.json.spi.JsonProvider',
'jakarta.json.stream.JsonLocation',
'jakarta.json.stream.JsonParser',
'jakarta.json.stream.JsonParser$Event',
'org.apache.tapestry5.json.JSONArray',
'org.apache.tapestry5.json.JSONCollection',
'org.apache.tapestry5.json.JSONObject',
'org.codehaus.jettison.json.JSONArray',
'org.codehaus.jettison.json.JSONException',
'org.codehaus.jettison.json.JSONObject',
'org.codehaus.jettison.json.JSONTokener',
'org.json.JSONArray',
'org.json.JSONObject',
'org.json.JSONTokener',
'org.objectweb.asm.ClassWriter',
'org.objectweb.asm.Label',
'org.objectweb.asm.MethodVisitor',
'org.objectweb.asm.Type',
)
}



tasks.named('yamlRestTest') {
usesDefaultDistribution("to be triaged")
}
Expand Down
Loading
Loading