Skip to content

Commit 861f13e

Browse files
authored
chore(cubesql): SessionManager - support extra_id for Session (#8612)
1 parent 0f6701f commit 861f13e

File tree

7 files changed

+112
-77
lines changed

7 files changed

+112
-77
lines changed

packages/cubejs-backend-native/src/node_export.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ async fn handle_sql_query(
216216
};
217217

218218
let session = session_manager
219-
.create_session(DatabaseProtocol::PostgreSQL, host, port)
220-
.await;
219+
.create_session(DatabaseProtocol::PostgreSQL, host, port, None)
220+
.await?;
221221

222222
session
223223
.state

rust/cubesql/cubesql/src/compile/engine/information_schema/mysql/processlist.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ impl TableProvider for InfoSchemaProcesslistProvider {
134134
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
135135
let mut builder = InformationSchemaProcesslistBuilder::new();
136136

137-
for process_list in self.sessions.process_list().await {
137+
for process_list in self.sessions.map_sessions::<SessionProcessList>().await {
138138
builder.add_row(process_list);
139139
}
140140

rust/cubesql/cubesql/src/compile/engine/information_schema/postgres/pg_stat_activity.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ impl PgStatActivityBuilder {
7272
self.oid.append_value(session.oid).unwrap();
7373
self.datname.append_option(session.datname).unwrap();
7474
self.pid.append_value(session.pid).unwrap();
75-
self.leader_pid.append_null().unwrap();
76-
self.usesysid.append_null().unwrap();
75+
self.leader_pid.append_option(session.leader_pid).unwrap();
76+
self.usesysid.append_option(session.usesysid).unwrap();
7777
self.usename.append_option(session.usename).unwrap();
7878
self.application_name
7979
.append_option(session.application_name)
@@ -205,7 +205,7 @@ impl TableProvider for PgCatalogStatActivityProvider {
205205
_filters: &[Expr],
206206
_limit: Option<usize>,
207207
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
208-
let sessions = self.sessions.stat_activity().await;
208+
let sessions = self.sessions.map_sessions::<SessionStatActivity>().await;
209209
let mut builder = PgStatActivityBuilder::new(sessions.len());
210210

211211
for session in sessions {

rust/cubesql/cubesql/src/compile/test/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,9 @@ async fn get_test_session_with_config_and_transport(
565565
};
566566
let session_manager = Arc::new(SessionManager::new(server.clone()));
567567
let session = session_manager
568-
.create_session(protocol, "127.0.0.1".to_string(), 1234)
569-
.await;
568+
.create_session(protocol, "127.0.0.1".to_string(), 1234, None)
569+
.await
570+
.unwrap();
570571

571572
// Populate like shims
572573
session.state.set_database(Some(db_name.to_string()));

rust/cubesql/cubesql/src/sql/postgres/service.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use tokio::{
77
};
88
use tokio_util::sync::CancellationToken;
99

10+
use super::shim::AsyncPostgresShim;
1011
use crate::{
1112
compile::DatabaseProtocol,
1213
config::processing_loop::{ProcessingLoop, ShutdownMode},
@@ -15,8 +16,6 @@ use crate::{
1516
CubeError,
1617
};
1718

18-
use super::shim::AsyncPostgresShim;
19-
2019
pub struct PostgresServer {
2120
// options
2221
address: String,
@@ -98,10 +97,18 @@ impl ProcessingLoop for PostgresServer {
9897
}
9998
};
10099

101-
let session = self
100+
let session = match self
102101
.session_manager
103-
.create_session(DatabaseProtocol::PostgreSQL, client_addr, client_port)
104-
.await;
102+
.create_session(DatabaseProtocol::PostgreSQL, client_addr, client_port, None)
103+
.await
104+
{
105+
Ok(r) => r,
106+
Err(err) => {
107+
error!("Session creation error: {}", err);
108+
continue;
109+
}
110+
};
111+
105112
let logger = Arc::new(SessionLogger::new(session.state.clone()));
106113

107114
trace!("[pg] New connection {}", session.state.connection_id);
@@ -147,7 +154,7 @@ impl ProcessingLoop for PostgresServer {
147154

148155
// Close the listening socket (so we _visibly_ stop accepting incoming connections) before
149156
// we wait for the outstanding connection tasks finish.
150-
std::mem::drop(listener);
157+
drop(listener);
151158

152159
// Now that we've had the stop signal, wait for outstanding connection tasks to finish
153160
// cleanly.

rust/cubesql/cubesql/src/sql/session.rs

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ pub enum QueryState {
6262
pub struct SessionState {
6363
// connection id, immutable
6464
pub connection_id: u32,
65+
// Can be UUID or anything else. MDX uses UUID
66+
pub extra_id: Option<String>,
6567
// secret for this session
6668
pub secret: u32,
6769
// client ip, immutable
@@ -95,6 +97,7 @@ pub struct SessionState {
9597
impl SessionState {
9698
pub fn new(
9799
connection_id: u32,
100+
extra_id: Option<String>,
98101
client_ip: String,
99102
client_port: u16,
100103
protocol: DatabaseProtocol,
@@ -106,6 +109,7 @@ impl SessionState {
106109

107110
Self {
108111
connection_id,
112+
extra_id,
109113
secret: rng.gen(),
110114
client_ip,
111115
client_port,
@@ -399,46 +403,7 @@ pub struct Session {
399403
pub state: Arc<SessionState>,
400404
}
401405

402-
impl Session {
403-
// For PostgreSQL
404-
pub fn to_stat_activity(self: &Arc<Self>) -> SessionStatActivity {
405-
let query = self.state.current_query();
406-
407-
let application_name = if let Some(v) = self.state.get_variable("application_name") {
408-
match v.value {
409-
ScalarValue::Utf8(r) => r,
410-
_ => None,
411-
}
412-
} else {
413-
None
414-
};
415-
416-
SessionStatActivity {
417-
oid: self.state.connection_id,
418-
datname: self.state.database(),
419-
pid: self.state.connection_id,
420-
leader_pid: None,
421-
usesysid: 0,
422-
usename: self.state.user(),
423-
application_name,
424-
client_addr: self.state.client_ip.clone(),
425-
client_hostname: None,
426-
client_port: self.state.client_port.clone(),
427-
query,
428-
}
429-
}
430-
431-
// For MySQL
432-
pub fn to_process_list(self: &Arc<Self>) -> SessionProcessList {
433-
SessionProcessList {
434-
id: self.state.connection_id,
435-
host: self.state.client_ip.clone(),
436-
user: self.state.user(),
437-
database: self.state.database(),
438-
}
439-
}
440-
}
441-
406+
/// Specific representation of session for MySQL
442407
#[derive(Debug)]
443408
pub struct SessionProcessList {
444409
pub id: u32,
@@ -447,17 +412,58 @@ pub struct SessionProcessList {
447412
pub database: Option<String>,
448413
}
449414

415+
impl From<&Arc<Session>> for SessionProcessList {
416+
fn from(session: &Arc<Session>) -> Self {
417+
Self {
418+
id: session.state.connection_id,
419+
host: session.state.client_ip.clone(),
420+
user: session.state.user(),
421+
database: session.state.database(),
422+
}
423+
}
424+
}
425+
426+
/// Specific representation of session for PostgreSQL
450427
#[derive(Debug)]
451428
pub struct SessionStatActivity {
452429
pub oid: u32,
453430
pub datname: Option<String>,
454431
pub pid: u32,
455432
pub leader_pid: Option<u32>,
456-
pub usesysid: u32,
433+
pub usesysid: Option<u32>,
457434
pub usename: Option<String>,
458435
pub application_name: Option<String>,
459436
pub client_addr: String,
460437
pub client_hostname: Option<String>,
461438
pub client_port: u16,
462439
pub query: Option<String>,
463440
}
441+
442+
impl From<&Arc<Session>> for SessionStatActivity {
443+
fn from(session: &Arc<Session>) -> Self {
444+
let query = session.state.current_query();
445+
446+
let application_name = if let Some(v) = session.state.get_variable("application_name") {
447+
match v.value {
448+
ScalarValue::Utf8(r) => r,
449+
_ => None,
450+
}
451+
} else {
452+
None
453+
};
454+
455+
Self {
456+
oid: session.state.connection_id,
457+
datname: session.state.database(),
458+
pid: session.state.connection_id,
459+
leader_pid: None,
460+
usesysid: None,
461+
usename: session.state.user(),
462+
application_name,
463+
client_addr: session.state.client_ip.clone(),
464+
client_hostname: None,
465+
client_port: session.state.client_port.clone(),
466+
query,
467+
}
468+
}
469+
}

rust/cubesql/cubesql/src/sql/session_manager.rs

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,21 @@ use std::{
1010

1111
use super::{
1212
server_manager::ServerManager,
13-
session::{Session, SessionProcessList, SessionStatActivity, SessionState},
13+
session::{Session, SessionState},
1414
};
1515
use crate::compile::DatabaseProtocol;
1616

17+
#[derive(Debug)]
18+
struct SessionManagerInner {
19+
sessions: HashMap<u32, Arc<Session>>,
20+
uid_to_session: HashMap<String, Arc<Session>>,
21+
}
22+
1723
#[derive(Debug)]
1824
pub struct SessionManager {
1925
// Sessions
2026
last_id: AtomicU32,
21-
sessions: RWLockAsync<HashMap<u32, Arc<Session>>>,
27+
sessions: RWLockAsync<SessionManagerInner>,
2228
pub temp_table_size: AtomicUsize,
2329
// Backref
2430
pub server: Arc<ServerManager>,
@@ -30,7 +36,10 @@ impl SessionManager {
3036
pub fn new(server: Arc<ServerManager>) -> Self {
3137
Self {
3238
last_id: AtomicU32::new(1),
33-
sessions: RWLockAsync::new(HashMap::new()),
39+
sessions: RWLockAsync::new(SessionManagerInner {
40+
sessions: HashMap::new(),
41+
uid_to_session: HashMap::new(),
42+
}),
3443
temp_table_size: AtomicUsize::new(0),
3544
server,
3645
}
@@ -41,60 +50,72 @@ impl SessionManager {
4150
protocol: DatabaseProtocol,
4251
client_addr: String,
4352
client_port: u16,
44-
) -> Arc<Session> {
53+
extra_id: Option<String>,
54+
) -> Result<Arc<Session>, CubeError> {
4555
let connection_id = self.last_id.fetch_add(1, Ordering::SeqCst);
4656

47-
let sess = Session {
57+
let session_ref = Arc::new(Session {
4858
session_manager: self.clone(),
4959
server: self.server.clone(),
5060
state: Arc::new(SessionState::new(
5161
connection_id,
62+
extra_id.clone(),
5263
client_addr,
5364
client_port,
5465
protocol,
5566
None,
5667
Duration::from_secs(self.server.config_obj.auth_expire_secs()),
5768
Arc::downgrade(self),
5869
)),
59-
};
60-
61-
let session_ref = Arc::new(sess);
70+
});
6271

6372
let mut guard = self.sessions.write().await;
6473

65-
guard.insert(connection_id, session_ref.clone());
74+
if let Some(extra_id) = extra_id {
75+
if guard.uid_to_session.contains_key(&extra_id) {
76+
return Err(CubeError::user(format!(
77+
"Session cannot be created, because extra_id: {} already exists",
78+
extra_id
79+
)));
80+
}
6681

67-
session_ref
68-
}
82+
guard.uid_to_session.insert(extra_id, session_ref.clone());
83+
}
6984

70-
pub async fn stat_activity(self: &Arc<Self>) -> Vec<SessionStatActivity> {
71-
let guard = self.sessions.read().await;
85+
guard.sessions.insert(connection_id, session_ref.clone());
7286

73-
guard
74-
.values()
75-
.map(Session::to_stat_activity)
76-
.collect::<Vec<SessionStatActivity>>()
87+
Ok(session_ref)
7788
}
7889

79-
pub async fn process_list(self: &Arc<Self>) -> Vec<SessionProcessList> {
90+
pub async fn map_sessions<T: for<'a> From<&'a Arc<Session>>>(self: &Arc<Self>) -> Vec<T> {
8091
let guard = self.sessions.read().await;
8192

8293
guard
94+
.sessions
8395
.values()
84-
.map(Session::to_process_list)
85-
.collect::<Vec<SessionProcessList>>()
96+
.map(|session| T::from(session))
97+
.collect::<Vec<T>>()
8698
}
8799

88100
pub async fn get_session(&self, connection_id: u32) -> Option<Arc<Session>> {
89101
let guard = self.sessions.read().await;
90102

91-
guard.get(&connection_id).map(|s| s.clone())
103+
guard.sessions.get(&connection_id).map(|s| s.clone())
104+
}
105+
106+
pub async fn get_session_by_extra_id(&self, extra_id: String) -> Option<Arc<Session>> {
107+
let guard = self.sessions.read().await;
108+
guard.uid_to_session.get(&extra_id).map(|s| s.clone())
92109
}
93110

94111
pub async fn drop_session(&self, connection_id: u32) {
95112
let mut guard = self.sessions.write().await;
96113

97-
if let Some(connection) = guard.remove(&connection_id) {
114+
if let Some(connection) = guard.sessions.remove(&connection_id) {
115+
if let Some(extra_id) = &connection.state.extra_id {
116+
guard.uid_to_session.remove(extra_id);
117+
}
118+
98119
self.temp_table_size.fetch_sub(
99120
connection.state.temp_tables().physical_size(),
100121
Ordering::SeqCst,

0 commit comments

Comments
 (0)