Skip to content

Commit 35bea8b

Browse files
authored
feat: add explicit session setters to core client and driver (#728)
1 parent 09827ec commit 35bea8b

File tree

8 files changed

+110
-28
lines changed

8 files changed

+110
-28
lines changed

core/src/client.rs

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,15 @@ impl APIClient {
177177
let password = percent_decode_str(password).decode_utf8()?;
178178
client.auth = Arc::new(BasicAuth::new(u.username(), password));
179179
}
180-
let database = match u.path().trim_start_matches('/') {
181-
"" => None,
182-
s => Some(s.to_string()),
183-
};
184-
let mut role = None;
180+
181+
let mut session_state = SessionState::default();
182+
183+
let database = u.path().trim_start_matches('/');
184+
if !database.is_empty() {
185+
session_state.set_database(database);
186+
}
187+
185188
let mut scheme = "https";
186-
let mut session_settings = BTreeMap::new();
187189
for (k, v) in u.query_pairs() {
188190
match k.as_ref() {
189191
"wait_time_secs" => {
@@ -222,7 +224,7 @@ impl APIClient {
222224
"warehouse" => {
223225
client.warehouse = Mutex::new(Some(v.to_string()));
224226
}
225-
"role" => role = Some(v.to_string()),
227+
"role" => session_state.set_role(v),
226228
"sslmode" => match v.as_ref() {
227229
"disable" => scheme = "http",
228230
"require" | "enable" => scheme = "https",
@@ -273,7 +275,7 @@ impl APIClient {
273275
}
274276
}
275277
_ => {
276-
session_settings.insert(k.to_string(), v.to_string());
278+
session_state.set(k, v);
277279
}
278280
}
279281
}
@@ -286,14 +288,8 @@ impl APIClient {
286288
},
287289
};
288290
client.scheme = scheme.to_string();
289-
290291
client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
291-
client.session_state = Mutex::new(
292-
SessionState::default()
293-
.with_settings(Some(session_settings))
294-
.with_role(role)
295-
.with_database(database),
296-
);
292+
client.session_state = Mutex::new(session_state);
297293

298294
Ok(client)
299295
}
@@ -372,6 +368,26 @@ impl APIClient {
372368
guard.database.clone()
373369
}
374370

371+
pub fn set_warehouse(&self, warehouse: impl Into<String>) {
372+
let mut guard = self.warehouse.lock();
373+
*guard = Some(warehouse.into());
374+
}
375+
376+
pub fn set_database(&self, database: impl Into<String>) {
377+
let mut guard = self.session_state.lock();
378+
guard.set_database(database);
379+
}
380+
381+
pub fn set_role(&self, role: impl Into<String>) {
382+
let mut guard = self.session_state.lock();
383+
guard.set_role(role);
384+
}
385+
386+
pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
387+
let mut guard = self.session_state.lock();
388+
guard.set(key, value);
389+
}
390+
375391
pub async fn current_role(&self) -> Option<String> {
376392
let guard = self.session_state.lock();
377393
guard.role.clone()

core/src/request.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ mod test {
8383

8484
#[test]
8585
fn build_request() -> Result<()> {
86+
let mut session = SessionState::default();
87+
session.set_database("default");
8688
let req = QueryRequest::new("select 1")
87-
.with_session(Some(
88-
SessionState::default().with_database(Some("default".to_string())),
89-
))
89+
.with_session(Some(session))
9090
.with_pagination(Some(PaginationConfig {
9191
wait_time_secs: Some(1),
9292
max_rows_in_buffer: Some(1),

core/src/session.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,16 @@ pub struct SessionState {
4040
}
4141

4242
impl SessionState {
43-
pub fn with_settings(mut self, settings: Option<BTreeMap<String, String>>) -> Self {
44-
self.settings = settings;
45-
self
43+
pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
44+
let settings = self.settings.get_or_insert_with(BTreeMap::new);
45+
settings.insert(key.into(), value.into());
4646
}
4747

48-
pub fn with_database(mut self, database: Option<String>) -> Self {
49-
self.database = database;
50-
self
48+
pub fn set_database(&mut self, database: impl Into<String>) {
49+
self.database = Some(database.into());
5150
}
5251

53-
pub fn with_role(mut self, role: Option<String>) -> Self {
54-
self.role = role;
55-
self
52+
pub fn set_role(&mut self, role: impl Into<String>) {
53+
self.role = Some(role.into());
5654
}
5755
}

driver/src/client.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,22 @@ impl Connection {
223223
self.inner.stream_load(sql, data, method).await
224224
}
225225

226+
pub fn set_warehouse(&self, warehouse: &str) -> Result<()> {
227+
self.inner.set_warehouse(warehouse)
228+
}
229+
230+
pub fn set_database(&self, database: &str) -> Result<()> {
231+
self.inner.set_database(database)
232+
}
233+
234+
pub fn set_role(&self, role: &str) -> Result<()> {
235+
self.inner.set_role(role)
236+
}
237+
238+
pub fn set_session(&self, key: &str, value: &str) -> Result<()> {
239+
self.inner.set_session(key, value)
240+
}
241+
226242
// PUT file://<path_to_file>/<filename> internalStage|externalStage
227243
pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
228244
self.inner.put_files(local_file, stage).await

driver/src/conn.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,14 @@ pub trait IConnection: Send + Sync {
143143
_method: LoadMethod,
144144
) -> Result<ServerStats>;
145145

146+
fn set_warehouse(&self, warehouse: &str) -> Result<()>;
147+
148+
fn set_database(&self, database: &str) -> Result<()>;
149+
150+
fn set_role(&self, role: &str) -> Result<()>;
151+
152+
fn set_session(&self, key: &str, value: &str) -> Result<()>;
153+
146154
// PUT file://<path_to_file>/<filename> internalStage|externalStage
147155
async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
148156
let mut total_count: usize = 0;

driver/src/flight_sql.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,30 @@ impl IConnection for FlightSQLConnection {
6666
None
6767
}
6868

69+
fn set_warehouse(&self, _warehouse: &str) -> Result<()> {
70+
Err(Error::Protocol(
71+
"set_warehouse unavailable for FlightSQL".to_string(),
72+
))
73+
}
74+
75+
fn set_database(&self, _database: &str) -> Result<()> {
76+
Err(Error::Protocol(
77+
"set_database unavailable for FlightSQL".to_string(),
78+
))
79+
}
80+
81+
fn set_role(&self, _role: &str) -> Result<()> {
82+
Err(Error::Protocol(
83+
"set_role unavailable for FlightSQL".to_string(),
84+
))
85+
}
86+
87+
fn set_session(&self, _key: &str, _value: &str) -> Result<()> {
88+
Err(Error::Protocol(
89+
"set_session unavailable for FlightSQL".to_string(),
90+
))
91+
}
92+
6993
async fn exec(&self, sql: &str) -> Result<i64> {
7094
self.handshake().await?;
7195
let mut client = self.client.lock().await;

driver/src/rest_api.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,26 @@ impl IConnection for RestAPIConnection {
138138
self.client.last_query_id()
139139
}
140140

141+
fn set_warehouse(&self, warehouse: &str) -> Result<()> {
142+
self.client.set_warehouse(warehouse.to_string());
143+
Ok(())
144+
}
145+
146+
fn set_database(&self, database: &str) -> Result<()> {
147+
self.client.set_database(database.to_string());
148+
Ok(())
149+
}
150+
151+
fn set_role(&self, role: &str) -> Result<()> {
152+
self.client.set_role(role.to_string());
153+
Ok(())
154+
}
155+
156+
fn set_session(&self, key: &str, value: &str) -> Result<()> {
157+
self.client.set_session(key.to_string(), value.to_string());
158+
Ok(())
159+
}
160+
141161
async fn close(&self) -> Result<()> {
142162
self.client.close().await;
143163
Ok(())

ttc/src/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
1818
use tokio::net::{TcpListener, TcpStream};
1919

2020
use bytes::Buf;
21-
use clap::{command, Parser};
21+
use clap::Parser;
2222

2323
#[derive(Debug, Clone, Parser, PartialEq)]
2424
#[command(name = "ttc")]

0 commit comments

Comments
 (0)