2121import org .elasticsearch .xpack .esql .core .type .DataType ;
2222import org .elasticsearch .xpack .esql .expression .function .inference .TextEmbedding ;
2323import org .elasticsearch .xpack .esql .inference .InferenceFunctionEvaluator ;
24+ import org .elasticsearch .xpack .esql .inference .bulk .BulkInferenceRequestIterator ;
2425import org .elasticsearch .xpack .esql .inference .bulk .BulkInferenceRunner ;
2526
2627import java .util .ArrayList ;
28+ import java .util .Iterator ;
2729import java .util .List ;
2830
2931public class TextEmbeddingFunctionEvaluator implements InferenceFunctionEvaluator {
@@ -42,18 +44,40 @@ public void eval(FoldContext foldContext, ActionListener<Expression> listener) {
4244 assert f .inferenceId () != null && f .inferenceId ().foldable () : "inferenceId should not be null and be foldable" ;
4345 assert f .inputText () != null && f .inputText ().foldable () : "inputText should not be null and be foldable" ;
4446
45- String inferenceId = BytesRefs .toString (f .inferenceId ().fold (foldContext ));
46- String inputText = BytesRefs .toString (f .inputText ().fold (foldContext ));
47+ final String inferenceId = BytesRefs .toString (f .inferenceId ().fold (foldContext ));
48+ final String inputText = BytesRefs .toString (f .inputText ().fold (foldContext ));
4749
48- //bulkInferenceRunner.executeBulk(inferenceRequest(inferenceId, inputText), listener.map(this::parseInferenceResponse));
50+ bulkInferenceRunner .executeBulk (new BulkInferenceRequestIterator () {
51+ private final Iterator <InferenceAction .Request > it = List .of (inferenceRequest (inferenceId , inputText )).iterator ();
52+
53+ @ Override
54+ public void close () {
55+
56+ }
57+
58+ @ Override
59+ public boolean hasNext () {
60+ return it .hasNext ();
61+ }
62+
63+ @ Override
64+ public InferenceAction .Request next () {
65+ return it .next ();
66+ }
67+
68+ @ Override
69+ public int estimatedSize () {
70+ return 1 ;
71+ }
72+ }, listener .map (this ::parseInferenceResponse ));
4973 }
5074
51- private InferenceAction .Request inferenceRequest (String inferenceId , String inputText ) {
75+ private static InferenceAction .Request inferenceRequest (String inferenceId , String inputText ) {
5276 return InferenceAction .Request .builder (inferenceId , TaskType .TEXT_EMBEDDING ).setInput (List .of (inputText )).build ();
5377 }
5478
55- private Literal parseInferenceResponse (InferenceAction .Response response ) {
56- if (response .getResults () instanceof TextEmbeddingResults <?> textEmbeddingResults ) {
79+ private Literal parseInferenceResponse (List < InferenceAction .Response > responses ) {
80+ if (responses . getFirst () .getResults () instanceof TextEmbeddingResults <?> textEmbeddingResults ) {
5781 return parseInferenceResponse (textEmbeddingResults );
5882 }
5983 throw new IllegalArgumentException ("Inference response should be of type TextEmbeddingResults" );
0 commit comments