Skip to content

Commit 3730127

Browse files
Wire U2M provider into Database::new_connection()\n\nTask ID: task-3.3-database-integration
1 parent cea68cd commit 3730127

File tree

6 files changed

+268
-236
lines changed

6 files changed

+268
-236
lines changed

rust/src/auth/oauth/cache.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ use std::path::PathBuf;
2727

2828
/// Cache key used for generating the cache filename.
2929
/// This struct is serialized to JSON and then hashed with SHA-256 to produce a unique filename.
30-
#[allow(dead_code)] // Used in Phase 3 (U2M)
30+
3131
#[derive(Debug, Serialize, Deserialize)]
3232
struct CacheKey {
3333
host: String,
3434
client_id: String,
3535
scopes: Vec<String>,
3636
}
3737

38-
#[allow(dead_code)] // Used in Phase 3 (U2M)
3938
impl CacheKey {
4039
/// Creates a new cache key from the given parameters.
4140
fn new(host: &str, client_id: &str, scopes: &[String]) -> Self {
@@ -69,10 +68,8 @@ impl CacheKey {
6968
/// set to `0o600` (owner read/write only) for security.
7069
///
7170
/// Cache I/O errors are logged as warnings and never block authentication.
72-
#[allow(dead_code)] // Used in Phase 3 (U2M)
7371
pub(crate) struct TokenCache;
7472

75-
#[allow(dead_code)] // Used in Phase 3 (U2M)
7673
impl TokenCache {
7774
/// Returns the cache directory path.
7875
/// Location: `~/.config/databricks-adbc/oauth/`

rust/src/auth/oauth/callback.rs

Lines changed: 25 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -33,80 +33,20 @@ use tokio::sync::oneshot;
3333
/// HTML response sent to the browser after successful callback.
3434
const SUCCESS_HTML: &str = r#"<!DOCTYPE html>
3535
<html>
36-
<head>
37-
<title>Authentication Successful</title>
38-
<style>
39-
body {
40-
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
41-
display: flex;
42-
justify-content: center;
43-
align-items: center;
44-
height: 100vh;
45-
margin: 0;
46-
background-color: #f5f5f5;
47-
}
48-
.message {
49-
text-align: center;
50-
padding: 2rem;
51-
background: white;
52-
border-radius: 8px;
53-
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
54-
}
55-
h1 {
56-
color: #2e7d32;
57-
margin: 0 0 1rem 0;
58-
}
59-
p {
60-
color: #666;
61-
margin: 0;
62-
}
63-
</style>
64-
</head>
36+
<head><title>Authentication Successful</title></head>
6537
<body>
66-
<div class="message">
67-
<h1>✓ Authentication Successful</h1>
68-
<p>You can close this tab and return to your application.</p>
69-
</div>
38+
<h1>Authentication Successful</h1>
39+
<p>You can close this tab and return to your application.</p>
7040
</body>
7141
</html>"#;
7242

7343
/// HTML response sent to the browser when an error occurs.
7444
const ERROR_HTML: &str = r#"<!DOCTYPE html>
7545
<html>
76-
<head>
77-
<title>Authentication Error</title>
78-
<style>
79-
body {
80-
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
81-
display: flex;
82-
justify-content: center;
83-
align-items: center;
84-
height: 100vh;
85-
margin: 0;
86-
background-color: #f5f5f5;
87-
}
88-
.message {
89-
text-align: center;
90-
padding: 2rem;
91-
background: white;
92-
border-radius: 8px;
93-
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
94-
}
95-
h1 {
96-
color: #c62828;
97-
margin: 0 0 1rem 0;
98-
}
99-
p {
100-
color: #666;
101-
margin: 0;
102-
}
103-
</style>
104-
</head>
46+
<head><title>Authentication Error</title></head>
10547
<body>
106-
<div class="message">
107-
<h1>✗ Authentication Error</h1>
108-
<p>An error occurred during authentication. You can close this tab and try again.</p>
109-
</div>
48+
<h1>Authentication Error</h1>
49+
<p>An error occurred during authentication. You can close this tab and try again.</p>
11050
</body>
11151
</html>"#;
11252

@@ -496,7 +436,7 @@ mod tests {
496436

497437
// Read the response
498438
let mut response = vec![0u8; 4096];
499-
stream.read(&mut response).await.ok();
439+
let _ = stream.read(&mut response).await;
500440
});
501441

502442
// Wait for callback
@@ -533,7 +473,7 @@ mod tests {
533473
stream.flush().await.ok();
534474

535475
let mut response = vec![0u8; 4096];
536-
stream.read(&mut response).await.ok();
476+
let _ = stream.read(&mut response).await;
537477
});
538478

539479
// Wait for callback - should fail due to state mismatch
@@ -668,9 +608,24 @@ mod tests {
668608

669609
#[tokio::test]
670610
async fn test_redirect_uri_format() {
671-
let server = CallbackServer::new(8020)
611+
// Use port 0 to let the OS assign an available port, avoiding conflicts
612+
let server = CallbackServer::new(0)
672613
.await
673614
.expect("Failed to create server");
674-
assert_eq!(server.redirect_uri(), "http://localhost:8020/callback");
615+
let uri = server.redirect_uri();
616+
assert!(
617+
uri.starts_with("http://localhost:"),
618+
"Expected URI starting with http://localhost:, got: {}",
619+
uri
620+
);
621+
// Verify the port is a valid non-zero number
622+
let port: u16 = uri
623+
.strip_prefix("http://localhost:")
624+
.unwrap()
625+
.strip_suffix("/callback")
626+
.unwrap_or_else(|| uri.strip_prefix("http://localhost:").unwrap())
627+
.parse()
628+
.unwrap();
629+
assert!(port > 0, "Expected a non-zero port, got: {}", port);
675630
}
676631
}

rust/src/auth/oauth/m2m.rs

Lines changed: 31 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -170,77 +170,6 @@ impl ClientCredentialsProvider {
170170
scopes,
171171
})
172172
}
173-
174-
/// Exchanges client credentials for an access token.
175-
///
176-
/// This method performs the OAuth 2.0 client credentials grant:
177-
/// ```text
178-
/// POST {token_endpoint}
179-
/// Authorization: Basic base64(client_id:client_secret)
180-
/// Content-Type: application/x-www-form-urlencoded
181-
///
182-
/// grant_type=client_credentials&scope=all-apis
183-
/// ```
184-
///
185-
/// Uses the DatabricksHttpClient's inner reqwest::Client to execute the request.
186-
/// The oauth2 crate adds Basic authentication header automatically.
187-
///
188-
/// Note: This method is currently unused but kept for potential future use
189-
/// (e.g., direct testing or alternative refresh strategies).
190-
#[allow(dead_code)]
191-
async fn fetch_token(&self) -> Result<OAuthToken> {
192-
// Build the OAuth client
193-
let oauth_client = BasicClient::new(ClientId::new(self.client_id.clone()))
194-
.set_client_secret(ClientSecret::new(self.client_secret.clone()))
195-
.set_auth_uri(AuthUrl::new(self.auth_endpoint.clone()).map_err(|e| {
196-
DatabricksErrorHelper::invalid_argument()
197-
.message(format!("Invalid authorization endpoint URL: {}", e))
198-
})?)
199-
.set_token_uri(TokenUrl::new(self.token_endpoint.clone()).map_err(|e| {
200-
DatabricksErrorHelper::invalid_argument()
201-
.message(format!("Invalid token endpoint URL: {}", e))
202-
})?);
203-
204-
// Build the token request with scopes
205-
let mut token_request = oauth_client.exchange_client_credentials();
206-
207-
for scope in &self.scopes {
208-
token_request = token_request.add_scope(Scope::new(scope.clone()));
209-
}
210-
211-
// Execute the token exchange using the inner reqwest client
212-
// The oauth2 crate's reqwest::Client implements AsyncHttpClient
213-
let token_response = token_request
214-
.request_async(self.http_client.inner())
215-
.await
216-
.map_err(|e| {
217-
DatabricksErrorHelper::io()
218-
.message(format!("M2M token exchange failed: {}", e))
219-
.context("client credentials grant")
220-
})?;
221-
222-
// Convert oauth2 token response to our OAuthToken
223-
let access_token = token_response.access_token().secret().to_string();
224-
let token_type = token_response.token_type().as_ref().to_string();
225-
let expires_in = token_response
226-
.expires_in()
227-
.map(|d| d.as_secs() as i64)
228-
.unwrap_or(3600); // Default to 1 hour if not specified
229-
230-
let scopes = token_response
231-
.scopes()
232-
.map(|s| s.iter().map(|scope| scope.to_string()).collect())
233-
.unwrap_or_else(|| self.scopes.clone());
234-
235-
// M2M tokens have no refresh_token
236-
Ok(OAuthToken::new(
237-
access_token,
238-
token_type,
239-
expires_in,
240-
None, // No refresh token for M2M
241-
scopes,
242-
))
243-
}
244173
}
245174

246175
impl AuthProvider for ClientCredentialsProvider {
@@ -367,20 +296,18 @@ mod tests {
367296
// Mock OIDC discovery
368297
Mock::given(method("GET"))
369298
.and(path("/oidc/.well-known/oauth-authorization-server"))
370-
.respond_with(
371-
ResponseTemplate::new(200).set_body_json(&serde_json::json!({
372-
"authorization_endpoint": format!("{}/oidc/v1/authorize", mock_server.uri()),
373-
"token_endpoint": token_endpoint,
374-
})),
375-
)
299+
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
300+
"authorization_endpoint": format!("{}/oidc/v1/authorize", mock_server.uri()),
301+
"token_endpoint": token_endpoint,
302+
})))
376303
.mount(&mock_server)
377304
.await;
378305

379306
// Mock token endpoint - verify grant_type and Basic auth
380307
Mock::given(method("POST"))
381308
.and(path("/oidc/v1/token"))
382309
.and(header("content-type", "application/x-www-form-urlencoded"))
383-
.respond_with(ResponseTemplate::new(200).set_body_json(&mock_token_response_body()))
310+
.respond_with(ResponseTemplate::new(200).set_body_json(mock_token_response_body()))
384311
.expect(1)
385312
.mount(&mock_server)
386313
.await;
@@ -416,12 +343,10 @@ mod tests {
416343
// Mock OIDC discovery
417344
Mock::given(method("GET"))
418345
.and(path("/oidc/.well-known/oauth-authorization-server"))
419-
.respond_with(
420-
ResponseTemplate::new(200).set_body_json(&serde_json::json!({
421-
"authorization_endpoint": format!("{}/oidc/v1/authorize", mock_server.uri()),
422-
"token_endpoint": token_endpoint,
423-
})),
424-
)
346+
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
347+
"authorization_endpoint": format!("{}/oidc/v1/authorize", mock_server.uri()),
348+
"token_endpoint": token_endpoint,
349+
})))
425350
.mount(&mock_server)
426351
.await;
427352

@@ -436,15 +361,15 @@ mod tests {
436361
let count = call_count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
437362
if count == 0 {
438363
// First call - return short-lived token
439-
ResponseTemplate::new(200).set_body_json(&serde_json::json!({
364+
ResponseTemplate::new(200).set_body_json(serde_json::json!({
440365
"access_token": "initial-token",
441366
"token_type": "Bearer",
442367
"expires_in": 1, // Very short expiry to trigger refresh
443368
"scope": "all-apis"
444369
}))
445370
} else {
446371
// Subsequent calls - return long-lived token
447-
ResponseTemplate::new(200).set_body_json(&serde_json::json!({
372+
ResponseTemplate::new(200).set_body_json(serde_json::json!({
448373
"access_token": "refreshed-token",
449374
"token_type": "Bearer",
450375
"expires_in": 3600,
@@ -493,19 +418,17 @@ mod tests {
493418
// Mock OIDC discovery with specific endpoints
494419
Mock::given(method("GET"))
495420
.and(path("/oidc/.well-known/oauth-authorization-server"))
496-
.respond_with(
497-
ResponseTemplate::new(200).set_body_json(&serde_json::json!({
498-
"authorization_endpoint": "https://custom.example.com/auth",
499-
"token_endpoint": "https://custom.example.com/token",
500-
})),
501-
)
421+
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
422+
"authorization_endpoint": "https://custom.example.com/auth",
423+
"token_endpoint": "https://custom.example.com/token",
424+
})))
502425
.mount(&mock_server)
503426
.await;
504427

505428
// Mock token endpoint
506429
Mock::given(method("POST"))
507430
.and(path("/token"))
508-
.respond_with(ResponseTemplate::new(200).set_body_json(&mock_token_response_body()))
431+
.respond_with(ResponseTemplate::new(200).set_body_json(mock_token_response_body()))
509432
.mount(&mock_server)
510433
.await;
511434

@@ -572,19 +495,17 @@ mod tests {
572495
// Mock OIDC discovery
573496
Mock::given(method("GET"))
574497
.and(path("/oidc/.well-known/oauth-authorization-server"))
575-
.respond_with(
576-
ResponseTemplate::new(200).set_body_json(&serde_json::json!({
577-
"authorization_endpoint": format!("{}/oidc/v1/authorize", mock_server.uri()),
578-
"token_endpoint": token_endpoint,
579-
})),
580-
)
498+
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
499+
"authorization_endpoint": format!("{}/oidc/v1/authorize", mock_server.uri()),
500+
"token_endpoint": token_endpoint,
501+
})))
581502
.mount(&mock_server)
582503
.await;
583504

584505
// Mock token endpoint - should only be called once despite concurrent requests
585506
Mock::given(method("POST"))
586507
.and(path("/oidc/v1/token"))
587-
.respond_with(ResponseTemplate::new(200).set_body_json(&mock_token_response_body()))
508+
.respond_with(ResponseTemplate::new(200).set_body_json(mock_token_response_body()))
588509
.expect(1) // Verify only one token fetch occurs
589510
.mount(&mock_server)
590511
.await;
@@ -640,26 +561,22 @@ mod tests {
640561
// Mock OIDC discovery
641562
Mock::given(method("GET"))
642563
.and(path("/oidc/.well-known/oauth-authorization-server"))
643-
.respond_with(
644-
ResponseTemplate::new(200).set_body_json(&serde_json::json!({
645-
"authorization_endpoint": format!("{}/oidc/v1/authorize", mock_server.uri()),
646-
"token_endpoint": token_endpoint,
647-
})),
648-
)
564+
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
565+
"authorization_endpoint": format!("{}/oidc/v1/authorize", mock_server.uri()),
566+
"token_endpoint": token_endpoint,
567+
})))
649568
.mount(&mock_server)
650569
.await;
651570

652571
// Mock token endpoint
653572
Mock::given(method("POST"))
654573
.and(path("/oidc/v1/token"))
655-
.respond_with(
656-
ResponseTemplate::new(200).set_body_json(&serde_json::json!({
657-
"access_token": "test-token-custom-scopes",
658-
"token_type": "Bearer",
659-
"expires_in": 3600,
660-
"scope": "custom-scope-1 custom-scope-2"
661-
})),
662-
)
574+
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
575+
"access_token": "test-token-custom-scopes",
576+
"token_type": "Bearer",
577+
"expires_in": 3600,
578+
"scope": "custom-scope-1 custom-scope-2"
579+
})))
663580
.mount(&mock_server)
664581
.await;
665582

0 commit comments

Comments
 (0)