1313import org .elasticsearch .common .io .stream .StreamInput ;
1414import org .elasticsearch .common .io .stream .StreamOutput ;
1515import org .elasticsearch .compute .data .Block ;
16+ import org .elasticsearch .compute .data .BlockUtils ;
1617import org .elasticsearch .compute .data .BytesRefBlock ;
17- import org .elasticsearch .compute .data .DoubleVector ;
18- import org .elasticsearch .compute .data .DoubleVectorBlock ;
18+ import org .elasticsearch .compute .data .DoubleBlock ;
1919import org .elasticsearch .compute .data .Page ;
2020import org .elasticsearch .compute .operator .DriverContext ;
2121import org .elasticsearch .compute .operator .Operator ;
22+ import org .elasticsearch .compute .operator .Warnings ;
2223import org .elasticsearch .core .Releasables ;
2324import org .elasticsearch .core .TimeValue ;
2425import org .elasticsearch .xcontent .XContentBuilder ;
2829import java .util .Collection ;
2930import java .util .Deque ;
3031import java .util .HashMap ;
32+ import java .util .List ;
3133import java .util .Map ;
3234
3335/**
4042 *
4143 */
4244public class LinearScoreEvalOperator implements Operator {
43- public record Factory (int discriminatorPosition , int scorePosition , LinearConfig linearConfig ) implements OperatorFactory {
45+ public record Factory (
46+ int discriminatorPosition ,
47+ int scorePosition ,
48+ LinearConfig linearConfig ,
49+ String sourceText ,
50+ int sourceLine ,
51+ int sourceColumn
52+ ) implements OperatorFactory {
4453
4554 @ Override
4655 public Operator get (DriverContext driverContext ) {
47- return new LinearScoreEvalOperator (discriminatorPosition , scorePosition , linearConfig );
56+ return new LinearScoreEvalOperator (
57+ driverContext ,
58+ discriminatorPosition ,
59+ scorePosition ,
60+ linearConfig ,
61+ sourceText ,
62+ sourceLine ,
63+ sourceColumn
64+ );
4865 }
4966
5067 @ Override
@@ -74,11 +91,30 @@ public String describe() {
7491 private long rowsReceived = 0 ;
7592 private long rowsEmitted = 0 ;
7693
77- public LinearScoreEvalOperator (int discriminatorPosition , int scorePosition , LinearConfig config ) {
94+ private final String sourceText ;
95+ private final int sourceLine ;
96+ private final int sourceColumn ;
97+ private Warnings warnings ;
98+ private final DriverContext driverContext ;
99+
100+ public LinearScoreEvalOperator (
101+ DriverContext driverContext ,
102+ int discriminatorPosition ,
103+ int scorePosition ,
104+ LinearConfig config ,
105+ String sourceText ,
106+ int sourceLine ,
107+ int sourceColumn
108+ ) {
78109 this .scorePosition = scorePosition ;
79110 this .discriminatorPosition = discriminatorPosition ;
80111 this .config = config ;
81112 this .normalizer = createNormalizer (config .normalizer ());
113+ this .driverContext = driverContext ;
114+
115+ this .sourceText = sourceText ;
116+ this .sourceLine = sourceLine ;
117+ this .sourceColumn = sourceColumn ;
82118
83119 finished = false ;
84120 inputPages = new ArrayDeque <>();
@@ -123,25 +159,54 @@ private void createOutputPages() {
123159
124160 private void processInputPage (Page inputPage ) {
125161 BytesRefBlock discriminatorBlock = inputPage .getBlock (discriminatorPosition );
126- DoubleVectorBlock initialScoreBlock = inputPage .getBlock (scorePosition );
162+ DoubleBlock initialScoreBlock = inputPage .getBlock (scorePosition );
127163
128164 Page newPage = null ;
129165 Block scoreBlock = null ;
130- DoubleVector .Builder scores = null ;
166+ DoubleBlock .Builder scores = null ;
131167
132168 try {
133- scores = discriminatorBlock .blockFactory ().newDoubleVectorBuilder (discriminatorBlock .getPositionCount ());
169+ scores = discriminatorBlock .blockFactory ().newDoubleBlockBuilder (discriminatorBlock .getPositionCount ());
134170
135171 for (int i = 0 ; i < inputPage .getPositionCount (); i ++) {
136- String discriminator = discriminatorBlock .getBytesRef (i , new BytesRef ()).utf8ToString ();
172+ Object discriminatorValue = BlockUtils .toJavaObject (discriminatorBlock , i );
173+
174+ if (discriminatorValue == null ) {
175+ warnings ().registerException (new IllegalArgumentException ("group column has null values; assigning null scores" ));
176+ scores .appendNull ();
177+ continue ;
178+ } else if (discriminatorValue instanceof List <?>) {
179+ warnings ().registerException (
180+ new IllegalArgumentException ("group column contains multivalued entries; assigning null scores" )
181+ );
182+ scores .appendNull ();
183+ continue ;
184+ }
185+ String discriminator = ((BytesRef ) discriminatorValue ).utf8ToString ();
137186
138187 var weight = config .weights ().get (discriminator ) == null ? 1.0 : config .weights ().get (discriminator );
139188
140- double score = initialScoreBlock .getDouble (i );
189+ initialScoreBlock .doesHaveMultivaluedFields ();
190+
191+ Object scoreValue = BlockUtils .toJavaObject (initialScoreBlock , i );
192+ if (scoreValue == null ) {
193+ warnings ().registerException (new IllegalArgumentException ("score column has null values; assigning null scores" ));
194+ scores .appendNull ();
195+ continue ;
196+ } else if (scoreValue instanceof List <?>) {
197+ warnings ().registerException (
198+ new IllegalArgumentException ("score column contains multivalued entries; assigning null scores" )
199+ );
200+ scores .appendNull ();
201+ continue ;
202+ }
203+
204+ double score = (double ) scoreValue ;
205+
141206 scores .appendDouble (weight * normalizer .normalize (score , discriminator ));
142207 }
143208
144- scoreBlock = scores .build (). asBlock () ;
209+ scoreBlock = scores .build ();
145210 newPage = inputPage .appendBlock (scoreBlock );
146211
147212 int [] projections = new int [newPage .getBlockCount () - 1 ];
@@ -270,23 +335,43 @@ private Normalizer createNormalizer(LinearConfig.Normalizer normalizer) {
270335 };
271336 }
272337
273- private interface Normalizer {
274- double normalize (double score , String discriminator );
338+ private abstract static class Normalizer {
339+ abstract double normalize (double score , String discriminator );
275340
276- void preprocess (Collection <Page > inputPages , int scorePosition , int discriminatorPosition );
341+ abstract void preprocess (double score , String discriminator );
342+
343+ void finalizePreprocess () {};
344+
345+ void preprocess (Collection <Page > inputPages , int scorePosition , int discriminatorPosition ) {
346+ for (Page inputPage : inputPages ) {
347+ DoubleBlock scoreBlock = inputPage .getBlock (scorePosition );
348+ BytesRefBlock discriminatorBlock = inputPage .getBlock (discriminatorPosition );
349+
350+ for (int i = 0 ; i < inputPage .getPositionCount (); i ++) {
351+ Object scoreValue = BlockUtils .toJavaObject (scoreBlock , i );
352+ Object discriminatorValue = BlockUtils .toJavaObject (discriminatorBlock , i );
353+
354+ if (scoreValue instanceof Double score && discriminatorValue instanceof BytesRef discriminator ) {
355+ preprocess (score , discriminator .utf8ToString ());
356+ }
357+ }
358+ }
359+
360+ finalizePreprocess ();
361+ }
277362 }
278363
279- private class NoneNormalizer implements Normalizer {
364+ private static class NoneNormalizer extends Normalizer {
280365 @ Override
281366 public double normalize (double score , String discriminator ) {
282367 return score ;
283368 }
284369
285370 @ Override
286- public void preprocess (Collection < Page > inputPages , int scorePosition , int discriminatorPosition ) {}
371+ void preprocess (double score , String discriminator ) {}
287372 }
288373
289- private class L2NormNormalizer implements Normalizer {
374+ private static class L2NormNormalizer extends Normalizer {
290375 private final Map <String , Double > l2Norms = new HashMap <>();
291376
292377 @ Override
@@ -297,24 +382,17 @@ public double normalize(double score, String discriminator) {
297382 }
298383
299384 @ Override
300- public void preprocess (Collection <Page > inputPages , int scorePosition , int discriminatorPosition ) {
301- for (Page inputPage : inputPages ) {
302- DoubleVectorBlock scoreBlock = inputPage .getBlock (scorePosition );
303- BytesRefBlock discriminatorBlock = inputPage .getBlock (discriminatorPosition );
304-
305- for (int i = 0 ; i < inputPage .getPositionCount (); i ++) {
306- double score = scoreBlock .getDouble (i );
307- String discriminator = discriminatorBlock .getBytesRef (i , new BytesRef ()).utf8ToString ();
308-
309- l2Norms .compute (discriminator , (k , v ) -> v == null ? score * score : v + score * score );
310- }
311- }
385+ void preprocess (double score , String discriminator ) {
386+ l2Norms .compute (discriminator , (k , v ) -> v == null ? score * score : v + score * score );
387+ }
312388
389+ @ Override
390+ void finalizePreprocess () {
313391 l2Norms .replaceAll ((k , v ) -> Math .sqrt (v ));
314392 }
315393 }
316394
317- private class MinMaxNormalizer implements Normalizer {
395+ private static class MinMaxNormalizer extends Normalizer {
318396 private final Map <String , Double > minScores = new HashMap <>();
319397 private final Map <String , Double > maxScores = new HashMap <>();
320398
@@ -334,19 +412,17 @@ public double normalize(double score, String discriminator) {
334412 }
335413
336414 @ Override
337- public void preprocess (Collection <Page > inputPages , int scorePosition , int discriminatorPosition ) {
338- for (Page inputPage : inputPages ) {
339- DoubleVectorBlock scoreBlock = inputPage .getBlock (scorePosition );
340- BytesRefBlock discriminatorBlock = inputPage .getBlock (discriminatorPosition );
341-
342- for (int i = 0 ; i < inputPage .getPositionCount (); i ++) {
343- double score = scoreBlock .getDouble (i );
344- String discriminator = discriminatorBlock .getBytesRef (i , new BytesRef ()).utf8ToString ();
415+ void preprocess (double score , String discriminator ) {
416+ minScores .compute (discriminator , (key , value ) -> value == null ? score : Math .min (value , score ));
417+ maxScores .compute (discriminator , (key , value ) -> value == null ? score : Math .max (value , score ));
418+ }
419+ }
345420
346- minScores .compute (discriminator , (key , value ) -> value == null ? score : Math .min (value , score ));
347- maxScores .compute (discriminator , (key , value ) -> value == null ? score : Math .max (value , score ));
348- }
349- }
421+ private Warnings warnings () {
422+ if (warnings == null ) {
423+ this .warnings = Warnings .createWarnings (driverContext .warningsMode (), sourceLine , sourceColumn , sourceText );
350424 }
425+
426+ return warnings ;
351427 }
352428}
0 commit comments