Skip to content

Commit 8ec5e1c

Browse files
authored
fix: Fix sample() reporting identical values in the entire column (#338)
1 parent d022e08 commit 8ec5e1c

File tree

9 files changed

+219
-46
lines changed

9 files changed

+219
-46
lines changed

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
* In `arrange()`, if the data was grouped, the order was never maintained even if
2121
`maintain_order = TRUE` was passed in `group_by()`. This is now fixed (#332).
2222

23+
* When exporting to CSV, `null_values` alone did not apply and could override explicitly
24+
provided `null_value`. This is now fixed (@Yousa-Mirage, #334).
25+
26+
* Fix `sample()` to make it work correctly (@Yousa-Mirage, #338).
27+
2328
# tidypolars 0.17.0
2429

2530
`tidypolars` requires `polars` >= 1.9.0 and `dplyr` >= 1.2.0.

R/funs-default.R

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,20 @@ pl_round <- function(x, digits = 0, ...) {
355355

356356
pl_sample <- function(x, size = NULL, replace = FALSE, ...) {
357357
check_empty_dots(...)
358-
# TODO: how should I handle seed, given that R sample() doesn't have this arg
358+
# WARNING: random seed is not supported and cannot take effect.
359+
if (missing(size)) {
360+
size <- x$len()
361+
}
362+
if (!is_polars_expr(size)) {
363+
if (!is.numeric(size) || size <= 0 || size %% 1 != 0) {
364+
cli_abort("{.code size} must be a positive integer.")
365+
}
366+
size <- as.integer(size)
367+
}
368+
359369
out <- x$sample(n = size, with_replacement = replace, shuffle = TRUE)
360-
if (is.null(size) || size == 1) {
370+
371+
if (!is_polars_expr(size) && size == 1L) {
361372
out <- out$first()
362373
}
363374
out

tests/testthat/_snaps/funs_default-lazy.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
Caused by error:
99
! lengths don't match: unable to add a column of length 4 to a DataFrame of height 5
1010

11+
# sample() validates size
12+
13+
Code
14+
current$collect()
15+
Condition
16+
Error in `mutate()`:
17+
! Error while running function `sample()` in Polars.
18+
x `size` must be a positive integer.
19+
1120
# seq_len() works
1221

1322
Code

tests/testthat/_snaps/funs_default.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@
1010
Caused by error:
1111
! lengths don't match: unable to add a column of length 4 to a DataFrame of height 5
1212

13+
# sample() validates size
14+
15+
Code
16+
mutate(test_pl, y = sample(x, size = 1.5))
17+
Condition
18+
Error in `mutate()`:
19+
! Error while running function `sample()` in Polars.
20+
x `size` must be a positive integer.
21+
1322
# seq_len() works
1423

1524
Code

tests/testthat/test-funs_default-lazy.R

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,96 @@ test_that("round() works", {
175175
)
176176
})
177177

178+
test_that("sample() works with default size and n() size", {
179+
test_df <- tibble(x = 1:5)
180+
test_pl <- as_polars_lf(test_df)
181+
182+
foo <- test_pl |>
183+
mutate(y = sample(x)) |>
184+
pull(y)
185+
res <- test_df |>
186+
mutate(y = sample(x)) |>
187+
pull(y)
188+
189+
expect_equal_lazy(sort(foo), sort(res))
190+
191+
foo_replace <- test_pl |>
192+
mutate(y = sample(x, replace = TRUE)) |>
193+
pull(y)
194+
res_replace <- test_df |>
195+
mutate(y = sample(x, replace = TRUE)) |>
196+
pull(y)
197+
198+
expect_true(all(foo_replace %in% 1:5))
199+
expect_true(all(res_replace %in% 1:5))
200+
201+
foo_1 <- test_pl |>
202+
mutate(y = sample(x, size = 1)) |>
203+
pull(y)
204+
res_1 <- test_df |>
205+
mutate(y = sample(x, size = 1)) |>
206+
pull(y)
207+
208+
expect_true(unique(foo_1) %in% 1:5)
209+
expect_true(unique(res_1) %in% 1:5)
210+
211+
foo_n <- test_pl |>
212+
mutate(y = sample(x, size = n())) |>
213+
pull(y)
214+
res_n <- test_df |>
215+
mutate(y = sample(x, size = n())) |>
216+
pull(y)
217+
218+
expect_equal_lazy(sort(foo_n), sort(res_n))
219+
})
220+
221+
test_that("sample() warns on unsupported args", {
222+
test_df <- tibble(x = 1:5)
223+
test_pl <- as_polars_lf(test_df)
224+
225+
expect_warning(
226+
mutate(test_pl, y = sample(x, prob = 0.5)),
227+
"doesn't know how to use some arguments"
228+
)
229+
})
230+
231+
test_that("sample() validates size", {
232+
test_df <- tibble(x = 1:5)
233+
test_pl <- as_polars_lf(test_df)
234+
235+
expect_both_error(
236+
mutate(test_pl, y = sample(x, size = -1)),
237+
mutate(test_df, y = sample(x, size = -1))
238+
)
239+
240+
expect_both_error(
241+
mutate(test_pl, y = sample(x, size = 0)),
242+
mutate(test_df, y = sample(x, size = 0))
243+
)
244+
245+
expect_both_error(
246+
mutate(test_pl, y = sample(x, size = NULL)),
247+
mutate(test_df, y = sample(x, size = NULL))
248+
)
249+
250+
expect_both_error(
251+
mutate(test_pl, y = sample(x, size = 3)),
252+
mutate(test_df, y = sample(x, size = 3))
253+
)
254+
255+
expect_both_error(
256+
mutate(test_pl, y = sample(x, size = 100, replace = FALSE)),
257+
mutate(test_df, y = sample(x, size = 100, replace = FALSE))
258+
)
259+
260+
# `mutate(test_df, y = sample(x, size = 1.5))` has a weird behavior
261+
# when size is a double in [1, 2)
262+
expect_snapshot_lazy(
263+
mutate(test_pl, y = sample(x, size = 1.5)),
264+
error = TRUE
265+
)
266+
})
267+
178268
test_that("stats::lag() is not supported", {
179269
test_df <- tibble(x = c(10, 20, 30, 40, 10, 20, 30, 40))
180270
test_pl <- as_polars_lf(test_df)

tests/testthat/test-funs_default.R

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,96 @@ test_that("round() works", {
171171
)
172172
})
173173

174+
test_that("sample() works with default size and n() size", {
175+
test_df <- tibble(x = 1:5)
176+
test_pl <- as_polars_df(test_df)
177+
178+
foo <- test_pl |>
179+
mutate(y = sample(x)) |>
180+
pull(y)
181+
res <- test_df |>
182+
mutate(y = sample(x)) |>
183+
pull(y)
184+
185+
expect_equal(sort(foo), sort(res))
186+
187+
foo_replace <- test_pl |>
188+
mutate(y = sample(x, replace = TRUE)) |>
189+
pull(y)
190+
res_replace <- test_df |>
191+
mutate(y = sample(x, replace = TRUE)) |>
192+
pull(y)
193+
194+
expect_true(all(foo_replace %in% 1:5))
195+
expect_true(all(res_replace %in% 1:5))
196+
197+
foo_1 <- test_pl |>
198+
mutate(y = sample(x, size = 1)) |>
199+
pull(y)
200+
res_1 <- test_df |>
201+
mutate(y = sample(x, size = 1)) |>
202+
pull(y)
203+
204+
expect_true(unique(foo_1) %in% 1:5)
205+
expect_true(unique(res_1) %in% 1:5)
206+
207+
foo_n <- test_pl |>
208+
mutate(y = sample(x, size = n())) |>
209+
pull(y)
210+
res_n <- test_df |>
211+
mutate(y = sample(x, size = n())) |>
212+
pull(y)
213+
214+
expect_equal(sort(foo_n), sort(res_n))
215+
})
216+
217+
test_that("sample() warns on unsupported args", {
218+
test_df <- tibble(x = 1:5)
219+
test_pl <- as_polars_df(test_df)
220+
221+
expect_warning(
222+
mutate(test_pl, y = sample(x, prob = 0.5)),
223+
"doesn't know how to use some arguments"
224+
)
225+
})
226+
227+
test_that("sample() validates size", {
228+
test_df <- tibble(x = 1:5)
229+
test_pl <- as_polars_df(test_df)
230+
231+
expect_both_error(
232+
mutate(test_pl, y = sample(x, size = -1)),
233+
mutate(test_df, y = sample(x, size = -1))
234+
)
235+
236+
expect_both_error(
237+
mutate(test_pl, y = sample(x, size = 0)),
238+
mutate(test_df, y = sample(x, size = 0))
239+
)
240+
241+
expect_both_error(
242+
mutate(test_pl, y = sample(x, size = NULL)),
243+
mutate(test_df, y = sample(x, size = NULL))
244+
)
245+
246+
expect_both_error(
247+
mutate(test_pl, y = sample(x, size = 3)),
248+
mutate(test_df, y = sample(x, size = 3))
249+
)
250+
251+
expect_both_error(
252+
mutate(test_pl, y = sample(x, size = 100, replace = FALSE)),
253+
mutate(test_df, y = sample(x, size = 100, replace = FALSE))
254+
)
255+
256+
# `mutate(test_df, y = sample(x, size = 1.5))` has a weird behavior
257+
# when size is a double in [1, 2)
258+
expect_snapshot(
259+
mutate(test_pl, y = sample(x, size = 1.5)),
260+
error = TRUE
261+
)
262+
})
263+
174264
test_that("stats::lag() is not supported", {
175265
test_df <- tibble(x = c(10, 20, 30, 40, 10, 20, 30, 40))
176266
test_pl <- as_polars_df(test_df)

tests/testthat/test-funs_math-lazy.R

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -296,28 +296,6 @@ test_that("rank() works on various input types", {
296296
)
297297
})
298298

299-
test_that("warns if unknown args", {
300-
test_df <- tibble(
301-
x1 = c("a", "a", "b", "a", "c"),
302-
x2 = c(2, 1, 5, 3, 1),
303-
value = sample(11:15),
304-
value_trigo = seq(0, 0.4, 0.1),
305-
value_mix = -2:2,
306-
value_with_NA = c(-2, -1, NA, 1, 2)
307-
)
308-
test_pl <- as_polars_lf(test_df)
309-
foo <- test_pl |>
310-
mutate(x = sample(x2)) |>
311-
pull(x)
312-
313-
expect_true(all(foo %in% c(1, 2, 3, 5)))
314-
315-
expect_warning(
316-
test_pl |> mutate(x = sample(x2, prob = 0.5)),
317-
"doesn't know how to use some arguments"
318-
)
319-
})
320-
321299
test_that("%% and %/% work", {
322300
test_df <- tibble(
323301
x1 = c("a", "a", "b", "a", "c"),

tests/testthat/test-funs_math.R

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -292,28 +292,6 @@ test_that("rank() works on various input types", {
292292
)
293293
})
294294

295-
test_that("warns if unknown args", {
296-
test_df <- tibble(
297-
x1 = c("a", "a", "b", "a", "c"),
298-
x2 = c(2, 1, 5, 3, 1),
299-
value = sample(11:15),
300-
value_trigo = seq(0, 0.4, 0.1),
301-
value_mix = -2:2,
302-
value_with_NA = c(-2, -1, NA, 1, 2)
303-
)
304-
test_pl <- as_polars_df(test_df)
305-
foo <- test_pl |>
306-
mutate(x = sample(x2)) |>
307-
pull(x)
308-
309-
expect_true(all(foo %in% c(1, 2, 3, 5)))
310-
311-
expect_warning(
312-
test_pl |> mutate(x = sample(x2, prob = 0.5)),
313-
"doesn't know how to use some arguments"
314-
)
315-
})
316-
317295
test_that("%% and %/% work", {
318296
test_df <- tibble(
319297
x1 = c("a", "a", "b", "a", "c"),

vignettes/supported-functions.Rmd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ out <- tribble(
6868
"`base`", "`rank`",
6969
"`base`", "`rev`",
7070
"`base`", "`round`",
71+
"`base`", "`sample`",
7172
"`base`", "`seq`",
7273
"`base`", "`seq_len`",
7374
"`base`", "`sin`",
@@ -187,6 +188,8 @@ out <- tribble(
187188
"In `tidypolars`, `na.last = NA` is not supported.",
188189
Package == "`base`" & Function == "`sort`" ~
189190
"In `tidypolars`, `na.last` must be explicitly supplied as `TRUE` or `FALSE`.",
191+
Package == "`base`" & Function == "`sample`" ~
192+
"`set.seed()` is not supported. Randomness is handled by Polars and does not use R's RNG state.",
190193
Package == "`lubridate`" & Function %in% c("`rollbackward`", "`rollback`", "`rollforward`") ~
191194
"While time zone handling should mimick the behaviour of `lubridate` in most cases, it is possible that Polars errors if rolling back/forward leads to am ambiguous datetime. It is also possible to have some differences in hours/minutes/seconds when converting between Polars and R.",
192195
Package == "`lubridate`" & Function == "`wday`" ~

0 commit comments

Comments
 (0)