Skip to content

Commit adbdf9d

Browse files
Merge pull request #14470 from BerriAI/litellm_dev_09_11_2025_p1
AzureAD Default credentials - select credential type based on environment
2 parents c4022ad + 12a1d08 commit adbdf9d

File tree

5 files changed

+116
-37
lines changed

5 files changed

+116
-37
lines changed

litellm/llms/azure/common_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ def get_azure_ad_token(
365365
azure_ad_token_provider = get_azure_ad_token_provider(azure_scope=scope)
366366
except ValueError:
367367
verbose_logger.debug("Azure AD Token Provider could not be used.")
368+
except Exception as e:
369+
verbose_logger.error(
370+
f"Error calling Azure AD token provider: {str(e)}. Follow docs - https://docs.litellm.ai/docs/providers/azure/#azure-ad-token-refresh---defaultazurecredential"
371+
)
372+
raise e
368373

369374
#########################################################
370375
# If litellm.enable_azure_ad_token_refresh is True and no other token provider is available,
@@ -561,7 +566,9 @@ def initialize_azure_sdk_client(
561566
"Using Azure AD token provider based on Service Principal with Secret workflow for Azure Auth"
562567
)
563568
try:
564-
azure_ad_token_provider = get_azure_ad_token_provider(azure_scope=scope)
569+
azure_ad_token_provider = get_azure_ad_token_provider(
570+
azure_scope=scope,
571+
)
565572
except ValueError:
566573
verbose_logger.debug("Azure AD Token Provider could not be used.")
567574
if api_version is None:

litellm/secret_managers/get_azure_ad_token_provider.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,36 @@
11
import os
22
from typing import Any, Callable, Optional, Union
33

4+
from litellm._logging import verbose_logger
45
from litellm.types.secret_managers.get_azure_ad_token_provider import (
56
AzureCredentialType,
67
)
78

89

10+
def infer_credential_type_from_environment() -> AzureCredentialType:
11+
if (
12+
os.environ.get("AZURE_CLIENT_ID")
13+
and os.environ.get("AZURE_CLIENT_SECRET")
14+
and os.environ.get("AZURE_TENANT_ID")
15+
):
16+
return AzureCredentialType.ClientSecretCredential
17+
elif os.environ.get("AZURE_CLIENT_ID"):
18+
return AzureCredentialType.ManagedIdentityCredential
19+
elif (
20+
os.environ.get("AZURE_CLIENT_ID")
21+
and os.environ.get("AZURE_TENANT_ID")
22+
and os.environ.get("AZURE_CERTIFICATE_PATH")
23+
and os.environ.get("AZURE_CERTIFICATE_PASSWORD")
24+
):
25+
return AzureCredentialType.CertificateCredential
26+
elif os.environ.get("AZURE_CERTIFICATE_PASSWORD"):
27+
return AzureCredentialType.CertificateCredential
28+
elif os.environ.get("AZURE_CERTIFICATE_PATH"):
29+
return AzureCredentialType.CertificateCredential
30+
else:
31+
return AzureCredentialType.DefaultAzureCredential
32+
33+
934
def get_azure_ad_token_provider(
1035
azure_scope: Optional[str] = None,
1136
azure_credential: Optional[AzureCredentialType] = None,
@@ -42,9 +67,14 @@ def get_azure_ad_token_provider(
4267
)
4368

4469
cred: str = (
45-
azure_credential.value if azure_credential else None
46-
or os.environ.get("AZURE_CREDENTIAL", AzureCredentialType.ClientSecretCredential)
47-
or AzureCredentialType.ClientSecretCredential
70+
azure_credential.value
71+
if azure_credential
72+
else None
73+
or os.environ.get("AZURE_CREDENTIAL")
74+
or infer_credential_type_from_environment()
75+
)
76+
verbose_logger.info(
77+
f"For Azure AD Token Provider, choosing credential type: {cred}"
4878
)
4979
credential: Optional[
5080
Union[

tests/llm_translation/test_gemini.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,11 @@ def test_gemini_image_generation():
267267
assert len(response.choices[0].message.images) > 0
268268
assert response.choices[0].message.images[0]["image_url"] is not None
269269
assert response.choices[0].message.images[0]["image_url"]["url"] is not None
270-
assert response.choices[0].message.images[0]["image_url"]["url"].startswith("data:image/png;base64,")
270+
assert (
271+
response.choices[0]
272+
.message.images[0]["image_url"]["url"]
273+
.startswith("data:image/png;base64,")
274+
)
271275

272276

273277
def test_gemini_2_5_flash_image_preview():
@@ -772,7 +776,8 @@ def test_system_message_with_no_user_message():
772776
assert response is not None
773777

774778
assert response.choices[0].message.content is not None
775-
779+
780+
776781
def get_current_weather(location, unit="fahrenheit"):
777782
"""Get the current weather in a given location"""
778783
if "tokyo" in location.lower():
@@ -889,9 +894,9 @@ def test_gemini_reasoning_effort_minimal():
889894

890895
# Test with different Gemini models to verify model-specific mapping
891896
test_cases = [
892-
("gemini/gemini-2.5-flash", 1), # Flash: minimum 1 token
893-
("gemini/gemini-2.5-pro", 128), # Pro: minimum 128 tokens
894-
("gemini/gemini-2.5-flash-lite", 512), # Flash-Lite: minimum 512 tokens
897+
("gemini/gemini-2.5-flash", 1), # Flash: minimum 1 token
898+
("gemini/gemini-2.5-pro", 128), # Pro: minimum 128 tokens
899+
("gemini/gemini-2.5-flash-lite", 512), # Flash-Lite: minimum 512 tokens
895900
]
896901

897902
for model, expected_min_budget in test_cases:
@@ -904,24 +909,32 @@ def test_gemini_reasoning_effort_minimal():
904909
"reasoning_effort": "minimal",
905910
},
906911
)
907-
912+
908913
# Verify that the thinking config is set correctly
909914
request_body = raw_request["raw_request_body"]
910-
assert "generationConfig" in request_body, f"Model {model} should have generationConfig"
911-
915+
assert (
916+
"generationConfig" in request_body
917+
), f"Model {model} should have generationConfig"
918+
912919
generation_config = request_body["generationConfig"]
913-
assert "thinkingConfig" in generation_config, f"Model {model} should have thinkingConfig"
914-
920+
assert (
921+
"thinkingConfig" in generation_config
922+
), f"Model {model} should have thinkingConfig"
923+
915924
thinking_config = generation_config["thinkingConfig"]
916-
assert "thinkingBudget" in thinking_config, f"Model {model} should have thinkingBudget"
917-
925+
assert (
926+
"thinkingBudget" in thinking_config
927+
), f"Model {model} should have thinkingBudget"
928+
918929
actual_budget = thinking_config["thinkingBudget"]
919-
assert actual_budget == expected_min_budget, \
920-
f"Model {model} should map 'minimal' to {expected_min_budget} tokens, got {actual_budget}"
921-
930+
assert (
931+
actual_budget == expected_min_budget
932+
), f"Model {model} should map 'minimal' to {expected_min_budget} tokens, got {actual_budget}"
933+
922934
# Verify that includeThoughts is True for minimal reasoning effort
923-
assert thinking_config.get("includeThoughts", True), \
924-
f"Model {model} should have includeThoughts=True for minimal reasoning effort"
935+
assert thinking_config.get(
936+
"includeThoughts", True
937+
), f"Model {model} should have includeThoughts=True for minimal reasoning effort"
925938

926939
# Test with unknown model (should use generic fallback)
927940
try:
@@ -933,13 +946,14 @@ def test_gemini_reasoning_effort_minimal():
933946
"reasoning_effort": "minimal",
934947
},
935948
)
936-
949+
937950
request_body = raw_request["raw_request_body"]
938951
generation_config = request_body["generationConfig"]
939952
thinking_config = generation_config["thinkingConfig"]
940953
# Should use generic fallback (128 tokens)
941-
assert thinking_config["thinkingBudget"] == 128, \
942-
"Unknown model should use generic fallback of 128 tokens"
954+
assert (
955+
thinking_config["thinkingBudget"] == 128
956+
), "Unknown model should use generic fallback of 128 tokens"
943957
except Exception as e:
944958
# If return_raw_request doesn't work for unknown models, that's okay
945959
# The important part is that our known models work correctly

tests/local_testing/test_amazing_vertex_completion.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ async def test_async_vertexai_response():
397397
| litellm.vertex_text_models
398398
| litellm.vertex_code_text_models
399399
)
400-
400+
401401
test_models = random.sample(list(test_models), 1)
402402
test_models += list(litellm.vertex_language_models) # always test gemini-pro
403403
for model in test_models:
@@ -504,7 +504,6 @@ async def test_async_vertexai_streaming_response():
504504
pytest.fail(f"An exception occurred: {e}")
505505

506506

507-
508507
@pytest.mark.parametrize("load_pdf", [False]) # True,
509508
@pytest.mark.flaky(retries=3, delay=1)
510509
def test_completion_function_plus_pdf(load_pdf):
@@ -547,6 +546,7 @@ def test_completion_function_plus_pdf(load_pdf):
547546
except Exception as e:
548547
pytest.fail("Got={}".format(str(e)))
549548

549+
550550
def encode_image(image_path):
551551
import base64
552552

@@ -910,7 +910,10 @@ async def test_partner_models_httpx(model, region, sync_mode):
910910
[
911911
("vertex_ai/meta/llama-4-scout-17b-16e-instruct-maas", "us-east5"),
912912
("vertex_ai/qwen/qwen3-coder-480b-a35b-instruct-maas", "us-south1"),
913-
("vertex_ai/mistral-large-2411", "us-central1"), # critical - we had this issue: https://github.com/BerriAI/litellm/issues/13888
913+
(
914+
"vertex_ai/mistral-large-2411",
915+
"us-central1",
916+
), # critical - we had this issue: https://github.com/BerriAI/litellm/issues/13888
914917
("vertex_ai/openai/gpt-oss-20b-maas", "us-central1"),
915918
],
916919
)
@@ -3773,7 +3776,7 @@ def test_vertex_ai_gemini_audio_ogg():
37733776
@pytest.mark.asyncio
37743777
async def test_vertex_ai_deepseek():
37753778
"""Test that deepseek models use the correct v1 API endpoint instead of v1beta1."""
3776-
#load_vertex_ai_credentials()
3779+
# load_vertex_ai_credentials()
37773780
litellm._turn_on_debug()
37783781
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
37793782

@@ -3786,21 +3789,17 @@ async def test_vertex_ai_deepseek():
37863789
{
37873790
"message": {
37883791
"role": "assistant",
3789-
"content": "Hello! How can I help you today?"
3792+
"content": "Hello! How can I help you today?",
37903793
},
37913794
"index": 0,
3792-
"finish_reason": "stop"
3795+
"finish_reason": "stop",
37933796
}
37943797
],
3795-
"usage": {
3796-
"prompt_tokens": 10,
3797-
"completion_tokens": 20,
3798-
"total_tokens": 30
3799-
},
3800-
"model": "deepseek-ai/deepseek-r1-0528-maas"
3798+
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
3799+
"model": "deepseek-ai/deepseek-r1-0528-maas",
38013800
}
38023801
mock_response.status_code = 200
3803-
3802+
38043803
with patch.object(client, "post", return_value=mock_response) as mock_post:
38053804
response = await acompletion(
38063805
model="vertex_ai/deepseek-ai/deepseek-r1-0528-maas",

tests/test_litellm/secret_managers/test_get_azure_ad_token_provider.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,32 @@ def test_get_azure_ad_token_provider_default_azure_credential(
214214
# Test that the returned callable works
215215
token = result()
216216
assert token == "mock-certificate-token"
217+
218+
@patch.dict(os.environ, {}, clear=True) # Clear all environment variables
219+
@patch("azure.identity.get_bearer_token_provider")
220+
@patch("azure.identity.DefaultAzureCredential")
221+
def test_get_azure_ad_token_provider_defaults_to_default_azure_credential(
222+
self, mock_default_azure_credential, mock_get_bearer_token_provider
223+
):
224+
"""Test get_azure_ad_token_provider defaults to DefaultAzureCredential when no credentials are present."""
225+
# Mock the Azure identity credential instance
226+
mock_credential_instance = MagicMock()
227+
mock_default_azure_credential.return_value = mock_credential_instance
228+
229+
# Mock the bearer token provider
230+
mock_token_provider = MagicMock(return_value="mock-default-token")
231+
mock_get_bearer_token_provider.return_value = mock_token_provider
232+
233+
# Call the function
234+
result = get_azure_ad_token_provider()
235+
236+
# Assertions
237+
assert callable(result)
238+
mock_default_azure_credential.assert_called_once_with()
239+
mock_get_bearer_token_provider.assert_called_once_with(
240+
mock_credential_instance, "https://cognitiveservices.azure.com/.default"
241+
)
242+
243+
# Test that the returned callable works
244+
token = result()
245+
assert token == "mock-default-token"

0 commit comments

Comments
 (0)