Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions src/snowflake/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,42 @@ pub(crate) struct SnowflakeQueryData {

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub(crate) struct SnowflakeStageCreds {
pub(crate) struct SnowflakeStageAwsCreds {
pub aws_key_id: String,
pub aws_secret_key: String,
pub aws_token: String,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub(crate) struct SnowflakeStageAzureCreds {
pub azure_sas_token: String,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub(crate) enum SnowflakeStageCreds {
Aws(SnowflakeStageAwsCreds),
Azure(SnowflakeStageAzureCreds),
}

impl SnowflakeStageCreds {
pub(crate) fn as_aws(&self) -> crate::Result<&SnowflakeStageAwsCreds> {
match self {
SnowflakeStageCreds::Aws(creds) => Ok(creds),
SnowflakeStageCreds::Azure(_) => Err(Error::invalid_response("Expected AWS credentials, but got Azure ones")),
}
}

pub(crate) fn as_azure(&self) -> crate::Result<&SnowflakeStageAzureCreds> {
match self {
SnowflakeStageCreds::Azure(creds) => Ok(creds),
SnowflakeStageCreds::Aws(_) => Err(Error::invalid_response("Expected Azure credentials, but got AWS ones")),
}
}
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct SnowflakeStageInfo {
Expand Down Expand Up @@ -118,6 +148,7 @@ pub(crate) enum NormalizedStageInfo {
storage_account: String,
container: String,
prefix: String,
azure_sas_token: String,
#[serde(skip_serializing_if = "Option::is_none")]
end_point: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -132,18 +163,34 @@ impl TryFrom<&SnowflakeStageInfo> for NormalizedStageInfo {
if value.location_type == "S3" {
let (bucket, prefix) = value.location.split_once('/')
.ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the bucket name"))?;
let creds = value.creds.as_aws()?;
return Ok(NormalizedStageInfo::S3 {
bucket: bucket.to_string(),
prefix: prefix.to_string(),
region: value.region.clone(),
aws_key_id: value.creds.aws_key_id.clone(),
aws_secret_key: value.creds.aws_secret_key.clone(),
aws_token: value.creds.aws_token.clone(),
aws_key_id: creds.aws_key_id.clone(),
aws_secret_key: creds.aws_secret_key.clone(),
aws_token: creds.aws_token.clone(),
end_point: value.end_point.clone(),
test_endpoint: value.test_endpoint.clone()
})
} else if value.location_type == "AZURE" {
let (container, prefix) = value.location.split_once('/')
.ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the container name"))?;
let creds = value.creds.as_azure()?;
let storage_account = value.storage_account
.clone()
.ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the storage account name"))?;
return Ok(NormalizedStageInfo::BlobStorage {
storage_account: storage_account,
container: container.to_string(),
prefix: prefix.to_string(),
azure_sas_token: creds.azure_sas_token.clone(),
end_point: value.end_point.clone(),
test_endpoint: value.test_endpoint.clone()
})
} else {
return Err(Error::not_implemented("Azure BlobStorage is not implemented"));
return Err(Error::not_implemented(format!("Location type {} is not implemented", value.location_type)));
}
}
}
Expand Down
210 changes: 203 additions & 7 deletions src/snowflake/kms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ impl Default for SnowflakeStageKmsConfig {
}

#[derive(Clone)]
pub(crate) struct SnowflakeStageKms {
pub(crate) struct SnowflakeStageS3Kms {
client: Arc<SnowflakeClient>,
stage: String,
prefix: String,
config: SnowflakeStageKmsConfig,
keyring: Cache<String, Key>
}

impl std::fmt::Debug for SnowflakeStageKms {
impl std::fmt::Debug for SnowflakeStageS3Kms {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SnowflakeStageKms")
f.debug_struct("SnowflakeStageS3Kms")
.field("client", &self.client)
.field("stage", &self.stage)
.field("config", &self.config)
Expand All @@ -56,14 +56,14 @@ impl std::fmt::Debug for SnowflakeStageKms {
}
}

impl SnowflakeStageKms {
impl SnowflakeStageS3Kms {
pub(crate) fn new(
client: Arc<SnowflakeClient>,
stage: impl Into<String>,
prefix: impl Into<String>,
config: SnowflakeStageKmsConfig
) -> SnowflakeStageKms {
SnowflakeStageKms {
) -> SnowflakeStageS3Kms {
SnowflakeStageS3Kms {
client,
stage: stage.into(),
prefix: prefix.into(),
Expand All @@ -77,7 +77,7 @@ impl SnowflakeStageKms {
}

#[async_trait::async_trait]
impl CryptoMaterialProvider for SnowflakeStageKms {
impl CryptoMaterialProvider for SnowflakeStageS3Kms {
async fn material_for_write(&self, _path: &str, data_len: Option<usize>) -> crate::Result<(ContentCryptoMaterial, Attributes)> {
let _guard = duration_on_drop!(metrics::material_for_write_duration);
let info = self.client.current_upload_info(&self.stage).await?;
Expand Down Expand Up @@ -193,3 +193,199 @@ impl CryptoMaterialProvider for SnowflakeStageKms {
Ok(content_material)
}
}

#[derive(Clone)]
pub(crate) struct SnowflakeStageAzureKms {
client: Arc<SnowflakeClient>,
stage: String,
prefix: String,
config: SnowflakeStageKmsConfig,
keyring: Cache<String, Key>,
}

impl std::fmt::Debug for SnowflakeStageAzureKms {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SnowflakeStageAzureKms")
.field("client", &self.client)
.field("stage", &self.stage)
.field("config", &self.config)
.field("keyring", &"redacted")
.finish()
}
}

impl SnowflakeStageAzureKms {
pub(crate) fn new(
client: Arc<SnowflakeClient>,
stage: impl Into<String>,
prefix: impl Into<String>,
config: SnowflakeStageKmsConfig
) -> SnowflakeStageAzureKms {
SnowflakeStageAzureKms {
client,
stage: stage.into(),
prefix: prefix.into(),
keyring: Cache::builder()
.max_capacity(config.keyring_capacity as u64)
.time_to_live(config.keyring_ttl)
.build(),
config
}
}
}

const AZURE_MATDESC_KEY: &str = "matdesc";
const AZURE_ENCDATA_KEY: &str = "encryptiondata";

#[async_trait::async_trait]
impl CryptoMaterialProvider for SnowflakeStageAzureKms {
async fn material_for_write(&self, _path: &str, data_len: Option<usize>) -> crate::Result<(ContentCryptoMaterial, Attributes)> {
let _guard = duration_on_drop!(metrics::material_for_write_duration);
let info = self.client.current_upload_info(&self.stage).await?;

let encryption_material = info.encryption_material.as_ref()
.ok_or_else(|| ErrorKind::StorageNotEncrypted(self.stage.clone()))?;

let description = MaterialDescription {
smk_id: encryption_material.smk_id.to_string(),
query_id: encryption_material.query_id.clone(),
key_size: "128".to_string()
};
let master_key = Key::from_base64(&encryption_material.query_stage_master_key)
.map_err(ErrorKind::MaterialDecode)?;

let scheme = self.config.crypto_scheme;
let material = ContentCryptoMaterial::generate(scheme);
let encrypted_cek = material.cek.clone().encrypt_aes_128_ecb(&master_key)
.map_err(ErrorKind::MaterialCrypt)?;
// TODO: should this be AES_256 or 128 for Azure? I am confused because the metadata
// says 256, but the master key that I am seeing has 128.

let mut attributes = Attributes::new();

// TODO: do we need to add aad?

// We hardcode most of these values as the Go Snowflake client does (see
// https://github.com/snowflakedb/gosnowflake/blob/099708d318689634a558f705ccc19b3b7b278972/azure_storage_client.go#L152)
let encryption_data = EncryptionData {
encryption_mode: "FullBlob".to_string(),
wrapped_content_key: WrappedContentKey {
key_id: "symmKey1".to_string(),
encrypted_key: encrypted_cek.as_base64(),
algorithm: "AES_CBC_256".to_string(),
},
encryption_agent: EncryptionAgent {
protocol: "1.0".to_string(),
encryption_algorithm: "AES_CBC_128".to_string(),
},
content_encryption_i_v: material.iv.as_base64(),
key_wrapping_metadata: KeyWrappingMetadata {
encryption_library: "Java 5.3.0".to_string(),
},
};

attributes.insert(
Attribute::Metadata(AZURE_ENCDATA_KEY.into()),
AttributeValue::from(serde_json::to_string(&encryption_data).context("failed to encode encryption data").to_err()?)
);

attributes.insert(
Attribute::Metadata(AZURE_MATDESC_KEY.into()),
AttributeValue::from(serde_json::to_string(&description).context("failed to encode matdesc").to_err()?)
);

// TODO: try to attach the (ununcrypted) content length to the file somehow
// TODO: try to attach a hash of the file

Ok((material, attributes))
}

async fn material_from_metadata(&self, path: &str, attr: &Attributes) -> crate::Result<ContentCryptoMaterial> {
// TODO: factor out code that is shared with S3 variant?

let _guard = duration_on_drop!(metrics::material_from_metadata_duration);
let path = path.strip_prefix(&self.prefix).unwrap_or(path);
let required_attribute = |key: &'static str| {
let v: &str = attr.get(&Attribute::Metadata(key.into()))
.ok_or_else(|| Error::required_config(format!("missing required attribute `{}`", key)))?
.as_ref();
Ok::<_, Error>(v)
};

let material_description: MaterialDescription = deserialize_str(required_attribute(AZURE_MATDESC_KEY)?)
.map_err(Error::deserialize_response_err("failed to deserialize matdesc"))?;

let master_key = self.keyring.try_get_with(material_description.query_id, async {
let info = self.client.fetch_path_info(&self.stage, path).await?;
let position = info.src_locations.iter().position(|l| l == path)
.ok_or_else(|| Error::invalid_response("path not found"))?;
let encryption_material = info.encryption_material.get(position)
.cloned()
.ok_or_else(|| Error::invalid_response("src locations and encryption material length mismatch"))?
.ok_or_else(|| Error::invalid_response("path not encrypted"))?;

let master_key = Key::from_base64(&encryption_material.query_stage_master_key)
.map_err(ErrorKind::MaterialDecode)?;
counter!(metrics::total_keyring_miss).increment(1);
Ok::<_, Error>(master_key)
}).await?;
counter!(metrics::total_keyring_get).increment(1);

let encryption_data: EncryptionData = deserialize_str(required_attribute(AZURE_ENCDATA_KEY)?)
.map_err(Error::deserialize_response_err("failed to deserialize encryption data"))?;

let cek = EncryptedKey::from_base64(&encryption_data.wrapped_content_key.encrypted_key)
.map_err(ErrorKind::MaterialDecode)?;
let cek = cek.decrypt_aes_128_ecb(&master_key)
.map_err(ErrorKind::MaterialCrypt)?;
let iv = Iv::from_base64(&encryption_data.content_encryption_i_v)
.map_err(ErrorKind::MaterialDecode)?;

let scheme = match encryption_data.encryption_agent.encryption_algorithm.as_str() {
"AES_GCM_256" => CryptoScheme::Aes256Gcm,
"AES_CBC_128" => CryptoScheme::Aes128Cbc,
v => unimplemented!("encryption algorithm `{}` not implemented", v)
};

let content_material = ContentCryptoMaterial {
scheme,
cek,
iv,
aad: None
};

Ok(content_material)
}
}


#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
struct EncryptionData {
encryption_mode: String,
wrapped_content_key: WrappedContentKey,
content_encryption_i_v: String,
encryption_agent: EncryptionAgent,
key_wrapping_metadata: KeyWrappingMetadata,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
struct WrappedContentKey {
key_id: String,
encrypted_key: String,
algorithm: String, // alg for encrypting the key
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
struct EncryptionAgent {
protocol: String,
encryption_algorithm: String, // alg for encryption the content
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
struct KeyWrappingMetadata {
encryption_library: String,
}
Loading