Skip to content

Commit f60d951

Browse files
committed
add transportversion, javadocs
1 parent 761d2ef commit f60d951

File tree

13 files changed

+112
-22
lines changed

13 files changed

+112
-22
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ static TransportVersion def(int id) {
162162
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
163163
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
164164
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20);
165+
public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21);
165166
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
166167
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
167168
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@@ -232,6 +233,7 @@ static TransportVersion def(int id) {
232233
public static final TransportVersion BATCHED_QUERY_EXECUTION_DELAYABLE_WRITABLE = def(9_057_0_00);
233234
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL = def(9_058_0_00);
234235
public static final TransportVersion COMPRESS_DELAYABLE_WRITEABLE = def(9_059_0_00);
236+
public static final TransportVersion ML_INFERENCE_SAGEMAKER = def(9_060_0_00);
235237

236238
/*
237239
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.TransportVersion;
12+
import org.elasticsearch.TransportVersions;
1213
import org.elasticsearch.action.ActionListener;
1314
import org.elasticsearch.action.support.SubscribableListener;
1415
import org.elasticsearch.core.Nullable;
@@ -267,7 +268,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> listen
267268

268269
@Override
269270
public TransportVersion getMinimalSupportedVersion() {
270-
return TransportVersion.current();
271+
return TransportVersions.ML_INFERENCE_SAGEMAKER;
271272
}
272273

273274
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
import java.util.Optional;
2323

2424
/**
25+
* This model represents all models in SageMaker. SageMaker maintains a base set of settings and configurations, and this model manages
26+
* those. Any settings that are required for a specific model are stored in the {@link SageMakerStoredServiceSchema} and
27+
* {@link SageMakerStoredTaskSchema}.
2528
* Design:
2629
* - Region is stored in ServiceSettings and is used to create the SageMaker client.
2730
* - RateLimiting is based on AWS Service Quota, metered by account and region. The SDK client handles rate limiting internally. In order to

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerServiceSettings.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.sagemaker.model;
99

1010
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
1112
import org.elasticsearch.common.ValidationException;
1213
import org.elasticsearch.common.io.stream.StreamInput;
1314
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -28,6 +29,10 @@
2829
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
2930
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
3031

32+
/**
33+
* Maintains the settings for SageMaker that cannot be changed without impacting semantic search and AI assistants.
34+
* Model-specific settings are stored in {@link SageMakerStoredServiceSchema}.
35+
*/
3136
record SageMakerServiceSettings(
3237
String endpointName,
3338
String region,
@@ -75,7 +80,7 @@ public String getWriteableName() {
7580

7681
@Override
7782
public TransportVersion getMinimalSupportedVersion() {
78-
return TransportVersion.current();
83+
return TransportVersions.ML_INFERENCE_SAGEMAKER;
7984
}
8085

8186
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.sagemaker.model;
99

1010
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
1112
import org.elasticsearch.common.ValidationException;
1213
import org.elasticsearch.common.io.stream.StreamInput;
1314
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -22,6 +23,9 @@
2223

2324
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
2425

26+
/**
27+
* Maintains mutable settings for SageMaker. Model-specific settings are stored in {@link SageMakerStoredTaskSchema}.
28+
*/
2529
record SageMakerTaskSettings(
2630
@Nullable String customAttributes,
2731
@Nullable String enableExplanations,
@@ -92,7 +96,7 @@ public String getWriteableName() {
9296

9397
@Override
9498
public TransportVersion getMinimalSupportedVersion() {
95-
return TransportVersion.current();
99+
return TransportVersions.ML_INFERENCE_SAGEMAKER;
96100
}
97101

98102
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
import java.util.Map;
2929
import java.util.stream.Stream;
3030

31+
/**
32+
* All the logic that is required to call any SageMaker model is handled within this Schema class.
33+
* Any model-specific logic is handled within the associated {@link SageMakerSchemaPayload}.
34+
* This schema is specific for SageMaker's non-streaming API. For streaming, see {@link SageMakerStreamSchema}.
35+
*/
3136
public class SageMakerSchema {
3237
private static final String INTERNAL_DEPENDENCY_ERROR = "Received an internal dependency error from SageMaker for [%s]";
3338
private static final String INTERNAL_FAILURE = "Received an internal failure from SageMaker for [%s]";
@@ -128,10 +133,6 @@ protected Tuple<String, RestStatus> errorMessageAndStatus(SageMakerModel model,
128133
return Tuple.tuple(errorMessage, restStatus);
129134
}
130135

131-
public String api() {
132-
return schemaPayload.api();
133-
}
134-
135136
public SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
136137
return schemaPayload.apiServiceSettings(serviceSettings, validationException);
137138
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayload.java

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,39 @@
2323
import java.util.stream.Stream;
2424

2525
public interface SageMakerSchemaPayload {
26+
27+
/**
28+
* The model API keyword that users will supply in the service settings when creating the request.
29+
* Automatically registered in {@link SageMakerSchemas}.
30+
*/
2631
String api();
2732

33+
/**
34+
* The supported TaskTypes for this model API.
35+
* Automatically registered in {@link SageMakerSchemas}.
36+
*/
2837
EnumSet<TaskType> supportedTasks();
2938

39+
/**
40+
* Implement this if the model requires extra ServiceSettings that can be saved to the model index.
41+
* This can be accessed via {@link SageMakerModel#apiServiceSettings()}.
42+
*/
3043
default SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
3144
return SageMakerStoredServiceSchema.NO_OP;
3245
}
3346

47+
/**
48+
* Implement this if the model requires extra TaskSettings that can be saved to the model index.
49+
* This can be accessed via {@link SageMakerModel#apiTaskSettings()}.
50+
*/
3451
default SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
3552
return SageMakerStoredTaskSchema.NO_OP;
3653
}
3754

3855
/**
39-
* Automatically register the required registry entries with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}.
56+
* This must be thrown if {@link SageMakerModel#apiServiceSettings()} or {@link SageMakerModel#apiTaskSettings()} return the wrong
57+
* object types.
4058
*/
41-
default Stream<NamedWriteableRegistry.Entry> namedWriteables() {
42-
return Stream.of();
43-
}
44-
4559
default Exception createUnsupportedSchemaException(SageMakerModel model) {
4660
return new IllegalArgumentException(
4761
Strings.format(
@@ -54,12 +68,31 @@ default Exception createUnsupportedSchemaException(SageMakerModel model) {
5468
);
5569
}
5670

71+
/**
72+
* Automatically register the required registry entries with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}.
73+
*/
74+
default Stream<NamedWriteableRegistry.Entry> namedWriteables() {
75+
return Stream.of();
76+
}
77+
78+
/**
79+
* The MIME type of the response from SageMaker.
80+
*/
5781
String accept(SageMakerModel model);
5882

83+
/**
84+
* The MIME type of the request to SageMaker.
85+
*/
5986
String contentType(SageMakerModel model);
6087

88+
/**
89+
* Translate to the body of the request in the MIME type specified by {@link #contentType(SageMakerModel)}.
90+
*/
6191
SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception;
6292

93+
/**
94+
* Translate from the body of the response in the MIME type specified by {@link #accept(SageMakerModel)}.
95+
*/
6396
InferenceServiceResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception;
6497

6598
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525
import static org.elasticsearch.core.Strings.format;
2626

27+
/**
28+
* The mapping and registry for all supported model API.
29+
*/
2730
public class SageMakerSchemas {
2831
private static final Map<TaskAndApi, SageMakerSchema> schemas;
2932
private static final Map<TaskAndApi, SageMakerStreamSchema> streamSchemas;
@@ -33,6 +36,9 @@ public class SageMakerSchemas {
3336
private static final EnumSet<TaskType> supportedTaskTypes;
3437

3538
static {
39+
/*
40+
* Add new model API to the register call.
41+
*/
3642
schemas = register(new OpenAiTextEmbeddingPayload());
3743

3844
streamSchemas = schemas.entrySet()

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
99

1010
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
1112
import org.elasticsearch.common.io.stream.StreamOutput;
1213
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
1314
import org.elasticsearch.xcontent.ToXContentFragment;
1415
import org.elasticsearch.xcontent.XContentBuilder;
1516

17+
/**
18+
* Contains any model-specific settings that are stored in SageMakerServiceSettings.
19+
*/
1620
public interface SageMakerStoredServiceSchema extends ToXContentFragment, VersionedNamedWriteable {
1721
SageMakerStoredServiceSchema NO_OP = new SageMakerStoredServiceSchema() {
1822
private static final String NAME = "noop_sagemaker_service_schema";
@@ -24,7 +28,7 @@ public String getWriteableName() {
2428

2529
@Override
2630
public TransportVersion getMinimalSupportedVersion() {
27-
return TransportVersion.current();
31+
return TransportVersions.ML_INFERENCE_SAGEMAKER;
2832
}
2933

3034
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
99

1010
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
1112
import org.elasticsearch.common.ValidationException;
1213
import org.elasticsearch.common.io.stream.StreamOutput;
1314
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
@@ -16,6 +17,10 @@
1617

1718
import java.util.Map;
1819

20+
/**
21+
* Contains any model-specific settings that are stored in SageMakerTaskSettings.
22+
* Because TaskSettings are updatable, this object must be able to mutate itself, which we handle through the {@link Builder}.
23+
*/
1924
public interface SageMakerStoredTaskSchema extends ToXContentFragment, VersionedNamedWriteable {
2025
SageMakerStoredTaskSchema NO_OP = new SageMakerStoredTaskSchema() {
2126

@@ -44,7 +49,7 @@ public String getWriteableName() {
4449

4550
@Override
4651
public TransportVersion getMinimalSupportedVersion() {
47-
return TransportVersion.current(); // TODO
52+
return TransportVersions.ML_INFERENCE_SAGEMAKER;
4853
}
4954

5055
@Override
@@ -60,9 +65,17 @@ default SageMakerStoredTaskSchema update(Map<String, Object> map, ValidationExce
6065
return toBuilder().fromMap(map, exception).build();
6166
}
6267

68+
/**
69+
* This is called during {@link #update(Map, ValidationException)}.
70+
* Implementations should set the current field values in the Builder, as the update function is expected to overwrite them.
71+
*/
6372
Builder toBuilder();
6473

6574
interface Builder {
75+
/**
76+
* The map will either come from the PUT request or the stored value in the model index.
77+
* It must match the map written by toXContent.
78+
*/
6679
Builder fromMap(Map<String, Object> map, ValidationException exception);
6780

6881
SageMakerStoredTaskSchema build();

0 commit comments

Comments
 (0)