Skip to content

Commit dcd4ab1

Browse files
authored
INTPYTHON-667 Support Azure OpenAI in tests (#159)
1 parent 5d157b1 commit dcd4ab1

File tree

6 files changed

+33
-17
lines changed

6 files changed

+33
-17
lines changed

libs/langchain-mongodb/tests/integration_tests/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from langchain_core.documents import Document
77
from langchain_core.embeddings import Embeddings
88
from langchain_ollama.embeddings import OllamaEmbeddings
9-
from langchain_openai import OpenAIEmbeddings
9+
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
1010
from pymongo import MongoClient
1111

1212
from ..utils import CONNECTION_STRING
@@ -34,12 +34,14 @@ def embedding() -> Embeddings:
3434
openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
3535
model="text-embedding-3-small",
3636
)
37+
if os.environ.get("AZURE_OPENAI_ENDPOINT"):
38+
return AzureOpenAIEmbeddings(model="text-embedding-3-small")
3739

3840
return OllamaEmbeddings(model="all-minilm:l6-v2")
3941

4042

4143
@pytest.fixture(scope="session")
4244
def dimensions() -> int:
43-
if os.environ.get("OPENAI_API_KEY"):
45+
if os.environ.get("OPENAI_API_KEY") or os.environ.get("AZURE_OPENAI_ENDPOINT"):
4446
return 1536
4547
return 384

libs/langchain-mongodb/tests/integration_tests/test_agent_toolkit.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import requests
66
from flaky import flaky # type:ignore[import-untyped]
7-
from langchain_openai import ChatOpenAI
7+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
88
from langgraph.prebuilt import create_react_agent
99
from pymongo import MongoClient
1010

@@ -47,20 +47,24 @@ def db(client: MongoClient) -> MongoDBDatabase:
4747

4848
@flaky(max_runs=5, min_passes=4)
4949
@pytest.mark.skipif(
50-
"OPENAI_API_KEY" not in os.environ, reason="test requires OpenAI for chat responses"
50+
"OPENAI_API_KEY" not in os.environ and "AZURE_OPENAI_ENDPOINT" not in os.environ,
51+
reason="test requires OpenAI for chat responses",
5152
)
5253
def test_toolkit_response(db):
5354
db_wrapper = MongoDBDatabase.from_connection_string(
5455
CONNECTION_STRING, database=DB_NAME
5556
)
56-
llm = ChatOpenAI(model="gpt-4o-mini", timeout=60)
57+
if "AZURE_OPENAI_ENDPOINT" in os.environ:
58+
llm = AzureChatOpenAI(model="gpt-4o-mini", timeout=60)
59+
else:
60+
llm = ChatOpenAI(model="gpt-4o-mini", timeout=60)
5761

5862
toolkit = MongoDBDatabaseToolkit(db=db_wrapper, llm=llm)
5963

60-
system_message = MONGODB_AGENT_SYSTEM_PROMPT.format(top_k=5)
64+
prompt = MONGODB_AGENT_SYSTEM_PROMPT.format(top_k=5)
6165

6266
test_query = "Which country's customers spent the most?"
63-
agent = create_react_agent(llm, toolkit.get_tools(), state_modifier=system_message)
67+
agent = create_react_agent(llm, toolkit.get_tools(), prompt=prompt)
6468
agent.step_timeout = 60
6569
events = agent.stream(
6670
{"messages": [("user", test_query)]},

libs/langchain-mongodb/tests/integration_tests/test_chain_example.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from langchain_core.output_parsers.string import StrOutputParser
1111
from langchain_core.prompts.chat import ChatPromptTemplate
1212
from langchain_core.runnables import RunnablePassthrough
13-
from langchain_openai import ChatOpenAI
13+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
14+
from langchain_openai.chat_models.base import BaseChatOpenAI
1415
from pymongo import MongoClient
1516
from pymongo.collection import Collection
1617

@@ -50,7 +51,7 @@ def collection(client: MongoClient) -> Collection:
5051

5152

5253
@pytest.mark.skipif(
53-
not os.environ.get("OPENAI_API_KEY"),
54+
not os.environ.get("OPENAI_API_KEY") and "AZURE_OPENAI_ENDPOINT" not in os.environ,
5455
reason="Requires OpenAI for chat responses.",
5556
)
5657
def test_chain(
@@ -120,7 +121,10 @@ def test_chain(
120121
"""
121122
prompt = ChatPromptTemplate.from_template(template)
122123

123-
model = ChatOpenAI()
124+
if "AZURE_OPENAI_ENDPOINT" in os.environ:
125+
model: BaseChatOpenAI = AzureChatOpenAI(model="o4-mini")
126+
else:
127+
model = ChatOpenAI()
124128

125129
chain = (
126130
{"context": retriever, "question": RunnablePassthrough()} # type: ignore

libs/langchain-mongodb/tests/integration_tests/test_retriever_selfquerying.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from langchain.chains.query_constructor.schema import AttributeInfo
88
from langchain.retrievers.self_query.base import SelfQueryRetriever
99
from langchain_core.documents import Document
10-
from langchain_openai import ChatOpenAI
10+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
11+
from langchain_openai.chat_models.base import BaseChatOpenAI
1112

1213
from langchain_mongodb import MongoDBAtlasVectorSearch, index
1314
from langchain_mongodb.retrievers import MongoDBAtlasSelfQueryRetriever
@@ -17,7 +18,7 @@
1718
COLLECTION_NAME = "test_self_querying_retriever"
1819
TIMEOUT = 120
1920

20-
if "OPENAI_API_KEY" not in os.environ:
21+
if "OPENAI_API_KEY" not in os.environ and "AZURE_OPENAI_ENDPOINT" not in os.environ:
2122
pytest.skip("Requires OpenAI for chat responses.", allow_module_level=True)
2223

2324

@@ -161,8 +162,10 @@ def vectorstore(
161162

162163

163164
@pytest.fixture
164-
def llm() -> ChatOpenAI:
165+
def llm() -> BaseChatOpenAI:
165166
"""Model used for interpreting query."""
167+
if "AZURE_OPENAI_ENDPOINT" in os.environ:
168+
return AzureChatOpenAI(model="gpt-4o", temperature=0.0, cache=False)
166169
return ChatOpenAI(model="gpt-4o", temperature=0.0, cache=False)
167170

168171

libs/langchain-mongodb/tests/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from langchain_core.outputs import ChatGeneration, ChatResult
2121
from langchain_ollama import ChatOllama
22-
from langchain_openai import ChatOpenAI
22+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
2323
from pydantic import model_validator
2424
from pymongo import MongoClient
2525
from pymongo.collection import Collection
@@ -46,6 +46,8 @@ def create_database() -> MongoDBDatabase:
4646

4747

4848
def create_llm() -> BaseChatModel:
49+
if os.environ.get("AZURE_OPENAI_ENDPOINT"):
50+
return AzureChatOpenAI(model="o4-mini", timeout=60, cache=False)
4951
if os.environ.get("OPENAI_API_KEY"):
5052
return ChatOpenAI(model="gpt-4o-mini", timeout=60, cache=False)
5153
return ChatOllama(model="llama3:8b", cache=False)

libs/langgraph-checkpoint-mongodb/tests/integration_tests/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@
33
import pytest
44
from langchain_core.embeddings import Embeddings
55
from langchain_ollama.embeddings import OllamaEmbeddings
6-
from langchain_openai import OpenAIEmbeddings
6+
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
77

88

99
@pytest.fixture(scope="session")
1010
def embedding() -> Embeddings:
11+
if os.environ.get("AZURE_OPENAI_ENDPOINT"):
12+
return AzureOpenAIEmbeddings(model="text-embedding-3-small")
1113
if os.environ.get("OPENAI_API_KEY"):
1214
return OpenAIEmbeddings(
1315
openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
1416
model="text-embedding-3-small",
1517
)
16-
1718
return OllamaEmbeddings(model="all-minilm:l6-v2")
1819

1920

2021
@pytest.fixture(scope="session")
2122
def dimensions() -> int:
22-
if os.environ.get("OPENAI_API_KEY"):
23+
if os.environ.get("OPENAI_API_KEY") or os.environ.get("AZURE_OPENAI_ENDPOINT"):
2324
return 1536
2425
return 384

0 commit comments

Comments
 (0)