11import os
22
33import pytest
4+ from azure .ai .inference import ChatCompletionsClient
45from openai import AzureOpenAI , OpenAI
56
67from codemodder .context import CodemodExecutionContext as Context
@@ -90,7 +91,7 @@ def test_failed_dependency_description(self, mocker):
9091 in description
9192 )
9293
93- def test_setup_llm_client_no_env_vars (self , mocker ):
94+ def test_setup_llm_clients_no_env_vars (self , mocker ):
9495 mocker .patch .dict (os .environ , clear = True )
9596 context = Context (
9697 mocker .Mock (),
@@ -102,7 +103,8 @@ def test_setup_llm_client_no_env_vars(self, mocker):
102103 [],
103104 [],
104105 )
105- assert context .llm_client is None
106+ assert context .openai_llm_client is None
107+ assert context .azure_llama_llm_client is None
106108
107109 def test_setup_openai_llm_client (self , mocker ):
108110 mocker .patch .dict (os .environ , {"CODEMODDER_OPENAI_API_KEY" : "test" })
@@ -116,7 +118,29 @@ def test_setup_openai_llm_client(self, mocker):
116118 [],
117119 [],
118120 )
119- assert isinstance (context .llm_client , OpenAI )
121+ assert isinstance (context .openai_llm_client , OpenAI )
122+
123+ def test_setup_both_llm_clients (self , mocker ):
124+ mocker .patch .dict (
125+ os .environ ,
126+ {
127+ "CODEMODDER_OPENAI_API_KEY" : "test" ,
128+ "CODEMODDER_AZURE_LLAMA_API_KEY" : "test" ,
129+ "CODEMODDER_AZURE_LLAMA_ENDPOINT" : "test" ,
130+ },
131+ )
132+ context = Context (
133+ mocker .Mock (),
134+ True ,
135+ False ,
136+ load_registered_codemods (),
137+ None ,
138+ PythonRepoManager (mocker .Mock ()),
139+ [],
140+ [],
141+ )
142+ assert isinstance (context .openai_llm_client , OpenAI )
143+ assert isinstance (context .azure_llama_llm_client , ChatCompletionsClient )
120144
121145 def test_setup_azure_llm_client (self , mocker ):
122146 mocker .patch .dict (
@@ -136,8 +160,10 @@ def test_setup_azure_llm_client(self, mocker):
136160 [],
137161 [],
138162 )
139- assert isinstance (context .llm_client , AzureOpenAI )
140- assert context .llm_client ._api_version == DEFAULT_AZURE_OPENAI_API_VERSION
163+ assert isinstance (context .openai_llm_client , AzureOpenAI )
164+ assert (
165+ context .openai_llm_client ._api_version == DEFAULT_AZURE_OPENAI_API_VERSION
166+ )
141167
142168 @pytest .mark .parametrize (
143169 "env_var" ,
@@ -157,6 +183,44 @@ def test_setup_azure_llm_client_missing_one(self, mocker, env_var):
157183 [],
158184 )
159185
186+ def test_setup_azure_llama_llm_client (self , mocker ):
187+ mocker .patch .dict (
188+ os .environ ,
189+ {
190+ "CODEMODDER_AZURE_LLAMA_API_KEY" : "test" ,
191+ "CODEMODDER_AZURE_LLAMA_ENDPOINT" : "test" ,
192+ },
193+ )
194+ context = Context (
195+ mocker .Mock (),
196+ True ,
197+ False ,
198+ load_registered_codemods (),
199+ None ,
200+ PythonRepoManager (mocker .Mock ()),
201+ [],
202+ [],
203+ )
204+ assert isinstance (context .azure_llama_llm_client , ChatCompletionsClient )
205+
206+ @pytest .mark .parametrize (
207+ "env_var" ,
208+ ["CODEMODDER_AZURE_LLAMA_API_KEY" , "CODEMODDER_AZURE_LLAMA_ENDPOINT" ],
209+ )
210+ def test_setup_azure_llama_llm_client_missing_one (self , mocker , env_var ):
211+ mocker .patch .dict (os .environ , {env_var : "test" })
212+ with pytest .raises (MisconfiguredAIClient ):
213+ Context (
214+ mocker .Mock (),
215+ True ,
216+ False ,
217+ load_registered_codemods (),
218+ None ,
219+ PythonRepoManager (mocker .Mock ()),
220+ [],
221+ [],
222+ )
223+
160224 def test_get_api_version_from_env (self , mocker ):
161225 version = "fake-version"
162226 mocker .patch .dict (
@@ -177,5 +241,5 @@ def test_get_api_version_from_env(self, mocker):
177241 [],
178242 [],
179243 )
180- assert isinstance (context .llm_client , AzureOpenAI )
181- assert context .llm_client ._api_version == version
244+ assert isinstance (context .openai_llm_client , AzureOpenAI )
245+ assert context .openai_llm_client ._api_version == version
0 commit comments