Skip to content

Commit 641ba8d

Browse files
authored
chore(cubesql): Initial support for custom protocols (#8572)
1 parent d8423f8 commit 641ba8d

35 files changed

+594
-477
lines changed

packages/cubejs-backend-native/src/node_export.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
use cubesql::compile::DatabaseProtocol;
12
use cubesql::compile::{convert_sql_to_cube_query, get_df_batches};
23
use cubesql::config::processing_loop::ShutdownMode;
34
use cubesql::config::ConfigObj;
4-
use cubesql::sql::{DatabaseProtocol, SessionManager};
5+
use cubesql::sql::SessionManager;
56
use cubesql::transport::TransportService;
67
use futures::StreamExt;
78

Lines changed: 141 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,165 @@
1-
use std::sync::Arc;
1+
use std::{any::Any, collections::HashMap, sync::Arc};
22

3+
use async_trait::async_trait;
4+
use cubeclient::models::V1CubeMeta;
35
use datafusion::{
4-
arrow::datatypes::DataType,
5-
error::Result,
6-
scalar::ScalarValue,
7-
variable::{VarProvider, VarType},
6+
arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit},
7+
datasource::{self, TableProvider},
8+
error::DataFusionError,
9+
execution::context::SessionState as DFSessionState,
10+
logical_plan::Expr,
11+
physical_plan::{udaf::AggregateUDF, udf::ScalarUDF, udtf::TableUDF, ExecutionPlan},
12+
sql::planner::ContextProvider,
813
};
9-
use log::warn;
1014

11-
use crate::sql::{session::DatabaseProtocol, ServerManager, SessionState};
15+
use crate::{
16+
compile::{DatabaseProtocolDetails, MetaContext},
17+
sql::{ColumnType, SessionManager, SessionState},
18+
transport::V1CubeMetaExt,
19+
CubeError,
20+
};
1221

13-
pub struct VariablesProvider {
14-
session: Arc<SessionState>,
15-
server: Arc<ServerManager>,
22+
#[derive(Clone)]
23+
pub struct CubeContext {
24+
/// Internal state for the context (default)
25+
pub state: Arc<DFSessionState>,
26+
/// References
27+
pub meta: Arc<MetaContext>,
28+
pub sessions: Arc<SessionManager>,
29+
pub session_state: Arc<SessionState>,
1630
}
1731

18-
impl VariablesProvider {
19-
pub fn new(session: Arc<SessionState>, server: Arc<ServerManager>) -> Self {
20-
Self { session, server }
32+
impl CubeContext {
33+
pub fn new(
34+
state: Arc<DFSessionState>,
35+
meta: Arc<MetaContext>,
36+
sessions: Arc<SessionManager>,
37+
session_state: Arc<SessionState>,
38+
) -> Self {
39+
Self {
40+
state,
41+
meta,
42+
sessions,
43+
session_state,
44+
}
45+
}
46+
47+
pub fn table_name_by_table_provider(
48+
&self,
49+
table_provider: Arc<dyn datasource::TableProvider>,
50+
) -> Result<String, CubeError> {
51+
self.session_state
52+
.protocol
53+
.table_name_by_table_provider(table_provider)
2154
}
2255

23-
fn get_session_value(&self, identifier: Vec<String>, var_type: VarType) -> Result<ScalarValue> {
24-
let key = if identifier.len() > 1 {
25-
let ignore_first = identifier[0].to_ascii_lowercase() == "@@session".to_owned();
26-
if ignore_first {
27-
identifier[1..].concat()
28-
} else {
29-
identifier.concat()[1..].to_string()
30-
}
31-
} else {
32-
identifier.concat()[1..].to_string()
33-
};
34-
35-
if let Some(var) = self.session.get_variable(&key) {
36-
if var.var_type == var_type {
37-
return Ok(var.value.clone());
38-
}
56+
pub fn get_function<T>(&self, name: &str, udfs: &HashMap<String, Arc<T>>) -> Option<Arc<T>> {
57+
if name.starts_with("pg_catalog.") {
58+
return udfs.get(&format!("{}", &name[11..name.len()])).cloned();
3959
}
4060

41-
warn!("Unknown session variable: {}", key);
61+
udfs.get(name).cloned()
62+
}
63+
}
4264

43-
Ok(ScalarValue::Utf8(None))
65+
impl ContextProvider for CubeContext {
66+
fn get_table_provider(
67+
&self,
68+
tr: datafusion::catalog::TableReference,
69+
) -> Option<Arc<dyn TableProvider>> {
70+
return self.session_state.protocol.get_provider(&self.clone(), tr);
4471
}
4572

46-
fn get_global_value(&self, identifier: Vec<String>) -> Result<ScalarValue> {
47-
let key = if identifier.len() > 1 {
48-
let ignore_first = identifier[0].to_ascii_lowercase() == "@@global".to_owned();
49-
50-
if ignore_first {
51-
identifier[1..].concat()
52-
} else {
53-
identifier.concat()[2..].to_string()
54-
}
55-
} else {
56-
identifier.concat()[2..].to_string()
57-
};
58-
59-
if let Some(var) = self
60-
.server
61-
.read_variables(DatabaseProtocol::MySQL)
62-
.get(&key)
63-
{
64-
if var.var_type == VarType::System {
65-
return Ok(var.value.clone());
66-
}
67-
}
73+
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
74+
self.get_function(name, &self.state.scalar_functions)
75+
}
76+
77+
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
78+
self.get_function(name, &self.state.aggregate_functions)
79+
}
6880

69-
warn!("Unknown system variable: {}", key);
81+
fn get_table_function_meta(&self, name: &str) -> Option<Arc<TableUDF>> {
82+
self.get_function(name, &self.state.table_functions)
83+
}
7084

71-
Ok(ScalarValue::Utf8(None))
85+
fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
86+
Some(DataType::Utf8)
7287
}
7388
}
7489

75-
impl VarProvider for VariablesProvider {
76-
/// get variable value
77-
fn get_value(&self, identifier: Vec<String>) -> Result<ScalarValue> {
78-
let first_word_vec: Vec<char> = identifier[0].chars().collect();
79-
if first_word_vec.len() < 2 {
80-
return Ok(ScalarValue::Utf8(None));
81-
}
90+
pub trait TableName {
91+
fn table_name(&self) -> &str;
92+
}
93+
94+
pub struct CubeTableProvider {
95+
cube: V1CubeMeta,
96+
}
8297

83-
match (&first_word_vec[0], &first_word_vec[1]) {
84-
('@', '@') => {
85-
if identifier.len() > 1
86-
&& identifier[0].to_ascii_lowercase() == "@@session".to_owned()
87-
{
88-
return self.get_session_value(identifier, VarType::System);
89-
}
90-
91-
return self.get_global_value(identifier);
92-
}
93-
('@', _) => return self.get_session_value(identifier, VarType::UserDefined),
94-
(_, _) => return Ok(ScalarValue::Utf8(None)),
95-
};
98+
impl CubeTableProvider {
99+
pub fn new(cube: V1CubeMeta) -> Self {
100+
Self { cube }
96101
}
102+
}
97103

98-
fn get_type(&self, _var_names: &[String]) -> Option<DataType> {
99-
Some(DataType::Utf8)
104+
impl TableName for CubeTableProvider {
105+
fn table_name(&self) -> &str {
106+
&self.cube.name
107+
}
108+
}
109+
110+
#[async_trait]
111+
impl TableProvider for CubeTableProvider {
112+
fn as_any(&self) -> &dyn Any {
113+
self
114+
}
115+
116+
fn schema(&self) -> SchemaRef {
117+
Arc::new(Schema::new(
118+
self.cube
119+
.get_columns()
120+
.into_iter()
121+
.map(|c| {
122+
Field::new(
123+
c.get_name(),
124+
match c.get_column_type() {
125+
ColumnType::Date(large) => {
126+
if large {
127+
DataType::Date64
128+
} else {
129+
DataType::Date32
130+
}
131+
}
132+
ColumnType::Interval(unit) => DataType::Interval(unit),
133+
ColumnType::String => DataType::Utf8,
134+
ColumnType::VarStr => DataType::Utf8,
135+
ColumnType::Boolean => DataType::Boolean,
136+
ColumnType::Double => DataType::Float64,
137+
ColumnType::Int8 => DataType::Int64,
138+
ColumnType::Int32 => DataType::Int64,
139+
ColumnType::Int64 => DataType::Int64,
140+
ColumnType::Blob => DataType::Utf8,
141+
ColumnType::Decimal(p, s) => DataType::Decimal(p, s),
142+
ColumnType::List(field) => DataType::List(field.clone()),
143+
ColumnType::Timestamp => {
144+
DataType::Timestamp(TimeUnit::Millisecond, None)
145+
}
146+
},
147+
true,
148+
)
149+
})
150+
.collect(),
151+
))
152+
}
153+
154+
async fn scan(
155+
&self,
156+
_projection: &Option<Vec<usize>>,
157+
_filters: &[Expr],
158+
_limit: Option<usize>,
159+
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
160+
Err(DataFusionError::Plan(format!(
161+
"Not rewritten table scan node for '{}' cube",
162+
self.cube.name
163+
)))
100164
}
101165
}
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::datasource::{self, TableProvider};
4+
5+
use super::information_schema::mysql::{
6+
collations::InfoSchemaCollationsProvider as MySqlSchemaCollationsProvider,
7+
columns::InfoSchemaColumnsProvider as MySqlSchemaColumnsProvider,
8+
key_column_usage::InfoSchemaKeyColumnUsageProvider as MySqlSchemaKeyColumnUsageProvider,
9+
processlist::InfoSchemaProcesslistProvider as MySqlSchemaProcesslistProvider,
10+
referential_constraints::InfoSchemaReferentialConstraintsProvider as MySqlSchemaReferentialConstraintsProvider,
11+
schemata::InfoSchemaSchemataProvider as MySqlSchemaSchemataProvider,
12+
statistics::InfoSchemaStatisticsProvider as MySqlSchemaStatisticsProvider,
13+
tables::InfoSchemaTableProvider as MySqlSchemaTableProvider,
14+
variables::PerfSchemaVariablesProvider as MySqlPerfSchemaVariablesProvider,
15+
};
16+
use crate::{
17+
compile::{
18+
engine::{CubeContext, CubeTableProvider, TableName},
19+
DatabaseProtocol,
20+
},
21+
CubeError,
22+
};
23+
24+
impl DatabaseProtocol {
25+
pub fn get_mysql_table_name(
26+
&self,
27+
table_provider: Arc<dyn TableProvider>,
28+
) -> Result<String, CubeError> {
29+
let any = table_provider.as_any();
30+
Ok(if let Some(t) = any.downcast_ref::<CubeTableProvider>() {
31+
t.table_name().to_string()
32+
} else if let Some(t) = any.downcast_ref::<MySqlSchemaTableProvider>() {
33+
t.table_name().to_string()
34+
} else if let Some(t) = any.downcast_ref::<MySqlSchemaColumnsProvider>() {
35+
t.table_name().to_string()
36+
} else if let Some(t) = any.downcast_ref::<MySqlSchemaStatisticsProvider>() {
37+
t.table_name().to_string()
38+
} else if let Some(t) = any.downcast_ref::<MySqlSchemaKeyColumnUsageProvider>() {
39+
t.table_name().to_string()
40+
} else if let Some(t) = any.downcast_ref::<MySqlSchemaSchemataProvider>() {
41+
t.table_name().to_string()
42+
} else if let Some(t) = any.downcast_ref::<MySqlSchemaReferentialConstraintsProvider>() {
43+
t.table_name().to_string()
44+
} else if let Some(t) = any.downcast_ref::<MySqlSchemaCollationsProvider>() {
45+
t.table_name().to_string()
46+
} else if let Some(t) = any.downcast_ref::<MySqlPerfSchemaVariablesProvider>() {
47+
t.table_name().to_string()
48+
} else if let Some(t) = any.downcast_ref::<MySqlSchemaProcesslistProvider>() {
49+
t.table_name().to_string()
50+
} else {
51+
return Err(CubeError::internal(format!(
52+
"Unknown table provider with schema: {:?}",
53+
table_provider.schema()
54+
)));
55+
})
56+
}
57+
58+
pub(crate) fn get_mysql_provider(
59+
&self,
60+
context: &CubeContext,
61+
tr: datafusion::catalog::TableReference,
62+
) -> Option<std::sync::Arc<dyn datasource::TableProvider>> {
63+
let (db, table) = match tr {
64+
datafusion::catalog::TableReference::Partial { schema, table, .. } => {
65+
(schema.to_ascii_lowercase(), table.to_ascii_lowercase())
66+
}
67+
datafusion::catalog::TableReference::Full {
68+
catalog: _,
69+
schema,
70+
table,
71+
} => (schema.to_ascii_lowercase(), table.to_ascii_lowercase()),
72+
datafusion::catalog::TableReference::Bare { table } => {
73+
("db".to_string(), table.to_ascii_lowercase())
74+
}
75+
};
76+
77+
match db.as_str() {
78+
"db" => {
79+
if let Some(cube) = context
80+
.meta
81+
.cubes
82+
.iter()
83+
.find(|c| c.name.eq_ignore_ascii_case(&table))
84+
{
85+
// TODO .clone()
86+
return Some(Arc::new(CubeTableProvider::new(cube.clone())));
87+
} else {
88+
return None;
89+
}
90+
}
91+
"information_schema" => match table.as_str() {
92+
"tables" => {
93+
return Some(Arc::new(MySqlSchemaTableProvider::new(
94+
context.meta.clone(),
95+
)))
96+
}
97+
"columns" => {
98+
return Some(Arc::new(MySqlSchemaColumnsProvider::new(
99+
context.meta.clone(),
100+
)))
101+
}
102+
"statistics" => return Some(Arc::new(MySqlSchemaStatisticsProvider::new())),
103+
"key_column_usage" => {
104+
return Some(Arc::new(MySqlSchemaKeyColumnUsageProvider::new()))
105+
}
106+
"schemata" => return Some(Arc::new(MySqlSchemaSchemataProvider::new())),
107+
"processlist" => {
108+
return Some(Arc::new(MySqlSchemaProcesslistProvider::new(
109+
context.sessions.clone(),
110+
)))
111+
}
112+
"referential_constraints" => {
113+
return Some(Arc::new(MySqlSchemaReferentialConstraintsProvider::new()))
114+
}
115+
"collations" => return Some(Arc::new(MySqlSchemaCollationsProvider::new())),
116+
_ => return None,
117+
},
118+
"performance_schema" => match table.as_str() {
119+
"global_variables" => {
120+
return Some(Arc::new(MySqlPerfSchemaVariablesProvider::new(
121+
"performance_schema.global_variables".to_string(),
122+
context
123+
.sessions
124+
.server
125+
.all_variables(context.session_state.protocol.clone()),
126+
)))
127+
}
128+
"session_variables" => {
129+
return Some(Arc::new(MySqlPerfSchemaVariablesProvider::new(
130+
"performance_schema.session_variables".to_string(),
131+
context.session_state.all_variables(),
132+
)))
133+
}
134+
_ => return None,
135+
},
136+
_ => return None,
137+
}
138+
}
139+
}

0 commit comments

Comments
 (0)