2
2
3
3
use base64:: { engine:: general_purpose, Engine as _} ;
4
4
use reqwest:: {
5
- header:: { HeaderMap , HeaderValue , ACCEPT , USER_AGENT } ,
5
+ header:: { HeaderMap , HeaderValue , ACCEPT , AUTHORIZATION , USER_AGENT } ,
6
6
Client ,
7
7
} ;
8
8
use serde:: Deserialize ;
@@ -18,6 +18,10 @@ use std::fs;
18
18
use std:: path:: PathBuf ;
19
19
use std:: sync:: RwLock ;
20
20
21
+ // Azure Identity imports for MSI authentication
22
+ use azure_core:: credentials:: TokenCredential ;
23
+ use azure_identity:: { ManagedIdentityCredential , ManagedIdentityCredentialOptions , UserAssignedId } ;
24
+
21
25
/// Authentication methods for the Geneva Config Client.
22
26
///
23
27
/// The client supports two authentication methods:
@@ -53,25 +57,29 @@ pub enum AuthMethod {
53
57
/// * `path` - Path to the PKCS#12 (.p12) certificate file
54
58
/// * `password` - Password to decrypt the PKCS#12 file
55
59
Certificate { path : PathBuf , password : String } ,
56
- /// Azure Managed Identity authentication
57
- ///
58
- /// Note(TODO): This is not yet implemented.
59
- ManagedIdentity ,
60
+ /// System-assigned managed identity (auto-detected)
61
+ SystemManagedIdentity ,
62
+ /// User-assigned managed identity by client ID
63
+ UserManagedIdentity { client_id : String } ,
64
+ /// User-assigned managed identity by object ID
65
+ UserManagedIdentityByObjectId { object_id : String } ,
66
+ /// User-assigned managed identity by resource ID
67
+ UserManagedIdentityByResourceId { resource_id : String } ,
60
68
#[ cfg( feature = "mock_auth" ) ]
61
69
MockAuth , // No authentication, used for testing purposes
62
70
}
63
71
64
72
#[ derive( Debug , Error ) ]
65
73
pub ( crate ) enum GenevaConfigClientError {
66
74
// Authentication-related errors
67
- #[ error( "Authentication method not implemented: {0}" ) ]
68
- AuthMethodNotImplemented ( String ) ,
69
75
#[ error( "Missing Auth Info: {0}" ) ]
70
76
AuthInfoNotFound ( String ) ,
71
77
#[ error( "Invalid or malformed JWT token: {0}" ) ]
72
78
JwtTokenError ( String ) ,
73
79
#[ error( "Certificate error: {0}" ) ]
74
80
Certificate ( String ) ,
81
+ #[ error( "MSI authentication error: {0}" ) ]
82
+ MsiAuth ( String ) ,
75
83
76
84
// Networking / HTTP / TLS
77
85
#[ error( "HTTP error: {0}" ) ]
@@ -129,6 +137,7 @@ pub(crate) struct GenevaConfigClientConfig {
129
137
pub ( crate ) region : String ,
130
138
pub ( crate ) config_major_version : u32 ,
131
139
pub ( crate ) auth_method : AuthMethod , // agent_identity and agent_version are hardcoded for now
140
+ pub ( crate ) msi_resource : Option < String > , // Required when using any Managed Identity variant
132
141
}
133
142
134
143
#[ allow( dead_code) ]
@@ -246,10 +255,10 @@ impl GenevaConfigClient {
246
255
. map_err ( |e| GenevaConfigClientError :: Certificate ( e. to_string ( ) ) ) ?;
247
256
client_builder = client_builder. use_preconfigured_tls ( tls_connector) ;
248
257
}
249
- AuthMethod :: ManagedIdentity => {
250
- return Err ( GenevaConfigClientError :: AuthMethodNotImplemented (
251
- "Managed Identity authentication is not implemented yet" . into ( ) ,
252
- ) ) ;
258
+ AuthMethod :: SystemManagedIdentity
259
+ | AuthMethod :: UserManagedIdentity { .. }
260
+ | AuthMethod :: UserManagedIdentityByObjectId { .. }
261
+ | AuthMethod :: UserManagedIdentityByResourceId { .. } => { /* no special HTTP client changes needed */
253
262
}
254
263
#[ cfg( feature = "mock_auth" ) ]
255
264
AuthMethod :: MockAuth => {
@@ -268,11 +277,24 @@ impl GenevaConfigClient {
268
277
let encoded_identity = general_purpose:: STANDARD . encode ( & identity) ;
269
278
let version_str = format ! ( "Ver{0}v0" , config. config_major_version) ;
270
279
280
+ // Use different API endpoints based on authentication method
281
+ // Certificate auth uses "api", MSI auth uses "userapi"
282
+ let api_path = match & config. auth_method {
283
+ AuthMethod :: Certificate { .. } => "api" ,
284
+ AuthMethod :: SystemManagedIdentity
285
+ | AuthMethod :: UserManagedIdentity { .. }
286
+ | AuthMethod :: UserManagedIdentityByObjectId { .. }
287
+ | AuthMethod :: UserManagedIdentityByResourceId { .. } => "userapi" ,
288
+ #[ cfg( feature = "mock_auth" ) ]
289
+ AuthMethod :: MockAuth => "api" , // treat mock like certificate path for URL shape
290
+ } ;
291
+
271
292
let mut pre_url = String :: with_capacity ( config. endpoint . len ( ) + 200 ) ;
272
293
write ! (
273
294
& mut pre_url,
274
- "{}/api /agent/v3/{}/{}/MonitoringStorageKeys/?Namespace={}&Region={}&Identity={}&OSType={}&ConfigMajorVersion={}" ,
295
+ "{}/{} /agent/v3/{}/{}/MonitoringStorageKeys/?Namespace={}&Region={}&Identity={}&OSType={}&ConfigMajorVersion={}" ,
275
296
config. endpoint. trim_end_matches( '/' ) ,
297
+ api_path,
276
298
config. environment,
277
299
config. account,
278
300
config. namespace,
@@ -310,6 +332,66 @@ impl GenevaConfigClient {
310
332
headers
311
333
}
312
334
335
+ /// Get MSI token for GCS authentication
336
+ async fn get_msi_token ( & self ) -> Result < String > {
337
+ let resource = self . config . msi_resource . as_ref ( ) . ok_or_else ( || {
338
+ GenevaConfigClientError :: MsiAuth (
339
+ "msi_resource not set in config (required for Managed Identity auth)" . to_string ( ) ,
340
+ )
341
+ } ) ?;
342
+
343
+ // Normalize resource (strip trailing "/.default" if provided by user)
344
+ let base = resource. trim_end_matches ( "/.default" ) . trim_end_matches ( '/' ) ;
345
+
346
+ // Candidate scopes tried with Azure Identity
347
+ let mut scope_candidates: Vec < String > = vec ! [ format!( "{base}/.default" ) , base. to_string( ) ] ;
348
+ // Add variant with trailing slash if not already present
349
+ if !base. ends_with ( '/' ) {
350
+ scope_candidates. push ( format ! ( "{base}/" ) ) ;
351
+ }
352
+
353
+ // Build credential based on selector
354
+ let user_assigned_id = match & self . config . auth_method {
355
+ AuthMethod :: SystemManagedIdentity => None ,
356
+ AuthMethod :: UserManagedIdentity { client_id } => {
357
+ Some ( UserAssignedId :: ClientId ( client_id. clone ( ) ) )
358
+ }
359
+ AuthMethod :: UserManagedIdentityByObjectId { object_id } => {
360
+ Some ( UserAssignedId :: ObjectId ( object_id. clone ( ) ) )
361
+ }
362
+ AuthMethod :: UserManagedIdentityByResourceId { resource_id } => {
363
+ Some ( UserAssignedId :: ResourceId ( resource_id. clone ( ) ) )
364
+ }
365
+ _ => {
366
+ return Err ( GenevaConfigClientError :: MsiAuth (
367
+ "get_msi_token called but auth method is not a managed identity variant"
368
+ . to_string ( ) ,
369
+ ) )
370
+ }
371
+ } ;
372
+
373
+ let options = ManagedIdentityCredentialOptions {
374
+ user_assigned_id,
375
+ ..Default :: default ( )
376
+ } ;
377
+ let credential = ManagedIdentityCredential :: new ( Some ( options) ) . map_err ( |e| {
378
+ GenevaConfigClientError :: MsiAuth ( format ! ( "Failed to create MSI credential: {e}" ) )
379
+ } ) ?;
380
+
381
+ let mut last_err: Option < String > = None ;
382
+ for scope in & scope_candidates {
383
+ match credential. get_token ( & [ scope. as_str ( ) ] , None ) . await {
384
+ Ok ( token) => return Ok ( token. token . secret ( ) . to_string ( ) ) ,
385
+ Err ( e) => last_err = Some ( e. to_string ( ) ) ,
386
+ }
387
+ }
388
+ let detail = last_err. unwrap_or_else ( || "no error detail" . into ( ) ) ;
389
+ Err ( GenevaConfigClientError :: MsiAuth ( format ! (
390
+ "Managed Identity token acquisition failed. Scopes tried: {scopes}. Last error: {detail}. IMDS fallback intentionally disabled." ,
391
+ scopes = scope_candidates. join( ", " )
392
+ ) ) )
393
+ }
394
+
313
395
/// Retrieves ingestion gateway information from the Geneva Config Service.
314
396
///
315
397
/// # HTTP API Details
@@ -381,7 +463,16 @@ impl GenevaConfigClient {
381
463
GenevaConfigClientError :: InternalError ( "Failed to parse token expiry" . into ( ) )
382
464
} ) ?;
383
465
384
- let token_endpoint = extract_endpoint_from_token ( & fresh_ingestion_gateway_info. auth_token ) ?;
466
+ let token_endpoint =
467
+ match extract_endpoint_from_token ( & fresh_ingestion_gateway_info. auth_token ) {
468
+ Ok ( ep) => ep,
469
+ Err ( err) => {
470
+ // Fallback: some tokens legitimately omit the Endpoint claim; use server endpoint.
471
+ #[ cfg( debug_assertions) ]
472
+ eprintln ! ( "[geneva][debug] token Endpoint claim missing or unparsable: {err}" ) ;
473
+ fresh_ingestion_gateway_info. endpoint . clone ( )
474
+ }
475
+ } ;
385
476
386
477
// Now update the cache with exclusive write access
387
478
let mut guard = self
@@ -432,10 +523,29 @@ impl GenevaConfigClient {
432
523
. headers ( self . static_headers . clone ( ) ) ; // Clone only cheap references
433
524
434
525
request = request. header ( "x-ms-client-request-id" , req_id) ;
435
- let response = request
436
- . send ( )
437
- . await
438
- . map_err ( GenevaConfigClientError :: Http ) ?;
526
+
527
+ // Add MSI authentication for managed identity auth method
528
+ match & self . config . auth_method {
529
+ AuthMethod :: SystemManagedIdentity
530
+ | AuthMethod :: UserManagedIdentity { .. }
531
+ | AuthMethod :: UserManagedIdentityByObjectId { .. }
532
+ | AuthMethod :: UserManagedIdentityByResourceId { .. } => {
533
+ let msi_token = self . get_msi_token ( ) . await ?;
534
+ request = request. header ( AUTHORIZATION , format ! ( "Bearer {}" , msi_token) ) ;
535
+ }
536
+ AuthMethod :: Certificate { .. } => { /* mTLS only */ }
537
+ #[ cfg( feature = "mock_auth" ) ]
538
+ AuthMethod :: MockAuth => { /* no auth header */ }
539
+ }
540
+
541
+ // Log the request details for debugging
542
+ let response = match request. send ( ) . await {
543
+ Ok ( response) => response,
544
+ Err ( e) => {
545
+ return Err ( GenevaConfigClientError :: Http ( e) ) ;
546
+ }
547
+ } ;
548
+
439
549
// Check if the response is successful
440
550
let status = response. status ( ) ;
441
551
let body = response. text ( ) . await ?;
@@ -506,12 +616,18 @@ fn extract_endpoint_from_token(token: &str) -> Result<String> {
506
616
_ => payload. to_string ( ) ,
507
617
} ;
508
618
509
- // Decode the Base64-encoded payload into raw bytes
510
- let decoded = general_purpose:: URL_SAFE_NO_PAD
511
- . decode ( payload)
512
- . map_err ( |e| {
513
- GenevaConfigClientError :: JwtTokenError ( format ! ( "Failed to decode JWT: {e}" ) )
514
- } ) ?;
619
+ // Decode the Base64-encoded payload into raw bytes with a more tolerant approach.
620
+ let decoded = match general_purpose:: URL_SAFE_NO_PAD . decode ( & payload) {
621
+ Ok ( b) => b,
622
+ Err ( e_url) => match general_purpose:: STANDARD . decode ( & payload) {
623
+ Ok ( b) => b,
624
+ Err ( e_std) => {
625
+ return Err ( GenevaConfigClientError :: JwtTokenError ( format ! (
626
+ "Failed to decode JWT (url_safe and standard): url_err={e_url}; std_err={e_std}"
627
+ ) ) )
628
+ }
629
+ } ,
630
+ } ;
515
631
516
632
// Convert the raw bytes into a UTF-8 string
517
633
let decoded_str = String :: from_utf8 ( decoded) . map_err ( |e| {
@@ -522,15 +638,12 @@ fn extract_endpoint_from_token(token: &str) -> Result<String> {
522
638
let payload_json: serde_json:: Value =
523
639
serde_json:: from_str ( & decoded_str) . map_err ( GenevaConfigClientError :: SerdeJson ) ?;
524
640
525
- // Extract "Endpoint" from JWT payload as a string, or fail if missing or invalid.
526
- let endpoint = payload_json[ "Endpoint" ]
527
- . as_str ( )
528
- . ok_or_else ( || {
529
- GenevaConfigClientError :: JwtTokenError ( "No Endpoint claim in JWT token" . to_string ( ) )
530
- } ) ?
531
- . to_string ( ) ;
532
-
533
- Ok ( endpoint)
641
+ if let Some ( ep) = payload_json[ "Endpoint" ] . as_str ( ) {
642
+ return Ok ( ep. to_string ( ) ) ;
643
+ }
644
+ Err ( GenevaConfigClientError :: JwtTokenError (
645
+ "No Endpoint claim in JWT token" . to_string ( ) ,
646
+ ) )
534
647
}
535
648
536
649
#[ cfg( feature = "self_signed_certs" ) ]
0 commit comments