Skip to content

Commit 5388482

Browse files
authored
Merge pull request #206 from marginal23326/refactor/consolidate-llm-tests
refactor: simplify LLM tests and remove duplication
2 parents dc41476 + d0b4f4c commit 5388482

File tree

1 file changed

+89
-125
lines changed

1 file changed

+89
-125
lines changed

tests/test_llm_api.py

Lines changed: 89 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,162 +1,126 @@
11
import os
22
import pdb
3+
from dataclasses import dataclass
34

45
from dotenv import load_dotenv
6+
from langchain_core.messages import HumanMessage, SystemMessage
7+
from langchain_ollama import ChatOllama
58

69
load_dotenv()
710

811
import sys
912

1013
sys.path.append(".")
1114

12-
13-
def test_openai_model():
14-
from langchain_core.messages import HumanMessage
15+
@dataclass
16+
class LLMConfig:
17+
provider: str
18+
model_name: str
19+
temperature: float = 0.8
20+
base_url: str = None
21+
api_key: str = None
22+
23+
def create_message_content(text, image_path=None):
24+
content = [{"type": "text", "text": text}]
25+
26+
if image_path:
27+
from src.utils import utils
28+
image_data = utils.encode_image(image_path)
29+
content.append({
30+
"type": "image_url",
31+
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"}
32+
})
33+
34+
return content
35+
36+
def get_env_value(key, provider):
37+
env_mappings = {
38+
"openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"},
39+
"azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"},
40+
"gemini": {"api_key": "GOOGLE_API_KEY"},
41+
"deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"}
42+
}
43+
44+
if provider in env_mappings and key in env_mappings[provider]:
45+
return os.getenv(env_mappings[provider][key], "")
46+
return ""
47+
48+
def test_llm(config, query, image_path=None, system_message=None):
1549
from src.utils import utils
1650

51+
# Special handling for Ollama-based models
52+
if config.provider == "ollama":
53+
if "deepseek-r1" in config.model_name:
54+
from src.utils.llm import DeepSeekR1ChatOllama
55+
llm = DeepSeekR1ChatOllama(model=config.model_name)
56+
else:
57+
llm = ChatOllama(model=config.model_name)
58+
59+
ai_msg = llm.invoke(query)
60+
print(ai_msg.content)
61+
if "deepseek-r1" in config.model_name:
62+
pdb.set_trace()
63+
return
64+
65+
# For other providers, use the standard configuration
1766
llm = utils.get_llm_model(
18-
provider="openai",
19-
model_name="gpt-4o",
20-
temperature=0.8,
21-
base_url=os.getenv("OPENAI_ENDPOINT", ""),
22-
api_key=os.getenv("OPENAI_API_KEY", "")
23-
)
24-
image_path = "assets/examples/test.png"
25-
image_data = utils.encode_image(image_path)
26-
message = HumanMessage(
27-
content=[
28-
{"type": "text", "text": "describe this image"},
29-
{
30-
"type": "image_url",
31-
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
32-
},
33-
]
67+
provider=config.provider,
68+
model_name=config.model_name,
69+
temperature=config.temperature,
70+
base_url=config.base_url or get_env_value("base_url", config.provider),
71+
api_key=config.api_key or get_env_value("api_key", config.provider)
3472
)
35-
ai_msg = llm.invoke([message])
36-
print(ai_msg.content)
37-
3873

39-
def test_gemini_model():
40-
# you need to enable your api key first: https://ai.google.dev/palm_docs/oauth_quickstart
41-
from langchain_core.messages import HumanMessage
42-
from src.utils import utils
43-
44-
llm = utils.get_llm_model(
45-
provider="gemini",
46-
model_name="gemini-2.0-flash-exp",
47-
temperature=0.8,
48-
api_key=os.getenv("GOOGLE_API_KEY", "")
49-
)
74+
# Prepare messages for non-Ollama models
75+
messages = []
76+
if system_message:
77+
messages.append(SystemMessage(content=create_message_content(system_message)))
78+
messages.append(HumanMessage(content=create_message_content(query, image_path)))
79+
ai_msg = llm.invoke(messages)
5080

51-
image_path = "assets/examples/test.png"
52-
image_data = utils.encode_image(image_path)
53-
message = HumanMessage(
54-
content=[
55-
{"type": "text", "text": "describe this image"},
56-
{
57-
"type": "image_url",
58-
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
59-
},
60-
]
61-
)
62-
ai_msg = llm.invoke([message])
81+
# Handle different response types
82+
if hasattr(ai_msg, "reasoning_content"):
83+
print(ai_msg.reasoning_content)
6384
print(ai_msg.content)
6485

86+
if config.provider == "deepseek" and "deepseek-reasoner" in config.model_name:
87+
print(llm.model_name)
88+
pdb.set_trace()
6589

66-
def test_azure_openai_model():
67-
from langchain_core.messages import HumanMessage
68-
from src.utils import utils
90+
def test_openai_model():
91+
config = LLMConfig(provider="openai", model_name="gpt-4o")
92+
test_llm(config, "Describe this image", "assets/examples/test.png")
6993

70-
llm = utils.get_llm_model(
71-
provider="azure_openai",
72-
model_name="gpt-4o",
73-
temperature=0.8,
74-
base_url=os.getenv("AZURE_OPENAI_ENDPOINT", ""),
75-
api_key=os.getenv("AZURE_OPENAI_API_KEY", "")
76-
)
77-
image_path = "assets/examples/test.png"
78-
image_data = utils.encode_image(image_path)
79-
message = HumanMessage(
80-
content=[
81-
{"type": "text", "text": "describe this image"},
82-
{
83-
"type": "image_url",
84-
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
85-
},
86-
]
87-
)
88-
ai_msg = llm.invoke([message])
89-
print(ai_msg.content)
94+
def test_gemini_model():
95+
# Enable your API key first if you haven't: https://ai.google.dev/palm_docs/oauth_quickstart
96+
config = LLMConfig(provider="gemini", model_name="gemini-2.0-flash-exp")
97+
test_llm(config, "Describe this image", "assets/examples/test.png")
9098

99+
def test_azure_openai_model():
100+
config = LLMConfig(provider="azure_openai", model_name="gpt-4o")
101+
test_llm(config, "Describe this image", "assets/examples/test.png")
91102

92103
def test_deepseek_model():
93-
from langchain_core.messages import HumanMessage
94-
from src.utils import utils
95-
96-
llm = utils.get_llm_model(
97-
provider="deepseek",
98-
model_name="deepseek-chat",
99-
temperature=0.8,
100-
base_url=os.getenv("DEEPSEEK_ENDPOINT", ""),
101-
api_key=os.getenv("DEEPSEEK_API_KEY", "")
102-
)
103-
message = HumanMessage(
104-
content=[
105-
{"type": "text", "text": "who are you?"}
106-
]
107-
)
108-
ai_msg = llm.invoke([message])
109-
print(ai_msg.content)
104+
config = LLMConfig(provider="deepseek", model_name="deepseek-chat")
105+
test_llm(config, "Who are you?")
110106

111107
def test_deepseek_r1_model():
112-
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
113-
from src.utils import utils
114-
115-
llm = utils.get_llm_model(
116-
provider="deepseek",
117-
model_name="deepseek-reasoner",
118-
temperature=0.8,
119-
base_url=os.getenv("DEEPSEEK_ENDPOINT", ""),
120-
api_key=os.getenv("DEEPSEEK_API_KEY", "")
121-
)
122-
messages = []
123-
sys_message = SystemMessage(
124-
content=[{"type": "text", "text": "you are a helpful AI assistant"}]
125-
)
126-
messages.append(sys_message)
127-
user_message = HumanMessage(
128-
content=[
129-
{"type": "text", "text": "9.11 and 9.8, which is greater?"}
130-
]
131-
)
132-
messages.append(user_message)
133-
ai_msg = llm.invoke(messages)
134-
print(ai_msg.reasoning_content)
135-
print(ai_msg.content)
136-
print(llm.model_name)
137-
pdb.set_trace()
108+
config = LLMConfig(provider="deepseek", model_name="deepseek-reasoner")
109+
test_llm(config, "Which is greater, 9.11 or 9.8?", system_message="You are a helpful AI assistant.")
138110

139111
def test_ollama_model():
140-
from langchain_ollama import ChatOllama
112+
config = LLMConfig(provider="ollama", model_name="qwen2.5:7b")
113+
test_llm(config, "Sing a ballad of LangChain.")
141114

142-
llm = ChatOllama(model="qwen2.5:7b")
143-
ai_msg = llm.invoke("Sing a ballad of LangChain.")
144-
print(ai_msg.content)
145-
146115
def test_deepseek_r1_ollama_model():
147-
from src.utils.llm import DeepSeekR1ChatOllama
148-
149-
llm = DeepSeekR1ChatOllama(model="deepseek-r1:14b")
150-
ai_msg = llm.invoke("how many r in strawberry?")
151-
print(ai_msg.content)
152-
pdb.set_trace()
153-
116+
config = LLMConfig(provider="ollama", model_name="deepseek-r1:14b")
117+
test_llm(config, "How many 'r's are in the word 'strawberry'?")
154118

155-
if __name__ == '__main__':
119+
if __name__ == "__main__":
156120
# test_openai_model()
157121
# test_gemini_model()
158122
# test_azure_openai_model()
159123
test_deepseek_model()
160124
# test_ollama_model()
161125
# test_deepseek_r1_model()
162-
# test_deepseek_r1_ollama_model()
126+
# test_deepseek_r1_ollama_model()

0 commit comments

Comments
 (0)