Skip to content

Commit 074e6db

Browse files
authored
feat(FlightSQL): support do_get_tables. (#10858)
1 parent b01ffcc commit 074e6db

File tree

3 files changed

+144
-2
lines changed

3 files changed

+144
-2
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::sync::Arc;
16+
17+
use arrow_array::builder::StringBuilder;
18+
use arrow_array::ArrayRef;
19+
use arrow_array::RecordBatch;
20+
use arrow_flight::utils::batches_to_flight_data;
21+
use arrow_schema::DataType;
22+
use arrow_schema::Field;
23+
use arrow_schema::Schema;
24+
use common_catalog::catalog::Catalog;
25+
use common_catalog::catalog::CatalogManager;
26+
use common_catalog::table_context::TableContext;
27+
use common_exception::ErrorCode;
28+
use futures_util::stream;
29+
use tonic::Status;
30+
31+
use crate::servers::flight_sql::flight_sql_service::DoGetStream;
32+
33+
pub(super) struct CatalogInfoProvider {}
34+
35+
impl CatalogInfoProvider {
36+
fn batch_to_get_stream(batch: RecordBatch) -> Result<DoGetStream, Status> {
37+
let schema = (*batch.schema()).clone();
38+
let batches = vec![batch];
39+
let flight_data = batches_to_flight_data(schema, batches)
40+
.map_err(|e| Status::internal(format!("{e:?}")))?
41+
.into_iter()
42+
.map(Ok);
43+
let stream = stream::iter(flight_data);
44+
Ok(Box::pin(stream))
45+
}
46+
47+
async fn get_tables_internal(
48+
ctx: Arc<dyn TableContext>,
49+
catalog_name: Option<String>,
50+
database_name: Option<String>,
51+
) -> common_exception::Result<(Vec<String>, Vec<String>, Vec<String>, Vec<String>)> {
52+
let tenant = ctx.get_tenant();
53+
let catalog_mgr = CatalogManager::instance();
54+
let catalogs: Vec<(String, Arc<dyn Catalog>)> = if let Some(catalog_name) = catalog_name {
55+
vec![(
56+
catalog_name.clone(),
57+
catalog_mgr.get_catalog(&catalog_name)?,
58+
)]
59+
} else {
60+
catalog_mgr
61+
.catalogs
62+
.iter()
63+
.map(|r| (r.key().to_string(), r.value().clone()))
64+
.collect()
65+
};
66+
67+
let mut catalog_names = vec![];
68+
let mut database_names = vec![];
69+
let mut table_names = vec![];
70+
let mut table_types = vec![];
71+
let table_type = "table".to_string();
72+
for (catalog_name, catalog) in catalogs.into_iter() {
73+
let dbs = if let Some(database_name) = &database_name {
74+
vec![catalog.get_database(tenant.as_str(), database_name).await?]
75+
} else {
76+
catalog.list_databases(tenant.as_str()).await?
77+
};
78+
for db in dbs {
79+
let db_name = db.name().to_string().into_boxed_str();
80+
let db_name: &str = Box::leak(db_name);
81+
let tables = match catalog.list_tables(tenant.as_str(), db_name).await {
82+
Ok(tables) => tables,
83+
Err(err) if err.code() == ErrorCode::EMPTY_SHARE_ENDPOINT_CONFIG => {
84+
tracing::warn!("list tables failed on db {}: {}", db.name(), err);
85+
continue;
86+
}
87+
Err(err) => return Err(err),
88+
};
89+
for table in tables {
90+
catalog_names.push(catalog_name.clone());
91+
database_names.push(db_name.to_string());
92+
table_names.push(table.name().to_string());
93+
table_types.push(table_type.clone());
94+
}
95+
}
96+
}
97+
Ok((catalog_names, database_names, table_names, table_types))
98+
}
99+
100+
pub(crate) async fn get_tables(
101+
ctx: Arc<dyn TableContext>,
102+
catalog_name: Option<String>,
103+
database_name: Option<String>,
104+
) -> Result<DoGetStream, Status> {
105+
let schema = Arc::new(Schema::new(vec![
106+
Field::new("catalog_name", DataType::Utf8, false),
107+
Field::new("db_schema_name", DataType::Utf8, false),
108+
Field::new("table_name", DataType::Utf8, false),
109+
Field::new("table_type", DataType::Utf8, false),
110+
]));
111+
let (catalog_name, db_schema_name, table_name, table_type) =
112+
Self::get_tables_internal(ctx.clone(), catalog_name, database_name)
113+
.await
114+
.map_err(|e| Status::internal(format!("{e:?}")))?;
115+
let batch = RecordBatch::try_new(schema, vec![
116+
Self::string_array(catalog_name),
117+
Self::string_array(db_schema_name),
118+
Self::string_array(table_name),
119+
Self::string_array(table_type),
120+
])
121+
.map_err(|e| Status::internal(format!("RecordBatch::try_new fail {:?}", e)))?;
122+
Self::batch_to_get_stream(batch)
123+
}
124+
125+
fn string_array(values: Vec<String>) -> ArrayRef {
126+
let mut builder = StringBuilder::new();
127+
for v in &values {
128+
builder.append_value(v);
129+
}
130+
Arc::new(builder.finish())
131+
}
132+
}

src/query/service/src/servers/flight_sql/flight_sql_service/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
// The servers module used for external communication with user, such as MySQL wired protocol, etc.
1616

17+
mod catalog;
1718
mod query;
1819
mod service;
1920
mod session;
@@ -23,6 +24,7 @@ use std::pin::Pin;
2324
use std::sync::Arc;
2425

2526
use arrow_flight::FlightData;
27+
use catalog::CatalogInfoProvider;
2628
use common_sql::plans::Plan;
2729
use common_sql::PlanExtras;
2830
use dashmap::DashMap;

src/query/service/src/servers/flight_sql/flight_sql_service/service.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,10 +377,18 @@ impl FlightSqlService for FlightSqlServiceImpl {
377377
async fn do_get_tables(
378378
&self,
379379
query: CommandGetTables,
380-
_request: Request<Ticket>,
380+
request: Request<Ticket>,
381381
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
382382
tracing::info!("do_get_tables({query:?})");
383-
Err(Status::unimplemented("do_get_tables not implemented"))
383+
let session = self.get_session(&request)?;
384+
let context = session
385+
.create_query_context()
386+
.await
387+
.map_err(|e| status!("Could not create_query_context", e))?;
388+
Ok(Response::new(
389+
super::CatalogInfoProvider::get_tables(context.clone(), query.catalog.clone(), None)
390+
.await?,
391+
))
384392
}
385393

386394
#[async_backtrace::framed]

0 commit comments

Comments
 (0)