@@ -352,29 +352,21 @@ private[hive] case class HiveUDAFFunction(
352
352
HiveEvaluator (evaluator, evaluator.init(GenericUDAFEvaluator .Mode .PARTIAL1 , inputInspectors))
353
353
}
354
354
355
- // The UDAF evaluator used to merge partial aggregation results.
355
+ // The UDAF evaluator used to consume partial aggregation results and produce final results.
356
+ // Hive `ObjectInspector` used to inspect final results.
356
357
@ transient
357
- private lazy val partial2ModeEvaluator = {
358
+ private lazy val finalHiveEvaluator = {
358
359
val evaluator = newEvaluator()
359
- evaluator.init(GenericUDAFEvaluator .Mode .PARTIAL2 , Array (partial1HiveEvaluator.objectInspector))
360
- evaluator
360
+ HiveEvaluator (
361
+ evaluator,
362
+ evaluator.init(GenericUDAFEvaluator .Mode .FINAL , Array (partial1HiveEvaluator.objectInspector)))
361
363
}
362
364
363
365
// Spark SQL data type of partial aggregation results
364
366
@ transient
365
367
private lazy val partialResultDataType =
366
368
inspectorToDataType(partial1HiveEvaluator.objectInspector)
367
369
368
- // The UDAF evaluator used to compute the final result from a partial aggregation result objects.
369
- // Hive `ObjectInspector` used to inspect the final aggregation result object.
370
- @ transient
371
- private lazy val finalHiveEvaluator = {
372
- val evaluator = newEvaluator()
373
- HiveEvaluator (
374
- evaluator,
375
- evaluator.init(GenericUDAFEvaluator .Mode .FINAL , Array (partial1HiveEvaluator.objectInspector)))
376
- }
377
-
378
370
// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
379
371
@ transient
380
372
private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
@@ -401,25 +393,43 @@ private[hive] case class HiveUDAFFunction(
401
393
s " $name( $distinct${children.map(_.sql).mkString(" , " )}) "
402
394
}
403
395
404
- override def createAggregationBuffer (): AggregationBuffer =
405
- partial1HiveEvaluator.evaluator.getNewAggregationBuffer
396
+ // The hive UDAF may create different buffers to handle different inputs: original data or
397
+ // aggregate buffer. However, the Spark UDAF framework does not expose this information when
398
+ // creating the buffer. Here we return null, and create the buffer in `update` and `merge`
399
+ // on demand, so that we can know what input we are dealing with.
400
+ override def createAggregationBuffer (): AggregationBuffer = null
406
401
407
402
@ transient
408
403
private lazy val inputProjection = UnsafeProjection .create(children)
409
404
410
405
override def update (buffer : AggregationBuffer , input : InternalRow ): AggregationBuffer = {
406
+ // The input is original data, we create buffer with the partial1 evaluator.
407
+ val nonNullBuffer = if (buffer == null ) {
408
+ partial1HiveEvaluator.evaluator.getNewAggregationBuffer
409
+ } else {
410
+ buffer
411
+ }
412
+
411
413
partial1HiveEvaluator.evaluator.iterate(
412
- buffer , wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
413
- buffer
414
+ nonNullBuffer , wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
415
+ nonNullBuffer
414
416
}
415
417
416
418
override def merge (buffer : AggregationBuffer , input : AggregationBuffer ): AggregationBuffer = {
419
+ // The input is aggregate buffer, we create buffer with the final evaluator.
420
+ val nonNullBuffer = if (buffer == null ) {
421
+ finalHiveEvaluator.evaluator.getNewAggregationBuffer
422
+ } else {
423
+ buffer
424
+ }
425
+
417
426
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
418
427
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
419
428
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
420
429
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
421
- partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
422
- buffer
430
+ finalHiveEvaluator.evaluator.merge(
431
+ nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
432
+ nonNullBuffer
423
433
}
424
434
425
435
override def eval (buffer : AggregationBuffer ): Any = {
@@ -450,11 +460,19 @@ private[hive] case class HiveUDAFFunction(
450
460
private val mutableRow = new GenericInternalRow (1 )
451
461
452
462
def serialize (buffer : AggregationBuffer ): Array [Byte ] = {
463
+ // The buffer may be null if there is no input. It's unclear if the hive UDAF accepts null
464
+ // buffer, for safety we create an empty buffer here.
465
+ val nonNullBuffer = if (buffer == null ) {
466
+ partial1HiveEvaluator.evaluator.getNewAggregationBuffer
467
+ } else {
468
+ buffer
469
+ }
470
+
453
471
// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
454
472
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
455
473
// Then we can unwrap it to a Spark SQL value.
456
474
mutableRow.update(0 , partialResultUnwrapper(
457
- partial1HiveEvaluator.evaluator.terminatePartial(buffer )))
475
+ partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer )))
458
476
val unsafeRow = projection(mutableRow)
459
477
val bytes = ByteBuffer .allocate(unsafeRow.getSizeInBytes)
460
478
unsafeRow.writeTo(bytes)
@@ -466,11 +484,11 @@ private[hive] case class HiveUDAFFunction(
466
484
// returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
467
485
// workaround here is creating an initial `AggregationBuffer` first and then merge the
468
486
// deserialized object into the buffer.
469
- val buffer = partial2ModeEvaluator .getNewAggregationBuffer
487
+ val buffer = finalHiveEvaluator.evaluator .getNewAggregationBuffer
470
488
val unsafeRow = new UnsafeRow (1 )
471
489
unsafeRow.pointTo(bytes, bytes.length)
472
490
val partialResult = unsafeRow.get(0 , partialResultDataType)
473
- partial2ModeEvaluator .merge(buffer, partialResultWrapper(partialResult))
491
+ finalHiveEvaluator.evaluator .merge(buffer, partialResultWrapper(partialResult))
474
492
buffer
475
493
}
476
494
}
0 commit comments