Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 171 additions & 13 deletions crates/catalog/rest/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::future::Future;
use std::str::FromStr;
use std::sync::Arc;
Expand All @@ -45,7 +46,7 @@ use crate::client::{
use crate::types::{
CatalogConfig, CommitTableRequest, CommitTableResponse, CreateTableRequest,
ListNamespaceResponse, ListTableResponse, LoadCredentialsResponse, LoadTableResponse,
NamespaceSerde, RegisterTableRequest, RenameTableRequest,
NamespaceSerde, RegisterTableRequest, RenameTableRequest, StorageCredential,
};

/// REST catalog URI
Expand All @@ -70,6 +71,8 @@ impl Default for RestCatalogBuilder {
props: HashMap::new(),
client: None,
authenticator: None,
storage_credentials_loader: None,
refresh_credentials: false,
})
}
}
Expand Down Expand Up @@ -112,6 +115,11 @@ impl CatalogBuilder for RestCatalogBuilder {
ErrorKind::DataInvalid,
"Catalog uri is required",
))
} else if self.0.refresh_credentials && self.0.storage_credentials_loader.is_none() {
Err(Error::new(
ErrorKind::DataInvalid,
"storage_credentials_loader is required when refresh_credentials is true",
))
} else {
Ok(RestCatalog::new(self.0))
}
Expand Down Expand Up @@ -147,6 +155,19 @@ impl RestCatalogBuilder {
}
}

/// Trait for custom storage credential loader.
///
/// Implement this trait to provide custom storage credential loading logic
/// instead of passing credentials directly or expecting them to be vended from the catalog.
#[async_trait::async_trait]
pub trait StorageCredentialsLoader: Send + Sync + Debug {
/// Load storage credentials using custom user-defined logic.
async fn load_credentials(
&self,
existing_credentials: Option<&StorageCredential>,
) -> Result<StorageCredential>;
}

/// Rest catalog configuration.
#[derive(Clone, Debug, TypedBuilder)]
pub(crate) struct RestCatalogConfig {
Expand All @@ -166,6 +187,12 @@ pub(crate) struct RestCatalogConfig {

#[builder(default)]
authenticator: Option<Arc<dyn CustomAuthenticator>>,

#[builder(default)]
storage_credentials_loader: Option<Arc<dyn StorageCredentialsLoader>>,

#[builder(default)]
refresh_credentials: bool,
}

impl RestCatalogConfig {
Expand Down Expand Up @@ -416,6 +443,7 @@ impl RestCatalog {
&self,
metadata_location: Option<&str>,
extra_config: Option<HashMap<String, String>>,
storage_credential: Option<StorageCredential>,
) -> Result<FileIO> {
let mut props = self.context().await?.config.props.clone();
if let Some(config) = extra_config {
Expand All @@ -431,10 +459,22 @@ impl RestCatalog {
};

let file_io = match metadata_location.or(warehouse_path) {
Some(url) => FileIO::from_path(url)?
.with_props(props)
.with_extensions(self.file_io_extensions.clone())
.build()?,
Some(url) => {
let mut file_io_builder = FileIO::from_path(url)?
.with_props(props)
.with_extensions(self.file_io_extensions.clone());

if self.user_config.refresh_credentials {
if let Some(cred) = storage_credential {
file_io_builder = file_io_builder.with_extension(cred);
}
if let Some(loader) = &self.user_config.storage_credentials_loader {
file_io_builder = file_io_builder.with_extension(loader.clone());
}
}

file_io_builder.build()?
}
None => {
return Err(Error::new(
ErrorKind::Unexpected,
Expand Down Expand Up @@ -510,14 +550,51 @@ impl RestCatalog {
// Per the OpenAPI spec: "Clients must first check whether the respective credentials
// exist in the storage-credentials field before checking the config for credentials."
// When vended-credentials header is set, credentials are returned in storage_credentials field.
if let Some(storage_credentials) = response.storage_credentials {
for cred in storage_credentials {
config.extend(cred.config);
let matched_credential = if let Some(storage_credentials) = response.storage_credentials {
// Find the credential with the longest prefix that matches the metadata_location
let mut best_match: Option<&StorageCredential> = None;
let mut longest_prefix_len = 0;

if let Some(ref metadata_location) = response.metadata_location {
for cred in &storage_credentials {
if metadata_location.starts_with(&cred.prefix)
&& cred.prefix.len() > longest_prefix_len
{
longest_prefix_len = cred.prefix.len();
best_match = Some(cred);
}
}
}
}

// Extend config with the best match
if let Some(cred) = best_match {
config.extend(cred.config.clone());
}

best_match.cloned()
} else {
None
};

// Finally, use custom storage credential loader if set, giving it a chance to override the previous configurations.
let final_credential = if let Some(storage_credentials_loader) =
&self.user_config.storage_credentials_loader
{
let credential = storage_credentials_loader
.load_credentials(matched_credential.as_ref())
.await?;
config.extend(credential.config.clone());
Some(credential)
} else {
matched_credential
};

let file_io = self
.load_file_io(response.metadata_location.as_deref(), Some(config))
.load_file_io(
response.metadata_location.as_deref(),
Some(config),
final_credential,
)
.await?;

let table_builder = Table::builder()
Expand Down Expand Up @@ -834,8 +911,9 @@ impl Catalog for RestCatalog {
.chain(self.user_config.props.clone())
.collect();

// TODO @vustef: Do we support vended credentials here?
let file_io = self
.load_file_io(Some(metadata_location), Some(config))
.load_file_io(Some(metadata_location), Some(config), None)
.await?;

let table_builder = Table::builder()
Expand Down Expand Up @@ -975,7 +1053,10 @@ impl Catalog for RestCatalog {
"Metadata location missing in `register_table` response!",
))?;

let file_io = self.load_file_io(Some(metadata_location), None).await?;
// TODO @vustef: Do we support vended credentials here?
let file_io = self
.load_file_io(Some(metadata_location), None, None)
.await?;

Table::builder()
.identifier(table_ident.clone())
Expand Down Expand Up @@ -1039,8 +1120,9 @@ impl Catalog for RestCatalog {
_ => return Err(deserialize_unexpected_catalog_error(http_response).await),
};

// TODO @vustef: Do we support vended credentials here?
let file_io = self
.load_file_io(Some(&response.metadata_location), None)
.load_file_io(Some(&response.metadata_location), None, None)
.await?;

Table::builder()
Expand Down Expand Up @@ -2962,4 +3044,80 @@ mod tests {
}
}
}

#[tokio::test]
async fn test_load_table_with_custom_credential_loader() {
use std::sync::atomic::{AtomicBool, Ordering};

// Dummy credential loader that just marks that it was called
#[derive(Debug)]
struct DummyCredentialLoader {
was_called: Arc<AtomicBool>,
}

#[async_trait::async_trait]
impl StorageCredentialsLoader for DummyCredentialLoader {
async fn load_credentials(
&self,
_existing_credentials: Option<&StorageCredential>,
) -> Result<StorageCredential> {
self.was_called.store(true, Ordering::SeqCst);
let mut config = HashMap::new();
config.insert("custom.key".to_string(), "custom.value".to_string());
Ok(StorageCredential {
prefix: "custom".to_string(),
config,
})
}
}

let mut server = Server::new_async().await;

let config_mock = create_config_mock(&mut server).await;

let load_table_mock = server
.mock("GET", "/v1/namespaces/ns1/tables/test1")
.with_status(200)
.with_body_from_file(format!(
"{}/testdata/{}",
env!("CARGO_MANIFEST_DIR"),
"load_table_response.json"
))
.create_async()
.await;

let was_called = Arc::new(AtomicBool::new(false));
let loader = Arc::new(DummyCredentialLoader {
was_called: was_called.clone(),
});

let catalog = RestCatalog::new(
RestCatalogConfig::builder()
.uri(server.url())
.storage_credentials_loader(Some(loader))
.build(),
);

let table = catalog
.load_table(&TableIdent::new(
NamespaceIdent::new("ns1".to_string()),
"test1".to_string(),
))
.await
.unwrap();

assert_eq!(
&TableIdent::from_strs(vec!["ns1", "test1"]).unwrap(),
table.identifier()
);

// Verify that the custom credential loader was called
assert!(
was_called.load(Ordering::SeqCst),
"Custom credential loader should have been called"
);

config_mock.assert_async().await;
load_table_mock.assert_async().await;
}
}
Loading