@@ -330,6 +330,15 @@ def _obtain_token(http_client, managed_identity, resource):
330330 managed_identity ,
331331 resource ,
332332 )
333+ if "MSI_ENDPOINT" in os .environ and "MSI_SECRET" in os .environ :
334+ # Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py
335+ return _obtain_token_on_machine_learning (
336+ http_client ,
337+ os .environ ["MSI_ENDPOINT" ],
338+ os .environ ["MSI_SECRET" ],
339+ managed_identity ,
340+ resource ,
341+ )
333342 if "IDENTITY_ENDPOINT" in os .environ and "IMDS_ENDPOINT" in os .environ :
334343 if ManagedIdentity .is_user_assigned (managed_identity ):
335344 raise ManagedIdentityError ( # Note: Azure Identity for Python raised exception too
@@ -346,6 +355,7 @@ def _obtain_token(http_client, managed_identity, resource):
346355
347356
348357def _adjust_param (params , managed_identity ):
358+ # Modify the params dict in place
349359 id_name = ManagedIdentity ._types_mapping .get (
350360 managed_identity .get (ManagedIdentity .ID_TYPE ))
351361 if id_name :
@@ -422,6 +432,39 @@ def _obtain_token_on_app_service(
422432 logger .debug ("IMDS emits unexpected payload: %s" , resp .text )
423433 raise
424434
435+ def _obtain_token_on_machine_learning (
436+ http_client , endpoint , secret , managed_identity , resource ,
437+ ):
438+ # Could not find protocol docs from https://docs.microsoft.com/en-us/azure/machine-learning
439+ # The following implementation is back ported from Azure Identity 1.15.0
440+ logger .debug ("Obtaining token via managed identity on Azure Machine Learning" )
441+ params = {"api-version" : "2017-09-01" , "resource" : resource }
442+ _adjust_param (params , managed_identity )
443+ if params ["api-version" ] == "2017-09-01" and "client_id" in params :
444+ # Workaround for a known bug in Azure ML 2017 API
445+ params ["clientid" ] = params .pop ("client_id" )
446+ resp = http_client .get (
447+ endpoint ,
448+ params = params ,
449+ headers = {"secret" : secret },
450+ )
451+ try :
452+ payload = json .loads (resp .text )
453+ if payload .get ("access_token" ) and payload .get ("expires_on" ):
454+ return { # Normalizing the payload into OAuth2 format
455+ "access_token" : payload ["access_token" ],
456+ "expires_in" : int (payload ["expires_on" ]) - int (time .time ()),
457+ "resource" : payload .get ("resource" ),
458+ "token_type" : payload .get ("token_type" , "Bearer" ),
459+ }
460+ return {
461+ "error" : "invalid_scope" , # TODO: To be tested
462+ "error_description" : "{}" .format (payload ),
463+ }
464+ except json .decoder .JSONDecodeError :
465+ logger .debug ("IMDS emits unexpected payload: %s" , resp .text )
466+ raise
467+
425468
426469def _obtain_token_on_service_fabric (
427470 http_client , endpoint , identity_header , server_thumbprint , resource ,
0 commit comments