|
1 | | -use std::sync::Arc; |
| 1 | +use std::{any::Any, collections::HashMap, sync::Arc}; |
2 | 2 |
|
| 3 | +use async_trait::async_trait; |
| 4 | +use cubeclient::models::V1CubeMeta; |
3 | 5 | use datafusion::{ |
4 | | - arrow::datatypes::DataType, |
5 | | - error::Result, |
6 | | - scalar::ScalarValue, |
7 | | - variable::{VarProvider, VarType}, |
| 6 | + arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}, |
| 7 | + datasource::{self, TableProvider}, |
| 8 | + error::DataFusionError, |
| 9 | + execution::context::SessionState as DFSessionState, |
| 10 | + logical_plan::Expr, |
| 11 | + physical_plan::{udaf::AggregateUDF, udf::ScalarUDF, udtf::TableUDF, ExecutionPlan}, |
| 12 | + sql::planner::ContextProvider, |
8 | 13 | }; |
9 | | -use log::warn; |
10 | 14 |
|
11 | | -use crate::sql::{session::DatabaseProtocol, ServerManager, SessionState}; |
| 15 | +use crate::{ |
| 16 | + compile::{DatabaseProtocolDetails, MetaContext}, |
| 17 | + sql::{ColumnType, SessionManager, SessionState}, |
| 18 | + transport::V1CubeMetaExt, |
| 19 | + CubeError, |
| 20 | +}; |
12 | 21 |
|
13 | | -pub struct VariablesProvider { |
14 | | - session: Arc<SessionState>, |
15 | | - server: Arc<ServerManager>, |
| 22 | +#[derive(Clone)] |
| 23 | +pub struct CubeContext { |
| 24 | + /// Internal state for the context (default) |
| 25 | + pub state: Arc<DFSessionState>, |
| 26 | + /// References |
| 27 | + pub meta: Arc<MetaContext>, |
| 28 | + pub sessions: Arc<SessionManager>, |
| 29 | + pub session_state: Arc<SessionState>, |
16 | 30 | } |
17 | 31 |
|
18 | | -impl VariablesProvider { |
19 | | - pub fn new(session: Arc<SessionState>, server: Arc<ServerManager>) -> Self { |
20 | | - Self { session, server } |
| 32 | +impl CubeContext { |
| 33 | + pub fn new( |
| 34 | + state: Arc<DFSessionState>, |
| 35 | + meta: Arc<MetaContext>, |
| 36 | + sessions: Arc<SessionManager>, |
| 37 | + session_state: Arc<SessionState>, |
| 38 | + ) -> Self { |
| 39 | + Self { |
| 40 | + state, |
| 41 | + meta, |
| 42 | + sessions, |
| 43 | + session_state, |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + pub fn table_name_by_table_provider( |
| 48 | + &self, |
| 49 | + table_provider: Arc<dyn datasource::TableProvider>, |
| 50 | + ) -> Result<String, CubeError> { |
| 51 | + self.session_state |
| 52 | + .protocol |
| 53 | + .table_name_by_table_provider(table_provider) |
21 | 54 | } |
22 | 55 |
|
23 | | - fn get_session_value(&self, identifier: Vec<String>, var_type: VarType) -> Result<ScalarValue> { |
24 | | - let key = if identifier.len() > 1 { |
25 | | - let ignore_first = identifier[0].to_ascii_lowercase() == "@@session".to_owned(); |
26 | | - if ignore_first { |
27 | | - identifier[1..].concat() |
28 | | - } else { |
29 | | - identifier.concat()[1..].to_string() |
30 | | - } |
31 | | - } else { |
32 | | - identifier.concat()[1..].to_string() |
33 | | - }; |
34 | | - |
35 | | - if let Some(var) = self.session.get_variable(&key) { |
36 | | - if var.var_type == var_type { |
37 | | - return Ok(var.value.clone()); |
38 | | - } |
| 56 | + pub fn get_function<T>(&self, name: &str, udfs: &HashMap<String, Arc<T>>) -> Option<Arc<T>> { |
| 57 | + if name.starts_with("pg_catalog.") { |
| 58 | + return udfs.get(&format!("{}", &name[11..name.len()])).cloned(); |
39 | 59 | } |
40 | 60 |
|
41 | | - warn!("Unknown session variable: {}", key); |
| 61 | + udfs.get(name).cloned() |
| 62 | + } |
| 63 | +} |
42 | 64 |
|
43 | | - Ok(ScalarValue::Utf8(None)) |
| 65 | +impl ContextProvider for CubeContext { |
| 66 | + fn get_table_provider( |
| 67 | + &self, |
| 68 | + tr: datafusion::catalog::TableReference, |
| 69 | + ) -> Option<Arc<dyn TableProvider>> { |
| 70 | + return self.session_state.protocol.get_provider(&self.clone(), tr); |
44 | 71 | } |
45 | 72 |
|
46 | | - fn get_global_value(&self, identifier: Vec<String>) -> Result<ScalarValue> { |
47 | | - let key = if identifier.len() > 1 { |
48 | | - let ignore_first = identifier[0].to_ascii_lowercase() == "@@global".to_owned(); |
49 | | - |
50 | | - if ignore_first { |
51 | | - identifier[1..].concat() |
52 | | - } else { |
53 | | - identifier.concat()[2..].to_string() |
54 | | - } |
55 | | - } else { |
56 | | - identifier.concat()[2..].to_string() |
57 | | - }; |
58 | | - |
59 | | - if let Some(var) = self |
60 | | - .server |
61 | | - .read_variables(DatabaseProtocol::MySQL) |
62 | | - .get(&key) |
63 | | - { |
64 | | - if var.var_type == VarType::System { |
65 | | - return Ok(var.value.clone()); |
66 | | - } |
67 | | - } |
| 73 | + fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> { |
| 74 | + self.get_function(name, &self.state.scalar_functions) |
| 75 | + } |
| 76 | + |
| 77 | + fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> { |
| 78 | + self.get_function(name, &self.state.aggregate_functions) |
| 79 | + } |
68 | 80 |
|
69 | | - warn!("Unknown system variable: {}", key); |
| 81 | + fn get_table_function_meta(&self, name: &str) -> Option<Arc<TableUDF>> { |
| 82 | + self.get_function(name, &self.state.table_functions) |
| 83 | + } |
70 | 84 |
|
71 | | - Ok(ScalarValue::Utf8(None)) |
| 85 | + fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> { |
| 86 | + Some(DataType::Utf8) |
72 | 87 | } |
73 | 88 | } |
74 | 89 |
|
75 | | -impl VarProvider for VariablesProvider { |
76 | | - /// get variable value |
77 | | - fn get_value(&self, identifier: Vec<String>) -> Result<ScalarValue> { |
78 | | - let first_word_vec: Vec<char> = identifier[0].chars().collect(); |
79 | | - if first_word_vec.len() < 2 { |
80 | | - return Ok(ScalarValue::Utf8(None)); |
81 | | - } |
| 90 | +pub trait TableName { |
| 91 | + fn table_name(&self) -> &str; |
| 92 | +} |
| 93 | + |
| 94 | +pub struct CubeTableProvider { |
| 95 | + cube: V1CubeMeta, |
| 96 | +} |
82 | 97 |
|
83 | | - match (&first_word_vec[0], &first_word_vec[1]) { |
84 | | - ('@', '@') => { |
85 | | - if identifier.len() > 1 |
86 | | - && identifier[0].to_ascii_lowercase() == "@@session".to_owned() |
87 | | - { |
88 | | - return self.get_session_value(identifier, VarType::System); |
89 | | - } |
90 | | - |
91 | | - return self.get_global_value(identifier); |
92 | | - } |
93 | | - ('@', _) => return self.get_session_value(identifier, VarType::UserDefined), |
94 | | - (_, _) => return Ok(ScalarValue::Utf8(None)), |
95 | | - }; |
| 98 | +impl CubeTableProvider { |
| 99 | + pub fn new(cube: V1CubeMeta) -> Self { |
| 100 | + Self { cube } |
96 | 101 | } |
| 102 | +} |
97 | 103 |
|
98 | | - fn get_type(&self, _var_names: &[String]) -> Option<DataType> { |
99 | | - Some(DataType::Utf8) |
| 104 | +impl TableName for CubeTableProvider { |
| 105 | + fn table_name(&self) -> &str { |
| 106 | + &self.cube.name |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +#[async_trait] |
| 111 | +impl TableProvider for CubeTableProvider { |
| 112 | + fn as_any(&self) -> &dyn Any { |
| 113 | + self |
| 114 | + } |
| 115 | + |
| 116 | + fn schema(&self) -> SchemaRef { |
| 117 | + Arc::new(Schema::new( |
| 118 | + self.cube |
| 119 | + .get_columns() |
| 120 | + .into_iter() |
| 121 | + .map(|c| { |
| 122 | + Field::new( |
| 123 | + c.get_name(), |
| 124 | + match c.get_column_type() { |
| 125 | + ColumnType::Date(large) => { |
| 126 | + if large { |
| 127 | + DataType::Date64 |
| 128 | + } else { |
| 129 | + DataType::Date32 |
| 130 | + } |
| 131 | + } |
| 132 | + ColumnType::Interval(unit) => DataType::Interval(unit), |
| 133 | + ColumnType::String => DataType::Utf8, |
| 134 | + ColumnType::VarStr => DataType::Utf8, |
| 135 | + ColumnType::Boolean => DataType::Boolean, |
| 136 | + ColumnType::Double => DataType::Float64, |
| 137 | + ColumnType::Int8 => DataType::Int64, |
| 138 | + ColumnType::Int32 => DataType::Int64, |
| 139 | + ColumnType::Int64 => DataType::Int64, |
| 140 | + ColumnType::Blob => DataType::Utf8, |
| 141 | + ColumnType::Decimal(p, s) => DataType::Decimal(p, s), |
| 142 | + ColumnType::List(field) => DataType::List(field.clone()), |
| 143 | + ColumnType::Timestamp => { |
| 144 | + DataType::Timestamp(TimeUnit::Millisecond, None) |
| 145 | + } |
| 146 | + }, |
| 147 | + true, |
| 148 | + ) |
| 149 | + }) |
| 150 | + .collect(), |
| 151 | + )) |
| 152 | + } |
| 153 | + |
| 154 | + async fn scan( |
| 155 | + &self, |
| 156 | + _projection: &Option<Vec<usize>>, |
| 157 | + _filters: &[Expr], |
| 158 | + _limit: Option<usize>, |
| 159 | + ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> { |
| 160 | + Err(DataFusionError::Plan(format!( |
| 161 | + "Not rewritten table scan node for '{}' cube", |
| 162 | + self.cube.name |
| 163 | + ))) |
100 | 164 | } |
101 | 165 | } |
0 commit comments