@@ -313,6 +313,15 @@ def _obtain_token(http_client, managed_identity, resource):
313313 managed_identity ,
314314 resource ,
315315 )
316+ if "MSI_ENDPOINT" in os .environ and "MSI_SECRET" in os .environ :
317+ # Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py
318+ return _obtain_token_on_machine_learning (
319+ http_client ,
320+ os .environ ["MSI_ENDPOINT" ],
321+ os .environ ["MSI_SECRET" ],
322+ managed_identity ,
323+ resource ,
324+ )
316325 if "IDENTITY_ENDPOINT" in os .environ and "IMDS_ENDPOINT" in os .environ :
317326 if ManagedIdentity .is_user_assigned (managed_identity ):
318327 raise ValueError ( # Note: Azure Identity for Python raised exception too
@@ -329,6 +338,7 @@ def _obtain_token(http_client, managed_identity, resource):
329338
330339
331340def _adjust_param (params , managed_identity ):
341+ # Modify the params dict in place
332342 id_name = ManagedIdentity ._types_mapping .get (
333343 managed_identity .get (ManagedIdentity .ID_TYPE ))
334344 if id_name :
@@ -405,6 +415,36 @@ def _obtain_token_on_app_service(
405415 logger .debug ("IMDS emits unexpected payload: %s" , resp .text )
406416 raise
407417
418+ def _obtain_token_on_machine_learning (
419+ http_client , endpoint , secret , managed_identity , resource ,
420+ ):
421+ # Could not find protocol docs from https://docs.microsoft.com/en-us/azure/machine-learning
422+ # The following implementation is back ported from Azure Identity 1.15.0
423+ logger .debug ("Obtaining token via managed identity on Azure Machine Learning" )
424+ params = {"api-version" : "2017-09-01" , "resource" : resource }
425+ _adjust_param (params , managed_identity )
426+ resp = http_client .get (
427+ endpoint ,
428+ params = params ,
429+ headers = {"secret" : secret },
430+ )
431+ try :
432+ payload = json .loads (resp .text )
433+ if payload .get ("access_token" ) and payload .get ("expires_on" ):
434+ return { # Normalizing the payload into OAuth2 format
435+ "access_token" : payload ["access_token" ],
436+ "expires_in" : int (payload ["expires_on" ]) - int (time .time ()),
437+ "resource" : payload .get ("resource" ),
438+ "token_type" : payload .get ("token_type" , "Bearer" ),
439+ }
440+ return {
441+ "error" : "invalid_scope" , # TODO: To be tested
442+ "error_description" : "{}" .format (payload ),
443+ }
444+ except ValueError :
445+ logger .debug ("IMDS emits unexpected payload: %s" , resp .text )
446+ raise
447+
408448
409449def _obtain_token_on_service_fabric (
410450 http_client , endpoint , identity_header , server_thumbprint , resource ,
0 commit comments