Skip to content

Commit 571e809

Browse files
authored
LLM and test refactor (#623)
* move setup llm and models to own module * limit what is imported * add differror to catch it
1 parent 88959ea commit 571e809

File tree

7 files changed

+134
-97
lines changed

7 files changed

+134
-97
lines changed

src/codemodder/codemodder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from codemodder.codemods.api import BaseCodemod
1313
from codemodder.codemods.semgrep import SemgrepRuleDetector
1414
from codemodder.codetf import CodeTF
15-
from codemodder.context import CodemodExecutionContext, MisconfiguredAIClient
15+
from codemodder.context import CodemodExecutionContext
1616
from codemodder.dependency import Dependency
17+
from codemodder.llm import MisconfiguredAIClient
1718
from codemodder.logging import configure_logger, log_list, log_section, logger
1819
from codemodder.project_analysis.file_parsers.package_store import PackageStore
1920
from codemodder.project_analysis.python_repo_manager import PythonRepoManager

src/codemodder/codemods/test/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
BaseDjangoCodemodTest,
66
BaseSASTCodemodTest,
77
BaseSemgrepCodemodTest,
8+
DiffError,
89
)

src/codemodder/codemods/test/utils.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@
1111
from codemodder.semgrep import run as semgrep_run
1212

1313

14+
class DiffError(Exception):
15+
"""Custom exception to raise when output code != expected output code."""
16+
17+
def __init__(self, expected, actual):
18+
self.expected = expected
19+
self.actual = actual
20+
21+
def __str__(self):
22+
return (
23+
f"\nExpected:\n\n{self.expected}\n does NOT match actual:\n\n{self.actual}"
24+
)
25+
26+
1427
class BaseCodemodTest:
1528
codemod: ClassVar = NotImplemented
1629

@@ -74,20 +87,25 @@ def run_and_assert(
7487
)
7588

7689
def assert_changes(self, root, file_path, input_code, expected, changes):
90+
assert os.path.relpath(file_path, root) == changes.path
91+
assert all(change.description for change in changes.changes)
92+
7793
expected_diff = create_diff(
7894
dedent(input_code).splitlines(keepends=True),
7995
dedent(expected).splitlines(keepends=True),
8096
)
81-
82-
assert expected_diff == changes.diff
83-
assert os.path.relpath(file_path, root) == changes.path
97+
try:
98+
assert expected_diff == changes.diff
99+
except AssertionError:
100+
raise DiffError(expected_diff, changes.diff)
84101

85102
with open(file_path, "r", encoding="utf-8") as tmp_file:
86103
output_code = tmp_file.read()
87104

88-
assert output_code == dedent(expected)
89-
# All changes must have non-empty descriptions
90-
assert all(change.description for change in changes.changes)
105+
try:
106+
assert output_code == (format_expected := dedent(expected))
107+
except AssertionError:
108+
raise DiffError(format_expected, output_code)
91109

92110
def run_and_assert_filepath(
93111
self,

src/codemodder/context.py

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import itertools
44
import logging
5-
import os
65
from pathlib import Path
76
from textwrap import indent
87
from typing import TYPE_CHECKING, Iterator, List
@@ -16,33 +15,19 @@
1615
build_failed_dependency_notification,
1716
)
1817
from codemodder.file_context import FileContext
18+
from codemodder.llm import setup_llm_client
1919
from codemodder.logging import log_list, logger
2020
from codemodder.project_analysis.file_parsers.package_store import PackageStore
2121
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
2222
from codemodder.registry import CodemodRegistry
2323
from codemodder.utils.timer import Timer
2424

25-
try:
26-
from openai import AzureOpenAI, OpenAI
27-
except ImportError:
28-
OpenAI = None
29-
AzureOpenAI = None
30-
31-
3225
if TYPE_CHECKING:
3326
from openai import OpenAI
3427

3528
from codemodder.codemods.base_codemod import BaseCodemod
3629

3730

38-
class MisconfiguredAIClient(ValueError):
39-
pass
40-
41-
42-
MODELS = ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"]
43-
DEFAULT_AZURE_OPENAI_API_VERSION = "2024-02-01"
44-
45-
4631
class CodemodExecutionContext:
4732
_failures_by_codemod: dict[str, list[Path]] = {}
4833
_dependency_update_by_codemod: dict[str, PackageStore | None] = {}
@@ -87,41 +72,7 @@ def __init__(
8772
self.path_exclude = path_exclude
8873
self.max_workers = max_workers
8974
self.tool_result_files_map = tool_result_files_map or {}
90-
self.llm_client = self._setup_llm_client()
91-
92-
def _setup_llm_client(self) -> OpenAI | None:
93-
if not AzureOpenAI:
94-
logger.info("Azure OpenAI API client not available")
95-
return None
96-
97-
azure_openapi_key = os.getenv("CODEMODDER_AZURE_OPENAI_API_KEY")
98-
azure_openapi_endpoint = os.getenv("CODEMODDER_AZURE_OPENAI_ENDPOINT")
99-
if bool(azure_openapi_key) ^ bool(azure_openapi_endpoint):
100-
raise MisconfiguredAIClient(
101-
"Azure OpenAI API key and endpoint must both be set or unset"
102-
)
103-
104-
if azure_openapi_key and azure_openapi_endpoint:
105-
logger.info("Using Azure OpenAI API client")
106-
return AzureOpenAI(
107-
api_key=azure_openapi_key,
108-
api_version=os.getenv(
109-
"CODEMODDER_AZURE_OPENAI_API_VERSION",
110-
DEFAULT_AZURE_OPENAI_API_VERSION,
111-
),
112-
azure_endpoint=azure_openapi_endpoint,
113-
)
114-
115-
if not OpenAI:
116-
logger.info("OpenAI API client not available")
117-
return None
118-
119-
if not (api_key := os.getenv("CODEMODDER_OPENAI_API_KEY")):
120-
logger.info("OpenAI API key not found")
121-
return None
122-
123-
logger.info("Using OpenAI API client")
124-
return OpenAI(api_key=api_key)
75+
self.llm_client = setup_llm_client()
12576

12677
def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]):
12778
self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets)
@@ -244,8 +195,3 @@ def log_changes(self, codemod_id: str):
244195
for change in changes:
245196
logger.info(" - %s", change.path)
246197
logger.debug(" diff:\n%s", indent(change.diff, " " * 6))
247-
248-
def __getattribute__(self, attr: str):
249-
if (name := attr.replace("_", "-")) in MODELS:
250-
return os.getenv(f"CODEMODDER_AZURE_OPENAI_{name.upper()}_DEPLOYMENT", name)
251-
return super().__getattribute__(attr)

src/codemodder/llm.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os
2+
from typing import TYPE_CHECKING
3+
4+
try:
5+
from openai import AzureOpenAI, OpenAI
6+
except ImportError:
7+
OpenAI = None
8+
AzureOpenAI = None
9+
10+
11+
if TYPE_CHECKING:
12+
from openai import OpenAI
13+
14+
from codemodder.logging import logger
15+
16+
__all__ = [
17+
"MODELS",
18+
"setup_llm_client",
19+
"MisconfiguredAIClient",
20+
]
21+
22+
models = ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"]
23+
DEFAULT_AZURE_OPENAI_API_VERSION = "2024-02-01"
24+
25+
26+
class ModelRegistry(dict):
27+
def __init__(self, models):
28+
super().__init__()
29+
self.models = models
30+
for model in models:
31+
attribute_name = model.replace("-", "_")
32+
self[attribute_name] = model
33+
34+
def __getattr__(self, name):
35+
if name in self:
36+
return os.getenv(
37+
f"CODEMODDER_AZURE_OPENAI_{self[name].upper()}_DEPLOYMENT", self[name]
38+
)
39+
raise AttributeError(
40+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
41+
)
42+
43+
44+
MODELS = ModelRegistry(models)
45+
46+
47+
def setup_llm_client() -> OpenAI | None:
48+
if not AzureOpenAI:
49+
logger.info("Azure OpenAI API client not available")
50+
return None
51+
52+
azure_openapi_key = os.getenv("CODEMODDER_AZURE_OPENAI_API_KEY")
53+
azure_openapi_endpoint = os.getenv("CODEMODDER_AZURE_OPENAI_ENDPOINT")
54+
if bool(azure_openapi_key) ^ bool(azure_openapi_endpoint):
55+
raise MisconfiguredAIClient(
56+
"Azure OpenAI API key and endpoint must both be set or unset"
57+
)
58+
59+
if azure_openapi_key and azure_openapi_endpoint:
60+
logger.info("Using Azure OpenAI API client")
61+
return AzureOpenAI(
62+
api_key=azure_openapi_key,
63+
api_version=os.getenv(
64+
"CODEMODDER_AZURE_OPENAI_API_VERSION",
65+
DEFAULT_AZURE_OPENAI_API_VERSION,
66+
),
67+
azure_endpoint=azure_openapi_endpoint,
68+
)
69+
70+
if not OpenAI:
71+
logger.info("OpenAI API client not available")
72+
return None
73+
74+
if not (api_key := os.getenv("CODEMODDER_OPENAI_API_KEY")):
75+
logger.info("OpenAI API key not found")
76+
return None
77+
78+
logger.info("Using OpenAI API client")
79+
return OpenAI(api_key=api_key)
80+
81+
82+
class MisconfiguredAIClient(ValueError):
83+
pass

tests/test_context.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
import pytest
44
from openai import AzureOpenAI, OpenAI
55

6-
from codemodder.context import DEFAULT_AZURE_OPENAI_API_VERSION
76
from codemodder.context import CodemodExecutionContext as Context
8-
from codemodder.context import MisconfiguredAIClient
97
from codemodder.dependency import Security
8+
from codemodder.llm import DEFAULT_AZURE_OPENAI_API_VERSION, MisconfiguredAIClient
109
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
1110
from codemodder.registry import load_registered_codemods
1211

@@ -146,38 +145,6 @@ def test_setup_azure_llm_client_missing_one(self, mocker, env_var):
146145
[],
147146
)
148147

149-
def test_get_model_name(self, mocker):
150-
context = Context(
151-
mocker.Mock(),
152-
True,
153-
False,
154-
load_registered_codemods(),
155-
PythonRepoManager(mocker.Mock()),
156-
[],
157-
[],
158-
)
159-
assert context.gpt_4_turbo_2024_04_09 == "gpt-4-turbo-2024-04-09"
160-
161-
@pytest.mark.parametrize("model", ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"])
162-
def test_model_get_name_from_env(self, mocker, model):
163-
name = "my-awesome-deployment"
164-
mocker.patch.dict(
165-
os.environ,
166-
{
167-
f"CODEMODDER_AZURE_OPENAI_{model.upper()}_DEPLOYMENT": name,
168-
},
169-
)
170-
context = Context(
171-
mocker.Mock(),
172-
True,
173-
False,
174-
load_registered_codemods(),
175-
PythonRepoManager(mocker.Mock()),
176-
[],
177-
[],
178-
)
179-
assert getattr(context, model.replace("-", "_")) == name
180-
181148
def test_get_api_version_from_env(self, mocker):
182149
version = "fake-version"
183150
mocker.patch.dict(

tests/test_llm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
3+
import pytest
4+
5+
from codemodder.llm import MODELS
6+
7+
8+
class TestModels:
9+
def test_get_model_name(self):
10+
assert MODELS.gpt_4_turbo_2024_04_09 == "gpt-4-turbo-2024-04-09"
11+
12+
@pytest.mark.parametrize("model", ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"])
13+
def test_model_get_name_from_env(self, mocker, model):
14+
name = "my-awesome-deployment"
15+
mocker.patch.dict(
16+
os.environ,
17+
{
18+
f"CODEMODDER_AZURE_OPENAI_{model.upper()}_DEPLOYMENT": name,
19+
},
20+
)
21+
assert getattr(MODELS, model.replace("-", "_")) == name

0 commit comments

Comments
 (0)