Skip to content

Commit e30c4da

Browse files
authored
feat(cubesql): EXTRACT SQL push down (#7151)
* feat(cubesql): `EXTRACT` SQL push down * Update datafusion
1 parent 6b9ae70 commit e30c4da

File tree

14 files changed

+445
-61
lines changed

14 files changed

+445
-61
lines changed

packages/cubejs-schema-compiler/src/adapter/BaseQuery.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2444,8 +2444,9 @@ class BaseQuery {
24442444
expressions: {
24452445
column_aliased: '{{expr}} {{quoted_alias}}',
24462446
case: 'CASE {% if expr %}{{ expr }} {% endif %}{% for when, then in when_then %}WHEN {{ when }} THEN {{ then }}{% endfor %}{% if else_expr %} ELSE {{ else_expr }}{% endif %} END',
2447-
binary: '{{ left }} {{ op }} {{ right }}',
2447+
binary: '({{ left }} {{ op }} {{ right }})',
24482448
sort: '{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST{% endif %}',
2449+
cast: 'CAST({{ expr }} AS {{ data_type }})',
24492450
},
24502451
quotes: {
24512452
identifiers: '"',

packages/cubejs-schema-compiler/src/adapter/BigqueryQuery.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ export class BigqueryQuery extends BaseQuery {
149149
const templates = super.sqlTemplates();
150150
templates.quotes.identifiers = '`';
151151
templates.quotes.escape = '\\`';
152+
templates.functions.DATETRUNC = 'DATETIME_TRUNC(CAST({{ args[1] }} AS DATETIME), {{ date_part }})';
153+
templates.expressions.binary = '{% if op == \'%\' %}MOD({{ left }}, {{ right }}){% else %}({{ left }} {{ op }} {{ right }}){% endif %}';
154+
templates.expressions.interval = 'INTERVAL {{ interval }}';
155+
templates.expressions.extract = 'EXTRACT({% if date_part == \'DOW\' %}DAYOFWEEK{% else %}{{ date_part }}{% endif %} FROM {{ expr }})';
152156
return templates;
153157
}
154158
}

packages/cubejs-schema-compiler/src/adapter/PostgresQuery.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ export class PostgresQuery extends BaseQuery {
5050
templates.params.param = '${{ param_index + 1 }}';
5151
templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})';
5252
templates.functions.CONCAT = 'CONCAT({% for arg in args %}CAST({{arg}} AS TEXT){% if not loop.last %},{% endif %}{% endfor %})';
53+
templates.functions.DATEPART = 'DATE_PART({{ args_concat }})';
54+
templates.expressions.interval = 'INTERVAL \'{{ interval }}\'';
55+
templates.expressions.extract = 'EXTRACT({{ date_part }} FROM {{ expr }})';
56+
5357
return templates;
5458
}
5559
}

packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ export class PrestodbQuery extends BaseQuery {
112112
sqlTemplates() {
113113
const templates = super.sqlTemplates();
114114
templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})';
115+
templates.functions.DATEPART = 'DATE_PART({{ args_concat }})';
116+
templates.expressions.extract = 'EXTRACT({{ date_part }} FROM {{ expr }})';
117+
templates.expressions.interval = 'INTERVAL \'{{ num }}\' {{ date_part }}';
115118
return templates;
116119
}
117120
}

packages/cubejs-schema-compiler/src/adapter/SnowflakeQuery.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ export class SnowflakeQuery extends BaseQuery {
5353
sqlTemplates() {
5454
const templates = super.sqlTemplates();
5555
templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})';
56+
templates.functions.DATEPART = 'DATE_PART({{ args_concat }})';
57+
templates.expressions.extract = 'EXTRACT({{ date_part }} FROM {{ expr }})';
58+
templates.expressions.interval = 'INTERVAL \'{{ interval }}\'';
5659
return templates;
5760
}
5861
}

rust/cubesql/Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/cubesql/cubesql/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ documentation = "https://cube.dev/docs"
99
homepage = "https://cube.dev"
1010

1111
[dependencies]
12-
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "915a600ea4d3b66161cd77ff94747960f840816e", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
12+
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
1313
anyhow = "1.0"
1414
thiserror = "1.0"
1515
cubeclient = { path = "../cubeclient" }

rust/cubesql/cubesql/src/compile/engine/df/scan.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,7 @@ pub fn transform_response<V: ValueObject>(
979979
(FieldValue::String(s), builder) => {
980980
let timestamp = NaiveDateTime::parse_from_str(s.as_str(), "%Y-%m-%dT%H:%M:%S.%f")
981981
.or_else(|_| NaiveDateTime::parse_from_str(s.as_str(), "%Y-%m-%d %H:%M:%S.%f"))
982+
.or_else(|_| NaiveDateTime::parse_from_str(s.as_str(), "%Y-%m-%dT%H:%M:%S"))
982983
.map_err(|e| {
983984
DataFusionError::Execution(format!(
984985
"Can't parse timestamp: '{}': {}",

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 151 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
22
compile::{
3-
engine::df::scan::{CubeScanNode, MemberField, WrappedSelectNode},
3+
engine::df::scan::{CubeScanNode, DataType, MemberField, WrappedSelectNode},
44
rewrite::WrappedSelectType,
55
},
66
sql::AuthContextRef,
@@ -16,7 +16,7 @@ use datafusion::{
1616
plan::Extension, replace_col, replace_col_to_expr, Column, DFSchema, DFSchemaRef, Expr,
1717
LogicalPlan, UserDefinedLogicalNode,
1818
},
19-
physical_plan::aggregates::AggregateFunction,
19+
physical_plan::{aggregates::AggregateFunction, functions::BuiltinScalarFunction},
2020
scalar::ScalarValue,
2121
};
2222
use itertools::Itertools;
@@ -144,6 +144,10 @@ pub struct SqlGenerationResult {
144144
pub request: V1LoadRequestQuery,
145145
}
146146

147+
lazy_static! {
148+
static ref DATE_PART_REGEX: Regex = Regex::new("^[A-Za-z_ ]+$").unwrap();
149+
}
150+
147151
impl CubeScanWrapperNode {
148152
pub async fn generate_sql(
149153
&self,
@@ -934,7 +938,61 @@ impl CubeScanWrapperNode {
934938
);
935939
Ok((resulting_sql, sql_query))
936940
}
937-
// Expr::Cast { .. } => {}
941+
Expr::Cast { expr, data_type } => {
942+
let (expr, sql_query) = Self::generate_sql_for_expr(
943+
plan.clone(),
944+
sql_query,
945+
sql_generator.clone(),
946+
*expr,
947+
ungrouped_scan_node.clone(),
948+
)
949+
.await?;
950+
let data_type = match data_type {
951+
DataType::Null => "NULL",
952+
DataType::Boolean => "BOOLEAN",
953+
DataType::Int8 => "INTEGER",
954+
DataType::Int16 => "INTEGER",
955+
DataType::Int32 => "INTEGER",
956+
DataType::Int64 => "INTEGER",
957+
DataType::UInt8 => "INTEGER",
958+
DataType::UInt16 => "INTEGER",
959+
DataType::UInt32 => "INTEGER",
960+
DataType::UInt64 => "INTEGER",
961+
DataType::Float16 => "FLOAT",
962+
DataType::Float32 => "FLOAT",
963+
DataType::Float64 => "DOUBLE",
964+
DataType::Timestamp(_, _) => "TIMESTAMP",
965+
DataType::Date32 => "DATE",
966+
DataType::Date64 => "DATE",
967+
DataType::Time32(_) => "TIME",
968+
DataType::Time64(_) => "TIME",
969+
DataType::Duration(_) => "INTERVAL",
970+
DataType::Interval(_) => "INTERVAL",
971+
DataType::Binary => "BYTEA",
972+
DataType::FixedSizeBinary(_) => "BYTEA",
973+
DataType::Utf8 => "TEXT",
974+
DataType::LargeUtf8 => "TEXT",
975+
x => {
976+
return Err(DataFusionError::Execution(format!(
977+
"Can't generate SQL for cast: type isn't supported: {:?}",
978+
x
979+
)));
980+
}
981+
};
982+
let resulting_sql = Self::escape_interpolation_quotes(
983+
sql_generator
984+
.get_sql_templates()
985+
.cast_expr(expr, data_type.to_string())
986+
.map_err(|e| {
987+
DataFusionError::Internal(format!(
988+
"Can't generate SQL for cast: {}",
989+
e
990+
))
991+
})?,
992+
ungrouped_scan_node.is_some(),
993+
);
994+
Ok((resulting_sql, sql_query))
995+
}
938996
// Expr::TryCast { .. } => {}
939997
Expr::Sort {
940998
expr,
@@ -1024,7 +1082,41 @@ impl CubeScanWrapperNode {
10241082
// ScalarValue::TimestampMicrosecond(_, _) => {}
10251083
// ScalarValue::TimestampNanosecond(_, _) => {}
10261084
// ScalarValue::IntervalYearMonth(_) => {}
1027-
// ScalarValue::IntervalDayTime(_) => {}
1085+
ScalarValue::IntervalDayTime(x) => {
1086+
if let Some(x) = x {
1087+
let days = x >> 32;
1088+
let millis = x & 0xFFFFFFFF;
1089+
if days > 0 && millis > 0 {
1090+
return Err(DataFusionError::Internal(format!(
1091+
"Can't generate SQL for interval: mixed intervals aren't supported: {} days {} millis encoded as {}",
1092+
days, millis, x
1093+
)));
1094+
}
1095+
let (num, date_part) = if days > 0 {
1096+
(days, "DAY")
1097+
} else {
1098+
(millis, "MILLISECOND")
1099+
};
1100+
let interval = format!("{} {}", num, date_part);
1101+
(
1102+
Self::escape_interpolation_quotes(
1103+
sql_generator
1104+
.get_sql_templates()
1105+
.interval_expr(interval, num, date_part.to_string())
1106+
.map_err(|e| {
1107+
DataFusionError::Internal(format!(
1108+
"Can't generate SQL for interval: {}",
1109+
e
1110+
))
1111+
})?,
1112+
ungrouped_scan_node.is_some(),
1113+
),
1114+
sql_query,
1115+
)
1116+
} else {
1117+
("NULL".to_string(), sql_query)
1118+
}
1119+
}
10281120
// ScalarValue::IntervalMonthDayNano(_) => {}
10291121
// ScalarValue::Struct(_, _) => {}
10301122
x => {
@@ -1036,6 +1128,60 @@ impl CubeScanWrapperNode {
10361128
})
10371129
}
10381130
Expr::ScalarFunction { fun, args } => {
1131+
if let BuiltinScalarFunction::DatePart = &fun {
1132+
if args.len() >= 2 {
1133+
match &args[0] {
1134+
Expr::Literal(ScalarValue::Utf8(Some(date_part))) => {
1135+
// Security check to prevent SQL injection
1136+
if !DATE_PART_REGEX.is_match(date_part) {
1137+
return Err(DataFusionError::Internal(format!(
1138+
"Can't generate SQL for scalar function: date part '{}' is not supported",
1139+
date_part
1140+
)));
1141+
}
1142+
let (arg_sql, query) = Self::generate_sql_for_expr(
1143+
plan.clone(),
1144+
sql_query,
1145+
sql_generator.clone(),
1146+
args[1].clone(),
1147+
ungrouped_scan_node.clone(),
1148+
)
1149+
.await?;
1150+
return Ok((
1151+
Self::escape_interpolation_quotes(
1152+
sql_generator
1153+
.get_sql_templates()
1154+
.extract_expr(date_part.to_string(), arg_sql)
1155+
.map_err(|e| {
1156+
DataFusionError::Internal(format!(
1157+
"Can't generate SQL for scalar function: {}",
1158+
e
1159+
))
1160+
})?,
1161+
ungrouped_scan_node.is_some(),
1162+
),
1163+
query,
1164+
));
1165+
}
1166+
_ => {}
1167+
}
1168+
}
1169+
}
1170+
let date_part = if let BuiltinScalarFunction::DateTrunc = &fun {
1171+
match &args[0] {
1172+
Expr::Literal(ScalarValue::Utf8(Some(date_part))) => {
1173+
// Security check to prevent SQL injection
1174+
if DATE_PART_REGEX.is_match(date_part) {
1175+
Some(date_part.to_string())
1176+
} else {
1177+
None
1178+
}
1179+
}
1180+
_ => None,
1181+
}
1182+
} else {
1183+
None
1184+
};
10391185
let mut sql_args = Vec::new();
10401186
for arg in args {
10411187
let (sql, query) = Self::generate_sql_for_expr(
@@ -1053,7 +1199,7 @@ impl CubeScanWrapperNode {
10531199
Self::escape_interpolation_quotes(
10541200
sql_generator
10551201
.get_sql_templates()
1056-
.scalar_function(fun, sql_args)
1202+
.scalar_function(fun, sql_args, date_part)
10571203
.map_err(|e| {
10581204
DataFusionError::Internal(format!(
10591205
"Can't generate SQL for scalar function: {}",

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14292,7 +14292,7 @@ ORDER BY \"COUNT(count)\" DESC"
1429214292
)
1429314293
.await;
1429414294

14295-
query.unwrap_err();
14295+
query.unwrap();
1429614296
}
1429714297

1429814298
#[tokio::test]
@@ -18534,6 +18534,35 @@ ORDER BY \"COUNT(count)\" DESC"
1853418534
);
1853518535
}
1853618536

18537+
#[tokio::test]
18538+
async fn test_wrapper_tableau_sunday_week() {
18539+
if !Rewriter::sql_push_down_enabled() {
18540+
return;
18541+
}
18542+
init_logger();
18543+
18544+
let query_plan = convert_select_to_query_plan(
18545+
"SELECT (CAST(DATE_TRUNC('day', CAST(order_date AS TIMESTAMP)) AS DATE) - (((7 + CAST(EXTRACT(DOW FROM order_date) AS BIGINT) - 1) % 7) * INTERVAL '1 DAY')) AS \"twk:date:ok\", AVG(avgPrice) mp FROM KibanaSampleDataEcommerce a GROUP BY 1 ORDER BY 1 DESC"
18546+
.to_string(),
18547+
DatabaseProtocol::PostgreSQL,
18548+
)
18549+
.await;
18550+
18551+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
18552+
println!(
18553+
"Physical plan: {}",
18554+
displayable(physical_plan.as_ref()).indent()
18555+
);
18556+
18557+
let logical_plan = query_plan.as_logical_plan();
18558+
assert!(logical_plan
18559+
.find_cube_scan_wrapper()
18560+
.wrapped_sql
18561+
.unwrap()
18562+
.sql
18563+
.contains("EXTRACT"));
18564+
}
18565+
1853718566
#[tokio::test]
1853818567
async fn test_thoughtspot_pg_date_trunc_year() {
1853918568
init_logger();

0 commit comments

Comments
 (0)