Skip to content

Commit bc6f191

Browse files
huaxingaoHyukjinKwon
authored andcommitted
[SPARK-24779][R] Add map_concat / map_from_entries / an option in months_between UDF to disable rounding-off
## What changes were proposed in this pull request? Add the R version of map_concat / map_from_entries / an option in months_between UDF to disable rounding-off ## How was this patch tested? Add test in test_sparkSQL.R Closes apache#21835 from huaxingao/spark-24779. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 0e2c487 commit bc6f191

File tree

4 files changed

+87
-7
lines changed

4 files changed

+87
-7
lines changed

R/pkg/NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,10 @@ exportMethods("%<=>%",
312312
"lower",
313313
"lpad",
314314
"ltrim",
315+
"map_concat",
315316
"map_entries",
316317
"map_from_arrays",
318+
"map_from_entries",
317319
"map_keys",
318320
"map_values",
319321
"max",

R/pkg/R/functions.R

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ NULL
8080
#' \item \code{from_utc_timestamp}, \code{to_utc_timestamp}: time zone to use.
8181
#' \item \code{next_day}: day of the week string.
8282
#' }
83+
#' @param ... additional argument(s).
84+
#' \itemize{
85+
#' \item \code{months_between}, this contains an optional parameter to specify the
86+
#' the result is rounded off to 8 digits.
87+
#' }
8388
#'
8489
#' @name column_datetime_diff_functions
8590
#' @rdname column_datetime_diff_functions
@@ -217,6 +222,7 @@ NULL
217222
#' additional named properties to control how it is converted and accepts the
218223
#' same options as the CSV data source.
219224
#' \item \code{arrays_zip}, this contains additional Columns of arrays to be merged.
225+
#' \item \code{map_concat}, this contains additional Columns of maps to be unioned.
220226
#' }
221227
#' @name column_collection_functions
222228
#' @rdname column_collection_functions
@@ -229,7 +235,7 @@ NULL
229235
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1), shuffle(tmp$v1)))
230236
#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1)))
231237
#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1)))
232-
#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 21)))
238+
#' head(select(tmp, reverse(tmp$v1), array_remove(tmp$v1, 21)))
233239
#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1))
234240
#' head(tmp2)
235241
#' head(select(tmp, posexplode(tmp$v1)))
@@ -238,15 +244,21 @@ NULL
238244
#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))
239245
#' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl))
240246
#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3)))
241-
#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))
247+
#' head(select(tmp3, element_at(tmp3$v3, "Valiant"), map_concat(tmp3$v3, tmp3$v3)))
242248
#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp))
243249
#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5)))
244250
#' head(select(tmp4, array_except(tmp4$v4, tmp4$v5), array_intersect(tmp4$v4, tmp4$v5)))
245251
#' head(select(tmp4, array_union(tmp4$v4, tmp4$v5)))
246-
#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, tmp4$v5)))
252+
#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5)))
247253
#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))
248254
#' tmp5 <- mutate(df, v6 = create_array(df$model, df$model))
249-
#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))}
255+
#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))
256+
#' tmp6 <- mutate(df, v7 = create_array(create_array(df$model, df$model)))
257+
#' head(select(tmp6, flatten(tmp6$v7)))
258+
#' tmp7 <- mutate(df, v8 = create_array(df$model, df$cyl), v9 = create_array(df$model, df$hp))
259+
#' head(select(tmp7, map_from_arrays(tmp7$v8, tmp7$v9)))
260+
#' tmp8 <- mutate(df, v10 = create_array(struct(df$model, df$cyl)))
261+
#' head(select(tmp8, map_from_entries(tmp8$v10)))}
250262
NULL
251263

252264
#' Window functions for Column operations
@@ -2074,15 +2086,21 @@ setMethod("levenshtein", signature(y = "Column"),
20742086
#' are on the same day of month, or both are the last day of month, time of day will be ignored.
20752087
#' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits.
20762088
#'
2089+
#' @param roundOff an optional parameter to specify if the result is rounded off to 8 digits
20772090
#' @rdname column_datetime_diff_functions
20782091
#' @aliases months_between months_between,Column-method
20792092
#' @note months_between since 1.5.0
20802093
setMethod("months_between", signature(y = "Column"),
2081-
function(y, x) {
2094+
function(y, x, roundOff = NULL) {
20822095
if (class(x) == "Column") {
20832096
x <- x@jc
20842097
}
2085-
jc <- callJStatic("org.apache.spark.sql.functions", "months_between", y@jc, x)
2098+
jc <- if (is.null(roundOff)) {
2099+
callJStatic("org.apache.spark.sql.functions", "months_between", y@jc, x)
2100+
} else {
2101+
callJStatic("org.apache.spark.sql.functions", "months_between", y@jc, x,
2102+
as.logical(roundOff))
2103+
}
20862104
column(jc)
20872105
})
20882106

@@ -3448,6 +3466,23 @@ setMethod("flatten",
34483466
column(jc)
34493467
})
34503468

3469+
#' @details
3470+
#' \code{map_concat}: Returns the union of all the given maps.
3471+
#'
3472+
#' @rdname column_collection_functions
3473+
#' @aliases map_concat map_concat,Column-method
3474+
#' @note map_concat since 3.0.0
3475+
setMethod("map_concat",
3476+
signature(x = "Column"),
3477+
function(x, ...) {
3478+
jcols <- lapply(list(x, ...), function(arg) {
3479+
stopifnot(class(arg) == "Column")
3480+
arg@jc
3481+
})
3482+
jc <- callJStatic("org.apache.spark.sql.functions", "map_concat", jcols)
3483+
column(jc)
3484+
})
3485+
34513486
#' @details
34523487
#' \code{map_entries}: Returns an unordered array of all entries in the given map.
34533488
#'
@@ -3476,6 +3511,19 @@ setMethod("map_from_arrays",
34763511
column(jc)
34773512
})
34783513

3514+
#' @details
3515+
#' \code{map_from_entries}: Returns a map created from the given array of entries.
3516+
#'
3517+
#' @rdname column_collection_functions
3518+
#' @aliases map_from_entries map_from_entries,Column-method
3519+
#' @note map_from_entries since 3.0.0
3520+
setMethod("map_from_entries",
3521+
signature(x = "Column"),
3522+
function(x) {
3523+
jc <- callJStatic("org.apache.spark.sql.functions", "map_from_entries", x@jc)
3524+
column(jc)
3525+
})
3526+
34793527
#' @details
34803528
#' \code{map_keys}: Returns an unordered array containing the keys of the map.
34813529
#'

R/pkg/R/generics.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,10 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") })
10781078
#' @name NULL
10791079
setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") })
10801080

1081+
#' @rdname column_collection_functions
1082+
#' @name NULL
1083+
setGeneric("map_concat", function(x, ...) { standardGeneric("map_concat") })
1084+
10811085
#' @rdname column_collection_functions
10821086
#' @name NULL
10831087
setGeneric("map_entries", function(x) { standardGeneric("map_entries") })
@@ -1086,6 +1090,10 @@ setGeneric("map_entries", function(x) { standardGeneric("map_entries") })
10861090
#' @name NULL
10871091
setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") })
10881092

1093+
#' @rdname column_collection_functions
1094+
#' @name NULL
1095+
setGeneric("map_from_entries", function(x) { standardGeneric("map_from_entries") })
1096+
10891097
#' @rdname column_collection_functions
10901098
#' @name NULL
10911099
setGeneric("map_keys", function(x) { standardGeneric("map_keys") })
@@ -1113,7 +1121,7 @@ setGeneric("month", function(x) { standardGeneric("month") })
11131121

11141122
#' @rdname column_datetime_diff_functions
11151123
#' @name NULL
1116-
setGeneric("months_between", function(y, x) { standardGeneric("months_between") })
1124+
setGeneric("months_between", function(y, x, ...) { standardGeneric("months_between") })
11171125

11181126
#' @rdname count
11191127
setGeneric("n", function(x) { standardGeneric("n") })

R/pkg/tests/fulltests/test_sparkSQL.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,14 @@ test_that("column functions", {
14971497
df5 <- createDataFrame(list(list(a = "010101")))
14981498
expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15")
14991499

1500+
# Test months_between()
1501+
df <- createDataFrame(list(list(a = as.Date("1997-02-28"),
1502+
b = as.Date("1996-10-30"))))
1503+
result1 <- collect(select(df, alias(months_between(df[[1]], df[[2]]), "month")))[[1]]
1504+
expect_equal(result1, 3.93548387)
1505+
result2 <- collect(select(df, alias(months_between(df[[1]], df[[2]], FALSE), "month")))[[1]]
1506+
expect_equal(result2, 3.935483870967742)
1507+
15001508
# Test array_contains(), array_max(), array_min(), array_position(), element_at() and reverse()
15011509
df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L))))
15021510
result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]]
@@ -1542,6 +1550,13 @@ test_that("column functions", {
15421550
expected_entries <- list(as.environment(list(x = 1, y = 2)))
15431551
expect_equal(result, expected_entries)
15441552

1553+
# Test map_from_entries()
1554+
df <- createDataFrame(list(list(list(listToStruct(list(c1 = "x", c2 = 1L)),
1555+
listToStruct(list(c1 = "y", c2 = 2L))))))
1556+
result <- collect(select(df, map_from_entries(df[[1]])))[[1]]
1557+
expected_entries <- list(as.environment(list(x = 1L, y = 2L)))
1558+
expect_equal(result, expected_entries)
1559+
15451560
# Test array_repeat()
15461561
df <- createDataFrame(list(list("a", 3L), list("b", 2L)))
15471562
result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]]
@@ -1600,6 +1615,13 @@ test_that("column functions", {
16001615
result <- collect(select(df, flatten(df[[1]])))[[1]]
16011616
expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L)))
16021617

1618+
# Test map_concat
1619+
df <- createDataFrame(list(list(map1 = as.environment(list(x = 1, y = 2)),
1620+
map2 = as.environment(list(a = 3, b = 4)))))
1621+
result <- collect(select(df, map_concat(df[[1]], df[[2]])))[[1]]
1622+
expected_entries <- list(as.environment(list(x = 1, y = 2, a = 3, b = 4)))
1623+
expect_equal(result, expected_entries)
1624+
16031625
# Test map_entries(), map_keys(), map_values() and element_at()
16041626
df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2)))))
16051627
result <- collect(select(df, map_entries(df$map)))[[1]]

0 commit comments

Comments
 (0)