-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Integrate SageMaker with OpenAI Embeddings #126856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi @prwhelan, I've created a changelog YAML for you. |
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
jonathan-buttner
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Just left a few thoughts.
...ence/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClient.java
Show resolved
Hide resolved
...ence/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClient.java
Outdated
Show resolved
Hide resolved
...ava/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerServiceSettings.java
Show resolved
Hide resolved
| return builder.endObject(); | ||
| } | ||
|
|
||
| private static <T> void optionalField(String name, T value, XContentBuilder builder) throws IOException { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, might be helpful to have this in a utility class somewhere eventually because we have to do stuff like this a lot.
...c/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java
Show resolved
Hide resolved
...rg/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java
Outdated
Show resolved
Hide resolved
...rg/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchemaPayload.java
Outdated
Show resolved
Hide resolved
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
...ence/src/test/java/org/elasticsearch/xpack/inference/services/InferenceSettingsTestCase.java
Outdated
Show resolved
Hide resolved
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
|
Pinging @elastic/ml-core (Team:ML) |
jonathan-buttner
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a reminder to add docs in the elasticsearch-specification repo.
| return Collections.unmodifiableMap(configurationMap); | ||
| }); | ||
| new LazyInitializable<>( | ||
| () -> configuration(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).collect( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Would Map.of() work instead of using a stream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see we're combining multiple streams in a separate place 👍
davidkyle
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| } else { | ||
| ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread); | ||
| log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker."); | ||
| listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.")); | |
| listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.", t)); |
| public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) { | ||
| if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) { | ||
| log.debug("Subscriber connecting to publisher."); | ||
| var publisher = holder.getAndSet(null).v1(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other implementations of this method call onError() if a subscriber is already set, should this do the same?
| Map<String, Object> config, | ||
| ActionListener<Model> parsedModelListener | ||
| ) { | ||
| ActionListener.completeWith(parsedModelListener, () -> modelBuilder.fromRequest(modelId, taskType, NAME, config)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
|
|
||
| public class SageMakerService implements InferenceService { | ||
| public static final String NAME = "sagemaker"; | ||
| private static final int DEFAULT_BATCH_SIZE = 2048; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like is a big number. 2048 may be an optimal size for SageMaker but a batch this size would use quite a lot of memory and isn't sympathetic with how the inference API works
| Map.entry( | ||
| API, | ||
| new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The API format that your SageMaker Endpoint expects.") | ||
| .setLabel("Api") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| .setLabel("Api") | |
| .setLabel("API") |
| public final void testXContentRoundTrip() throws IOException { | ||
| var instance = createTestInstance(); | ||
| var instanceAsMap = toMap(instance); | ||
| var roundTripInstance = fromMutableMap(new HashMap<>(instanceAsMap)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙇
💔 Backport failed
You can use sqren/backport to manually backport by running |
Integrating with SageMaker. Current design: - SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well. - `api` implementations are extensions of `SageMakerSchemaPayload`, which supports: - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings` - conversion logic from model, service settings, task settings, and input to `SdkBytes` - conversion logic from responding `SdkBytes` to `InferenceServiceResults` - Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set - We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
Integrating with SageMaker. Current design: - SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well. - `api` implementations are extensions of `SageMakerSchemaPayload`, which supports: - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings` - conversion logic from model, service settings, task settings, and input to `SdkBytes` - conversion logic from responding `SdkBytes` to `InferenceServiceResults` - Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set - We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
Integrating with SageMaker.
Current design:
apirepresents the structure of the payload that we will send, for exampleopenai,elastic,common, probablycohereorhuggingfaceas well.apiimplementations are extensions ofSageMakerSchemaPayload, which supports:coherewould requireembedding_typeandopenaiwould requiredimensionsin theservice_settingsSdkBytesSdkBytestoInferenceServiceResultsservice_settingsandtask_settingsthat are independent of the api format that we will store and set