Skip to content

Commit 13465f0

Browse files
authored
fix: openai env var load after init and before score also (#316)
1 parent 7ffc2f0 commit 13465f0

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

src/ragas/embeddings/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ def __init__(self, api_key: str = NO_KEY):
4040

4141
def validate_api_key(self):
4242
if self.openai_api_key == NO_KEY:
43-
raise OpenAIKeyNotFound
43+
os_env_key = os.getenv("OPENAI_API_KEY", NO_KEY)
44+
if os_env_key != NO_KEY:
45+
self.api_key = os_env_key
46+
else:
47+
raise OpenAIKeyNotFound
4448

4549

4650
class AzureOpenAIEmbeddings(BaseAzureOpenAIEmbeddings, RagasEmbeddings):
@@ -73,7 +77,11 @@ def __init__(
7377

7478
def validate_api_key(self):
7579
if self.openai_api_key == NO_KEY:
76-
raise AzureOpenAIKeyNotFound
80+
os_env_key = os.getenv("AZURE_OPENAI_API_KEY", NO_KEY)
81+
if os_env_key != NO_KEY:
82+
self.api_key = os_env_key
83+
else:
84+
raise AzureOpenAIKeyNotFound
7785

7886

7987
@dataclass

src/ragas/llms/openai.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,11 @@ def validate_api_key(self):
196196
if api_key != NO_KEY:
197197
self._client.api_key = api_key
198198
if self.llm.api_key == NO_KEY:
199-
raise OpenAIKeyNotFound
199+
os_env_key = os.getenv(self._api_key_env_var, NO_KEY)
200+
if os_env_key != NO_KEY:
201+
self.api_key = os_env_key
202+
else:
203+
raise OpenAIKeyNotFound
200204

201205

202206
@dataclass
@@ -221,4 +225,8 @@ def _client_init(self):
221225

222226
def validate_api_key(self):
223227
if self.llm.api_key == NO_KEY:
224-
raise AzureOpenAIKeyNotFound
228+
os_env_key = os.getenv(self._api_key_env_var, NO_KEY)
229+
if os_env_key != NO_KEY:
230+
self.api_key = os_env_key
231+
else:
232+
raise AzureOpenAIKeyNotFound

tests/unit/test_llm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,11 @@ def test_validate_api_key_for_different_llms(
136136
obj, api_key = factory(with_api_key=True)
137137
assert obj.validate_api_key
138138
assert obj.api_key == api_key
139+
140+
# assert loading key from environment variables after instantiation
141+
if environ_key in os.environ:
142+
os.environ.pop(environ_key)
143+
obj = factory(with_api_key=False)
144+
os.environ[environ_key] = "random-key-102848595"
145+
assert obj.validate_api_key() is None
146+
assert obj.api_key == "random-key-102848595"

0 commit comments

Comments
 (0)