Skip to content

Commit a481794

Browse files
huaxingaoFelix Cheung
authored andcommitted
[SPARK-25007][R] Add array_intersect/array_except/array_union/shuffle to SparkR
## What changes were proposed in this pull request? Add the R version of array_intersect/array_except/array_union/shuffle ## How was this patch tested? Add test in test_sparkSQL.R Author: Huaxin Gao <[email protected]> Closes apache#22291 from huaxingao/spark-25007.
1 parent a3dccd2 commit a481794

File tree

4 files changed

+97
-1
lines changed

4 files changed

+97
-1
lines changed

R/pkg/NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ exportMethods("%<=>%",
204204
"approxQuantile",
205205
"array_contains",
206206
"array_distinct",
207+
"array_except",
208+
"array_intersect",
207209
"array_join",
208210
"array_max",
209211
"array_min",
@@ -212,6 +214,7 @@ exportMethods("%<=>%",
212214
"array_repeat",
213215
"array_sort",
214216
"arrays_overlap",
217+
"array_union",
215218
"arrays_zip",
216219
"asc",
217220
"ascii",
@@ -355,6 +358,7 @@ exportMethods("%<=>%",
355358
"shiftLeft",
356359
"shiftRight",
357360
"shiftRightUnsigned",
361+
"shuffle",
358362
"sd",
359363
"sign",
360364
"signum",

R/pkg/R/functions.R

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ NULL
208208
#' # Dataframe used throughout this doc
209209
#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
210210
#' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp))
211-
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1)))
211+
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1), shuffle(tmp$v1)))
212212
#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1)))
213213
#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1)))
214214
#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 21)))
@@ -223,6 +223,8 @@ NULL
223223
#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))
224224
#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp))
225225
#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5)))
226+
#' head(select(tmp4, array_except(tmp4$v4, tmp4$v5), array_intersect(tmp4$v4, tmp4$v5)))
227+
#' head(select(tmp4, array_union(tmp4$v4, tmp4$v5)))
226228
#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, tmp4$v5)))
227229
#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))
228230
#' tmp5 <- mutate(df, v6 = create_array(df$model, df$model))
@@ -3024,6 +3026,34 @@ setMethod("array_distinct",
30243026
column(jc)
30253027
})
30263028

3029+
#' @details
3030+
#' \code{array_except}: Returns an array of the elements in the first array but not in the second
3031+
#' array, without duplicates. The order of elements in the result is not determined.
3032+
#'
3033+
#' @rdname column_collection_functions
3034+
#' @aliases array_except array_except,Column-method
3035+
#' @note array_except since 2.4.0
3036+
setMethod("array_except",
3037+
signature(x = "Column", y = "Column"),
3038+
function(x, y) {
3039+
jc <- callJStatic("org.apache.spark.sql.functions", "array_except", x@jc, y@jc)
3040+
column(jc)
3041+
})
3042+
3043+
#' @details
3044+
#' \code{array_intersect}: Returns an array of the elements in the intersection of the given two
3045+
#' arrays, without duplicates.
3046+
#'
3047+
#' @rdname column_collection_functions
3048+
#' @aliases array_intersect array_intersect,Column-method
3049+
#' @note array_intersect since 2.4.0
3050+
setMethod("array_intersect",
3051+
signature(x = "Column", y = "Column"),
3052+
function(x, y) {
3053+
jc <- callJStatic("org.apache.spark.sql.functions", "array_intersect", x@jc, y@jc)
3054+
column(jc)
3055+
})
3056+
30273057
#' @details
30283058
#' \code{array_join}: Concatenates the elements of column using the delimiter.
30293059
#' Null values are replaced with nullReplacement if set, otherwise they are ignored.
@@ -3149,6 +3179,20 @@ setMethod("arrays_overlap",
31493179
column(jc)
31503180
})
31513181

3182+
#' @details
3183+
#' \code{array_union}: Returns an array of the elements in the union of the given two arrays,
3184+
#' without duplicates.
3185+
#'
3186+
#' @rdname column_collection_functions
3187+
#' @aliases array_union array_union,Column-method
3188+
#' @note array_union since 2.4.0
3189+
setMethod("array_union",
3190+
signature(x = "Column", y = "Column"),
3191+
function(x, y) {
3192+
jc <- callJStatic("org.apache.spark.sql.functions", "array_union", x@jc, y@jc)
3193+
column(jc)
3194+
})
3195+
31523196
#' @details
31533197
#' \code{arrays_zip}: Returns a merged array of structs in which the N-th struct contains all N-th
31543198
#' values of input arrays.
@@ -3167,6 +3211,19 @@ setMethod("arrays_zip",
31673211
column(jc)
31683212
})
31693213

3214+
#' @details
3215+
#' \code{shuffle}: Returns a random permutation of the given array.
3216+
#'
3217+
#' @rdname column_collection_functions
3218+
#' @aliases shuffle shuffle,Column-method
3219+
#' @note shuffle since 2.4.0
3220+
setMethod("shuffle",
3221+
signature(x = "Column"),
3222+
function(x) {
3223+
jc <- callJStatic("org.apache.spark.sql.functions", "shuffle", x@jc)
3224+
column(jc)
3225+
})
3226+
31703227
#' @details
31713228
#' \code{flatten}: Creates a single array from an array of arrays.
31723229
#' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed.

R/pkg/R/generics.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,14 @@ setGeneric("array_contains", function(x, value) { standardGeneric("array_contain
767767
#' @name NULL
768768
setGeneric("array_distinct", function(x) { standardGeneric("array_distinct") })
769769

770+
#' @rdname column_collection_functions
771+
#' @name NULL
772+
setGeneric("array_except", function(x, y) { standardGeneric("array_except") })
773+
774+
#' @rdname column_collection_functions
775+
#' @name NULL
776+
setGeneric("array_intersect", function(x, y) { standardGeneric("array_intersect") })
777+
770778
#' @rdname column_collection_functions
771779
#' @name NULL
772780
setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") })
@@ -799,6 +807,10 @@ setGeneric("array_sort", function(x) { standardGeneric("array_sort") })
799807
#' @name NULL
800808
setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") })
801809

810+
#' @rdname column_collection_functions
811+
#' @name NULL
812+
setGeneric("array_union", function(x, y) { standardGeneric("array_union") })
813+
802814
#' @rdname column_collection_functions
803815
#' @name NULL
804816
setGeneric("arrays_zip", function(x, ...) { standardGeneric("arrays_zip") })
@@ -1220,6 +1232,10 @@ setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") })
12201232
#' @name NULL
12211233
setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") })
12221234

1235+
#' @rdname column_collection_functions
1236+
#' @name NULL
1237+
setGeneric("shuffle", function(x) { standardGeneric("shuffle") })
1238+
12231239
#' @rdname column_math_functions
12241240
#' @name NULL
12251241
setGeneric("signum", function(x) { standardGeneric("signum") })

R/pkg/tests/fulltests/test_sparkSQL.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,25 @@ test_that("column functions", {
15981598
result <- collect(select(df, element_at(df$map, "y")))[[1]]
15991599
expect_equal(result, 2)
16001600

1601+
# Test array_except(), array_intersect() and array_union()
1602+
df <- createDataFrame(list(list(list(1L, 2L, 3L), list(3L, 1L)),
1603+
list(list(1L, 2L), list(3L, 4L)),
1604+
list(list(1L, 2L, 3L), list(3L, 4L))))
1605+
result1 <- collect(select(df, array_except(df[[1]], df[[2]])))[[1]]
1606+
expect_equal(result1, list(list(2L), list(1L, 2L), list(1L, 2L)))
1607+
1608+
result2 <- collect(select(df, array_intersect(df[[1]], df[[2]])))[[1]]
1609+
expect_equal(result2, list(list(1L, 3L), list(), list(3L)))
1610+
1611+
result3 <- collect(select(df, array_union(df[[1]], df[[2]])))[[1]]
1612+
expect_equal(result3, list(list(1L, 2L, 3L), list(1L, 2L, 3L, 4L), list(1L, 2L, 3L, 4L)))
1613+
1614+
# Test shuffle()
1615+
df <- createDataFrame(list(list(list(1L, 20L, 3L, 5L)), list(list(4L, 5L, 6L, 7L))))
1616+
result <- collect(select(df, shuffle(df[[1]])))[[1]]
1617+
expect_true(setequal(result[[1]], c(1L, 20L, 3L, 5L)))
1618+
expect_true(setequal(result[[2]], c(4L, 5L, 6L, 7L)))
1619+
16011620
# Test that stats::lag is working
16021621
expect_equal(length(lag(ldeaths, 12)), 72)
16031622

0 commit comments

Comments
 (0)