-
-
Notifications
You must be signed in to change notification settings - Fork 65
Expand file tree
/
Copy pathtest_module_llm.py
More file actions
157 lines (130 loc) · 5.98 KB
/
test_module_llm.py
File metadata and controls
157 lines (130 loc) · 5.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
from unittest.mock import Mock, patch
import pytest
from pydantic import BaseModel
from mesa_llm.module_llm import ModuleLLM
class TestModuleLLM:
"""Test ModuleLLM class"""
def test_missing_provider_prefix(self):
"""ModuleLLM should raise ValueError when llm_model has no provider prefix."""
with pytest.raises(ValueError, match="Invalid model format"):
ModuleLLM(llm_model="gpt-4o")
def test_module_llm_initialization(self, mock_environment):
# Test initialization with default values
llm = ModuleLLM(llm_model="openai/gpt-4o")
assert llm.api_key == "test_openai_key"
assert llm.api_base is None
assert llm.llm_model == "openai/gpt-4o"
assert llm.system_prompt is None
# Test initialization with ollama provider
llm = ModuleLLM(llm_model="ollama/llama2")
assert llm.api_base == "http://localhost:11434"
assert llm.llm_model == "ollama/llama2"
assert llm.system_prompt is None
# Test initialization with ollama provider + custom api_base
llm = ModuleLLM(llm_model="ollama/llama2", api_base="http://localhost:99999")
assert llm.api_base == "http://localhost:99999"
assert llm.llm_model == "ollama/llama2"
assert llm.system_prompt is None
# Test init without api_key in dotenv
with patch.dict(os.environ, {}, clear=True), pytest.raises(ValueError):
ModuleLLM(llm_model="openai/gpt-4o")
def test_build_messages(self):
# Test _build_messages with string prompt
llm = ModuleLLM(llm_model="openai/gpt-4o")
messages = llm._build_messages("Hello, how are you?")
assert messages == [
{"role": "system", "content": ""},
{"role": "user", "content": "Hello, how are you?"},
]
# Test _build_messages with list of prompts
messages = llm._build_messages(
["Hello, how are you?", "What is the weather in Tokyo?"]
)
assert messages == [
{"role": "system", "content": ""},
{"role": "user", "content": "Hello, how are you?"},
{"role": "user", "content": "What is the weather in Tokyo?"},
]
# Test _build_messages with system prompt
llm = ModuleLLM(
llm_model="openai/gpt-4o", system_prompt="You are a helpful assistant."
)
messages = llm._build_messages("Hello, how are you?")
assert messages == [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
]
# Test _build_messages with system prompt and list of prompts
messages = llm._build_messages(
["Hello, how are you?", "What is the weather in Tokyo?"]
)
assert messages == [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "user", "content": "What is the weather in Tokyo?"},
]
# Test _build_messages no system prompt and no prompt
llm = ModuleLLM(llm_model="openai/gpt-4o")
messages = llm._build_messages(prompt=None)
assert messages == [{"role": "system", "content": ""}]
messages = llm._build_messages(
"Hello, how are you?", system_prompt="Per-call prompt"
)
assert messages == [
{"role": "system", "content": "Per-call prompt"},
{"role": "user", "content": "Hello, how are you?"},
]
def test_parse_structured_output(self):
class DummyOutput(BaseModel):
answer: str
llm = ModuleLLM(llm_model="openai/gpt-4o")
response = Mock()
response.choices = [Mock()]
response.choices[0].message = Mock()
response.choices[0].message.parsed = DummyOutput(answer="parsed")
parsed = llm.parse_structured_output(response, DummyOutput)
assert parsed.answer == "parsed"
response.choices[0].message.parsed = {"answer": "dict"}
parsed = llm.parse_structured_output(response, DummyOutput)
assert parsed.answer == "dict"
response.choices[0].message.parsed = None
response.choices[0].message.content = '{"answer":"json"}'
parsed = llm.parse_structured_output(response, DummyOutput)
assert parsed.answer == "json"
def test_generate(self, monkeypatch, llm_response_factory):
monkeypatch.setattr(
"mesa_llm.module_llm.completion", lambda **kwargs: llm_response_factory()
)
# Test generate with string prompt
llm = ModuleLLM(llm_model="openai/gpt-4o")
response = llm.generate(prompt="Hello, how are you?")
assert response is not None
# Test generate with list of prompts
response = llm.generate(
prompt=["Hello, how are you?", "What is the weather in Tokyo?"]
)
assert response is not None
# Test generate with string prompt for Ollama
llm = ModuleLLM(llm_model="ollama/llama2")
response = llm.generate(prompt="Hello, how are you?")
assert response is not None
# Test generate with list of prompts
response = llm.generate(
prompt=["Hello, how are you?", "What is the weather in Tokyo?"]
)
assert response is not None
@pytest.mark.asyncio
async def test_agenerate(self, monkeypatch, llm_response_factory):
async def _dummy_acompletion(**kwargs):
return llm_response_factory()
monkeypatch.setattr("mesa_llm.module_llm.acompletion", _dummy_acompletion)
# Test agenerate with string prompt
llm = ModuleLLM(llm_model="openai/gpt-4o")
response = await llm.agenerate(prompt="Hello, how are you?")
assert response is not None
# Test agenerate with list of prompts
response = await llm.agenerate(
prompt=["Hello, how are you?", "What is the weather in Tokyo?"]
)
assert response is not None