Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,15 @@ impl APIClient {
let password = percent_decode_str(password).decode_utf8()?;
client.auth = Arc::new(BasicAuth::new(u.username(), password));
}
let database = match u.path().trim_start_matches('/') {
"" => None,
s => Some(s.to_string()),
};
let mut role = None;

let mut session_state = SessionState::default();

let database = u.path().trim_start_matches('/');
if !database.is_empty() {
session_state.set_database(database);
}

let mut scheme = "https";
let mut session_settings = BTreeMap::new();
for (k, v) in u.query_pairs() {
match k.as_ref() {
"wait_time_secs" => {
Expand Down Expand Up @@ -222,7 +224,7 @@ impl APIClient {
"warehouse" => {
client.warehouse = Mutex::new(Some(v.to_string()));
}
"role" => role = Some(v.to_string()),
"role" => session_state.set_role(v),
"sslmode" => match v.as_ref() {
"disable" => scheme = "http",
"require" | "enable" => scheme = "https",
Expand Down Expand Up @@ -273,7 +275,7 @@ impl APIClient {
}
}
_ => {
session_settings.insert(k.to_string(), v.to_string());
session_state.set(k, v);
}
}
}
Expand All @@ -286,14 +288,8 @@ impl APIClient {
},
};
client.scheme = scheme.to_string();

client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
client.session_state = Mutex::new(
SessionState::default()
.with_settings(Some(session_settings))
.with_role(role)
.with_database(database),
);
client.session_state = Mutex::new(session_state);

Ok(client)
}
Expand Down Expand Up @@ -372,6 +368,26 @@ impl APIClient {
guard.database.clone()
}

pub fn set_warehouse(&self, warehouse: impl Into<String>) {
let mut guard = self.warehouse.lock();
*guard = Some(warehouse.into());
}

pub fn set_database(&self, database: impl Into<String>) {
let mut guard = self.session_state.lock();
guard.set_database(database);
}

pub fn set_role(&self, role: impl Into<String>) {
let mut guard = self.session_state.lock();
guard.set_role(role);
}

pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
let mut guard = self.session_state.lock();
guard.set(key, value);
}

pub async fn current_role(&self) -> Option<String> {
let guard = self.session_state.lock();
guard.role.clone()
Expand Down
6 changes: 3 additions & 3 deletions core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ mod test {

#[test]
fn build_request() -> Result<()> {
let mut session = SessionState::default();
session.set_database("default");
let req = QueryRequest::new("select 1")
.with_session(Some(
SessionState::default().with_database(Some("default".to_string())),
))
.with_session(Some(session))
.with_pagination(Some(PaginationConfig {
wait_time_secs: Some(1),
max_rows_in_buffer: Some(1),
Expand Down
16 changes: 7 additions & 9 deletions core/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,16 @@ pub struct SessionState {
}

impl SessionState {
pub fn with_settings(mut self, settings: Option<BTreeMap<String, String>>) -> Self {
self.settings = settings;
self
pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
let settings = self.settings.get_or_insert_with(BTreeMap::new);
settings.insert(key.into(), value.into());
}

pub fn with_database(mut self, database: Option<String>) -> Self {
self.database = database;
self
pub fn set_database(&mut self, database: impl Into<String>) {
self.database = Some(database.into());
}

pub fn with_role(mut self, role: Option<String>) -> Self {
self.role = role;
self
pub fn set_role(&mut self, role: impl Into<String>) {
self.role = Some(role.into());
}
}
16 changes: 16 additions & 0 deletions driver/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,22 @@ impl Connection {
self.inner.stream_load(sql, data, method).await
}

pub fn set_warehouse(&self, warehouse: &str) -> Result<()> {
self.inner.set_warehouse(warehouse)
}

pub fn set_database(&self, database: &str) -> Result<()> {
self.inner.set_database(database)
}

pub fn set_role(&self, role: &str) -> Result<()> {
self.inner.set_role(role)
}

pub fn set_session(&self, key: &str, value: &str) -> Result<()> {
self.inner.set_session(key, value)
}

// PUT file://<path_to_file>/<filename> internalStage|externalStage
pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
self.inner.put_files(local_file, stage).await
Expand Down
8 changes: 8 additions & 0 deletions driver/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ pub trait IConnection: Send + Sync {
_method: LoadMethod,
) -> Result<ServerStats>;

fn set_warehouse(&self, warehouse: &str) -> Result<()>;

fn set_database(&self, database: &str) -> Result<()>;

fn set_role(&self, role: &str) -> Result<()>;

fn set_session(&self, key: &str, value: &str) -> Result<()>;

// PUT file://<path_to_file>/<filename> internalStage|externalStage
async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
let mut total_count: usize = 0;
Expand Down
24 changes: 24 additions & 0 deletions driver/src/flight_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,30 @@ impl IConnection for FlightSQLConnection {
None
}

fn set_warehouse(&self, _warehouse: &str) -> Result<()> {
Err(Error::Protocol(
"set_warehouse unavailable for FlightSQL".to_string(),
))
}

fn set_database(&self, _database: &str) -> Result<()> {
Err(Error::Protocol(
"set_database unavailable for FlightSQL".to_string(),
))
}

fn set_role(&self, _role: &str) -> Result<()> {
Err(Error::Protocol(
"set_role unavailable for FlightSQL".to_string(),
))
}

fn set_session(&self, _key: &str, _value: &str) -> Result<()> {
Err(Error::Protocol(
"set_session unavailable for FlightSQL".to_string(),
))
}

async fn exec(&self, sql: &str) -> Result<i64> {
self.handshake().await?;
let mut client = self.client.lock().await;
Expand Down
20 changes: 20 additions & 0 deletions driver/src/rest_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,26 @@ impl IConnection for RestAPIConnection {
self.client.last_query_id()
}

fn set_warehouse(&self, warehouse: &str) -> Result<()> {
self.client.set_warehouse(warehouse.to_string());
Ok(())
}

fn set_database(&self, database: &str) -> Result<()> {
self.client.set_database(database.to_string());
Ok(())
}

fn set_role(&self, role: &str) -> Result<()> {
self.client.set_role(role.to_string());
Ok(())
}

fn set_session(&self, key: &str, value: &str) -> Result<()> {
self.client.set_session(key.to_string(), value.to_string());
Ok(())
}

async fn close(&self) -> Result<()> {
self.client.close().await;
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion ttc/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};

use bytes::Buf;
use clap::{command, Parser};
use clap::Parser;

#[derive(Debug, Clone, Parser, PartialEq)]
#[command(name = "ttc")]
Expand Down
Loading