Skip to content

Commit 7a41a84

Browse files
authored
Add Azure Llama client support (#872)
* breaking api: rename to openai_llm_client * setup azure llama client * export func * fix pyproject.toml
1 parent c0c3563 commit 7a41a84

File tree

4 files changed

+113
-11
lines changed

4 files changed

+113
-11
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ get-hashes = 'codemodder.scripts.get_hashes:main'
4949

5050
[project.optional-dependencies]
5151
test = [
52+
"azure-ai-inference>=1.0.0b1,<2.0",
5253
"coverage>=7.6,<7.7",
5354
"coverage-threshold~=0.4",
5455
"defusedxml==0.7.1",
@@ -86,6 +87,10 @@ complexity = [
8687
openai = [
8788
"openai>=1.50,<1.52",
8889
]
90+
azure = [
91+
"azure-ai-inference>=1.0.0b1,<2.0",
92+
]
93+
8994
all = [
9095
"codemodder[test]",
9196
"codemodder[complexity]",

src/codemodder/context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
build_failed_dependency_notification,
1818
)
1919
from codemodder.file_context import FileContext
20-
from codemodder.llm import setup_llm_client
20+
from codemodder.llm import setup_azure_llama_llm_client, setup_openai_llm_client
2121
from codemodder.logging import log_list, logger
2222
from codemodder.project_analysis.file_parsers.package_store import PackageStore
2323
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
@@ -82,7 +82,8 @@ def __init__(
8282
self.max_workers = max_workers
8383
self.tool_result_files_map = tool_result_files_map or {}
8484
self.semgrep_prefilter_results = None
85-
self.llm_client = setup_llm_client()
85+
self.openai_llm_client = setup_openai_llm_client()
86+
self.azure_llama_llm_client = setup_azure_llama_llm_client()
8687

8788
def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]):
8889
self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets)

src/codemodder/llm.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,24 @@
99
OpenAI = None
1010
AzureOpenAI = None
1111

12+
try:
13+
from azure.ai.inference import ChatCompletionsClient
14+
from azure.core.credentials import AzureKeyCredential
15+
except ImportError:
16+
ChatCompletionsClient = None
17+
AzureKeyCredential = None
1218

1319
if TYPE_CHECKING:
1420
from openai import OpenAI
21+
from azure.ai.inference import ChatCompletionsClient
22+
from azure.core.credentials import AzureKeyCredential
1523

1624
from codemodder.logging import logger
1725

1826
__all__ = [
1927
"MODELS",
20-
"setup_llm_client",
28+
"setup_openai_llm_client",
29+
"setup_azure_llama_llm_client",
2130
"MisconfiguredAIClient",
2231
]
2332

@@ -46,7 +55,8 @@ def __getattr__(self, name):
4655
MODELS = ModelRegistry(models)
4756

4857

49-
def setup_llm_client() -> OpenAI | None:
58+
def setup_openai_llm_client() -> OpenAI | None:
59+
"""Configure either the Azure OpenAI LLM client or the OpenAI client, in that order."""
5060
if not AzureOpenAI:
5161
logger.info("Azure OpenAI API client not available")
5262
return None
@@ -81,5 +91,27 @@ def setup_llm_client() -> OpenAI | None:
8191
return OpenAI(api_key=api_key)
8292

8393

94+
def setup_azure_llama_llm_client() -> ChatCompletionsClient | None:
95+
"""Configure the Azure Llama LLM client."""
96+
if not ChatCompletionsClient:
97+
logger.info("Azure API client not available")
98+
return None
99+
100+
azure_llama_key = os.getenv("CODEMODDER_AZURE_LLAMA_API_KEY")
101+
azure_llama_endpoint = os.getenv("CODEMODDER_AZURE_LLAMA_ENDPOINT")
102+
if bool(azure_llama_key) ^ bool(azure_llama_endpoint):
103+
raise MisconfiguredAIClient(
104+
"Azure Llama API key and endpoint must both be set or unset"
105+
)
106+
107+
if azure_llama_key and azure_llama_endpoint:
108+
logger.info("Using Azure Llama API client")
109+
return ChatCompletionsClient(
110+
credential=AzureKeyCredential(azure_llama_key),
111+
endpoint=azure_llama_endpoint,
112+
)
113+
return None
114+
115+
84116
class MisconfiguredAIClient(ValueError):
85117
pass

tests/test_context.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
import pytest
4+
from azure.ai.inference import ChatCompletionsClient
45
from openai import AzureOpenAI, OpenAI
56

67
from codemodder.context import CodemodExecutionContext as Context
@@ -90,7 +91,7 @@ def test_failed_dependency_description(self, mocker):
9091
in description
9192
)
9293

93-
def test_setup_llm_client_no_env_vars(self, mocker):
94+
def test_setup_llm_clients_no_env_vars(self, mocker):
9495
mocker.patch.dict(os.environ, clear=True)
9596
context = Context(
9697
mocker.Mock(),
@@ -102,7 +103,8 @@ def test_setup_llm_client_no_env_vars(self, mocker):
102103
[],
103104
[],
104105
)
105-
assert context.llm_client is None
106+
assert context.openai_llm_client is None
107+
assert context.azure_llama_llm_client is None
106108

107109
def test_setup_openai_llm_client(self, mocker):
108110
mocker.patch.dict(os.environ, {"CODEMODDER_OPENAI_API_KEY": "test"})
@@ -116,7 +118,29 @@ def test_setup_openai_llm_client(self, mocker):
116118
[],
117119
[],
118120
)
119-
assert isinstance(context.llm_client, OpenAI)
121+
assert isinstance(context.openai_llm_client, OpenAI)
122+
123+
def test_setup_both_llm_clients(self, mocker):
124+
mocker.patch.dict(
125+
os.environ,
126+
{
127+
"CODEMODDER_OPENAI_API_KEY": "test",
128+
"CODEMODDER_AZURE_LLAMA_API_KEY": "test",
129+
"CODEMODDER_AZURE_LLAMA_ENDPOINT": "test",
130+
},
131+
)
132+
context = Context(
133+
mocker.Mock(),
134+
True,
135+
False,
136+
load_registered_codemods(),
137+
None,
138+
PythonRepoManager(mocker.Mock()),
139+
[],
140+
[],
141+
)
142+
assert isinstance(context.openai_llm_client, OpenAI)
143+
assert isinstance(context.azure_llama_llm_client, ChatCompletionsClient)
120144

121145
def test_setup_azure_llm_client(self, mocker):
122146
mocker.patch.dict(
@@ -136,8 +160,10 @@ def test_setup_azure_llm_client(self, mocker):
136160
[],
137161
[],
138162
)
139-
assert isinstance(context.llm_client, AzureOpenAI)
140-
assert context.llm_client._api_version == DEFAULT_AZURE_OPENAI_API_VERSION
163+
assert isinstance(context.openai_llm_client, AzureOpenAI)
164+
assert (
165+
context.openai_llm_client._api_version == DEFAULT_AZURE_OPENAI_API_VERSION
166+
)
141167

142168
@pytest.mark.parametrize(
143169
"env_var",
@@ -157,6 +183,44 @@ def test_setup_azure_llm_client_missing_one(self, mocker, env_var):
157183
[],
158184
)
159185

186+
def test_setup_azure_llama_llm_client(self, mocker):
187+
mocker.patch.dict(
188+
os.environ,
189+
{
190+
"CODEMODDER_AZURE_LLAMA_API_KEY": "test",
191+
"CODEMODDER_AZURE_LLAMA_ENDPOINT": "test",
192+
},
193+
)
194+
context = Context(
195+
mocker.Mock(),
196+
True,
197+
False,
198+
load_registered_codemods(),
199+
None,
200+
PythonRepoManager(mocker.Mock()),
201+
[],
202+
[],
203+
)
204+
assert isinstance(context.azure_llama_llm_client, ChatCompletionsClient)
205+
206+
@pytest.mark.parametrize(
207+
"env_var",
208+
["CODEMODDER_AZURE_LLAMA_API_KEY", "CODEMODDER_AZURE_LLAMA_ENDPOINT"],
209+
)
210+
def test_setup_azure_llama_llm_client_missing_one(self, mocker, env_var):
211+
mocker.patch.dict(os.environ, {env_var: "test"})
212+
with pytest.raises(MisconfiguredAIClient):
213+
Context(
214+
mocker.Mock(),
215+
True,
216+
False,
217+
load_registered_codemods(),
218+
None,
219+
PythonRepoManager(mocker.Mock()),
220+
[],
221+
[],
222+
)
223+
160224
def test_get_api_version_from_env(self, mocker):
161225
version = "fake-version"
162226
mocker.patch.dict(
@@ -177,5 +241,5 @@ def test_get_api_version_from_env(self, mocker):
177241
[],
178242
[],
179243
)
180-
assert isinstance(context.llm_client, AzureOpenAI)
181-
assert context.llm_client._api_version == version
244+
assert isinstance(context.openai_llm_client, AzureOpenAI)
245+
assert context.openai_llm_client._api_version == version

0 commit comments

Comments
 (0)