Skip to content

Commit 9d7a23f

Browse files
committed
Create the text embedding request iterator
1 parent cee2935 commit 9d7a23f

File tree

6 files changed

+563
-58
lines changed

6 files changed

+563
-58
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ public class EsqlFunctionRegistry {
272272
}
273273

274274
// Translation table for error messaging in the following function
275-
private static final String[] NUM_NAMES = {"zero", "one", "two", "three", "four", "five", "six"};
275+
private static final String[] NUM_NAMES = { "zero", "one", "two", "three", "four", "five", "six" };
276276

277277
// list of functions grouped by type of functions (aggregate, statistics, math etc) and ordered alphabetically inside each group
278278
// a single function will have one entry for itself with its name associated to its instance and, also, one entry for each alias
@@ -353,7 +353,7 @@ private static FunctionDefinition[][] functions() {
353353
def(Values.class, uni(Values::new), "values"),
354354
def(WeightedAvg.class, bi(WeightedAvg::new), "weighted_avg"),
355355
def(Present.class, uni(Present::new), "present"),
356-
def(Absent.class, uni(Absent::new), "absent")},
356+
def(Absent.class, uni(Absent::new), "absent") },
357357
// math
358358
new FunctionDefinition[] {
359359
def(Abs.class, Abs::new, "abs"),
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.compute.data.BytesRefBlock;
12+
import org.elasticsearch.core.Releasable;
13+
import org.elasticsearch.core.Releasables;
14+
15+
/**
16+
* Helper class that reads text strings from a {@link BytesRefBlock}.
17+
* This class is used by inference operators to extract text content from block data.
18+
*/
19+
public class InputTextReader implements Releasable {
20+
private final BytesRefBlock textBlock;
21+
private final StringBuilder strBuilder = new StringBuilder();
22+
private BytesRef readBuffer = new BytesRef();
23+
24+
public InputTextReader(BytesRefBlock textBlock) {
25+
this.textBlock = textBlock;
26+
}
27+
28+
/**
29+
* Reads the text string at the given position.
30+
* Multiple values at the position are concatenated with newlines.
31+
*
32+
* @param pos the position index in the block
33+
* @return the text string at the position, or null if the position contains a null value
34+
*/
35+
public String readText(int pos) {
36+
return readText(pos, Integer.MAX_VALUE);
37+
}
38+
39+
/**
40+
* Reads the text string at the given position.
41+
*
42+
* @param pos the position index in the block
43+
* @param limit the maximum number of value to read from the position
44+
* @return the text string at the position, or null if the position contains a null value
45+
*/
46+
public String readText(int pos, int limit) {
47+
if (textBlock.isNull(pos)) {
48+
return null;
49+
}
50+
51+
strBuilder.setLength(0);
52+
int maxPos = Math.min(limit, textBlock.getValueCount(pos));
53+
for (int valueIndex = 0; valueIndex < maxPos; valueIndex++) {
54+
readBuffer = textBlock.getBytesRef(textBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
55+
strBuilder.append(readBuffer.utf8ToString());
56+
if (valueIndex != maxPos - 1) {
57+
strBuilder.append("\n");
58+
}
59+
}
60+
61+
return strBuilder.toString();
62+
}
63+
64+
/**
65+
* Returns the total number of positions (text entries) in the block.
66+
*/
67+
public int estimatedSize() {
68+
return textBlock.getPositionCount();
69+
}
70+
71+
@Override
72+
public void close() {
73+
textBlock.allowPassingToDifferentDriver();
74+
Releasables.close(textBlock);
75+
}
76+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java

Lines changed: 6 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77

88
package org.elasticsearch.xpack.esql.inference.completion;
99

10-
import org.apache.lucene.util.BytesRef;
1110
import org.elasticsearch.compute.data.BytesRefBlock;
12-
import org.elasticsearch.core.Releasable;
1311
import org.elasticsearch.core.Releasables;
1412
import org.elasticsearch.inference.TaskType;
1513
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
14+
import org.elasticsearch.xpack.esql.inference.InputTextReader;
1615
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
1716

1817
import java.util.List;
@@ -24,7 +23,7 @@
2423
*/
2524
public class CompletionOperatorRequestIterator implements BulkInferenceRequestIterator {
2625

27-
private final PromptReader promptReader;
26+
private final InputTextReader textReader;
2827
private final String inferenceId;
2928
private final int size;
3029
private int currentPos = 0;
@@ -36,7 +35,7 @@ public class CompletionOperatorRequestIterator implements BulkInferenceRequestIt
3635
* @param inferenceId The ID of the inference model to invoke.
3736
*/
3837
public CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) {
39-
this.promptReader = new PromptReader(promptBlock);
38+
this.textReader = new InputTextReader(promptBlock);
4039
this.size = promptBlock.getPositionCount();
4140
this.inferenceId = inferenceId;
4241
}
@@ -52,7 +51,7 @@ public InferenceAction.Request next() {
5251
throw new NoSuchElementException();
5352
}
5453

55-
return inferenceRequest(promptReader.readPrompt(currentPos++));
54+
return inferenceRequest(textReader.readText(currentPos++));
5655
}
5756

5857
/**
@@ -68,60 +67,11 @@ private InferenceAction.Request inferenceRequest(String prompt) {
6867

6968
@Override
7069
public int estimatedSize() {
71-
return promptReader.estimatedSize();
70+
return textReader.estimatedSize();
7271
}
7372

7473
@Override
7574
public void close() {
76-
Releasables.close(promptReader);
77-
}
78-
79-
/**
80-
* Helper class that reads prompts from a {@link BytesRefBlock}.
81-
*/
82-
private static class PromptReader implements Releasable {
83-
private final BytesRefBlock promptBlock;
84-
private final StringBuilder strBuilder = new StringBuilder();
85-
private BytesRef readBuffer = new BytesRef();
86-
87-
private PromptReader(BytesRefBlock promptBlock) {
88-
this.promptBlock = promptBlock;
89-
}
90-
91-
/**
92-
* Reads the prompt string at the given position..
93-
*
94-
* @param pos the position index in the block
95-
*/
96-
public String readPrompt(int pos) {
97-
if (promptBlock.isNull(pos)) {
98-
return null;
99-
}
100-
101-
strBuilder.setLength(0);
102-
103-
for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
104-
readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
105-
strBuilder.append(readBuffer.utf8ToString());
106-
if (valueIndex != promptBlock.getValueCount(pos) - 1) {
107-
strBuilder.append("\n");
108-
}
109-
}
110-
111-
return strBuilder.toString();
112-
}
113-
114-
/**
115-
* Returns the total number of positions (prompts) in the block.
116-
*/
117-
public int estimatedSize() {
118-
return promptBlock.getPositionCount();
119-
}
120-
121-
@Override
122-
public void close() {
123-
promptBlock.allowPassingToDifferentDriver();
124-
Releasables.close(promptBlock);
125-
}
75+
Releasables.close(textReader);
12676
}
12777
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.textembedding;
9+
10+
import org.elasticsearch.compute.data.BytesRefBlock;
11+
import org.elasticsearch.core.Releasables;
12+
import org.elasticsearch.inference.TaskType;
13+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
14+
import org.elasticsearch.xpack.esql.inference.InputTextReader;
15+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
16+
17+
import java.util.List;
18+
import java.util.NoSuchElementException;
19+
20+
/**
21+
* This iterator reads text inputs from a {@link BytesRefBlock} and converts them into individual {@link InferenceAction.Request} instances
22+
* of type {@link TaskType#TEXT_EMBEDDING}.
23+
*/
24+
public class TextEmbeddingOperatorRequestIterator implements BulkInferenceRequestIterator {
25+
26+
private final InputTextReader textReader;
27+
private final String inferenceId;
28+
private final int size;
29+
private int currentPos = 0;
30+
31+
/**
32+
* Constructs a new iterator from the given block of text inputs.
33+
*
34+
* @param textBlock The input block containing text to embed.
35+
* @param inferenceId The ID of the inference model to invoke.
36+
*/
37+
public TextEmbeddingOperatorRequestIterator(BytesRefBlock textBlock, String inferenceId) {
38+
this.textReader = new InputTextReader(textBlock);
39+
this.size = textBlock.getPositionCount();
40+
this.inferenceId = inferenceId;
41+
}
42+
43+
@Override
44+
public boolean hasNext() {
45+
return currentPos < size;
46+
}
47+
48+
@Override
49+
public InferenceAction.Request next() {
50+
if (hasNext() == false) {
51+
throw new NoSuchElementException();
52+
}
53+
54+
/*
55+
* Keep only the first value in case of multi-valued fields.
56+
* TODO: check if it is consistent with how the query vector builder is working.
57+
*/
58+
return inferenceRequest(textReader.readText(currentPos++, 1));
59+
}
60+
61+
/**
62+
* Wraps a single text string into an {@link InferenceAction.Request} for text embedding.
63+
*/
64+
private InferenceAction.Request inferenceRequest(String text) {
65+
if (text == null) {
66+
return null;
67+
}
68+
69+
return InferenceAction.Request.builder(inferenceId, TaskType.TEXT_EMBEDDING).setInput(List.of(text)).build();
70+
}
71+
72+
@Override
73+
public int estimatedSize() {
74+
return textReader.estimatedSize();
75+
}
76+
77+
@Override
78+
public void close() {
79+
Releasables.close(textReader);
80+
}
81+
}

0 commit comments

Comments
 (0)