@@ -87,48 +87,78 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchIn
87
87
if (pyTorchResult .getInferenceResult ().length < 1 ) {
88
88
throw new ElasticsearchStatusException ("question answering result has no data" , RestStatus .INTERNAL_SERVER_ERROR );
89
89
}
90
+
91
+ // The result format is pairs of 'start' and 'end' logits,
92
+ // one pair for each span.
93
+ // Multiple spans occur where the context text is longer than
94
+ // the max sequence length, so the input must be windowed with
95
+ // overlap and evaluated in multiple calls.
96
+ // Note the response format changed in 8.9 due to the change in
97
+ // pytorch_inference to not process requests in batches.
98
+
99
+ // The output tensor is a 3d array of doubles.
100
+ // 1. The 1st index is the pairs of start and end for each span.
101
+ // If there is 1 span there will be 2 elements in this dimension,
102
+ // for 2 spans 4 elements
103
+ // 2. The 2nd index is the number results per span.
104
+ // This dimension is always equal to 1.
105
+ // 3. The 3rd index is the actual scores.
106
+ // This is an array of doubles equal in size to the number of
107
+ // input tokens plus and delimiters (e.g. SEP and CLS tokens)
108
+ // added by the tokenizer.
109
+ //
110
+ // inferenceResult[span_index_start_end][0][scores]
111
+
90
112
// Should be a collection of "starts" and "ends"
91
- if (pyTorchResult .getInferenceResult ().length != 2 ) {
113
+ if (pyTorchResult .getInferenceResult ().length % 2 != 0 ) {
92
114
throw new ElasticsearchStatusException (
93
- "question answering result has invalid dimension, expected 2 found [{}]" ,
115
+ "question answering result has invalid dimension, number of dimensions must be a multiple of 2 found [{}]" ,
94
116
RestStatus .INTERNAL_SERVER_ERROR ,
95
117
pyTorchResult .getInferenceResult ().length
96
118
);
97
119
}
98
- double [][] starts = pyTorchResult .getInferenceResult ()[0 ];
99
- double [][] ends = pyTorchResult .getInferenceResult ()[1 ];
100
- if (starts .length != ends .length ) {
101
- throw new ElasticsearchStatusException (
102
- "question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]" ,
103
- RestStatus .INTERNAL_SERVER_ERROR ,
104
- starts .length ,
105
- ends .length
106
- );
107
- }
120
+
121
+ final int numAnswersToGather = Math .max (numTopClasses , 1 );
122
+ ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue (numAnswersToGather );
108
123
List <TokenizationResult .Tokens > tokensList = tokenization .getTokensBySequenceId ().get (0 );
109
- if (starts .length != tokensList .size ()) {
124
+
125
+ int numberOfSpans = pyTorchResult .getInferenceResult ().length / 2 ;
126
+ if (numberOfSpans != tokensList .size ()) {
110
127
throw new ElasticsearchStatusException (
111
- "question answering result has invalid dimensions; start positions number [{}] equal batched token size [{}]" ,
128
+ "question answering result has invalid dimensions; the number of spans [{}] does not match batched token size [{}]" ,
112
129
RestStatus .INTERNAL_SERVER_ERROR ,
113
- starts . length ,
130
+ numberOfSpans ,
114
131
tokensList .size ()
115
132
);
116
133
}
117
- final int numAnswersToGather = Math .max (numTopClasses , 1 );
118
134
119
- ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue (numAnswersToGather );
120
- for (int i = 0 ; i < starts .length ; i ++) {
135
+ for (int spanIndex = 0 ; spanIndex < numberOfSpans ; spanIndex ++) {
136
+ double [][] starts = pyTorchResult .getInferenceResult ()[spanIndex * 2 ];
137
+ double [][] ends = pyTorchResult .getInferenceResult ()[(spanIndex * 2 ) + 1 ];
138
+ assert starts .length == 1 ;
139
+ assert ends .length == 1 ;
140
+
141
+ if (starts .length != ends .length ) {
142
+ throw new ElasticsearchStatusException (
143
+ "question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]" ,
144
+ RestStatus .INTERNAL_SERVER_ERROR ,
145
+ starts .length ,
146
+ ends .length
147
+ );
148
+ }
149
+
121
150
topScores (
122
- starts [i ],
123
- ends [i ],
151
+ starts [0 ], // always 1 element in this dimension
152
+ ends [0 ],
124
153
numAnswersToGather ,
125
154
finalEntries ::insertWithOverflow ,
126
- tokensList .get (i ).seqPairOffset (),
127
- tokensList .get (i ).tokenIds ().length ,
155
+ tokensList .get (spanIndex ).seqPairOffset (),
156
+ tokensList .get (spanIndex ).tokenIds ().length ,
128
157
maxAnswerLength ,
129
- i
158
+ spanIndex
130
159
);
131
160
}
161
+
132
162
QuestionAnsweringInferenceResults .TopAnswerEntry [] topAnswerList =
133
163
new QuestionAnsweringInferenceResults .TopAnswerEntry [numAnswersToGather ];
134
164
for (int i = numAnswersToGather - 1 ; i >= 0 ; i --) {
0 commit comments