Skip to content

Commit 0859c7d

Browse files
committed
Lint
1 parent 46ef9b6 commit 0859c7d

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import org.elasticsearch.xpack.esql.core.type.DataType;
2222
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
2323
import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator;
24+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
2425
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
2526

2627
import java.util.ArrayList;
28+
import java.util.Iterator;
2729
import java.util.List;
2830

2931
public class TextEmbeddingFunctionEvaluator implements InferenceFunctionEvaluator {
@@ -42,18 +44,40 @@ public void eval(FoldContext foldContext, ActionListener<Expression> listener) {
4244
assert f.inferenceId() != null && f.inferenceId().foldable() : "inferenceId should not be null and be foldable";
4345
assert f.inputText() != null && f.inputText().foldable() : "inputText should not be null and be foldable";
4446

45-
String inferenceId = BytesRefs.toString(f.inferenceId().fold(foldContext));
46-
String inputText = BytesRefs.toString(f.inputText().fold(foldContext));
47+
final String inferenceId = BytesRefs.toString(f.inferenceId().fold(foldContext));
48+
final String inputText = BytesRefs.toString(f.inputText().fold(foldContext));
4749

48-
//bulkInferenceRunner.executeBulk(inferenceRequest(inferenceId, inputText), listener.map(this::parseInferenceResponse));
50+
bulkInferenceRunner.executeBulk(new BulkInferenceRequestIterator() {
51+
private final Iterator<InferenceAction.Request> it = List.of(inferenceRequest(inferenceId, inputText)).iterator();
52+
53+
@Override
54+
public void close() {
55+
56+
}
57+
58+
@Override
59+
public boolean hasNext() {
60+
return it.hasNext();
61+
}
62+
63+
@Override
64+
public InferenceAction.Request next() {
65+
return it.next();
66+
}
67+
68+
@Override
69+
public int estimatedSize() {
70+
return 1;
71+
}
72+
}, listener.map(this::parseInferenceResponse));
4973
}
5074

51-
private InferenceAction.Request inferenceRequest(String inferenceId, String inputText) {
75+
private static InferenceAction.Request inferenceRequest(String inferenceId, String inputText) {
5276
return InferenceAction.Request.builder(inferenceId, TaskType.TEXT_EMBEDDING).setInput(List.of(inputText)).build();
5377
}
5478

55-
private Literal parseInferenceResponse(InferenceAction.Response response) {
56-
if (response.getResults() instanceof TextEmbeddingResults<?> textEmbeddingResults) {
79+
private Literal parseInferenceResponse(List<InferenceAction.Response> responses) {
80+
if (responses.getFirst().getResults() instanceof TextEmbeddingResults<?> textEmbeddingResults) {
5781
return parseInferenceResponse(textEmbeddingResults);
5882
}
5983
throw new IllegalArgumentException("Inference response should be of type TextEmbeddingResults");

0 commit comments

Comments
 (0)