Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ private DefaultElserFeatureFlag() {}
private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("inference_default_elser");

public static boolean isEnabled() {
return FEATURE_FLAG.isEnabled();
return true;
// return FEATURE_FLAG.isEnabled();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;
Expand Down Expand Up @@ -680,25 +681,13 @@ public void chunkedInfer(
esModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var batch : batchedRequests) {
var inferenceRequest = buildInferenceRequest(
esModel.mlNodeDeploymentId(),
EmptyConfigUpdate.INSTANCE,
batch.batch().inputs(),
inputType,
timeout
);

ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
.delegateFailureAndWrap(
(l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l)
);

var maybeDeployListener = mlResultsListener.delegateResponse(
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener)
);

client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
if (batchedRequests.isEmpty()) {
listener.onResponse(List.of());
} else {
// Avoid filling the inference queue by executing the batches in series
// Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request
var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests);
sequentialRunner.run();
}
} else {
listener.onFailure(notElasticsearchModelException(model));
Expand Down Expand Up @@ -1018,6 +1007,82 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) {
}
}

/**
* Iterates over the batch executing a limited number requests at a time to avoid
* filling the ML node inference queue.
*
* First, a single request is executed, which can also trigger deploying a model
* if necessary. When this request is successfully executed, a callback executes
* N requests in parallel next. Each of these requests also has a callback that
* executes one more request, so that at all time N requests are in-flight. This
* continues until all requests are executed.
*/
class BatchIterator {
private static final int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200

private final AtomicInteger index = new AtomicInteger();
private final ElasticsearchInternalModel esModel;
private final List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners;
private final InputType inputType;
private final TimeValue timeout;

BatchIterator(
ElasticsearchInternalModel esModel,
InputType inputType,
TimeValue timeout,
List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners
) {
this.esModel = esModel;
this.requestAndListeners = requestAndListeners;
this.inputType = inputType;
this.timeout = timeout;
}

void run() {
// The first request may deploy the model, and upon completion runs
// NUM_REQUESTS_INFLIGHT in parallel.
inferenceExecutor.execute(() -> inferBatch(NUM_REQUESTS_INFLIGHT, true));
}

private void inferBatch(int runAfterCount, boolean maybeDeploy) {
int batchIndex = index.getAndIncrement();
if (batchIndex >= requestAndListeners.size()) {
return;
}
executeRequest(batchIndex, maybeDeploy, () -> {
for (int i = 0; i < runAfterCount; i++) {
// Subsequent requests may not deploy the model, because the first request
// already did so. Upon completion, it runs one more request.
inferenceExecutor.execute(() -> inferBatch(1, false));
}
});
}

private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAfter) {
EmbeddingRequestChunker.BatchRequestAndListener batch = requestAndListeners.get(batchIndex);
var inferenceRequest = buildInferenceRequest(
esModel.mlNodeDeploymentId(),
EmptyConfigUpdate.INSTANCE,
batch.batch().inputs(),
inputType,
timeout
);
logger.trace("Executing batch index={}", batchIndex);

ActionListener<InferModelAction.Response> listener = batch.listener()
.delegateFailureAndWrap(
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
);
if (runAfter != null) {
listener = ActionListener.runAfter(listener, runAfter);
}
if (maybeDeploy) {
listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l));
}
client.execute(InferModelAction.INSTANCE, inferenceRequest, listener);
}
}

public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,25 @@
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;

public class EmbeddingRequestChunkerTests extends ESTestCase {

public void testEmptyInput() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, empty());
}

public void testBlankInput() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
}

public void testShortInputsAreSingleBatch() {
String input = "one chunk";
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
Expand Down
Loading
Loading