Skip to content

Commit c47ec27

Browse files
committed
add AST walker and update tests
1 parent 62f1c48 commit c47ec27

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

R/data.table.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,9 +3342,12 @@ is_constantish = function(q, check_singleton=FALSE) {
33423342
"as.complex", "as.logical", "as.Date", "as.POSIXct", "as.factor")
33433343
}
33443344

3345+
.gforce_ops = c("+", "-", "*", "/", "^", "%%", "%/%")
3346+
33453347
.gforce_ok = function(q, x, envir=parent.frame(2L)) {
33463348
if (is.N(q)) return(TRUE) # For #334
33473349
if (!is.call(q)) return(FALSE) # plain columns are not gforce-able since they might not aggregate (see test 104.1)
3350+
if (q %iscall% "(") return(.gforce_ok(q[[2L]], x, envir))
33483351

33493352
q1 = .get_gcall(q)
33503353
if (!is.null(q1)) {
@@ -3362,7 +3365,7 @@ is_constantish = function(q, check_singleton=FALSE) {
33623365
}
33633366

33643367
# check if arithmetic operator -> recursively validate ALL branches (like in AST)
3365-
if (is.symbol(q[[1L]]) && q[[1L]] %chin% c("+", "-", "*", "/", "^", "%%", "%/%")) {
3368+
if (is.symbol(q[[1L]]) && q[[1L]] %chin% .gforce_ops) {
33663369
for (i in 2:length(q)) {
33673370
if (!.gforce_ok(q[[i]], x, envir)) return(FALSE)
33683371
}
@@ -3374,6 +3377,10 @@ is_constantish = function(q, check_singleton=FALSE) {
33743377

33753378
.gforce_jsub = function(q, names_x, envir=parent.frame(2L)) {
33763379
if (!is.call(q)) return(q)
3380+
if (q %iscall% "(") {
3381+
q[[2L]] = .gforce_jsub(q[[2L]], names_x, envir)
3382+
return(q)
3383+
}
33773384

33783385
q1 = .get_gcall(q)
33793386
if (!is.null(q1)) {
@@ -3390,7 +3397,7 @@ is_constantish = function(q, check_singleton=FALSE) {
33903397
}
33913398

33923399
# if arithmetic operator, recursively substitute its operands. we know what branches are valid from .gforce_ok
3393-
if (is.symbol(q[[1L]]) && q[[1L]] %chin% c("+", "-", "*", "/", "^", "%%", "%/%")) {
3400+
if (is.symbol(q[[1L]]) && q[[1L]] %chin% .gforce_ops) {
33943401
for (i in 2:length(q)) {
33953402
q[[i]] = .gforce_jsub(q[[i]], names_x, envir)
33963403
}

inst/tests/tests.Rraw

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21219,20 +21219,20 @@ test(2344.04, key(DT[, .(V4 = c("b", "a"), V2, V5 = c("y", "x"), V1)]), c("V1",
2121921219
# conversions should not turn gforce off #2934
2122021220
# lapply gforce should also work without .SD #5032
2122121221
# support arithmetic in j with gforce #3815
21222-
dt = data.table(a=1:4, b=1:2)
2122321222
out = c("GForce FALSE", "GForce FALSE", "GForce TRUE")
21223+
dt = data.table(a=1:4, b=1:2)
2122421224
test(2345.01, levels=0:2, dt[, max(as.character(a)), by=b, verbose=TRUE], data.table(b=1:2, V1=c("3","4")), output=out)
2122521225
test(2345.02, levels=0:2, dt[, max(as.numeric(a)), by=b, verbose=TRUE], data.table(b=1:2, V1=c(3,4)), output=out)
21226-
test(2345.03, levels=0:2, dt[, Map(sum, .SD), b, verbose=TRUE], dt[, lapply(.SD, sum), b], output=out)
21226+
dt = data.table(a=1:4, b=1:2)
21227+
test(2345.11, levels=0:2, dt[, Map(sum, .SD), b, verbose=TRUE], dt[, lapply(.SD, sum), b], output=out)
21228+
test(2345.12, levels=0:2, dt[, Map(sum, .SD, .SD), by=b, verbose=TRUE], output="GForce FALSE")
2122721229
dt = data.table(a = NA_integer_, b = 1:2, c = c(TRUE, FALSE))
21228-
test(2345.04, levels=0:2, dt[, Map(weighted.mean, .SD, na.rm=c), b, .SDcols="a", verbose=TRUE], data.table(b=1:2, a=c(NaN, NA_real_)), output="GForce FALSE")
21229-
test(2345.05, levels=0:2, dt[,list(weighted.mean(a, na.rm=c)), b, verbose=TRUE], data.table(b=1:2, V1=c(NaN, NA_real_)), output="GForce FALSE")
21230+
test(2345.13, levels=0:2, dt[, Map(weighted.mean, .SD, na.rm=c), b, .SDcols="a", verbose=TRUE], data.table(b=1:2, a=c(NaN, NA_real_)), output="GForce FALSE")
21231+
test(2345.14, levels=0:2, dt[,list(weighted.mean(a, na.rm=c)), b, verbose=TRUE], data.table(b=1:2, V1=c(NaN, NA_real_)), output="GForce FALSE")
2123021232
dt = data.table(a=1:2, b=1, c=1:4)
21231-
out = c("GForce FALSE", "lapply optimization changed j", "GForce TRUE")
21232-
test(2345.06, levels=0:2, dt[, lapply(list(b, c), sum), by=a, verbose=TRUE], output=out)
21233-
test(2345.07, levels=0:2, dt[, c(list(sum(b), sum(c))), by=a, verbose=TRUE], output=out)
21234-
test(2345.08, levels=0:2, names(dt[, lapply(list(b, c), sum), by=a]))
21235-
dt = data.table(a=1:4, b=1:2)
21236-
out = c("GForce FALSE", "GForce FALSE", "GForce TRUE")
21237-
test(2345.09, levels=0:2, dt[, .(max(a)-min(a)), by=b, verbose=TRUE], output=out)
21233+
test(2345.21, levels=0:2, dt[, lapply(list(b, c), sum), by=a, verbose=TRUE], output=out)
21234+
test(2345.22, levels=0:2, dt[, c(list(sum(b), sum(c))), by=a, verbose=TRUE], output=out)
21235+
test(2345.23, levels=0:2, names(dt[, lapply(list(b, c), sum), by=a]))
2123821236
dt = data.table(a=1:4, b=1:2)
21237+
test(2345.31, levels=0:2, dt[, .(max(a)-min(a)), by=b, verbose=TRUE], output=out)
21238+
test(2345.32, levels=0:2, dt[, .((max(a) - min(a)) / (max(a) + min(a))), by=b, verbose=TRUE], data.table(b=1:2, V1=c(0.5, 1/3)), output=out)

0 commit comments

Comments
 (0)