|
2 | 2 |
|
3 | 3 | import itertools |
4 | 4 | import logging |
5 | | -import os |
6 | 5 | from pathlib import Path |
7 | 6 | from textwrap import indent |
8 | 7 | from typing import TYPE_CHECKING, Iterator, List |
|
16 | 15 | build_failed_dependency_notification, |
17 | 16 | ) |
18 | 17 | from codemodder.file_context import FileContext |
| 18 | +from codemodder.llm import setup_llm_client |
19 | 19 | from codemodder.logging import log_list, logger |
20 | 20 | from codemodder.project_analysis.file_parsers.package_store import PackageStore |
21 | 21 | from codemodder.project_analysis.python_repo_manager import PythonRepoManager |
22 | 22 | from codemodder.registry import CodemodRegistry |
23 | 23 | from codemodder.utils.timer import Timer |
24 | 24 |
|
25 | | -try: |
26 | | - from openai import AzureOpenAI, OpenAI |
27 | | -except ImportError: |
28 | | - OpenAI = None |
29 | | - AzureOpenAI = None |
30 | | - |
31 | | - |
32 | 25 | if TYPE_CHECKING: |
33 | 26 | from openai import OpenAI |
34 | 27 |
|
35 | 28 | from codemodder.codemods.base_codemod import BaseCodemod |
36 | 29 |
|
37 | 30 |
|
38 | | -class MisconfiguredAIClient(ValueError): |
39 | | - pass |
40 | | - |
41 | | - |
42 | | -MODELS = ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"] |
43 | | -DEFAULT_AZURE_OPENAI_API_VERSION = "2024-02-01" |
44 | | - |
45 | | - |
46 | 31 | class CodemodExecutionContext: |
47 | 32 | _failures_by_codemod: dict[str, list[Path]] = {} |
48 | 33 | _dependency_update_by_codemod: dict[str, PackageStore | None] = {} |
@@ -87,41 +72,7 @@ def __init__( |
87 | 72 | self.path_exclude = path_exclude |
88 | 73 | self.max_workers = max_workers |
89 | 74 | self.tool_result_files_map = tool_result_files_map or {} |
90 | | - self.llm_client = self._setup_llm_client() |
91 | | - |
92 | | - def _setup_llm_client(self) -> OpenAI | None: |
93 | | - if not AzureOpenAI: |
94 | | - logger.info("Azure OpenAI API client not available") |
95 | | - return None |
96 | | - |
97 | | - azure_openapi_key = os.getenv("CODEMODDER_AZURE_OPENAI_API_KEY") |
98 | | - azure_openapi_endpoint = os.getenv("CODEMODDER_AZURE_OPENAI_ENDPOINT") |
99 | | - if bool(azure_openapi_key) ^ bool(azure_openapi_endpoint): |
100 | | - raise MisconfiguredAIClient( |
101 | | - "Azure OpenAI API key and endpoint must both be set or unset" |
102 | | - ) |
103 | | - |
104 | | - if azure_openapi_key and azure_openapi_endpoint: |
105 | | - logger.info("Using Azure OpenAI API client") |
106 | | - return AzureOpenAI( |
107 | | - api_key=azure_openapi_key, |
108 | | - api_version=os.getenv( |
109 | | - "CODEMODDER_AZURE_OPENAI_API_VERSION", |
110 | | - DEFAULT_AZURE_OPENAI_API_VERSION, |
111 | | - ), |
112 | | - azure_endpoint=azure_openapi_endpoint, |
113 | | - ) |
114 | | - |
115 | | - if not OpenAI: |
116 | | - logger.info("OpenAI API client not available") |
117 | | - return None |
118 | | - |
119 | | - if not (api_key := os.getenv("CODEMODDER_OPENAI_API_KEY")): |
120 | | - logger.info("OpenAI API key not found") |
121 | | - return None |
122 | | - |
123 | | - logger.info("Using OpenAI API client") |
124 | | - return OpenAI(api_key=api_key) |
| 75 | + self.llm_client = setup_llm_client() |
125 | 76 |
|
126 | 77 | def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]): |
127 | 78 | self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets) |
@@ -244,8 +195,3 @@ def log_changes(self, codemod_id: str): |
244 | 195 | for change in changes: |
245 | 196 | logger.info(" - %s", change.path) |
246 | 197 | logger.debug(" diff:\n%s", indent(change.diff, " " * 6)) |
247 | | - |
248 | | - def __getattribute__(self, attr: str): |
249 | | - if (name := attr.replace("_", "-")) in MODELS: |
250 | | - return os.getenv(f"CODEMODDER_AZURE_OPENAI_{name.upper()}_DEPLOYMENT", name) |
251 | | - return super().__getattribute__(attr) |
0 commit comments