Skip to content

Commit a66333c

Browse files
authored
Enable use of Azure OpenAI client (#592)
* Add support for Azure OpenAI client * Add support for model/deployment parameters * Handle Azure API version via environment * Add openai as test dependency
1 parent 5811e8e commit a66333c

File tree

4 files changed

+186
-21
lines changed

4 files changed

+186
-21
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ test = [
5454
"Jinja2~=3.1.2",
5555
"jsonschema~=4.22.0",
5656
"lxml>=4.9.3,<5.3.0",
57+
"openai~=1.23.0",
5758
"mock==5.1.*",
5859
"pre-commit<4",
5960
"Pyjwt~=2.8.0",

src/codemodder/codemodder.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from codemodder.codemods.api import BaseCodemod
1313
from codemodder.codemods.semgrep import SemgrepRuleDetector
1414
from codemodder.codetf import CodeTF
15-
from codemodder.context import CodemodExecutionContext
15+
from codemodder.context import CodemodExecutionContext, MisconfiguredAIClient
1616
from codemodder.dependency import Dependency
1717
from codemodder.logging import configure_logger, log_list, log_section, logger
1818
from codemodder.project_analysis.file_parsers.package_store import PackageStore
@@ -166,17 +166,22 @@ def run(original_args) -> int:
166166
tool_result_files_map["defectdojo"] = argv.defectdojo_findings_json or []
167167

168168
repo_manager = PythonRepoManager(Path(argv.directory))
169-
context = CodemodExecutionContext(
170-
Path(argv.directory),
171-
argv.dry_run,
172-
argv.verbose,
173-
codemod_registry,
174-
repo_manager,
175-
argv.path_include,
176-
argv.path_exclude,
177-
tool_result_files_map,
178-
argv.max_workers,
179-
)
169+
170+
try:
171+
context = CodemodExecutionContext(
172+
Path(argv.directory),
173+
argv.dry_run,
174+
argv.verbose,
175+
codemod_registry,
176+
repo_manager,
177+
argv.path_include,
178+
argv.path_exclude,
179+
tool_result_files_map,
180+
argv.max_workers,
181+
)
182+
except MisconfiguredAIClient as e:
183+
logger.error(e)
184+
return 3 # Codemodder instructions conflicted (according to spec)
180185

181186
repo_manager.parse_project()
182187

src/codemodder/context.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,26 @@
2323
from codemodder.utils.timer import Timer
2424

2525
try:
26-
from openai import Client
26+
from openai import AzureOpenAI, OpenAI
2727
except ImportError:
28-
Client = None
28+
OpenAI = None
29+
AzureOpenAI = None
2930

3031

3132
if TYPE_CHECKING:
32-
from openai import Client
33+
from openai import OpenAI
3334

3435
from codemodder.codemods.base_codemod import BaseCodemod
3536

3637

38+
class MisconfiguredAIClient(ValueError):
39+
pass
40+
41+
42+
MODELS = ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"]
43+
DEFAULT_AZURE_OPENAI_API_VERSION = "2024-02-01"
44+
45+
3746
class CodemodExecutionContext:
3847
_failures_by_codemod: dict[str, list[Path]] = {}
3948
_dependency_update_by_codemod: dict[str, PackageStore | None] = {}
@@ -49,7 +58,7 @@ class CodemodExecutionContext:
4958
path_exclude: list[str]
5059
max_workers: int = 1
5160
tool_result_files_map: dict[str, list[str]]
52-
llm_client: Client | None = None
61+
llm_client: OpenAI | None = None
5362

5463
def __init__(
5564
self,
@@ -80,16 +89,39 @@ def __init__(
8089
self.tool_result_files_map = tool_result_files_map or {}
8190
self.llm_client = self._setup_llm_client()
8291

83-
def _setup_llm_client(self) -> Client | None:
84-
if not Client:
85-
logger.debug("OpenAI API client not available")
92+
def _setup_llm_client(self) -> OpenAI | None:
93+
if not AzureOpenAI:
94+
logger.info("Azure OpenAI API client not available")
95+
return None
96+
97+
azure_openapi_key = os.getenv("CODEMODDER_AZURE_OPENAI_API_KEY")
98+
azure_openapi_endpoint = os.getenv("CODEMODDER_AZURE_OPENAI_ENDPOINT")
99+
if bool(azure_openapi_key) ^ bool(azure_openapi_endpoint):
100+
raise MisconfiguredAIClient(
101+
"Azure OpenAI API key and endpoint must both be set or unset"
102+
)
103+
104+
if azure_openapi_key and azure_openapi_endpoint:
105+
logger.info("Using Azure OpenAI API client")
106+
return AzureOpenAI(
107+
api_key=azure_openapi_key,
108+
api_version=os.getenv(
109+
"CODEMODDER_AZURE_OPENAI_API_VERSION",
110+
DEFAULT_AZURE_OPENAI_API_VERSION,
111+
),
112+
azure_endpoint=azure_openapi_endpoint,
113+
)
114+
115+
if not OpenAI:
116+
logger.info("OpenAI API client not available")
86117
return None
87118

88119
if not (api_key := os.getenv("CODEMODDER_OPENAI_API_KEY")):
89-
logger.debug("OpenAI API key not found")
120+
logger.info("OpenAI API key not found")
90121
return None
91122

92-
return Client(api_key=api_key)
123+
logger.info("Using OpenAI API client")
124+
return OpenAI(api_key=api_key)
93125

94126
def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]):
95127
self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets)
@@ -212,3 +244,8 @@ def log_changes(self, codemod_id: str):
212244
for change in changes:
213245
logger.info(" - %s", change.path)
214246
logger.debug(" diff:\n%s", indent(change.diff, " " * 6))
247+
248+
def __getattribute__(self, attr: str):
249+
if (name := attr.replace("_", "-")) in MODELS:
250+
return os.getenv(f"CODEMODDER_AZURE_OPENAI_{name.upper()}_DEPLOYMENT", name)
251+
return super().__getattribute__(attr)

tests/test_context.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import os
2+
13
import pytest
4+
from openai import AzureOpenAI, OpenAI
25

6+
from codemodder.context import DEFAULT_AZURE_OPENAI_API_VERSION
37
from codemodder.context import CodemodExecutionContext as Context
8+
from codemodder.context import MisconfiguredAIClient
49
from codemodder.dependency import Security
510
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
611
from codemodder.registry import load_registered_codemods
@@ -77,3 +82,120 @@ def test_failed_dependency_description(self, mocker):
7782
```"""
7883
in description
7984
)
85+
86+
def test_setup_llm_client_no_env_vars(self, mocker):
87+
mocker.patch.dict(os.environ, clear=True)
88+
context = Context(
89+
mocker.Mock(),
90+
True,
91+
False,
92+
load_registered_codemods(),
93+
PythonRepoManager(mocker.Mock()),
94+
[],
95+
[],
96+
)
97+
assert context.llm_client is None
98+
99+
def test_setup_openai_llm_client(self, mocker):
100+
mocker.patch.dict(os.environ, {"CODEMODDER_OPENAI_API_KEY": "test"})
101+
context = Context(
102+
mocker.Mock(),
103+
True,
104+
False,
105+
load_registered_codemods(),
106+
PythonRepoManager(mocker.Mock()),
107+
[],
108+
[],
109+
)
110+
assert isinstance(context.llm_client, OpenAI)
111+
112+
def test_setup_azure_llm_client(self, mocker):
113+
mocker.patch.dict(
114+
os.environ,
115+
{
116+
"CODEMODDER_AZURE_OPENAI_API_KEY": "test",
117+
"CODEMODDER_AZURE_OPENAI_ENDPOINT": "test",
118+
},
119+
)
120+
context = Context(
121+
mocker.Mock(),
122+
True,
123+
False,
124+
load_registered_codemods(),
125+
PythonRepoManager(mocker.Mock()),
126+
[],
127+
[],
128+
)
129+
assert isinstance(context.llm_client, AzureOpenAI)
130+
assert context.llm_client._api_version == DEFAULT_AZURE_OPENAI_API_VERSION
131+
132+
@pytest.mark.parametrize(
133+
"env_var",
134+
["CODEMODDER_AZURE_OPENAI_API_KEY", "CODEMODDER_AZURE_OPENAI_ENDPOINT"],
135+
)
136+
def test_setup_azure_llm_client_missing_one(self, mocker, env_var):
137+
mocker.patch.dict(os.environ, {env_var: "test"})
138+
with pytest.raises(MisconfiguredAIClient):
139+
Context(
140+
mocker.Mock(),
141+
True,
142+
False,
143+
load_registered_codemods(),
144+
PythonRepoManager(mocker.Mock()),
145+
[],
146+
[],
147+
)
148+
149+
def test_get_model_name(self, mocker):
150+
context = Context(
151+
mocker.Mock(),
152+
True,
153+
False,
154+
load_registered_codemods(),
155+
PythonRepoManager(mocker.Mock()),
156+
[],
157+
[],
158+
)
159+
assert context.gpt_4_turbo_2024_04_09 == "gpt-4-turbo-2024-04-09"
160+
161+
@pytest.mark.parametrize("model", ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"])
162+
def test_model_get_name_from_env(self, mocker, model):
163+
name = "my-awesome-deployment"
164+
mocker.patch.dict(
165+
os.environ,
166+
{
167+
f"CODEMODDER_AZURE_OPENAI_{model.upper()}_DEPLOYMENT": name,
168+
},
169+
)
170+
context = Context(
171+
mocker.Mock(),
172+
True,
173+
False,
174+
load_registered_codemods(),
175+
PythonRepoManager(mocker.Mock()),
176+
[],
177+
[],
178+
)
179+
assert getattr(context, model.replace("-", "_")) == name
180+
181+
def test_get_api_version_from_env(self, mocker):
182+
version = "fake-version"
183+
mocker.patch.dict(
184+
os.environ,
185+
{
186+
"CODEMODDER_AZURE_OPENAI_API_KEY": "test",
187+
"CODEMODDER_AZURE_OPENAI_ENDPOINT": "test",
188+
"CODEMODDER_AZURE_OPENAI_API_VERSION": version,
189+
},
190+
)
191+
context = Context(
192+
mocker.Mock(),
193+
True,
194+
False,
195+
load_registered_codemods(),
196+
PythonRepoManager(mocker.Mock()),
197+
[],
198+
[],
199+
)
200+
assert isinstance(context.llm_client, AzureOpenAI)
201+
assert context.llm_client._api_version == version

0 commit comments

Comments
 (0)