Skip to content

Commit 03d772e

Browse files
authored
[ML] Adapt Question Answering processing for non-batched evaluation (#98167) (#98258)
1 parent 8e14de6 commit 03d772e

File tree

4 files changed

+125
-24
lines changed

4 files changed

+125
-24
lines changed

docs/changelog/98167.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 98167
2+
summary: Fix failure processing Question Answering model output where the input has been spanned over multiple sequences
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 97917

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -87,48 +87,78 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchIn
8787
if (pyTorchResult.getInferenceResult().length < 1) {
8888
throw new ElasticsearchStatusException("question answering result has no data", RestStatus.INTERNAL_SERVER_ERROR);
8989
}
90+
91+
// The result format is pairs of 'start' and 'end' logits,
92+
// one pair for each span.
93+
// Multiple spans occur where the context text is longer than
94+
// the max sequence length, so the input must be windowed with
95+
// overlap and evaluated in multiple calls.
96+
// Note the response format changed in 8.9 due to the change in
97+
// pytorch_inference to not process requests in batches.
98+
99+
// The output tensor is a 3d array of doubles.
100+
// 1. The 1st index is the pairs of start and end for each span.
101+
// If there is 1 span there will be 2 elements in this dimension,
102+
// for 2 spans 4 elements
103+
// 2. The 2nd index is the number results per span.
104+
// This dimension is always equal to 1.
105+
// 3. The 3rd index is the actual scores.
106+
// This is an array of doubles equal in size to the number of
107+
// input tokens plus and delimiters (e.g. SEP and CLS tokens)
108+
// added by the tokenizer.
109+
//
110+
// inferenceResult[span_index_start_end][0][scores]
111+
90112
// Should be a collection of "starts" and "ends"
91-
if (pyTorchResult.getInferenceResult().length != 2) {
113+
if (pyTorchResult.getInferenceResult().length % 2 != 0) {
92114
throw new ElasticsearchStatusException(
93-
"question answering result has invalid dimension, expected 2 found [{}]",
115+
"question answering result has invalid dimension, number of dimensions must be a multiple of 2 found [{}]",
94116
RestStatus.INTERNAL_SERVER_ERROR,
95117
pyTorchResult.getInferenceResult().length
96118
);
97119
}
98-
double[][] starts = pyTorchResult.getInferenceResult()[0];
99-
double[][] ends = pyTorchResult.getInferenceResult()[1];
100-
if (starts.length != ends.length) {
101-
throw new ElasticsearchStatusException(
102-
"question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]",
103-
RestStatus.INTERNAL_SERVER_ERROR,
104-
starts.length,
105-
ends.length
106-
);
107-
}
120+
121+
final int numAnswersToGather = Math.max(numTopClasses, 1);
122+
ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather);
108123
List<TokenizationResult.Tokens> tokensList = tokenization.getTokensBySequenceId().get(0);
109-
if (starts.length != tokensList.size()) {
124+
125+
int numberOfSpans = pyTorchResult.getInferenceResult().length / 2;
126+
if (numberOfSpans != tokensList.size()) {
110127
throw new ElasticsearchStatusException(
111-
"question answering result has invalid dimensions; start positions number [{}] equal batched token size [{}]",
128+
"question answering result has invalid dimensions; the number of spans [{}] does not match batched token size [{}]",
112129
RestStatus.INTERNAL_SERVER_ERROR,
113-
starts.length,
130+
numberOfSpans,
114131
tokensList.size()
115132
);
116133
}
117-
final int numAnswersToGather = Math.max(numTopClasses, 1);
118134

119-
ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather);
120-
for (int i = 0; i < starts.length; i++) {
135+
for (int spanIndex = 0; spanIndex < numberOfSpans; spanIndex++) {
136+
double[][] starts = pyTorchResult.getInferenceResult()[spanIndex * 2];
137+
double[][] ends = pyTorchResult.getInferenceResult()[(spanIndex * 2) + 1];
138+
assert starts.length == 1;
139+
assert ends.length == 1;
140+
141+
if (starts.length != ends.length) {
142+
throw new ElasticsearchStatusException(
143+
"question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]",
144+
RestStatus.INTERNAL_SERVER_ERROR,
145+
starts.length,
146+
ends.length
147+
);
148+
}
149+
121150
topScores(
122-
starts[i],
123-
ends[i],
151+
starts[0], // always 1 element in this dimension
152+
ends[0],
124153
numAnswersToGather,
125154
finalEntries::insertWithOverflow,
126-
tokensList.get(i).seqPairOffset(),
127-
tokensList.get(i).tokenIds().length,
155+
tokensList.get(spanIndex).seqPairOffset(),
156+
tokensList.get(spanIndex).tokenIds().length,
128157
maxAnswerLength,
129-
i
158+
spanIndex
130159
);
131160
}
161+
132162
QuestionAnsweringInferenceResults.TopAnswerEntry[] topAnswerList =
133163
new QuestionAnsweringInferenceResults.TopAnswerEntry[numAnswersToGather];
134164
for (int i = numAnswersToGather - 1; i >= 0; i--) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public Map<Integer, List<Tokens>> getTokensBySequenceId() {
4848
return tokens.stream().collect(Collectors.groupingBy(Tokens::sequenceId));
4949
}
5050

51-
List<Tokens> getTokens() {
51+
public List<Tokens> getTokens() {
5252
return tokens;
5353
}
5454

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.IOException;
2121
import java.util.List;
2222
import java.util.concurrent.atomic.AtomicReference;
23+
import java.util.stream.DoubleStream;
2324

2425
import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB;
2526
import static org.hamcrest.Matchers.closeTo;
@@ -168,4 +169,68 @@ public void testTopScoresMoreThanOne() {
168169
assertThat(topScores[1].endToken(), equalTo(5));
169170
}
170171

172+
public void testProcessorMuliptleSpans() throws IOException {
173+
String question = "is Elasticsearch fun?";
174+
String input = "Pancake day is fun with Elasticsearch and little red car";
175+
int span = 4;
176+
int maxSequenceLength = 14;
177+
int numberTopClasses = 3;
178+
179+
BertTokenization tokenization = new BertTokenization(false, true, maxSequenceLength, Tokenization.Truncate.NONE, span);
180+
BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, tokenization).build();
181+
QuestionAnsweringConfig config = new QuestionAnsweringConfig(
182+
question,
183+
numberTopClasses,
184+
10,
185+
new VocabularyConfig("index_name"),
186+
tokenization,
187+
"prediction"
188+
);
189+
QuestionAnsweringProcessor processor = new QuestionAnsweringProcessor(tokenizer);
190+
TokenizationResult tokenizationResult = processor.getRequestBuilder(config)
191+
.buildRequest(List.of(input), "1", Tokenization.Truncate.NONE, span)
192+
.tokenization();
193+
assertThat(tokenizationResult.anyTruncated(), is(false));
194+
195+
// now we know what the tokenization looks like
196+
// (number of spans and size of each) fake the
197+
// question answering response
198+
199+
int numberSpans = tokenizationResult.getTokens().size();
200+
double[][][] modelTensorOutput = new double[numberSpans * 2][][];
201+
for (int i = 0; i < numberSpans; i++) {
202+
var windowTokens = tokenizationResult.getTokens().get(i);
203+
// size of output
204+
int outputSize = windowTokens.tokenIds().length;
205+
// generate low value -ve scores that will not mark
206+
// the expected result with a high degree of probability
207+
double[] starts = DoubleStream.generate(() -> -randomDoubleBetween(0.001, 1.0, true)).limit(outputSize).toArray();
208+
double[] ends = DoubleStream.generate(() -> -randomDoubleBetween(0.001, 1.0, true)).limit(outputSize).toArray();
209+
modelTensorOutput[i * 2] = new double[][] { starts };
210+
modelTensorOutput[(i * 2) + 1] = new double[][] { ends };
211+
}
212+
213+
int spanContainingTheAnswer = randomIntBetween(0, numberSpans - 1);
214+
215+
// insert numbers to mark the answer in the chosen span
216+
int answerStart = tokenizationResult.getTokens().get(spanContainingTheAnswer).seqPairOffset(); // first token of second sequence
217+
// last token of the second sequence ignoring the final SEP added by the BERT tokenizer
218+
int answerEnd = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokenIds().length - 2;
219+
modelTensorOutput[spanContainingTheAnswer * 2][0][answerStart] = 0.5;
220+
modelTensorOutput[(spanContainingTheAnswer * 2) + 1][0][answerEnd] = 1.0;
221+
222+
NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(config);
223+
PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(modelTensorOutput);
224+
QuestionAnsweringInferenceResults result = (QuestionAnsweringInferenceResults) resultProcessor.processResult(
225+
tokenizationResult,
226+
pyTorchResult
227+
);
228+
229+
// The expected answer is the full text of the span containing the answer
230+
int expectedStart = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).get(0).startOffset();
231+
int lastTokenPosition = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).size() - 1;
232+
int expectedEnd = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).get(lastTokenPosition).endOffset();
233+
234+
assertThat(result.getAnswer(), equalTo(input.substring(expectedStart, expectedEnd)));
235+
}
171236
}

0 commit comments

Comments
 (0)