Skip to content

Commit 61089f0

Browse files
authored
feat(cubesql): Introduce CUBESQL_DISABLE_STRICT_AGG_TYPE_MATCH to avoid aggregation type checking during querying (#7316)
1 parent 30b921b commit 61089f0

File tree

9 files changed

+180
-31
lines changed

9 files changed

+180
-31
lines changed

rust/cubesql/cubesql/src/compile/context.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,8 @@ impl QueryContext {
528528
match selection {
529529
Selection::Measure(measure) => {
530530
if measure.agg_type.is_some()
531-
&& !measure.is_same_agg_type(&call_agg_type)
531+
// TODO not used
532+
&& !measure.is_same_agg_type(&call_agg_type, false)
532533
{
533534
return Err(CompilationError::user(format!(
534535
"Measure aggregation type doesn't match. The aggregation type for '{}' is '{}()' but '{}()' was provided",

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1660,8 +1660,11 @@ mod tests {
16601660
use crate::{
16611661
compile::{
16621662
rewrite::rewriter::Rewriter,
1663-
test::{get_string_cube_meta, get_test_tenant_ctx_with_meta},
1663+
test::{
1664+
get_string_cube_meta, get_test_session_with_config, get_test_tenant_ctx_with_meta,
1665+
},
16641666
},
1667+
config::{ConfigObj, ConfigObjImpl},
16651668
sql::{dataframe::batch_to_dataframe, types::StatusFlags},
16661669
};
16671670
use datafusion::{logical_plan::PlanVisitor, physical_plan::displayable};
@@ -1701,6 +1704,23 @@ mod tests {
17011704
query.unwrap()
17021705
}
17031706

1707+
async fn convert_select_to_query_plan_with_config(
1708+
query: String,
1709+
db: DatabaseProtocol,
1710+
config_obj: Arc<dyn ConfigObj>,
1711+
) -> QueryPlan {
1712+
env::set_var("TZ", "UTC");
1713+
1714+
let query = convert_sql_to_cube_query(
1715+
&query,
1716+
get_test_tenant_ctx(),
1717+
get_test_session_with_config(db, config_obj).await,
1718+
)
1719+
.await;
1720+
1721+
query.unwrap()
1722+
}
1723+
17041724
async fn convert_select_to_query_plan_with_meta(
17051725
query: String,
17061726
meta: Vec<V1CubeMeta>,
@@ -18628,6 +18648,40 @@ ORDER BY \"COUNT(count)\" DESC"
1862818648
);
1862918649
}
1863018650

18651+
#[tokio::test]
18652+
async fn test_case_wrapper_non_strict_match() {
18653+
if !Rewriter::sql_push_down_enabled() {
18654+
return;
18655+
}
18656+
init_logger();
18657+
18658+
let mut config = ConfigObjImpl::default();
18659+
18660+
config.disable_strict_agg_type_match = true;
18661+
18662+
let query_plan = convert_select_to_query_plan_with_config(
18663+
"SELECT CASE WHEN customer_gender = 'female' THEN 'f' ELSE 'm' END, SUM(avgPrice) mp FROM KibanaSampleDataEcommerce a GROUP BY 1"
18664+
.to_string(),
18665+
DatabaseProtocol::PostgreSQL,
18666+
Arc::new(config)
18667+
)
18668+
.await;
18669+
18670+
let logical_plan = query_plan.as_logical_plan();
18671+
assert!(logical_plan
18672+
.find_cube_scan_wrapper()
18673+
.wrapped_sql
18674+
.unwrap()
18675+
.sql
18676+
.contains("CASE WHEN"));
18677+
18678+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
18679+
println!(
18680+
"Physical plan: {}",
18681+
displayable(physical_plan.as_ref()).indent()
18682+
);
18683+
}
18684+
1863118685
#[tokio::test]
1863218686
async fn test_case_wrapper_ungrouped_sorted() {
1863318687
if !Rewriter::sql_push_down_enabled() {

rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1975,6 +1975,12 @@ impl MemberRules {
19751975
let cast_data_type_var = cast_data_type_var.map(|var| var!(var));
19761976
let measure_out_var = var!(measure_out_var);
19771977
let meta_context = self.cube_context.meta.clone();
1978+
let disable_strict_agg_type_match = self
1979+
.cube_context
1980+
.sessions
1981+
.server
1982+
.config_obj
1983+
.disable_strict_agg_type_match();
19781984
move |egraph, subst| {
19791985
if let Some(alias) = original_expr_name(egraph, subst[original_expr_var]) {
19801986
for measure_name in var_iter!(egraph[subst[measure_name_var]], MeasureName)
@@ -2040,6 +2046,7 @@ impl MemberRules {
20402046
cube_alias.to_string(),
20412047
subst[original_expr_var],
20422048
alias_to_cube.clone(),
2049+
disable_strict_agg_type_match,
20432050
);
20442051
return true;
20452052
}
@@ -2088,6 +2095,12 @@ impl MemberRules {
20882095
let cast_data_type_var = cast_data_type_var.map(|var| var!(var));
20892096
let measure_out_var = measure_out_var.parse().unwrap();
20902097
let meta_context = self.cube_context.meta.clone();
2098+
let disable_strict_agg_type_match = self
2099+
.cube_context
2100+
.sessions
2101+
.server
2102+
.config_obj
2103+
.disable_strict_agg_type_match();
20912104
move |egraph, subst| {
20922105
for column in measure_var
20932106
.map(|measure_var| {
@@ -2153,6 +2166,7 @@ impl MemberRules {
21532166
cube_alias,
21542167
subst[aggr_expr_var],
21552168
alias_to_cube,
2169+
disable_strict_agg_type_match,
21562170
);
21572171

21582172
return true;
@@ -2192,8 +2206,14 @@ impl MemberRules {
21922206
cube_alias: String,
21932207
expr: Id,
21942208
alias_to_cube: Vec<((String, String), String)>,
2209+
disable_strict_agg_type_match: bool,
21952210
) {
2196-
if call_agg_type.is_some() && !measure.is_same_agg_type(call_agg_type.as_ref().unwrap()) {
2211+
if call_agg_type.is_some()
2212+
&& !measure.is_same_agg_type(
2213+
call_agg_type.as_ref().unwrap(),
2214+
disable_strict_agg_type_match,
2215+
)
2216+
{
21972217
let mut agg_type = measure
21982218
.agg_type
21992219
.as_ref()

rust/cubesql/cubesql/src/compile/rewrite/rules/split.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5461,6 +5461,12 @@ impl SplitRules {
54615461
let distinct_var = distinct_var.map(|v| var!(v));
54625462
let out_expr_var = out_expr_var.map(|v| var!(v));
54635463
let meta = self.cube_context.meta.clone();
5464+
let disable_strict_agg_type_match = self
5465+
.cube_context
5466+
.sessions
5467+
.server
5468+
.config_obj
5469+
.disable_strict_agg_type_match();
54645470
move |egraph, subst| {
54655471
for alias_to_cube in var_iter!(
54665472
egraph[subst[cube_expr_var]],
@@ -5501,7 +5507,10 @@ impl SplitRules {
55015507
Some(&fun),
55025508
*distinct,
55035509
);
5504-
if !measure.is_same_agg_type(&agg_type.unwrap()) {
5510+
if !measure.is_same_agg_type(
5511+
&agg_type.unwrap(),
5512+
disable_strict_agg_type_match,
5513+
) {
55055514
if let Some(expr_name) = original_expr_name(
55065515
egraph,
55075516
subst[aggr_expr_var],

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ impl WrapperRules {
180180
let cube_members_var = var!(cube_members_var);
181181
let measure_out_var = var!(measure_out_var);
182182
let cube_context = self.cube_context.clone();
183+
let disable_strict_agg_type_match = self
184+
.cube_context
185+
.sessions
186+
.server
187+
.config_obj
188+
.disable_strict_agg_type_match();
183189
move |egraph, subst| {
184190
if let Some(alias) = original_expr_name(egraph, subst[original_expr_var]) {
185191
for fun in fun_name_var
@@ -224,8 +230,10 @@ impl WrapperRules {
224230
cube_context.meta.find_measure_with_name(member.to_string())
225231
{
226232
if call_agg_type.is_none()
227-
|| measure
228-
.is_same_agg_type(call_agg_type.as_ref().unwrap())
233+
|| measure.is_same_agg_type(
234+
call_agg_type.as_ref().unwrap(),
235+
disable_strict_agg_type_match,
236+
)
229237
{
230238
let column_expr_column =
231239
egraph.add(LogicalPlanLanguage::ColumnExprColumn(

rust/cubesql/cubesql/src/compile/test/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use datafusion::arrow::datatypes::SchemaRef;
99

1010
use crate::{
1111
compile::engine::df::{scan::MemberField, wrapper::SqlQuery},
12+
config::{ConfigObj, ConfigObjImpl},
1213
sql::{
1314
session::DatabaseProtocol, AuthContextRef, AuthenticateResponse, HttpAuthContext,
1415
ServerManager, Session, SessionManager, SqlAuthService,
@@ -265,10 +266,18 @@ pub fn get_test_tenant_ctx_with_meta(meta: Vec<V1CubeMeta>) -> Arc<MetaContext>
265266
}
266267

267268
pub async fn get_test_session(protocol: DatabaseProtocol) -> Arc<Session> {
269+
get_test_session_with_config(protocol, Arc::new(ConfigObjImpl::default())).await
270+
}
271+
272+
pub async fn get_test_session_with_config(
273+
protocol: DatabaseProtocol,
274+
config_obj: Arc<dyn ConfigObj>,
275+
) -> Arc<Session> {
268276
let server = Arc::new(ServerManager::new(
269277
get_test_auth(),
270278
get_test_transport(),
271279
None,
280+
config_obj,
272281
));
273282

274283
let db_name = match &protocol {

rust/cubesql/cubesql/src/config/mod.rs

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ use crate::{
1616
use futures::future::join_all;
1717
use log::error;
1818

19-
use mockall::automock;
20-
21-
use std::env;
19+
use std::{
20+
env,
21+
fmt::{Debug, Display},
22+
str::FromStr,
23+
};
2224

2325
use std::sync::Arc;
2426

@@ -108,15 +110,16 @@ pub struct Config {
108110
injector: Arc<Injector>,
109111
}
110112

111-
#[automock]
112-
pub trait ConfigObj: DIService {
113+
pub trait ConfigObj: DIService + Debug {
113114
fn bind_address(&self) -> &Option<String>;
114115

115116
fn postgres_bind_address(&self) -> &Option<String>;
116117

117118
fn query_timeout(&self) -> u64;
118119

119120
fn nonce(&self) -> &Option<Vec<u8>>;
121+
122+
fn disable_strict_agg_type_match(&self) -> bool;
120123
}
121124

122125
#[derive(Debug, Clone)]
@@ -126,10 +129,36 @@ pub struct ConfigObjImpl {
126129
pub nonce: Option<Vec<u8>>,
127130
pub query_timeout: u64,
128131
pub timezone: Option<String>,
132+
pub disable_strict_agg_type_match: bool,
133+
}
134+
135+
impl ConfigObjImpl {
136+
pub fn default() -> Self {
137+
let query_timeout = env::var("CUBESQL_QUERY_TIMEOUT")
138+
.ok()
139+
.map(|v| v.parse::<u64>().unwrap())
140+
.unwrap_or(120);
141+
Self {
142+
bind_address: env::var("CUBESQL_BIND_ADDR").ok().or_else(|| {
143+
env::var("CUBESQL_PORT")
144+
.ok()
145+
.map(|v| format!("0.0.0.0:{}", v.parse::<u16>().unwrap()))
146+
}),
147+
postgres_bind_address: env::var("CUBESQL_PG_PORT")
148+
.ok()
149+
.map(|port| format!("0.0.0.0:{}", port.parse::<u16>().unwrap())),
150+
nonce: None,
151+
query_timeout,
152+
timezone: Some("UTC".to_string()),
153+
disable_strict_agg_type_match: env_parse(
154+
"CUBESQL_DISABLE_STRICT_AGG_TYPE_MATCH",
155+
false,
156+
),
157+
}
158+
}
129159
}
130160

131161
crate::di_service!(ConfigObjImpl, [ConfigObj]);
132-
crate::di_service!(MockConfigObj, [ConfigObj]);
133162

134163
impl ConfigObj for ConfigObjImpl {
135164
fn bind_address(&self) -> &Option<String> {
@@ -147,6 +176,10 @@ impl ConfigObj for ConfigObjImpl {
147176
fn query_timeout(&self) -> u64 {
148177
self.query_timeout
149178
}
179+
180+
fn disable_strict_agg_type_match(&self) -> bool {
181+
self.disable_strict_agg_type_match
182+
}
150183
}
151184

152185
lazy_static! {
@@ -156,25 +189,9 @@ lazy_static! {
156189

157190
impl Config {
158191
pub fn default() -> Config {
159-
let query_timeout = env::var("CUBESQL_QUERY_TIMEOUT")
160-
.ok()
161-
.map(|v| v.parse::<u64>().unwrap())
162-
.unwrap_or(120);
163192
Config {
164193
injector: Injector::new(),
165-
config_obj: Arc::new(ConfigObjImpl {
166-
bind_address: env::var("CUBESQL_BIND_ADDR").ok().or_else(|| {
167-
env::var("CUBESQL_PORT")
168-
.ok()
169-
.map(|v| format!("0.0.0.0:{}", v.parse::<u16>().unwrap()))
170-
}),
171-
postgres_bind_address: env::var("CUBESQL_PG_PORT")
172-
.ok()
173-
.map(|port| format!("0.0.0.0:{}", port.parse::<u16>().unwrap())),
174-
nonce: None,
175-
query_timeout,
176-
timezone: Some("UTC".to_string()),
177-
}),
194+
config_obj: Arc::new(ConfigObjImpl::default()),
178195
}
179196
}
180197

@@ -189,6 +206,7 @@ impl Config {
189206
nonce: None,
190207
query_timeout,
191208
timezone,
209+
disable_strict_agg_type_match: false,
192210
}),
193211
}
194212
}
@@ -231,6 +249,7 @@ impl Config {
231249
i.get_service_typed().await,
232250
i.get_service_typed().await,
233251
config.nonce().clone(),
252+
config.clone(),
234253
))
235254
})
236255
.await;
@@ -287,4 +306,26 @@ impl Config {
287306
}
288307
}
289308

309+
pub fn env_parse<T>(name: &str, default: T) -> T
310+
where
311+
T: FromStr,
312+
T::Err: Display,
313+
{
314+
env_optparse(name).unwrap_or(default)
315+
}
316+
317+
fn env_optparse<T>(name: &str) -> Option<T>
318+
where
319+
T: FromStr,
320+
T::Err: Display,
321+
{
322+
env::var(name).ok().map(|x| match x.parse::<T>() {
323+
Ok(v) => v,
324+
Err(e) => panic!(
325+
"Could not parse environment variable '{}' with '{}' value: {}",
326+
name, x, e
327+
),
328+
})
329+
}
330+
290331
type LoopHandle = JoinHandle<Result<(), CubeError>>;

0 commit comments

Comments
 (0)