Skip to content

Commit e80cde1

Browse files
feat: Support expressions in count() and add_count() (#346)
1 parent 2a5cf25 commit e80cde1

File tree

8 files changed

+314
-117
lines changed

8 files changed

+314
-117
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
* Better error message in `filter()` when a condition uses `=` instead of `==` (#341).
1515

16+
* `count()` and `add_count()` now work with expressions, e.g. `count(mtcars, mpg + 1)`
17+
(#346).
18+
1619
## Bug fixes
1720

1821
* Fix `NA` handling in `cummin()`, `cumprod()`, `cumsum()` (@Yousa-Mirage, #326).

R/count.R

Lines changed: 180 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,100 @@ count.polars_data_frame <- function(
3939
mo <- attributes(x)$maintain_grp_order %||% FALSE
4040
is_grouped <- !is.null(grps)
4141

42-
disallow_named_expressions(...)
43-
vars <- tidyselect_dots(x, ...)
44-
vars <- c(grps, vars)
45-
out <- count_(x, vars, sort = sort, name = name, new_col = FALSE)
42+
polars_exprs <- translate_dots(
43+
x,
44+
...,
45+
env = rlang::current_env(),
46+
caller = rlang::caller_env()
47+
)
48+
49+
# Only unnamed inputs
50+
if (!is.null(names(polars_exprs))) {
51+
polars_exprs <- lapply(polars_exprs, \(x) {
52+
lapply(x, function(y) {
53+
if (length(y) == 0) {
54+
cli_abort(
55+
"{.pkg tidypolars} doesn't support both named and unnamed inputs in {.fn count}.",
56+
call = rlang::caller_env(4)
57+
)
58+
}
59+
y
60+
})
61+
})
62+
names(polars_exprs) <- NULL
63+
polars_exprs <- unlist(polars_exprs, recursive = FALSE)
64+
}
65+
66+
name <- check_count_name(x, names(x), name)
67+
68+
if (length(polars_exprs) == 0) {
69+
if (is_grouped) {
70+
out <- x$group_by(grps)$len()$rename(len = name)
71+
if (isTRUE(sort)) {
72+
out <- out$sort(
73+
name,
74+
!!!grps,
75+
descending = c(TRUE, rep(FALSE, length(grps)))
76+
)
77+
} else {
78+
out <- out$sort(grps)
79+
}
80+
out <- group_by(out, all_of(grps), maintain_order = mo)
81+
} else {
82+
out <- x$group_by(`__tidypolars_grp__` = pl$lit(1))$len()$drop(
83+
"__tidypolars_grp__"
84+
)$rename(len = name)
85+
}
86+
87+
return(add_tidypolars_class(out))
88+
}
89+
90+
if (is.null(names(polars_exprs))) {
91+
new_names <- enexprs(...)
92+
new_names <- lapply(new_names, expr_deparse)
93+
names(polars_exprs) <- unlist(new_names, use.names = FALSE)
94+
}
95+
96+
if (is_grouped) {
97+
# If there are some duplicates in grps and names(polars_exprs), we want to
98+
# favor the value in names(polars_exprs), but the column order of the
99+
# output should follow the order of grps and then names(polars_exprs).
100+
grps2 <- grps[!(grps %in% names(polars_exprs))]
101+
names_polars_exprs2 <- names(polars_exprs)[!(names(polars_exprs) %in% grps)]
102+
if (length(grps2) > 0) {
103+
out <- x$group_by(grps2, !!!polars_exprs)$len()$rename(len = name)$select(
104+
grps,
105+
names_polars_exprs2,
106+
name
107+
)
108+
} else {
109+
out <- x$group_by(!!!polars_exprs)$len()$rename(len = name)
110+
}
111+
} else {
112+
out <- x$group_by(!!!polars_exprs)$len()$rename(len = name)
113+
}
46114

47-
out <- if (is_grouped) {
48-
group_by(out, all_of(grps), maintain_order = mo)
115+
if (isTRUE(sort)) {
116+
if (is_grouped) {
117+
out <- out$sort(
118+
name,
119+
grps,
120+
!!!names(polars_exprs),
121+
descending = c(TRUE, rep(FALSE, length(grps) + length(polars_exprs)))
122+
)
123+
} else {
124+
out <- out$sort(
125+
name,
126+
!!!names(polars_exprs),
127+
descending = c(TRUE, rep(FALSE, length(polars_exprs)))
128+
)
129+
}
49130
} else {
50-
out
131+
out <- out$sort(grps, !!!names(polars_exprs))
132+
}
133+
134+
if (is_grouped) {
135+
out <- group_by(out, all_of(grps), maintain_order = mo)
51136
}
52137

53138
add_tidypolars_class(out)
@@ -59,14 +144,15 @@ tally.polars_data_frame <- function(x, wt = NULL, sort = FALSE, name = "n") {
59144
if (!missing(wt)) {
60145
check_unsupported_arg(wt = quo_text(enquo(wt)))
61146
}
62-
grps <- attributes(x)$pl_grps
63-
mo <- attributes(x)$maintain_grp_order %||% FALSE
64-
is_grouped <- !is.null(grps)
65-
out <- count_(x, grps, sort = sort, name = name, new_col = FALSE)
147+
out <- count(x, sort = sort, name = name)
148+
grps <- attributes(out)$pl_grps
149+
mo <- attributes(out)$maintain_grp_order %||% FALSE
66150

67151
if (length(grps) > 1) {
68152
grps <- grps[-length(grps)]
69153
out <- group_by(out, all_of(grps), maintain_order = mo)
154+
} else if (length(grps) == 1) {
155+
out <- ungroup(out)
70156
}
71157

72158
add_tidypolars_class(out)
@@ -96,95 +182,109 @@ add_count.polars_data_frame <- function(
96182
mo <- attributes(x)$maintain_grp_order %||% FALSE
97183
is_grouped <- !is.null(grps)
98184

99-
vars <- tidyselect_dots(x, ...)
100-
vars <- c(grps, vars)
101-
out <- count_(
185+
polars_exprs <- translate_dots(
102186
x,
103-
vars,
104-
sort = sort,
105-
name = name,
106-
new_col = TRUE,
107-
missing_name = missing(name)
187+
...,
188+
env = rlang::current_env(),
189+
caller = rlang::caller_env()
108190
)
109191

110-
out <- if (is_grouped) {
111-
group_by(out, all_of(grps), maintain_order = mo)
112-
} else {
113-
out
192+
# Only unnamed inputs
193+
if (!is.null(names(polars_exprs))) {
194+
polars_exprs <- lapply(polars_exprs, \(x) {
195+
lapply(x, function(y) {
196+
if (length(y) == 0) {
197+
cli_abort(
198+
"{.pkg tidypolars} doesn't support both named and unnamed inputs in {.fn add_count}.",
199+
call = rlang::caller_env(4)
200+
)
201+
}
202+
y
203+
})
204+
})
205+
names(polars_exprs) <- NULL
206+
polars_exprs <- unlist(polars_exprs, recursive = FALSE)
114207
}
115208

116-
add_tidypolars_class(out)
117-
}
209+
if (length(polars_exprs) == 0) {
210+
if (is_grouped) {
211+
out <- x$with_columns(pl$len()$over(!!!grps)$alias(name))
212+
if (isTRUE(sort)) {
213+
out <- out$sort(
214+
name,
215+
grps,
216+
descending = c(TRUE, rep(FALSE, length(grps)))
217+
)
218+
}
219+
out <- group_by(out, all_of(grps), maintain_order = mo)
220+
} else {
221+
out <- x$with_columns(pl$len()$alias(name))
222+
}
118223

119-
#' @rdname count.polars_data_frame
120-
#' @export
121-
add_count.polars_lazy_frame <- add_count.polars_data_frame
224+
return(add_tidypolars_class(out))
225+
}
122226

123-
count_ <- function(x, vars, sort, name, new_col = FALSE, missing_name = FALSE) {
124-
name <- check_count_name(x, vars, name, missing_name)
125-
if (isTRUE(new_col)) {
126-
if (length(vars) == 0) {
127-
out <- x$with_columns(
128-
pl$len()$alias(name)
129-
)
227+
if (is.null(names(polars_exprs))) {
228+
new_names <- enexprs(...)
229+
new_names <- lapply(new_names, expr_deparse)
230+
names(polars_exprs) <- unlist(new_names, use.names = FALSE)
231+
}
232+
233+
name <- check_count_name(x, names(x), name)
234+
235+
x <- x$with_columns(!!!polars_exprs)
236+
237+
if (is_grouped) {
238+
grps2 <- grps[!(grps %in% names(polars_exprs))]
239+
if (length(grps2) > 0) {
240+
out <- x$with_columns(pl$len()$over(grps2, !!!names(polars_exprs))$alias(
241+
name
242+
))
130243
} else {
131-
out <- x$with_columns(
132-
pl$len()$alias(name)$over(!!!vars)
133-
)
244+
out <- x$with_columns(pl$len()$over(!!!names(polars_exprs))$alias(name))
134245
}
135246
} else {
136-
if (length(vars) == 0) {
137-
out <- x$select(
138-
pl$len()$alias(name)
139-
)
140-
} else {
141-
# https://github.com/etiennebacher/tidypolars/issues/193
142-
vars <- unique(vars)
143-
out <- x$group_by(vars, .maintain_order = FALSE)$agg(
144-
pl$len()$alias(name)
145-
)
146-
}
247+
out <- x$with_columns(pl$len()$over(!!!names(polars_exprs))$alias(name))
147248
}
148249

149250
if (isTRUE(sort)) {
150-
if (isFALSE(new_col) && length(vars) > 0) {
151-
out$sort(name, !!!vars, descending = c(TRUE, rep(FALSE, length(vars))))
251+
if (is_grouped) {
252+
out <- out$sort(
253+
name,
254+
grps,
255+
!!!names(polars_exprs),
256+
descending = c(TRUE, rep(FALSE, length(grps) + length(polars_exprs)))
257+
)
152258
} else {
153-
out$sort(name, descending = TRUE)
259+
out <- out$sort(name, descending = TRUE)
154260
}
155-
} else if (isFALSE(new_col) && length(vars) > 0) {
156-
out$sort(vars)
157-
} else {
158-
out
159261
}
262+
263+
if (is_grouped) {
264+
out <- group_by(out, all_of(grps), maintain_order = mo)
265+
}
266+
267+
add_tidypolars_class(out)
160268
}
161269

162-
check_count_name <- function(x, vars, name, missing_name) {
270+
#' @rdname count.polars_data_frame
271+
#' @export
272+
add_count.polars_lazy_frame <- add_count.polars_data_frame
273+
274+
275+
check_count_name <- function(x, vars, name) {
163276
new_name <- name
164-
if (isTRUE(missing_name)) {
165-
while (new_name %in% names(x)) {
166-
new_name <- paste0(new_name, "n")
167-
}
168-
if (new_name != name) {
169-
cli_inform(
170-
c(
171-
"Storing counts in {.code {new_name}}, as {.code n} already present in input.",
172-
"i" = "Use {.code name = \"new_name\"} to pick a new name."
173-
)
174-
)
175-
}
176-
} else {
177-
while (new_name %in% vars) {
178-
new_name <- paste0(new_name, "n")
179-
}
180-
if (new_name != name) {
181-
cli_inform(
182-
c(
183-
"Storing counts in {.code {new_name}}, as {.code n} already present in input.",
184-
"i" = "Use {.code name = \"new_name\"} to pick a new name."
185-
)
277+
278+
while (new_name %in% vars) {
279+
new_name <- paste0(new_name, "n")
280+
}
281+
if (new_name != name) {
282+
cli_inform(
283+
c(
284+
"Storing counts in {.code {new_name}}, as {.code n} already present in input.",
285+
"i" = "Use {.code name = \"new_name\"} to pick a new name."
186286
)
187-
}
287+
)
188288
}
189289

190290
new_name

R/join.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,6 @@ eval_inequality_join <- function(x, y, how, by, suffix) {
537537

538538
by3 <- lapply(seq_along(by$condition), function(i) {
539539
if (by$condition[i] == "==") {
540-
# flir-ignore
541540
to_drop <<- append(to_drop, by2$y[[i]])
542541
by2$x[[i]]$eq(by2$y[[i]])
543542
} else if (by$condition[i] == ">") {

R/utils-expr.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ translate_dots <- function(.data, ..., env, caller) {
5252
env = env,
5353
caller = caller
5454
)
55-
# flir-ignore
5655
new_vars <<- c(new_vars, names(dots)[x])
5756
tmp
5857
})
Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1-
# count() doesn't support named expressions, #233
1+
# count works with expressions
22

33
Code
44
current$collect()
55
Condition
66
Error in `count()`:
7-
! tidypolars doesn't support named expressions in `count()`.
7+
! tidypolars doesn't support both named and unnamed inputs in `count()`.
8+
9+
# add_count works with expressions
10+
11+
Code
12+
current$collect()
13+
Condition
14+
Error in `add_count()`:
15+
! tidypolars doesn't support both named and unnamed inputs in `add_count()`.
816

tests/testthat/_snaps/count.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1-
# count() doesn't support named expressions, #233
1+
# count works with expressions
22

33
Code
4-
count(as_polars_df(iris), is_present = !is.na(Sepal.Length))
4+
count(test_pl, foo = mpg > 20, vs == 1)
55
Condition
66
Error in `count()`:
7-
! tidypolars doesn't support named expressions in `count()`.
7+
! tidypolars doesn't support both named and unnamed inputs in `count()`.
8+
9+
# add_count works with expressions
10+
11+
Code
12+
add_count(test_pl, foo = mpg > 20, vs == 1)
13+
Condition
14+
Error in `add_count()`:
15+
! tidypolars doesn't support both named and unnamed inputs in `add_count()`.
816

0 commit comments

Comments
 (0)