diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index c784f32039..cf10a6eca3 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::collections::HashMap; +use std::fmt::Debug; use std::future::Future; use std::str::FromStr; use std::sync::Arc; @@ -45,7 +46,7 @@ use crate::client::{ use crate::types::{ CatalogConfig, CommitTableRequest, CommitTableResponse, CreateTableRequest, ListNamespaceResponse, ListTableResponse, LoadCredentialsResponse, LoadTableResponse, - NamespaceSerde, RegisterTableRequest, RenameTableRequest, + NamespaceSerde, RegisterTableRequest, RenameTableRequest, StorageCredential, }; /// REST catalog URI @@ -70,6 +71,8 @@ impl Default for RestCatalogBuilder { props: HashMap::new(), client: None, authenticator: None, + storage_credentials_loader: None, + refresh_credentials: false, }) } } @@ -112,6 +115,11 @@ impl CatalogBuilder for RestCatalogBuilder { ErrorKind::DataInvalid, "Catalog uri is required", )) + } else if self.0.refresh_credentials && self.0.storage_credentials_loader.is_none() { + Err(Error::new( + ErrorKind::DataInvalid, + "storage_credentials_loader is required when refresh_credentials is true", + )) } else { Ok(RestCatalog::new(self.0)) } @@ -147,6 +155,19 @@ impl RestCatalogBuilder { } } +/// Trait for custom storage credential loader. +/// +/// Implement this trait to provide custom storage credential loading logic +/// instead of passing credentials directly or expecting them to be vended from the catalog. +#[async_trait::async_trait] +pub trait StorageCredentialsLoader: Send + Sync + Debug { + /// Load storage credentials using custom user-defined logic. + async fn load_credentials( + &self, + existing_credentials: Option<&StorageCredential>, + ) -> Result; +} + /// Rest catalog configuration. #[derive(Clone, Debug, TypedBuilder)] pub(crate) struct RestCatalogConfig { @@ -166,6 +187,12 @@ pub(crate) struct RestCatalogConfig { #[builder(default)] authenticator: Option>, + + #[builder(default)] + storage_credentials_loader: Option>, + + #[builder(default)] + refresh_credentials: bool, } impl RestCatalogConfig { @@ -416,6 +443,7 @@ impl RestCatalog { &self, metadata_location: Option<&str>, extra_config: Option>, + storage_credential: Option, ) -> Result { let mut props = self.context().await?.config.props.clone(); if let Some(config) = extra_config { @@ -431,10 +459,22 @@ impl RestCatalog { }; let file_io = match metadata_location.or(warehouse_path) { - Some(url) => FileIO::from_path(url)? - .with_props(props) - .with_extensions(self.file_io_extensions.clone()) - .build()?, + Some(url) => { + let mut file_io_builder = FileIO::from_path(url)? + .with_props(props) + .with_extensions(self.file_io_extensions.clone()); + + if self.user_config.refresh_credentials { + if let Some(cred) = storage_credential { + file_io_builder = file_io_builder.with_extension(cred); + } + if let Some(loader) = &self.user_config.storage_credentials_loader { + file_io_builder = file_io_builder.with_extension(loader.clone()); + } + } + + file_io_builder.build()? + } None => { return Err(Error::new( ErrorKind::Unexpected, @@ -510,14 +550,51 @@ impl RestCatalog { // Per the OpenAPI spec: "Clients must first check whether the respective credentials // exist in the storage-credentials field before checking the config for credentials." // When vended-credentials header is set, credentials are returned in storage_credentials field. - if let Some(storage_credentials) = response.storage_credentials { - for cred in storage_credentials { - config.extend(cred.config); + let matched_credential = if let Some(storage_credentials) = response.storage_credentials { + // Find the credential with the longest prefix that matches the metadata_location + let mut best_match: Option<&StorageCredential> = None; + let mut longest_prefix_len = 0; + + if let Some(ref metadata_location) = response.metadata_location { + for cred in &storage_credentials { + if metadata_location.starts_with(&cred.prefix) + && cred.prefix.len() > longest_prefix_len + { + longest_prefix_len = cred.prefix.len(); + best_match = Some(cred); + } + } } - } + + // Extend config with the best match + if let Some(cred) = best_match { + config.extend(cred.config.clone()); + } + + best_match.cloned() + } else { + None + }; + + // Finally, use custom storage credential loader if set, giving it a chance to override the previous configurations. + let final_credential = if let Some(storage_credentials_loader) = + &self.user_config.storage_credentials_loader + { + let credential = storage_credentials_loader + .load_credentials(matched_credential.as_ref()) + .await?; + config.extend(credential.config.clone()); + Some(credential) + } else { + matched_credential + }; let file_io = self - .load_file_io(response.metadata_location.as_deref(), Some(config)) + .load_file_io( + response.metadata_location.as_deref(), + Some(config), + final_credential, + ) .await?; let table_builder = Table::builder() @@ -834,8 +911,9 @@ impl Catalog for RestCatalog { .chain(self.user_config.props.clone()) .collect(); + // TODO @vustef: Do we support vended credentials here? let file_io = self - .load_file_io(Some(metadata_location), Some(config)) + .load_file_io(Some(metadata_location), Some(config), None) .await?; let table_builder = Table::builder() @@ -975,7 +1053,10 @@ impl Catalog for RestCatalog { "Metadata location missing in `register_table` response!", ))?; - let file_io = self.load_file_io(Some(metadata_location), None).await?; + // TODO @vustef: Do we support vended credentials here? + let file_io = self + .load_file_io(Some(metadata_location), None, None) + .await?; Table::builder() .identifier(table_ident.clone()) @@ -1039,8 +1120,9 @@ impl Catalog for RestCatalog { _ => return Err(deserialize_unexpected_catalog_error(http_response).await), }; + // TODO @vustef: Do we support vended credentials here? let file_io = self - .load_file_io(Some(&response.metadata_location), None) + .load_file_io(Some(&response.metadata_location), None, None) .await?; Table::builder() @@ -2962,4 +3044,80 @@ mod tests { } } } + + #[tokio::test] + async fn test_load_table_with_custom_credential_loader() { + use std::sync::atomic::{AtomicBool, Ordering}; + + // Dummy credential loader that just marks that it was called + #[derive(Debug)] + struct DummyCredentialLoader { + was_called: Arc, + } + + #[async_trait::async_trait] + impl StorageCredentialsLoader for DummyCredentialLoader { + async fn load_credentials( + &self, + _existing_credentials: Option<&StorageCredential>, + ) -> Result { + self.was_called.store(true, Ordering::SeqCst); + let mut config = HashMap::new(); + config.insert("custom.key".to_string(), "custom.value".to_string()); + Ok(StorageCredential { + prefix: "custom".to_string(), + config, + }) + } + } + + let mut server = Server::new_async().await; + + let config_mock = create_config_mock(&mut server).await; + + let load_table_mock = server + .mock("GET", "/v1/namespaces/ns1/tables/test1") + .with_status(200) + .with_body_from_file(format!( + "{}/testdata/{}", + env!("CARGO_MANIFEST_DIR"), + "load_table_response.json" + )) + .create_async() + .await; + + let was_called = Arc::new(AtomicBool::new(false)); + let loader = Arc::new(DummyCredentialLoader { + was_called: was_called.clone(), + }); + + let catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .storage_credentials_loader(Some(loader)) + .build(), + ); + + let table = catalog + .load_table(&TableIdent::new( + NamespaceIdent::new("ns1".to_string()), + "test1".to_string(), + )) + .await + .unwrap(); + + assert_eq!( + &TableIdent::from_strs(vec!["ns1", "test1"]).unwrap(), + table.identifier() + ); + + // Verify that the custom credential loader was called + assert!( + was_called.load(Ordering::SeqCst), + "Custom credential loader should have been called" + ); + + config_mock.assert_async().await; + load_table_mock.assert_async().await; + } }