Skip to content

Commit 199d696

Browse files
authored
Enable library callers to disable AI clients (#907)
1 parent 65129f7 commit 199d696

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed

src/codemodder/codemodder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def run(
135135
original_cli_args: list[str] | None = None,
136136
codemod_registry: registry.CodemodRegistry | None = None,
137137
sast_only: bool = False,
138+
ai_client: bool = True,
138139
) -> tuple[CodeTF | None, int]:
139140
start = datetime.datetime.now()
140141

@@ -173,6 +174,7 @@ def run(
173174
path_exclude,
174175
tool_result_files_map,
175176
max_workers,
177+
ai_client,
176178
)
177179
except MisconfiguredAIClient as e:
178180
logger.error(e)

src/codemodder/context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
path_exclude: list[str] | None = None,
6767
tool_result_files_map: dict[str, list[Path]] | None = None,
6868
max_workers: int = 1,
69+
ai_client: bool = True,
6970
):
7071
self.directory = directory
7172
self.dry_run = dry_run
@@ -84,8 +85,10 @@ def __init__(
8485
self.max_workers = max_workers
8586
self.tool_result_files_map = tool_result_files_map or {}
8687
self.semgrep_prefilter_results = None
87-
self.openai_llm_client = setup_openai_llm_client()
88-
self.azure_llama_llm_client = setup_azure_llama_llm_client()
88+
self.openai_llm_client = setup_openai_llm_client() if ai_client else None
89+
self.azure_llama_llm_client = (
90+
setup_azure_llama_llm_client() if ai_client else None
91+
)
8992

9093
def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]):
9194
self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets)

tests/test_context.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,58 @@ def test_get_api_version_from_env(self, mocker):
243243
)
244244
assert isinstance(context.openai_llm_client, AzureOpenAI)
245245
assert context.openai_llm_client._api_version == version
246+
247+
def test_disable_ai_client_openai(self, mocker):
248+
mocker.patch.dict(os.environ, {"CODEMODDER_OPENAI_API_KEY": "test"})
249+
context = Context(
250+
mocker.Mock(),
251+
True,
252+
False,
253+
load_registered_codemods(),
254+
None,
255+
PythonRepoManager(mocker.Mock()),
256+
[],
257+
[],
258+
ai_client=False,
259+
)
260+
assert context.openai_llm_client is None
261+
262+
def test_disable_ai_client_azure(self, mocker):
263+
mocker.patch.dict(
264+
os.environ,
265+
{
266+
"CODEMODDER_AZURE_OPENAI_API_KEY": "test",
267+
"CODEMODDER_AZURE_OPENAI_ENDPOINT": "test",
268+
},
269+
)
270+
context = Context(
271+
mocker.Mock(),
272+
True,
273+
False,
274+
load_registered_codemods(),
275+
None,
276+
PythonRepoManager(mocker.Mock()),
277+
[],
278+
[],
279+
ai_client=False,
280+
)
281+
assert context.openai_llm_client is None
282+
283+
@pytest.mark.parametrize(
284+
"env_var",
285+
["CODEMODDER_AZURE_OPENAI_API_KEY", "CODEMODDER_AZURE_OPENAI_ENDPOINT"],
286+
)
287+
def test_no_misconfiguration_ai_client_disabled(self, mocker, env_var):
288+
mocker.patch.dict(os.environ, {env_var: "test"})
289+
context = Context(
290+
mocker.Mock(),
291+
True,
292+
False,
293+
load_registered_codemods(),
294+
None,
295+
PythonRepoManager(mocker.Mock()),
296+
[],
297+
[],
298+
ai_client=False,
299+
)
300+
assert context.openai_llm_client is None

0 commit comments

Comments
 (0)