Skip to content

Commit 84aefc5

Browse files
committed
feat: add an ability to override scopes for auth
1 parent 049e63e commit 84aefc5

File tree

7 files changed

+94
-43
lines changed

7 files changed

+94
-43
lines changed

Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,5 @@ open = "5"
4444
chrono = {version = "0.4", features = ["serde"]}
4545

4646
[workspace.package]
47-
version = "0.1.3"
47+
version = "0.1.4"
4848
edition = "2024"

client/src/config.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ impl Config {
6767
auth_url: None,
6868
token_url: None,
6969
callback_port: None,
70+
scopes: vec![],
7071
},
7172
};
7273

common/src/auth.rs

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ pub struct AuthManager {
9191
http_client: reqwest::Client,
9292
keyring_user: String,
9393
callback_port: u16,
94+
scopes: Vec<String>,
9495
}
9596

9697
impl AuthManager {
@@ -121,6 +122,7 @@ impl AuthManager {
121122
http_client,
122123
keyring_user,
123124
callback_port,
125+
scopes: config.scopes.clone(),
124126
})
125127
}
126128

@@ -221,11 +223,22 @@ impl AuthManager {
221223
// Generate PKCE challenge
222224
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
223225

224-
// Generate authorization URL with offline_access scope for refresh tokens
225-
let (auth_url, csrf_token) = self.client
226-
.authorize_url(CsrfToken::new_random)
227-
.add_scope(Scope::new("openid".to_string()))
228-
.add_scope(Scope::new("offline_access".to_string()))
226+
// Generate authorization URL
227+
// If scopes are configured, use only those (all-or-nothing override)
228+
// Otherwise use defaults: openid + offline_access
229+
let mut auth_request = self.client.authorize_url(CsrfToken::new_random);
230+
231+
if self.scopes.is_empty() {
232+
auth_request = auth_request
233+
.add_scope(Scope::new("openid".to_string()))
234+
.add_scope(Scope::new("offline_access".to_string()));
235+
} else {
236+
for scope in &self.scopes {
237+
auth_request = auth_request.add_scope(Scope::new(scope.clone()));
238+
}
239+
}
240+
241+
let (auth_url, csrf_token) = auth_request
229242
.set_pkce_challenge(pkce_challenge)
230243
.url();
231244

common/src/config.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ pub struct OAuthConfig {
1717
/// Port for OAuth callback server. Defaults to 19284.
1818
#[serde(skip_serializing_if = "Option::is_none")]
1919
pub callback_port: Option<u16>,
20+
/// OAuth scopes to request. If empty, defaults to `openid` and `offline_access`.
21+
/// If provided, these scopes are used instead (all-or-nothing override).
22+
/// For Azure AD, typically include `openid`, `offline_access`, and
23+
/// `{client_id}/.default` to get an access token for your API.
24+
#[serde(default, skip_serializing_if = "Vec::is_empty")]
25+
pub scopes: Vec<String>,
2026
}
2127

2228
#[derive(Debug, Deserialize)]

mcp-client/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ pub mod config {
8585
auth_url: None,
8686
token_url: None,
8787
callback_port: None,
88+
scopes: vec![],
8889
},
8990
};
9091

tests/src/mcp.rs

Lines changed: 61 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ use std::time::Duration;
1010

1111
use chrono::Utc;
1212
use mcp_client::{
13+
JsonRpcRequest, RequestLogEntry, SharedState,
1314
config::{Config, OAuthConfig},
14-
handle_request, new_shared_state, JsonRpcRequest, RequestLogEntry,
15-
SharedState,
15+
handle_request, new_shared_state,
1616
};
17-
use serde_json::{json, Value};
17+
use serde_json::{Value, json};
1818
use server::config::Config as ServerConfig;
1919

20-
use test_helpers::{get_test_token, MockLocalServer};
20+
use test_helpers::{MockLocalServer, get_test_token};
2121

2222
// Test configuration
2323
const KEYCLOAK_ISSUER: &str = "http://localhost:8180/realms/relay";
@@ -27,9 +27,7 @@ const TEST_PASSWORD: &str = "testpass";
2727
const JWT_AUDIENCE: &str = "webhook-relay-cli";
2828

2929
fn init_tracing() {
30-
let _ = tracing_subscriber::fmt()
31-
.with_env_filter("info")
32-
.try_init();
30+
let _ = tracing_subscriber::fmt().with_env_filter("info").try_init();
3331
}
3432

3533
fn create_server_config(http_port: u16, grpc_port: u16) -> ServerConfig {
@@ -56,6 +54,7 @@ fn create_mcp_config(grpc_addr: &str, local_endpoint: &str) -> Config {
5654
auth_url: None,
5755
token_url: None,
5856
callback_port: None,
57+
scopes: vec![],
5958
},
6059
}
6160
}
@@ -162,12 +161,12 @@ async fn test_mcp_list_tools() {
162161
let tools = result["tools"].as_array().expect("No tools array");
163162

164163
// Verify expected tools exist
165-
let tool_names: Vec<&str> = tools
166-
.iter()
167-
.map(|t| t["name"].as_str().unwrap())
168-
.collect();
164+
let tool_names: Vec<&str> = tools.iter().map(|t| t["name"].as_str().unwrap()).collect();
169165

170-
assert!(tool_names.contains(&"get_config"), "Missing get_config tool");
166+
assert!(
167+
tool_names.contains(&"get_config"),
168+
"Missing get_config tool"
169+
);
171170
assert!(
172171
tool_names.contains(&"get_request_log"),
173172
"Missing get_request_log tool"
@@ -256,7 +255,10 @@ async fn test_mcp_get_config_connected_status() {
256255
.expect("Should have text content");
257256

258257
// Verify connected status is displayed
259-
assert!(content.contains("Connected"), "Should show connected status");
258+
assert!(
259+
content.contains("Connected"),
260+
"Should show connected status"
261+
);
260262
assert!(
261263
content.contains("http://relay.example.com/webhook/abc123"),
262264
"Should show webhook endpoint"
@@ -349,8 +351,14 @@ async fn test_mcp_get_request_log_with_entries() {
349351
.expect("Should have text content");
350352

351353
// Verify entries are displayed (newest first)
352-
assert!(content.contains("req-002"), "Should contain second request ID");
353-
assert!(content.contains("req-001"), "Should contain first request ID");
354+
assert!(
355+
content.contains("req-002"),
356+
"Should contain second request ID"
357+
);
358+
assert!(
359+
content.contains("req-001"),
360+
"Should contain first request ID"
361+
);
354362
assert!(content.contains("POST"), "Should show method");
355363
assert!(
356364
content.contains("/webhook/test"),
@@ -360,7 +368,10 @@ async fn test_mcp_get_request_log_with_entries() {
360368
content.contains("/webhook/another"),
361369
"Should show second request path"
362370
);
363-
assert!(content.contains("2 total entries"), "Should show total count");
371+
assert!(
372+
content.contains("2 total entries"),
373+
"Should show total count"
374+
);
364375

365376
// Verify newest first ordering (req-002 should appear before req-001)
366377
let pos_002 = content.find("req-002").expect("req-002 not found");
@@ -419,8 +430,14 @@ async fn test_mcp_get_request_log_pagination() {
419430
// Page 0 should have newest 2 entries (req-005, req-004)
420431
assert!(content.contains("req-005"), "Page 0 should have req-005");
421432
assert!(content.contains("req-004"), "Page 0 should have req-004");
422-
assert!(!content.contains("req-003"), "Page 0 should not have req-003");
423-
assert!(content.contains("Page 1 of 3"), "Should show correct page info");
433+
assert!(
434+
!content.contains("req-003"),
435+
"Page 0 should not have req-003"
436+
);
437+
assert!(
438+
content.contains("Page 1 of 3"),
439+
"Should show correct page info"
440+
);
424441

425442
// Get page 1 with page_size 2
426443
let request = make_request(
@@ -446,8 +463,14 @@ async fn test_mcp_get_request_log_pagination() {
446463
// Page 1 should have req-003, req-002
447464
assert!(content.contains("req-003"), "Page 1 should have req-003");
448465
assert!(content.contains("req-002"), "Page 1 should have req-002");
449-
assert!(!content.contains("req-005"), "Page 1 should not have req-005");
450-
assert!(content.contains("Page 2 of 3"), "Should show correct page info");
466+
assert!(
467+
!content.contains("req-005"),
468+
"Page 1 should not have req-005"
469+
);
470+
assert!(
471+
content.contains("Page 2 of 3"),
472+
"Should show correct page info"
473+
);
451474

452475
// Get page 2 with page_size 2
453476
let request = make_request(
@@ -472,7 +495,10 @@ async fn test_mcp_get_request_log_pagination() {
472495

473496
// Page 2 should have only req-001
474497
assert!(content.contains("req-001"), "Page 2 should have req-001");
475-
assert!(!content.contains("req-002"), "Page 2 should not have req-002");
498+
assert!(
499+
!content.contains("req-002"),
500+
"Page 2 should not have req-002"
501+
);
476502

477503
tracing::info!("Pagination test passed!");
478504
}
@@ -589,9 +615,14 @@ async fn test_mcp_full_webhook_flow() {
589615
tracing::info!("Relay server HTTP: {}, gRPC: {}", http_addr, grpc_addr);
590616

591617
// 3. Get OAuth token
592-
let token = get_test_token(KEYCLOAK_ISSUER, KEYCLOAK_CLIENT_ID, TEST_USERNAME, TEST_PASSWORD)
593-
.await
594-
.expect("Failed to get test token");
618+
let token = get_test_token(
619+
KEYCLOAK_ISSUER,
620+
KEYCLOAK_CLIENT_ID,
621+
TEST_USERNAME,
622+
TEST_PASSWORD,
623+
)
624+
.await
625+
.expect("Failed to get test token");
595626

596627
// 4. Create MCP state with config
597628
let mcp_config = create_mcp_config(&grpc_addr, &local_endpoint);
@@ -635,8 +666,7 @@ async fn test_mcp_full_webhook_flow() {
635666
let path = http_request.path.clone();
636667
let query = http_request.query.clone();
637668
let request_headers = http_request.headers.clone();
638-
let request_body =
639-
String::from_utf8_lossy(&http_request.body).to_string();
669+
let request_body = String::from_utf8_lossy(&http_request.body).to_string();
640670

641671
// Forward request to local endpoint
642672
let url = format!("{}{}", local_endpoint.trim_end_matches('/'), path);
@@ -711,13 +741,13 @@ async fn test_mcp_full_webhook_flow() {
711741
content.contains("Connected"),
712742
"Should show connected status"
713743
);
714-
assert!(
715-
content.contains(&endpoint),
716-
"Should show webhook endpoint"
717-
);
744+
assert!(content.contains(&endpoint), "Should show webhook endpoint");
718745

719746
// 8. Send a webhook to the relay server
720-
let route = endpoint.rsplit('/').next().expect("Invalid endpoint format");
747+
let route = endpoint
748+
.rsplit('/')
749+
.next()
750+
.expect("Invalid endpoint format");
721751
let webhook_url = format!("{}/{}/test-path", http_addr, route);
722752
let webhook_body = r#"{"event": "mcp_test", "data": "hello"}"#;
723753

0 commit comments

Comments
 (0)