1
1
import os
2
2
import pdb
3
+ from dataclasses import dataclass
3
4
4
5
from dotenv import load_dotenv
6
+ from langchain_core .messages import HumanMessage , SystemMessage
7
+ from langchain_ollama import ChatOllama
5
8
6
9
load_dotenv ()
7
10
8
11
import sys
9
12
10
13
sys .path .append ("." )
11
14
12
-
13
- def test_openai_model ():
14
- from langchain_core .messages import HumanMessage
15
+ @dataclass
16
+ class LLMConfig :
17
+ provider : str
18
+ model_name : str
19
+ temperature : float = 0.8
20
+ base_url : str = None
21
+ api_key : str = None
22
+
23
+ def create_message_content (text , image_path = None ):
24
+ content = [{"type" : "text" , "text" : text }]
25
+
26
+ if image_path :
27
+ from src .utils import utils
28
+ image_data = utils .encode_image (image_path )
29
+ content .append ({
30
+ "type" : "image_url" ,
31
+ "image_url" : {"url" : f"data:image/jpeg;base64,{ image_data } " }
32
+ })
33
+
34
+ return content
35
+
36
+ def get_env_value (key , provider ):
37
+ env_mappings = {
38
+ "openai" : {"api_key" : "OPENAI_API_KEY" , "base_url" : "OPENAI_ENDPOINT" },
39
+ "azure_openai" : {"api_key" : "AZURE_OPENAI_API_KEY" , "base_url" : "AZURE_OPENAI_ENDPOINT" },
40
+ "gemini" : {"api_key" : "GOOGLE_API_KEY" },
41
+ "deepseek" : {"api_key" : "DEEPSEEK_API_KEY" , "base_url" : "DEEPSEEK_ENDPOINT" }
42
+ }
43
+
44
+ if provider in env_mappings and key in env_mappings [provider ]:
45
+ return os .getenv (env_mappings [provider ][key ], "" )
46
+ return ""
47
+
48
+ def test_llm (config , query , image_path = None , system_message = None ):
15
49
from src .utils import utils
16
50
51
+ # Special handling for Ollama-based models
52
+ if config .provider == "ollama" :
53
+ if "deepseek-r1" in config .model_name :
54
+ from src .utils .llm import DeepSeekR1ChatOllama
55
+ llm = DeepSeekR1ChatOllama (model = config .model_name )
56
+ else :
57
+ llm = ChatOllama (model = config .model_name )
58
+
59
+ ai_msg = llm .invoke (query )
60
+ print (ai_msg .content )
61
+ if "deepseek-r1" in config .model_name :
62
+ pdb .set_trace ()
63
+ return
64
+
65
+ # For other providers, use the standard configuration
17
66
llm = utils .get_llm_model (
18
- provider = "openai" ,
19
- model_name = "gpt-4o" ,
20
- temperature = 0.8 ,
21
- base_url = os .getenv ("OPENAI_ENDPOINT" , "" ),
22
- api_key = os .getenv ("OPENAI_API_KEY" , "" )
23
- )
24
- image_path = "assets/examples/test.png"
25
- image_data = utils .encode_image (image_path )
26
- message = HumanMessage (
27
- content = [
28
- {"type" : "text" , "text" : "describe this image" },
29
- {
30
- "type" : "image_url" ,
31
- "image_url" : {"url" : f"data:image/jpeg;base64,{ image_data } " },
32
- },
33
- ]
67
+ provider = config .provider ,
68
+ model_name = config .model_name ,
69
+ temperature = config .temperature ,
70
+ base_url = config .base_url or get_env_value ("base_url" , config .provider ),
71
+ api_key = config .api_key or get_env_value ("api_key" , config .provider )
34
72
)
35
- ai_msg = llm .invoke ([message ])
36
- print (ai_msg .content )
37
-
38
73
39
- def test_gemini_model ():
40
- # you need to enable your api key first: https://ai.google.dev/palm_docs/oauth_quickstart
41
- from langchain_core .messages import HumanMessage
42
- from src .utils import utils
43
-
44
- llm = utils .get_llm_model (
45
- provider = "gemini" ,
46
- model_name = "gemini-2.0-flash-exp" ,
47
- temperature = 0.8 ,
48
- api_key = os .getenv ("GOOGLE_API_KEY" , "" )
49
- )
74
+ # Prepare messages for non-Ollama models
75
+ messages = []
76
+ if system_message :
77
+ messages .append (SystemMessage (content = create_message_content (system_message )))
78
+ messages .append (HumanMessage (content = create_message_content (query , image_path )))
79
+ ai_msg = llm .invoke (messages )
50
80
51
- image_path = "assets/examples/test.png"
52
- image_data = utils .encode_image (image_path )
53
- message = HumanMessage (
54
- content = [
55
- {"type" : "text" , "text" : "describe this image" },
56
- {
57
- "type" : "image_url" ,
58
- "image_url" : {"url" : f"data:image/jpeg;base64,{ image_data } " },
59
- },
60
- ]
61
- )
62
- ai_msg = llm .invoke ([message ])
81
+ # Handle different response types
82
+ if hasattr (ai_msg , "reasoning_content" ):
83
+ print (ai_msg .reasoning_content )
63
84
print (ai_msg .content )
64
85
86
+ if config .provider == "deepseek" and "deepseek-reasoner" in config .model_name :
87
+ print (llm .model_name )
88
+ pdb .set_trace ()
65
89
66
- def test_azure_openai_model ():
67
- from langchain_core . messages import HumanMessage
68
- from src . utils import utils
90
+ def test_openai_model ():
91
+ config = LLMConfig ( provider = "openai" , model_name = "gpt-4o" )
92
+ test_llm ( config , "Describe this image" , "assets/examples/test.png" )
69
93
70
- llm = utils .get_llm_model (
71
- provider = "azure_openai" ,
72
- model_name = "gpt-4o" ,
73
- temperature = 0.8 ,
74
- base_url = os .getenv ("AZURE_OPENAI_ENDPOINT" , "" ),
75
- api_key = os .getenv ("AZURE_OPENAI_API_KEY" , "" )
76
- )
77
- image_path = "assets/examples/test.png"
78
- image_data = utils .encode_image (image_path )
79
- message = HumanMessage (
80
- content = [
81
- {"type" : "text" , "text" : "describe this image" },
82
- {
83
- "type" : "image_url" ,
84
- "image_url" : {"url" : f"data:image/jpeg;base64,{ image_data } " },
85
- },
86
- ]
87
- )
88
- ai_msg = llm .invoke ([message ])
89
- print (ai_msg .content )
94
+ def test_gemini_model ():
95
+ # Enable your API key first if you haven't: https://ai.google.dev/palm_docs/oauth_quickstart
96
+ config = LLMConfig (provider = "gemini" , model_name = "gemini-2.0-flash-exp" )
97
+ test_llm (config , "Describe this image" , "assets/examples/test.png" )
90
98
99
+ def test_azure_openai_model ():
100
+ config = LLMConfig (provider = "azure_openai" , model_name = "gpt-4o" )
101
+ test_llm (config , "Describe this image" , "assets/examples/test.png" )
91
102
92
103
def test_deepseek_model ():
93
- from langchain_core .messages import HumanMessage
94
- from src .utils import utils
95
-
96
- llm = utils .get_llm_model (
97
- provider = "deepseek" ,
98
- model_name = "deepseek-chat" ,
99
- temperature = 0.8 ,
100
- base_url = os .getenv ("DEEPSEEK_ENDPOINT" , "" ),
101
- api_key = os .getenv ("DEEPSEEK_API_KEY" , "" )
102
- )
103
- message = HumanMessage (
104
- content = [
105
- {"type" : "text" , "text" : "who are you?" }
106
- ]
107
- )
108
- ai_msg = llm .invoke ([message ])
109
- print (ai_msg .content )
104
+ config = LLMConfig (provider = "deepseek" , model_name = "deepseek-chat" )
105
+ test_llm (config , "Who are you?" )
110
106
111
107
def test_deepseek_r1_model ():
112
- from langchain_core .messages import HumanMessage , SystemMessage , AIMessage
113
- from src .utils import utils
114
-
115
- llm = utils .get_llm_model (
116
- provider = "deepseek" ,
117
- model_name = "deepseek-reasoner" ,
118
- temperature = 0.8 ,
119
- base_url = os .getenv ("DEEPSEEK_ENDPOINT" , "" ),
120
- api_key = os .getenv ("DEEPSEEK_API_KEY" , "" )
121
- )
122
- messages = []
123
- sys_message = SystemMessage (
124
- content = [{"type" : "text" , "text" : "you are a helpful AI assistant" }]
125
- )
126
- messages .append (sys_message )
127
- user_message = HumanMessage (
128
- content = [
129
- {"type" : "text" , "text" : "9.11 and 9.8, which is greater?" }
130
- ]
131
- )
132
- messages .append (user_message )
133
- ai_msg = llm .invoke (messages )
134
- print (ai_msg .reasoning_content )
135
- print (ai_msg .content )
136
- print (llm .model_name )
137
- pdb .set_trace ()
108
+ config = LLMConfig (provider = "deepseek" , model_name = "deepseek-reasoner" )
109
+ test_llm (config , "Which is greater, 9.11 or 9.8?" , system_message = "You are a helpful AI assistant." )
138
110
139
111
def test_ollama_model ():
140
- from langchain_ollama import ChatOllama
112
+ config = LLMConfig (provider = "ollama" , model_name = "qwen2.5:7b" )
113
+ test_llm (config , "Sing a ballad of LangChain." )
141
114
142
- llm = ChatOllama (model = "qwen2.5:7b" )
143
- ai_msg = llm .invoke ("Sing a ballad of LangChain." )
144
- print (ai_msg .content )
145
-
146
115
def test_deepseek_r1_ollama_model ():
147
- from src .utils .llm import DeepSeekR1ChatOllama
148
-
149
- llm = DeepSeekR1ChatOllama (model = "deepseek-r1:14b" )
150
- ai_msg = llm .invoke ("how many r in strawberry?" )
151
- print (ai_msg .content )
152
- pdb .set_trace ()
153
-
116
+ config = LLMConfig (provider = "ollama" , model_name = "deepseek-r1:14b" )
117
+ test_llm (config , "How many 'r's are in the word 'strawberry'?" )
154
118
155
- if __name__ == ' __main__' :
119
+ if __name__ == " __main__" :
156
120
# test_openai_model()
157
121
# test_gemini_model()
158
122
# test_azure_openai_model()
159
123
test_deepseek_model ()
160
124
# test_ollama_model()
161
125
# test_deepseek_r1_model()
162
- # test_deepseek_r1_ollama_model()
126
+ # test_deepseek_r1_ollama_model()
0 commit comments