@@ -3344,31 +3344,59 @@ is_constantish = function(q, check_singleton=FALSE) {
33443344
33453345.gforce_ok = function (q , x , envir = parent.frame(2L )) {
33463346 if (is.N(q )) return (TRUE ) # For #334
3347+ if (! is.call(q )) return (FALSE ) # plain columns are not gforce-able since they might not aggregate (see test 104.1)
3348+
33473349 q1 = .get_gcall(q )
3348- if (is.null(q1 )) return (FALSE )
3349- q2 = if (.is_type_conversion(q [[2L ]]) && is.symbol(q [[2L ]][[2L ]])) q [[2L ]][[2L ]] else q [[2L ]]
3350- if (! q2 %chin % names(x ) && q2 != " .I" ) return (FALSE ) # 875
3351- if (length(q )== 2L || (.arg_is_narm(q ) && is_constantish(q [[3L ]]) &&
3352- ! (is.symbol(q [[3L ]]) && q [[3L ]] %chin % names(x )))) return (TRUE )
3353- switch (as.character(q1 ),
3354- " shift" = .gshift_ok(q ),
3355- " weighted.mean" = .gweighted.mean_ok(q , x ),
3356- " tail" = , " head" = .ghead_ok(q ),
3357- " [[" = , " [" = `.g[_ok`(q , x , envir ),
3358- FALSE
3359- )
3350+ if (! is.null(q1 )) {
3351+ q2 = if (.is_type_conversion(q [[2L ]]) && is.symbol(q [[2L ]][[2L ]])) q [[2L ]][[2L ]] else q [[2L ]]
3352+ if (! q2 %chin % names(x ) && q2 != " .I" ) return (FALSE ) # 875
3353+ if (length(q )== 2L || (.arg_is_narm(q ) && is_constantish(q [[3L ]]) &&
3354+ ! (is.symbol(q [[3L ]]) && q [[3L ]] %chin % names(x )))) return (TRUE )
3355+ return (switch (as.character(q1 ),
3356+ " shift" = .gshift_ok(q ),
3357+ " weighted.mean" = .gweighted.mean_ok(q , x ),
3358+ " tail" = , " head" = .ghead_ok(q ),
3359+ " [[" = , " [" = `.g[_ok`(q , x , envir ),
3360+ FALSE
3361+ ))
3362+ }
3363+
3364+ # check if arithmetic operator -> recursively validate ALL branches (like in AST)
3365+ if (is.symbol(q [[1L ]]) && q [[1L ]] %chin % c(" +" , " -" , " *" , " /" , " ^" , " %%" , " %/%" )) {
3366+ for (i in 2 : length(q )) {
3367+ if (! .gforce_ok(q [[i ]], x , envir )) return (FALSE )
3368+ }
3369+ return (TRUE )
3370+ }
3371+
3372+ FALSE
33603373}
33613374
33623375.gforce_jsub = function (q , names_x , envir = parent.frame(2L )) {
3363- call_name = if (is.symbol(q [[1L ]])) q [[1L ]] else q [[1L ]][[3L ]] # latter is like data.table::shift, #5942. .gshift_ok checked this will work.
3364- q [[1L ]] = as.name(paste0(" g" , call_name ))
3365- # gforce needs to evaluate arguments before calling C part TODO: move the evaluation into gforce_ok
3366- # do not evaluate vars present as columns in x
3367- if (length(q ) > = 3L ) {
3368- for (i in 3 : length(q )) {
3369- if (is.symbol(q [[i ]]) && ! (q [[i ]] %chin % names_x )) q [[i ]] = eval(q [[i ]], envir ) # tests 1187.2 & 1187.4
3376+ if (! is.call(q )) return (q )
3377+
3378+ q1 = .get_gcall(q )
3379+ if (! is.null(q1 )) {
3380+ call_name = if (is.symbol(q [[1L ]])) q [[1L ]] else q [[1L ]][[3L ]] # latter is like data.table::shift, #5942. .gshift_ok checked this will work.
3381+ q [[1L ]] = as.name(paste0(" g" , call_name ))
3382+ # gforce needs to evaluate arguments before calling C part TODO: move the evaluation into gforce_ok
3383+ # do not evaluate vars present as columns in x
3384+ if (length(q ) > = 3L ) {
3385+ for (i in 3 : length(q )) {
3386+ if (is.symbol(q [[i ]]) && ! (q [[i ]] %chin % names_x )) q [[i ]] = eval(q [[i ]], envir ) # tests 1187.2 & 1187.4
3387+ }
3388+ }
3389+ return (q )
3390+ }
3391+
3392+ # if arithmetic operator, recursively substitute its operands. we know what branches are valid from .gforce_ok
3393+ if (is.symbol(q [[1L ]]) && q [[1L ]] %chin % c(" +" , " -" , " *" , " /" , " ^" , " %%" , " %/%" )) {
3394+ for (i in 2 : length(q )) {
3395+ q [[i ]] = .gforce_jsub(q [[i ]], names_x , envir )
33703396 }
3397+ return (q )
33713398 }
3399+ # should not reach here since .gforce_ok
33723400 q
33733401}
33743402
0 commit comments