Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,9 @@ impl Client {
let (store, crypto_material_provider, stage_prefix, extension) = build_store_for_snowflake_stage(map, config.retry_config.clone()).await?;

let prefix = match (stage_prefix, config.prefix) {
(s, Some(u)) if s.ends_with("/") => Some(format!("{s}{u}")),
(s, Some(u)) => Some(format!("{s}/{u}")),
(s, None) => Some(s)
(Some(s), Some(u)) if s.ends_with("/") => Some(format!("{s}{u}")),
(Some(s), Some(u)) => Some(format!("{s}/{u}")),
(s, u) => s.or(u)
};

config.prefix = prefix;
Expand Down
85 changes: 78 additions & 7 deletions src/snowflake/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use std::{collections::HashMap, sync::Arc, time::{Duration, Instant, SystemTime,
use tokio::sync::Mutex;
use zeroize::Zeroize;
use moka::future::Cache;
use crate::{duration_on_drop, error::{Error, RetryState}, metrics};
use crate::{duration_on_drop, error::{Error, RetryState, Kind as ErrorKind}, metrics};
use crate::util::{deserialize_str, deserialize_slice};
// use anyhow::anyhow;
use crate::encryption::Key;


#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -74,12 +74,42 @@ pub(crate) struct SnowflakeQueryData {

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub(crate) struct SnowflakeStageCreds {
pub(crate) struct SnowflakeStageAwsCreds {
pub aws_key_id: String,
pub aws_secret_key: String,
pub aws_token: String,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub(crate) struct SnowflakeStageAzureCreds {
pub azure_sas_token: String,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub(crate) enum SnowflakeStageCreds {
Aws(SnowflakeStageAwsCreds),
Azure(SnowflakeStageAzureCreds),
}

impl SnowflakeStageCreds {
pub(crate) fn as_aws(&self) -> crate::Result<&SnowflakeStageAwsCreds> {
match self {
SnowflakeStageCreds::Aws(creds) => Ok(creds),
SnowflakeStageCreds::Azure(_) => Err(Error::invalid_response("Expected AWS credentials, but got Azure ones")),
}
}

pub(crate) fn as_azure(&self) -> crate::Result<&SnowflakeStageAzureCreds> {
match self {
SnowflakeStageCreds::Azure(creds) => Ok(creds),
SnowflakeStageCreds::Aws(_) => Err(Error::invalid_response("Expected Azure credentials, but got AWS ones")),
}
}
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct SnowflakeStageInfo {
Expand Down Expand Up @@ -118,6 +148,7 @@ pub(crate) enum NormalizedStageInfo {
storage_account: String,
container: String,
prefix: String,
azure_sas_token: String,
#[serde(skip_serializing_if = "Option::is_none")]
end_point: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -132,18 +163,34 @@ impl TryFrom<&SnowflakeStageInfo> for NormalizedStageInfo {
if value.location_type == "S3" {
let (bucket, prefix) = value.location.split_once('/')
.ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the bucket name"))?;
let creds = value.creds.as_aws()?;
return Ok(NormalizedStageInfo::S3 {
bucket: bucket.to_string(),
prefix: prefix.to_string(),
region: value.region.clone(),
aws_key_id: value.creds.aws_key_id.clone(),
aws_secret_key: value.creds.aws_secret_key.clone(),
aws_token: value.creds.aws_token.clone(),
aws_key_id: creds.aws_key_id.clone(),
aws_secret_key: creds.aws_secret_key.clone(),
aws_token: creds.aws_token.clone(),
end_point: value.end_point.clone(),
test_endpoint: value.test_endpoint.clone()
})
} else if value.location_type == "AZURE" {
let (container, prefix) = value.location.split_once('/')
.ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the container name"))?;
let creds = value.creds.as_azure()?;
let storage_account = value.storage_account
.clone()
.ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the storage account name"))?;
return Ok(NormalizedStageInfo::BlobStorage {
storage_account: storage_account,
container: container.to_string(),
prefix: prefix.to_string(),
azure_sas_token: creds.azure_sas_token.clone(),
end_point: value.end_point.clone(),
test_endpoint: value.test_endpoint.clone()
})
} else {
return Err(Error::not_implemented("Azure BlobStorage is not implemented"));
return Err(Error::not_implemented(format!("Location type {} is not implemented", value.location_type)));
}
}
}
Expand Down Expand Up @@ -614,6 +661,30 @@ impl SnowflakeClient {
}).await?;
Ok(stage_info)
}
pub(crate) async fn get_master_key(
&self,
query_id: String,
path: &str,
stage: &str,
keyring: &Cache<String, Key>,
) -> crate::Result<Key> {
let master_key = keyring.try_get_with(query_id, async {
let info = self.fetch_path_info(stage, path).await?;
let position = info.src_locations.iter().position(|l| l == path)
.ok_or_else(|| Error::invalid_response("path not found"))?;
let encryption_material = info.encryption_material.get(position)
.cloned()
.ok_or_else(|| Error::invalid_response("src locations and encryption material length mismatch"))?
.ok_or_else(|| Error::invalid_response("path not encrypted"))?;

let master_key = Key::from_base64(&encryption_material.query_stage_master_key)
.map_err(ErrorKind::MaterialDecode)?;
counter!(metrics::total_keyring_miss).increment(1);
Ok::<_, Error>(master_key)
}).await?;
counter!(metrics::total_keyring_get).increment(1);
Ok(master_key)
}
}

#[cfg(test)]
Expand Down
Loading