13
13
import org .elasticsearch .common .io .stream .StreamInput ;
14
14
import org .elasticsearch .common .io .stream .StreamOutput ;
15
15
import org .elasticsearch .compute .data .Block ;
16
+ import org .elasticsearch .compute .data .BlockUtils ;
16
17
import org .elasticsearch .compute .data .BytesRefBlock ;
17
- import org .elasticsearch .compute .data .DoubleVector ;
18
- import org .elasticsearch .compute .data .DoubleVectorBlock ;
18
+ import org .elasticsearch .compute .data .DoubleBlock ;
19
19
import org .elasticsearch .compute .data .Page ;
20
20
import org .elasticsearch .compute .operator .DriverContext ;
21
21
import org .elasticsearch .compute .operator .Operator ;
22
+ import org .elasticsearch .compute .operator .Warnings ;
22
23
import org .elasticsearch .core .Releasables ;
23
24
import org .elasticsearch .core .TimeValue ;
24
25
import org .elasticsearch .xcontent .XContentBuilder ;
28
29
import java .util .Collection ;
29
30
import java .util .Deque ;
30
31
import java .util .HashMap ;
32
+ import java .util .List ;
31
33
import java .util .Map ;
32
34
33
35
/**
40
42
*
41
43
*/
42
44
public class LinearScoreEvalOperator implements Operator {
43
- public record Factory (int discriminatorPosition , int scorePosition , LinearConfig linearConfig ) implements OperatorFactory {
45
+ public record Factory (
46
+ int discriminatorPosition ,
47
+ int scorePosition ,
48
+ LinearConfig linearConfig ,
49
+ String sourceText ,
50
+ int sourceLine ,
51
+ int sourceColumn
52
+ ) implements OperatorFactory {
44
53
45
54
@ Override
46
55
public Operator get (DriverContext driverContext ) {
47
- return new LinearScoreEvalOperator (discriminatorPosition , scorePosition , linearConfig );
56
+ return new LinearScoreEvalOperator (
57
+ driverContext ,
58
+ discriminatorPosition ,
59
+ scorePosition ,
60
+ linearConfig ,
61
+ sourceText ,
62
+ sourceLine ,
63
+ sourceColumn
64
+ );
48
65
}
49
66
50
67
@ Override
@@ -74,11 +91,30 @@ public String describe() {
74
91
private long rowsReceived = 0 ;
75
92
private long rowsEmitted = 0 ;
76
93
77
- public LinearScoreEvalOperator (int discriminatorPosition , int scorePosition , LinearConfig config ) {
94
+ private final String sourceText ;
95
+ private final int sourceLine ;
96
+ private final int sourceColumn ;
97
+ private Warnings warnings ;
98
+ private final DriverContext driverContext ;
99
+
100
+ public LinearScoreEvalOperator (
101
+ DriverContext driverContext ,
102
+ int discriminatorPosition ,
103
+ int scorePosition ,
104
+ LinearConfig config ,
105
+ String sourceText ,
106
+ int sourceLine ,
107
+ int sourceColumn
108
+ ) {
78
109
this .scorePosition = scorePosition ;
79
110
this .discriminatorPosition = discriminatorPosition ;
80
111
this .config = config ;
81
112
this .normalizer = createNormalizer (config .normalizer ());
113
+ this .driverContext = driverContext ;
114
+
115
+ this .sourceText = sourceText ;
116
+ this .sourceLine = sourceLine ;
117
+ this .sourceColumn = sourceColumn ;
82
118
83
119
finished = false ;
84
120
inputPages = new ArrayDeque <>();
@@ -123,25 +159,54 @@ private void createOutputPages() {
123
159
124
160
private void processInputPage (Page inputPage ) {
125
161
BytesRefBlock discriminatorBlock = inputPage .getBlock (discriminatorPosition );
126
- DoubleVectorBlock initialScoreBlock = inputPage .getBlock (scorePosition );
162
+ DoubleBlock initialScoreBlock = inputPage .getBlock (scorePosition );
127
163
128
164
Page newPage = null ;
129
165
Block scoreBlock = null ;
130
- DoubleVector .Builder scores = null ;
166
+ DoubleBlock .Builder scores = null ;
131
167
132
168
try {
133
- scores = discriminatorBlock .blockFactory ().newDoubleVectorBuilder (discriminatorBlock .getPositionCount ());
169
+ scores = discriminatorBlock .blockFactory ().newDoubleBlockBuilder (discriminatorBlock .getPositionCount ());
134
170
135
171
for (int i = 0 ; i < inputPage .getPositionCount (); i ++) {
136
- String discriminator = discriminatorBlock .getBytesRef (i , new BytesRef ()).utf8ToString ();
172
+ Object discriminatorValue = BlockUtils .toJavaObject (discriminatorBlock , i );
173
+
174
+ if (discriminatorValue == null ) {
175
+ warnings ().registerException (new IllegalArgumentException ("group column has null values; assigning null scores" ));
176
+ scores .appendNull ();
177
+ continue ;
178
+ } else if (discriminatorValue instanceof List <?>) {
179
+ warnings ().registerException (
180
+ new IllegalArgumentException ("group column contains multivalued entries; assigning null scores" )
181
+ );
182
+ scores .appendNull ();
183
+ continue ;
184
+ }
185
+ String discriminator = ((BytesRef ) discriminatorValue ).utf8ToString ();
137
186
138
187
var weight = config .weights ().get (discriminator ) == null ? 1.0 : config .weights ().get (discriminator );
139
188
140
- double score = initialScoreBlock .getDouble (i );
189
+ initialScoreBlock .doesHaveMultivaluedFields ();
190
+
191
+ Object scoreValue = BlockUtils .toJavaObject (initialScoreBlock , i );
192
+ if (scoreValue == null ) {
193
+ warnings ().registerException (new IllegalArgumentException ("score column has null values; assigning null scores" ));
194
+ scores .appendNull ();
195
+ continue ;
196
+ } else if (scoreValue instanceof List <?>) {
197
+ warnings ().registerException (
198
+ new IllegalArgumentException ("score column contains multivalued entries; assigning null scores" )
199
+ );
200
+ scores .appendNull ();
201
+ continue ;
202
+ }
203
+
204
+ double score = (double ) scoreValue ;
205
+
141
206
scores .appendDouble (weight * normalizer .normalize (score , discriminator ));
142
207
}
143
208
144
- scoreBlock = scores .build (). asBlock () ;
209
+ scoreBlock = scores .build ();
145
210
newPage = inputPage .appendBlock (scoreBlock );
146
211
147
212
int [] projections = new int [newPage .getBlockCount () - 1 ];
@@ -270,23 +335,43 @@ private Normalizer createNormalizer(LinearConfig.Normalizer normalizer) {
270
335
};
271
336
}
272
337
273
- private interface Normalizer {
274
- double normalize (double score , String discriminator );
338
+ private abstract static class Normalizer {
339
+ abstract double normalize (double score , String discriminator );
275
340
276
- void preprocess (Collection <Page > inputPages , int scorePosition , int discriminatorPosition );
341
+ abstract void preprocess (double score , String discriminator );
342
+
343
+ void finalizePreprocess () {};
344
+
345
+ void preprocess (Collection <Page > inputPages , int scorePosition , int discriminatorPosition ) {
346
+ for (Page inputPage : inputPages ) {
347
+ DoubleBlock scoreBlock = inputPage .getBlock (scorePosition );
348
+ BytesRefBlock discriminatorBlock = inputPage .getBlock (discriminatorPosition );
349
+
350
+ for (int i = 0 ; i < inputPage .getPositionCount (); i ++) {
351
+ Object scoreValue = BlockUtils .toJavaObject (scoreBlock , i );
352
+ Object discriminatorValue = BlockUtils .toJavaObject (discriminatorBlock , i );
353
+
354
+ if (scoreValue instanceof Double score && discriminatorValue instanceof BytesRef discriminator ) {
355
+ preprocess (score , discriminator .utf8ToString ());
356
+ }
357
+ }
358
+ }
359
+
360
+ finalizePreprocess ();
361
+ }
277
362
}
278
363
279
- private class NoneNormalizer implements Normalizer {
364
+ private static class NoneNormalizer extends Normalizer {
280
365
@ Override
281
366
public double normalize (double score , String discriminator ) {
282
367
return score ;
283
368
}
284
369
285
370
@ Override
286
- public void preprocess (Collection < Page > inputPages , int scorePosition , int discriminatorPosition ) {}
371
+ void preprocess (double score , String discriminator ) {}
287
372
}
288
373
289
- private class L2NormNormalizer implements Normalizer {
374
+ private static class L2NormNormalizer extends Normalizer {
290
375
private final Map <String , Double > l2Norms = new HashMap <>();
291
376
292
377
@ Override
@@ -297,24 +382,17 @@ public double normalize(double score, String discriminator) {
297
382
}
298
383
299
384
@ Override
300
- public void preprocess (Collection <Page > inputPages , int scorePosition , int discriminatorPosition ) {
301
- for (Page inputPage : inputPages ) {
302
- DoubleVectorBlock scoreBlock = inputPage .getBlock (scorePosition );
303
- BytesRefBlock discriminatorBlock = inputPage .getBlock (discriminatorPosition );
304
-
305
- for (int i = 0 ; i < inputPage .getPositionCount (); i ++) {
306
- double score = scoreBlock .getDouble (i );
307
- String discriminator = discriminatorBlock .getBytesRef (i , new BytesRef ()).utf8ToString ();
308
-
309
- l2Norms .compute (discriminator , (k , v ) -> v == null ? score * score : v + score * score );
310
- }
311
- }
385
+ void preprocess (double score , String discriminator ) {
386
+ l2Norms .compute (discriminator , (k , v ) -> v == null ? score * score : v + score * score );
387
+ }
312
388
389
+ @ Override
390
+ void finalizePreprocess () {
313
391
l2Norms .replaceAll ((k , v ) -> Math .sqrt (v ));
314
392
}
315
393
}
316
394
317
- private class MinMaxNormalizer implements Normalizer {
395
+ private static class MinMaxNormalizer extends Normalizer {
318
396
private final Map <String , Double > minScores = new HashMap <>();
319
397
private final Map <String , Double > maxScores = new HashMap <>();
320
398
@@ -334,19 +412,17 @@ public double normalize(double score, String discriminator) {
334
412
}
335
413
336
414
@ Override
337
- public void preprocess (Collection <Page > inputPages , int scorePosition , int discriminatorPosition ) {
338
- for (Page inputPage : inputPages ) {
339
- DoubleVectorBlock scoreBlock = inputPage .getBlock (scorePosition );
340
- BytesRefBlock discriminatorBlock = inputPage .getBlock (discriminatorPosition );
341
-
342
- for (int i = 0 ; i < inputPage .getPositionCount (); i ++) {
343
- double score = scoreBlock .getDouble (i );
344
- String discriminator = discriminatorBlock .getBytesRef (i , new BytesRef ()).utf8ToString ();
415
+ void preprocess (double score , String discriminator ) {
416
+ minScores .compute (discriminator , (key , value ) -> value == null ? score : Math .min (value , score ));
417
+ maxScores .compute (discriminator , (key , value ) -> value == null ? score : Math .max (value , score ));
418
+ }
419
+ }
345
420
346
- minScores .compute (discriminator , (key , value ) -> value == null ? score : Math .min (value , score ));
347
- maxScores .compute (discriminator , (key , value ) -> value == null ? score : Math .max (value , score ));
348
- }
349
- }
421
+ private Warnings warnings () {
422
+ if (warnings == null ) {
423
+ this .warnings = Warnings .createWarnings (driverContext .warningsMode (), sourceLine , sourceColumn , sourceText );
350
424
}
425
+
426
+ return warnings ;
351
427
}
352
428
}
0 commit comments