@@ -31,6 +31,7 @@ import org.apache.spark.sql.types._
31
31
import org .apache .spark .unsafe .Platform
32
32
import org .apache .spark .unsafe .array .ByteArrayMethods
33
33
import org .apache .spark .unsafe .types .{ByteArray , UTF8String }
34
+ import org .apache .spark .util .collection .OpenHashSet
34
35
35
36
/**
36
37
* Base trait for [[BinaryExpression ]]s with two arrays of the same element type and implicit
@@ -2355,3 +2356,281 @@ case class ArrayRemove(left: Expression, right: Expression)
2355
2356
2356
2357
override def prettyName : String = " array_remove"
2357
2358
}
2359
+
2360
+ /**
2361
+ * Removes duplicate values from the array.
2362
+ */
2363
+ @ ExpressionDescription (
2364
+ usage = " _FUNC_(array) - Removes duplicate values from the array." ,
2365
+ examples = """
2366
+ Examples:
2367
+ > SELECT _FUNC_(array(1, 2, 3, null, 3));
2368
+ [1,2,3,null]
2369
+ """ , since = " 2.4.0" )
2370
+ case class ArrayDistinct (child : Expression )
2371
+ extends UnaryExpression with ExpectsInputTypes {
2372
+
2373
+ override def inputTypes : Seq [AbstractDataType ] = Seq (ArrayType )
2374
+
2375
+ override def dataType : DataType = child.dataType
2376
+
2377
+ @ transient lazy val elementType : DataType = dataType.asInstanceOf [ArrayType ].elementType
2378
+
2379
+ @ transient private lazy val ordering : Ordering [Any ] =
2380
+ TypeUtils .getInterpretedOrdering(elementType)
2381
+
2382
+ override def checkInputDataTypes (): TypeCheckResult = {
2383
+ super .checkInputDataTypes() match {
2384
+ case f : TypeCheckResult .TypeCheckFailure => f
2385
+ case TypeCheckResult .TypeCheckSuccess =>
2386
+ TypeUtils .checkForOrderingExpr(elementType, s " function $prettyName" )
2387
+ }
2388
+ }
2389
+
2390
+ @ transient private lazy val elementTypeSupportEquals = elementType match {
2391
+ case BinaryType => false
2392
+ case _ : AtomicType => true
2393
+ case _ => false
2394
+ }
2395
+
2396
+ override def nullSafeEval (array : Any ): Any = {
2397
+ val data = array.asInstanceOf [ArrayData ].toArray[AnyRef ](elementType)
2398
+ if (elementTypeSupportEquals) {
2399
+ new GenericArrayData (data.distinct.asInstanceOf [Array [Any ]])
2400
+ } else {
2401
+ var foundNullElement = false
2402
+ var pos = 0
2403
+ for (i <- 0 until data.length) {
2404
+ if (data(i) == null ) {
2405
+ if (! foundNullElement) {
2406
+ foundNullElement = true
2407
+ pos = pos + 1
2408
+ }
2409
+ } else {
2410
+ var j = 0
2411
+ var done = false
2412
+ while (j <= i && ! done) {
2413
+ if (data(j) != null && ordering.equiv(data(j), data(i))) {
2414
+ done = true
2415
+ }
2416
+ j = j + 1
2417
+ }
2418
+ if (i == j - 1 ) {
2419
+ pos = pos + 1
2420
+ }
2421
+ }
2422
+ }
2423
+ new GenericArrayData (data.slice(0 , pos))
2424
+ }
2425
+ }
2426
+
2427
+ override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
2428
+ nullSafeCodeGen(ctx, ev, (array) => {
2429
+ val i = ctx.freshName(" i" )
2430
+ val j = ctx.freshName(" j" )
2431
+ val sizeOfDistinctArray = ctx.freshName(" sizeOfDistinctArray" )
2432
+ val getValue1 = CodeGenerator .getValue(array, elementType, i)
2433
+ val getValue2 = CodeGenerator .getValue(array, elementType, j)
2434
+ val foundNullElement = ctx.freshName(" foundNullElement" )
2435
+ val openHashSet = classOf [OpenHashSet [_]].getName
2436
+ val hs = ctx.freshName(" hs" )
2437
+ val classTag = s " scala.reflect.ClassTag $$ .MODULE $$ .Object() "
2438
+ if (elementTypeSupportEquals) {
2439
+ s """
2440
+ |int $sizeOfDistinctArray = 0;
2441
+ |boolean $foundNullElement = false;
2442
+ | $openHashSet $hs = new $openHashSet( $classTag);
2443
+ |for (int $i = 0; $i < $array.numElements(); $i ++) {
2444
+ | if ( $array.isNullAt( $i)) {
2445
+ | $foundNullElement = true;
2446
+ | } else {
2447
+ | $hs.add( $getValue1);
2448
+ | }
2449
+ |}
2450
+ | $sizeOfDistinctArray = $hs.size() + ( $foundNullElement ? 1 : 0);
2451
+ | ${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
2452
+ """ .stripMargin
2453
+ } else {
2454
+ s """
2455
+ |int $sizeOfDistinctArray = 0;
2456
+ |boolean $foundNullElement = false;
2457
+ |for (int $i = 0; $i < $array.numElements(); $i ++) {
2458
+ | if ( $array.isNullAt( $i)) {
2459
+ | if (!( $foundNullElement)) {
2460
+ | $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
2461
+ | $foundNullElement = true;
2462
+ | }
2463
+ | } else {
2464
+ | int $j;
2465
+ | for ( $j = 0; $j < $i; $j ++) {
2466
+ | if (! $array.isNullAt( $j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) {
2467
+ | break;
2468
+ | }
2469
+ | }
2470
+ | if ( $i == $j) {
2471
+ | $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
2472
+ | }
2473
+ | }
2474
+ |}
2475
+ |
2476
+ | ${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
2477
+ """ .stripMargin
2478
+ }
2479
+ })
2480
+ }
2481
+
2482
+ private def setNull (
2483
+ isPrimitive : Boolean ,
2484
+ foundNullElement : String ,
2485
+ distinctArray : String ,
2486
+ pos : String ): String = {
2487
+ val setNullValue =
2488
+ if (! isPrimitive) {
2489
+ s " $distinctArray[ $pos] = null " ;
2490
+ } else {
2491
+ s " $distinctArray.setNullAt( $pos) " ;
2492
+ }
2493
+
2494
+ s """
2495
+ |if (!( $foundNullElement)) {
2496
+ | $setNullValue;
2497
+ | $pos = $pos + 1;
2498
+ | $foundNullElement = true;
2499
+ |}
2500
+ """ .stripMargin
2501
+ }
2502
+
2503
+ private def setNotNullValue (isPrimitive : Boolean ,
2504
+ distinctArray : String ,
2505
+ pos : String ,
2506
+ getValue1 : String ,
2507
+ primitiveValueTypeName : String ): String = {
2508
+ if (! isPrimitive) {
2509
+ s " $distinctArray[ $pos] = $getValue1" ;
2510
+ } else {
2511
+ s " $distinctArray.set $primitiveValueTypeName( $pos, $getValue1) " ;
2512
+ }
2513
+ }
2514
+
2515
+ private def setValueForFastEval (
2516
+ isPrimitive : Boolean ,
2517
+ hs : String ,
2518
+ distinctArray : String ,
2519
+ pos : String ,
2520
+ getValue1 : String ,
2521
+ primitiveValueTypeName : String ): String = {
2522
+ val setValue = setNotNullValue(isPrimitive,
2523
+ distinctArray, pos, getValue1, primitiveValueTypeName)
2524
+ s """
2525
+ |if (!( $hs.contains( $getValue1))) {
2526
+ | $hs.add( $getValue1);
2527
+ | $setValue;
2528
+ | $pos = $pos + 1;
2529
+ |}
2530
+ """ .stripMargin
2531
+ }
2532
+
2533
+ private def setValueForBruteForceEval (
2534
+ isPrimitive : Boolean ,
2535
+ i : String ,
2536
+ j : String ,
2537
+ inputArray : String ,
2538
+ distinctArray : String ,
2539
+ pos : String ,
2540
+ getValue1 : String ,
2541
+ isEqual : String ,
2542
+ primitiveValueTypeName : String ): String = {
2543
+ val setValue = setNotNullValue(isPrimitive,
2544
+ distinctArray, pos, getValue1, primitiveValueTypeName)
2545
+ s """
2546
+ |int $j;
2547
+ |for ( $j = 0; $j < $i; $j ++) {
2548
+ | if (! $inputArray.isNullAt( $j) && $isEqual) {
2549
+ | break;
2550
+ | }
2551
+ |}
2552
+ |if ( $i == $j) {
2553
+ | $setValue;
2554
+ | $pos = $pos + 1;
2555
+ |}
2556
+ """ .stripMargin
2557
+ }
2558
+
2559
+ def genCodeForResult (
2560
+ ctx : CodegenContext ,
2561
+ ev : ExprCode ,
2562
+ inputArray : String ,
2563
+ size : String ): String = {
2564
+ val distinctArray = ctx.freshName(" distinctArray" )
2565
+ val i = ctx.freshName(" i" )
2566
+ val j = ctx.freshName(" j" )
2567
+ val pos = ctx.freshName(" pos" )
2568
+ val getValue1 = CodeGenerator .getValue(inputArray, elementType, i)
2569
+ val getValue2 = CodeGenerator .getValue(inputArray, elementType, j)
2570
+ val isEqual = ctx.genEqual(elementType, getValue1, getValue2)
2571
+ val foundNullElement = ctx.freshName(" foundNullElement" )
2572
+ val hs = ctx.freshName(" hs" )
2573
+ val openHashSet = classOf [OpenHashSet [_]].getName
2574
+ if (! CodeGenerator .isPrimitiveType(elementType)) {
2575
+ val arrayClass = classOf [GenericArrayData ].getName
2576
+ val classTag = s " scala.reflect.ClassTag $$ .MODULE $$ .Object() "
2577
+ val setNullForNonPrimitive =
2578
+ setNull(false , foundNullElement, distinctArray, pos)
2579
+ if (elementTypeSupportEquals) {
2580
+ val setValueForFast = setValueForFastEval(false , hs, distinctArray, pos, getValue1, " " )
2581
+ s """
2582
+ |int $pos = 0;
2583
+ |Object[] $distinctArray = new Object[ $size];
2584
+ |boolean $foundNullElement = false;
2585
+ | $openHashSet $hs = new $openHashSet( $classTag);
2586
+ |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
2587
+ | if ( $inputArray.isNullAt( $i)) {
2588
+ | $setNullForNonPrimitive;
2589
+ | } else {
2590
+ | $setValueForFast;
2591
+ | }
2592
+ |}
2593
+ | ${ev.value} = new $arrayClass( $distinctArray);
2594
+ """ .stripMargin
2595
+ } else {
2596
+ val setValueForBruteForce = setValueForBruteForceEval(
2597
+ false , i, j, inputArray, distinctArray, pos, getValue1, isEqual, " " )
2598
+ s """
2599
+ |int $pos = 0;
2600
+ |Object[] $distinctArray = new Object[ $size];
2601
+ |boolean $foundNullElement = false;
2602
+ |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
2603
+ | if ( $inputArray.isNullAt( $i)) {
2604
+ | $setNullForNonPrimitive;
2605
+ | } else {
2606
+ | $setValueForBruteForce;
2607
+ | }
2608
+ |}
2609
+ | ${ev.value} = new $arrayClass( $distinctArray);
2610
+ """ .stripMargin
2611
+ }
2612
+ } else {
2613
+ val primitiveValueTypeName = CodeGenerator .primitiveTypeName(elementType)
2614
+ val setNullForPrimitive = setNull(true , foundNullElement, distinctArray, pos)
2615
+ val classTag = s " scala.reflect.ClassTag $$ .MODULE $$ . $primitiveValueTypeName() "
2616
+ val setValueForFast =
2617
+ setValueForFastEval(true , hs, distinctArray, pos, getValue1, primitiveValueTypeName)
2618
+ s """
2619
+ | ${ctx.createUnsafeArray(distinctArray, size, elementType, s " $prettyName failed. " )}
2620
+ |int $pos = 0;
2621
+ |boolean $foundNullElement = false;
2622
+ | $openHashSet $hs = new $openHashSet( $classTag);
2623
+ |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
2624
+ | if ( $inputArray.isNullAt( $i)) {
2625
+ | $setNullForPrimitive;
2626
+ | } else {
2627
+ | $setValueForFast;
2628
+ | }
2629
+ |}
2630
+ | ${ev.value} = $distinctArray;
2631
+ """ .stripMargin
2632
+ }
2633
+ }
2634
+
2635
+ override def prettyName : String = " array_distinct"
2636
+ }
0 commit comments