Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ All files must have Apache 2.0 license headers:
- Unit tests go in `#[cfg(test)] mod tests { }` at the bottom of each file
- Integration tests go in `tests/` directory
- Test names: `test_<function>_<scenario>`
- E2E tests that require real Databricks connection should be marked with `#[ignore]`

### Error Handling

Expand Down
23 changes: 21 additions & 2 deletions rust/docs/designs/oauth-u2m-m2m-design.md
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,27 @@ Tests in `client/http.rs` `#[cfg(test)]` verifying the `OnceLock`-based auth pro

### End-to-End Tests

- `test_m2m_end_to_end` -- real Databricks workspace with service principal credentials (requires env vars, `#[ignore]` by default)
- `test_u2m_end_to_end` -- manual test only (`#[ignore]`), requires interactive browser
End-to-end tests in `tests/oauth_e2e.rs` verify the complete OAuth implementation against real Databricks workspaces:

**M2M (Service Principal) Test:**
- `test_m2m_end_to_end` -- Connects to real Databricks workspace using service principal credentials
- Requires environment variables: `DATABRICKS_HOST`, `DATABRICKS_CLIENT_ID`, `DATABRICKS_CLIENT_SECRET`, `DATABRICKS_WAREHOUSE_ID`
- Marked with `#[ignore]` to prevent running in CI by default
- Verifies: (1) Database config with mechanism=11, flow=1, (2) OIDC discovery, (3) Client credentials token exchange, (4) Connection creation, (5) Query execution (SELECT 1), (6) Result validation
- Run with: `cargo test --test oauth_e2e test_m2m_end_to_end -- --ignored --nocapture`

**U2M (Browser-Based) Test:**
- `test_u2m_end_to_end` -- Manual test requiring interactive browser authentication
- Requires environment variables: `DATABRICKS_HOST`, `DATABRICKS_WAREHOUSE_ID`, `DATABRICKS_CLIENT_ID` (optional, defaults to "databricks-cli")
- Always marked with `#[ignore]` (manual test only)
- Verifies: (1) Database config with mechanism=11, flow=2, (2) OIDC discovery, (3) Token cache check, (4) Browser launch (if needed), (5) Authorization code exchange, (6) Connection creation, (7) Query execution (SELECT 1), (8) Result validation, (9) Token caching to disk
- Run with: `cargo test --test oauth_e2e test_u2m_end_to_end -- --ignored --nocapture`
- Note: Will launch default web browser for user to complete authentication

**Configuration Validation Test:**
- `test_oauth_config_validation` -- Verifies proper error messages for missing/invalid OAuth configuration
- Runs in standard test suite (not ignored)
- Tests: missing auth.flow when mechanism=OAuth, missing client_secret for M2M, invalid mechanism/flow values

---

Expand Down
230 changes: 230 additions & 0 deletions rust/src/auth/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
// Copyright (c) 2025 ADBC Drivers Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Authentication configuration types for the Databricks ADBC driver.
//!
//! This module defines the enums and configuration struct used to configure
//! authentication when creating a new database connection.

use crate::error::DatabricksErrorHelper;
use driverbase::error::ErrorHelper;

/// Authentication type -- single selector for the authentication method.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthType {
/// Personal access token.
AccessToken,
/// M2M: client credentials grant for service principals.
OAuthM2m,
/// U2M: browser-based authorization code + PKCE.
OAuthU2m,
}

impl TryFrom<&str> for AuthType {
type Error = crate::error::Error;

fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
match value {
"access_token" => Ok(AuthType::AccessToken),
"oauth_m2m" => Ok(AuthType::OAuthM2m),
"oauth_u2m" => Ok(AuthType::OAuthU2m),
_ => Err(DatabricksErrorHelper::invalid_argument().message(format!(
"Invalid auth type: '{}'. Valid values: 'access_token', 'oauth_m2m', 'oauth_u2m'",
value
))),
}
}
}

impl std::fmt::Display for AuthType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthType::AccessToken => write!(f, "access_token"),
AuthType::OAuthM2m => write!(f, "oauth_m2m"),
AuthType::OAuthU2m => write!(f, "oauth_u2m"),
}
}
}

/// Authentication configuration parsed from Database options.
///
/// This struct collects all auth-related options set via `Database::set_option()`.
/// It is used by `Database::new_connection()` to validate the configuration and
/// create the appropriate `AuthProvider`.
#[derive(Debug, Default, Clone)]
pub struct AuthConfig {
pub auth_type: Option<AuthType>,
pub client_id: Option<String>,
pub client_secret: Option<String>,
pub scopes: Option<String>,
pub token_endpoint: Option<String>,
pub redirect_port: Option<u16>,
}

impl AuthConfig {
/// Validates the auth configuration and returns the auth type.
///
/// This checks that:
/// - An auth type is specified
/// - Required fields are present for the chosen auth type
pub fn validate(&self, access_token: &Option<String>) -> crate::error::Result<AuthType> {
let auth_type = self.auth_type.ok_or_else(|| {
DatabricksErrorHelper::invalid_argument().message(
"databricks.auth.type is required. Valid values: 'access_token', 'oauth_m2m', 'oauth_u2m'",
)
})?;

match auth_type {
AuthType::AccessToken => {
if access_token.is_none() {
return Err(DatabricksErrorHelper::invalid_argument().message(
"databricks.access_token is required when auth type is 'access_token'",
));
}
}
AuthType::OAuthM2m => {
if self.client_id.is_none() {
return Err(DatabricksErrorHelper::invalid_argument().message(
"databricks.auth.client_id is required when auth type is 'oauth_m2m'",
));
}
if self.client_secret.is_none() {
return Err(DatabricksErrorHelper::invalid_argument().message(
"databricks.auth.client_secret is required when auth type is 'oauth_m2m'",
));
}
}
AuthType::OAuthU2m => {
// U2M flow has no required fields - all parameters have defaults:
// - client_id defaults to "databricks-cli"
// - scopes defaults to "all-apis offline_access"
// - redirect_port defaults to 8020
}
}

Ok(auth_type)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_auth_type_valid() {
assert_eq!(
AuthType::try_from("access_token").unwrap(),
AuthType::AccessToken
);
assert_eq!(AuthType::try_from("oauth_m2m").unwrap(), AuthType::OAuthM2m);
assert_eq!(AuthType::try_from("oauth_u2m").unwrap(), AuthType::OAuthU2m);
}

#[test]
fn test_auth_type_invalid() {
assert!(AuthType::try_from("pat").is_err());
assert!(AuthType::try_from("oauth").is_err());
assert!(AuthType::try_from("0").is_err());
assert!(AuthType::try_from("11").is_err());
assert!(AuthType::try_from("").is_err());
}

#[test]
fn test_auth_type_display() {
assert_eq!(AuthType::AccessToken.to_string(), "access_token");
assert_eq!(AuthType::OAuthM2m.to_string(), "oauth_m2m");
assert_eq!(AuthType::OAuthU2m.to_string(), "oauth_u2m");
}

#[test]
fn test_validate_missing_auth_type() {
let config = AuthConfig::default();
let result = config.validate(&None);
assert!(result.is_err());
let err_msg = format!("{:?}", result.unwrap_err());
assert!(err_msg.contains("databricks.auth.type is required"));
}

#[test]
fn test_validate_access_token_missing_token() {
let config = AuthConfig {
auth_type: Some(AuthType::AccessToken),
..Default::default()
};
let result = config.validate(&None);
assert!(result.is_err());
let err_msg = format!("{:?}", result.unwrap_err());
assert!(err_msg.contains("databricks.access_token is required"));
}

#[test]
fn test_validate_access_token_with_token() {
let config = AuthConfig {
auth_type: Some(AuthType::AccessToken),
..Default::default()
};
let result = config.validate(&Some("token".to_string()));
assert!(result.is_ok());
assert_eq!(result.unwrap(), AuthType::AccessToken);
}

#[test]
fn test_validate_oauth_m2m_missing_client_id() {
let config = AuthConfig {
auth_type: Some(AuthType::OAuthM2m),
..Default::default()
};
let result = config.validate(&None);
assert!(result.is_err());
let err_msg = format!("{:?}", result.unwrap_err());
assert!(err_msg.contains("databricks.auth.client_id is required"));
}

#[test]
fn test_validate_oauth_m2m_missing_secret() {
let config = AuthConfig {
auth_type: Some(AuthType::OAuthM2m),
client_id: Some("id".to_string()),
..Default::default()
};
let result = config.validate(&None);
assert!(result.is_err());
let err_msg = format!("{:?}", result.unwrap_err());
assert!(err_msg.contains("databricks.auth.client_secret is required"));
}

#[test]
fn test_validate_oauth_m2m_valid() {
let config = AuthConfig {
auth_type: Some(AuthType::OAuthM2m),
client_id: Some("id".to_string()),
client_secret: Some("secret".to_string()),
..Default::default()
};
let result = config.validate(&None);
assert!(result.is_ok());
assert_eq!(result.unwrap(), AuthType::OAuthM2m);
}

#[test]
fn test_validate_oauth_u2m_no_required_fields() {
let config = AuthConfig {
auth_type: Some(AuthType::OAuthU2m),
..Default::default()
};
let result = config.validate(&None);
assert!(result.is_ok());
assert_eq!(result.unwrap(), AuthType::OAuthU2m);
}
}
3 changes: 3 additions & 0 deletions rust/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

//! Authentication mechanisms for the Databricks ADBC driver.

pub mod config;
pub mod oauth;
pub mod pat;

pub use config::{AuthConfig, AuthType};
pub use oauth::{AuthorizationCodeProvider, ClientCredentialsProvider};
pub use pat::PersonalAccessToken;

use crate::error::Result;
Expand Down
7 changes: 2 additions & 5 deletions rust/src/auth/oauth/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ use std::path::PathBuf;

/// Cache key used for generating the cache filename.
/// This struct is serialized to JSON and then hashed with SHA-256 to produce a unique filename.
#[allow(dead_code)] // Used in Phase 3 (U2M)

#[derive(Debug, Serialize, Deserialize)]
struct CacheKey {
host: String,
client_id: String,
scopes: Vec<String>,
}

#[allow(dead_code)] // Used in Phase 3 (U2M)
impl CacheKey {
/// Creates a new cache key from the given parameters.
fn new(host: &str, client_id: &str, scopes: &[String]) -> Self {
Expand Down Expand Up @@ -69,10 +68,8 @@ impl CacheKey {
/// set to `0o600` (owner read/write only) for security.
///
/// Cache I/O errors are logged as warnings and never block authentication.
#[allow(dead_code)] // Used in Phase 3 (U2M)
pub(crate) struct TokenCache;
pub struct TokenCache;

#[allow(dead_code)] // Used in Phase 3 (U2M)
impl TokenCache {
/// Returns the cache directory path.
/// Location: `~/.config/databricks-adbc/oauth/`
Expand Down
Loading
Loading