Skip to content

Commit ec7600d

Browse files
Wire U2M provider into Database::new_connection()\n\nTask ID: task-3.3-database-integration
1 parent eefef6e commit ec7600d

File tree

6 files changed

+236
-179
lines changed

6 files changed

+236
-179
lines changed

rust/src/auth/oauth/cache.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ use std::path::PathBuf;
2727

2828
/// Cache key used for generating the cache filename.
2929
/// This struct is serialized to JSON and then hashed with SHA-256 to produce a unique filename.
30-
#[allow(dead_code)] // Used in Phase 3 (U2M)
30+
3131
#[derive(Debug, Serialize, Deserialize)]
3232
struct CacheKey {
3333
host: String,
3434
client_id: String,
3535
scopes: Vec<String>,
3636
}
3737

38-
#[allow(dead_code)] // Used in Phase 3 (U2M)
3938
impl CacheKey {
4039
/// Creates a new cache key from the given parameters.
4140
fn new(host: &str, client_id: &str, scopes: &[String]) -> Self {
@@ -69,10 +68,8 @@ impl CacheKey {
6968
/// set to `0o600` (owner read/write only) for security.
7069
///
7170
/// Cache I/O errors are logged as warnings and never block authentication.
72-
#[allow(dead_code)] // Used in Phase 3 (U2M)
7371
pub(crate) struct TokenCache;
7472

75-
#[allow(dead_code)] // Used in Phase 3 (U2M)
7673
impl TokenCache {
7774
/// Returns the cache directory path.
7875
/// Location: `~/.config/databricks-adbc/oauth/`

rust/src/auth/oauth/callback.rs

Lines changed: 23 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -33,80 +33,20 @@ use tokio::sync::oneshot;
3333
/// HTML response sent to the browser after successful callback.
3434
const SUCCESS_HTML: &str = r#"<!DOCTYPE html>
3535
<html>
36-
<head>
37-
<title>Authentication Successful</title>
38-
<style>
39-
body {
40-
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
41-
display: flex;
42-
justify-content: center;
43-
align-items: center;
44-
height: 100vh;
45-
margin: 0;
46-
background-color: #f5f5f5;
47-
}
48-
.message {
49-
text-align: center;
50-
padding: 2rem;
51-
background: white;
52-
border-radius: 8px;
53-
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
54-
}
55-
h1 {
56-
color: #2e7d32;
57-
margin: 0 0 1rem 0;
58-
}
59-
p {
60-
color: #666;
61-
margin: 0;
62-
}
63-
</style>
64-
</head>
36+
<head><title>Authentication Successful</title></head>
6537
<body>
66-
<div class="message">
67-
<h1>✓ Authentication Successful</h1>
68-
<p>You can close this tab and return to your application.</p>
69-
</div>
38+
<h1>Authentication Successful</h1>
39+
<p>You can close this tab and return to your application.</p>
7040
</body>
7141
</html>"#;
7242

7343
/// HTML response sent to the browser when an error occurs.
7444
const ERROR_HTML: &str = r#"<!DOCTYPE html>
7545
<html>
76-
<head>
77-
<title>Authentication Error</title>
78-
<style>
79-
body {
80-
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
81-
display: flex;
82-
justify-content: center;
83-
align-items: center;
84-
height: 100vh;
85-
margin: 0;
86-
background-color: #f5f5f5;
87-
}
88-
.message {
89-
text-align: center;
90-
padding: 2rem;
91-
background: white;
92-
border-radius: 8px;
93-
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
94-
}
95-
h1 {
96-
color: #c62828;
97-
margin: 0 0 1rem 0;
98-
}
99-
p {
100-
color: #666;
101-
margin: 0;
102-
}
103-
</style>
104-
</head>
46+
<head><title>Authentication Error</title></head>
10547
<body>
106-
<div class="message">
107-
<h1>✗ Authentication Error</h1>
108-
<p>An error occurred during authentication. You can close this tab and try again.</p>
109-
</div>
48+
<h1>Authentication Error</h1>
49+
<p>An error occurred during authentication. You can close this tab and try again.</p>
11050
</body>
11151
</html>"#;
11252

@@ -668,9 +608,24 @@ mod tests {
668608

669609
#[tokio::test]
670610
async fn test_redirect_uri_format() {
671-
let server = CallbackServer::new(8020)
611+
// Use port 0 to let the OS assign an available port, avoiding conflicts
612+
let server = CallbackServer::new(0)
672613
.await
673614
.expect("Failed to create server");
674-
assert_eq!(server.redirect_uri(), "http://localhost:8020/callback");
615+
let uri = server.redirect_uri();
616+
assert!(
617+
uri.starts_with("http://localhost:"),
618+
"Expected URI starting with http://localhost:, got: {}",
619+
uri
620+
);
621+
// Verify the port is a valid non-zero number
622+
let port: u16 = uri
623+
.strip_prefix("http://localhost:")
624+
.unwrap()
625+
.strip_suffix("/callback")
626+
.unwrap_or_else(|| uri.strip_prefix("http://localhost:").unwrap())
627+
.parse()
628+
.unwrap();
629+
assert!(port > 0, "Expected a non-zero port, got: {}", port);
675630
}
676631
}

rust/src/auth/oauth/m2m.rs

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -170,77 +170,6 @@ impl ClientCredentialsProvider {
170170
scopes,
171171
})
172172
}
173-
174-
/// Exchanges client credentials for an access token.
175-
///
176-
/// This method performs the OAuth 2.0 client credentials grant:
177-
/// ```text
178-
/// POST {token_endpoint}
179-
/// Authorization: Basic base64(client_id:client_secret)
180-
/// Content-Type: application/x-www-form-urlencoded
181-
///
182-
/// grant_type=client_credentials&scope=all-apis
183-
/// ```
184-
///
185-
/// Uses the DatabricksHttpClient's inner reqwest::Client to execute the request.
186-
/// The oauth2 crate adds Basic authentication header automatically.
187-
///
188-
/// Note: This method is currently unused but kept for potential future use
189-
/// (e.g., direct testing or alternative refresh strategies).
190-
#[allow(dead_code)]
191-
async fn fetch_token(&self) -> Result<OAuthToken> {
192-
// Build the OAuth client
193-
let oauth_client = BasicClient::new(ClientId::new(self.client_id.clone()))
194-
.set_client_secret(ClientSecret::new(self.client_secret.clone()))
195-
.set_auth_uri(AuthUrl::new(self.auth_endpoint.clone()).map_err(|e| {
196-
DatabricksErrorHelper::invalid_argument()
197-
.message(format!("Invalid authorization endpoint URL: {}", e))
198-
})?)
199-
.set_token_uri(TokenUrl::new(self.token_endpoint.clone()).map_err(|e| {
200-
DatabricksErrorHelper::invalid_argument()
201-
.message(format!("Invalid token endpoint URL: {}", e))
202-
})?);
203-
204-
// Build the token request with scopes
205-
let mut token_request = oauth_client.exchange_client_credentials();
206-
207-
for scope in &self.scopes {
208-
token_request = token_request.add_scope(Scope::new(scope.clone()));
209-
}
210-
211-
// Execute the token exchange using the inner reqwest client
212-
// The oauth2 crate's reqwest::Client implements AsyncHttpClient
213-
let token_response = token_request
214-
.request_async(self.http_client.inner())
215-
.await
216-
.map_err(|e| {
217-
DatabricksErrorHelper::io()
218-
.message(format!("M2M token exchange failed: {}", e))
219-
.context("client credentials grant")
220-
})?;
221-
222-
// Convert oauth2 token response to our OAuthToken
223-
let access_token = token_response.access_token().secret().to_string();
224-
let token_type = token_response.token_type().as_ref().to_string();
225-
let expires_in = token_response
226-
.expires_in()
227-
.map(|d| d.as_secs() as i64)
228-
.unwrap_or(3600); // Default to 1 hour if not specified
229-
230-
let scopes = token_response
231-
.scopes()
232-
.map(|s| s.iter().map(|scope| scope.to_string()).collect())
233-
.unwrap_or_else(|| self.scopes.clone());
234-
235-
// M2M tokens have no refresh_token
236-
Ok(OAuthToken::new(
237-
access_token,
238-
token_type,
239-
expires_in,
240-
None, // No refresh token for M2M
241-
scopes,
242-
))
243-
}
244173
}
245174

246175
impl AuthProvider for ClientCredentialsProvider {

rust/src/auth/oauth/token_store.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ use std::sync::{Arc, RwLock};
7373
/// ))
7474
/// })?;
7575
/// ```
76-
#[allow(dead_code)] // Used in Phase 2 (M2M) and Phase 3 (U2M)
76+
7777
#[derive(Debug)]
7878
pub(crate) struct TokenStore {
7979
/// The current token, protected by a read-write lock.
@@ -83,7 +83,6 @@ pub(crate) struct TokenStore {
8383
refreshing: Arc<AtomicBool>,
8484
}
8585

86-
#[allow(dead_code)] // Used in Phase 2 (M2M) and Phase 3 (U2M)
8786
impl TokenStore {
8887
/// Creates a new empty token store.
8988
pub fn new() -> Self {
@@ -284,7 +283,6 @@ impl TokenStore {
284283
}
285284

286285
/// Internal enum representing the token's current state.
287-
#[allow(dead_code)] // Used internally by TokenStore
288286
#[derive(Debug)]
289287
enum TokenState {
290288
/// No token is present in the store.

rust/src/auth/oauth/u2m.rs

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ pub struct AuthorizationCodeProvider {
111111
callback_port: u16,
112112
/// Callback timeout
113113
callback_timeout: Duration,
114+
/// Whether to open browser automatically (set to false in tests)
115+
open_browser: bool,
114116
}
115117

116118
impl AuthorizationCodeProvider {
@@ -179,9 +181,37 @@ impl AuthorizationCodeProvider {
179181
callback_port: u16,
180182
callback_timeout: Duration,
181183
) -> Result<Self> {
182-
// Discover OIDC endpoints
184+
Self::new_with_full_config(
185+
host,
186+
client_id,
187+
http_client,
188+
scopes,
189+
callback_port,
190+
callback_timeout,
191+
None, // No token_endpoint override
192+
)
193+
.await
194+
}
195+
196+
/// Creates a new U2M provider with full configuration including optional token endpoint override.
197+
///
198+
/// This is used by Database when a custom token endpoint is configured.
199+
/// Most users should use `new()` or `new_with_config()` which use OIDC discovery.
200+
pub async fn new_with_full_config(
201+
host: &str,
202+
client_id: &str,
203+
http_client: Arc<DatabricksHttpClient>,
204+
scopes: Vec<String>,
205+
callback_port: u16,
206+
callback_timeout: Duration,
207+
token_endpoint_override: Option<String>,
208+
) -> Result<Self> {
209+
// Discover OIDC endpoints (unless token_endpoint is overridden, we still need auth endpoint)
183210
let endpoints = OidcEndpoints::discover(host, &http_client).await?;
184211

212+
// Use override if provided, otherwise use discovered endpoint
213+
let token_endpoint = token_endpoint_override.unwrap_or(endpoints.token_endpoint);
214+
185215
// Create token store
186216
let token_store = TokenStore::new();
187217

@@ -206,17 +236,29 @@ impl AuthorizationCodeProvider {
206236
Ok(Self {
207237
host: host.to_string(),
208238
client_id: client_id.to_string(),
209-
token_endpoint: endpoints.token_endpoint,
239+
token_endpoint,
210240
auth_endpoint: endpoints.authorization_endpoint,
211241
token_store,
212242
http_client,
213243
scopes,
214244
callback_port,
215245
callback_timeout,
246+
open_browser: true,
216247
})
217248
}
218249
}
219250

251+
impl AuthorizationCodeProvider {
252+
/// Disables automatic browser opening during the U2M flow.
253+
///
254+
/// This is primarily useful for testing scenarios where the browser
255+
/// flow fallback should not launch a real browser.
256+
#[doc(hidden)]
257+
pub fn suppress_browser(&mut self) {
258+
self.open_browser = false;
259+
}
260+
}
261+
220262
impl AuthProvider for AuthorizationCodeProvider {
221263
/// Returns a valid Bearer token for authentication.
222264
///
@@ -246,6 +288,7 @@ impl AuthProvider for AuthorizationCodeProvider {
246288
let scopes = self.scopes.clone();
247289
let callback_port = self.callback_port;
248290
let callback_timeout = self.callback_timeout;
291+
let open_browser = self.open_browser;
249292
let http_client = self.http_client.clone();
250293

251294
// Get or refresh token via TokenStore
@@ -366,10 +409,14 @@ impl AuthProvider for AuthorizationCodeProvider {
366409

367410
let (auth_url, csrf_state) = auth_url_builder.url();
368411

369-
// Launch browser
370-
tracing::info!("Opening browser for OAuth authorization: {}", auth_url);
371-
if let Err(e) = open::that(auth_url.as_str()) {
372-
tracing::warn!("Failed to automatically open browser: {}. Please manually navigate to: {}", e, auth_url);
412+
// Launch browser (unless suppressed for testing)
413+
if open_browser {
414+
tracing::info!("Opening browser for OAuth authorization: {}", auth_url);
415+
if let Err(e) = open::that(auth_url.as_str()) {
416+
tracing::warn!("Failed to automatically open browser: {}. Please manually navigate to: {}", e, auth_url);
417+
}
418+
} else {
419+
tracing::debug!("Browser opening suppressed, authorization URL: {}", auth_url);
373420
}
374421

375422
// Wait for callback
@@ -607,17 +654,18 @@ mod tests {
607654
&cached_token,
608655
);
609656

610-
// Create provider
611-
let provider = AuthorizationCodeProvider::new_with_config(
657+
// Create provider with browser opening suppressed
658+
let mut provider = AuthorizationCodeProvider::new_with_config(
612659
&mock_server.uri(),
613660
"test-client-id",
614661
http_client,
615662
vec!["all-apis".to_string(), "offline_access".to_string()],
616-
8021, // Different port to avoid conflicts
663+
0, // Use port 0 to avoid conflicts
617664
Duration::from_secs(5), // Short timeout for test
618665
)
619666
.await
620667
.expect("Failed to create provider");
668+
provider.suppress_browser();
621669

622670
// Wait to ensure token expires
623671
tokio::time::sleep(Duration::from_secs(2)).await;

0 commit comments

Comments
 (0)