Skip to content

Commit e9f4b27

Browse files
committed
Add unit testing
1 parent 511a4c0 commit e9f4b27

12 files changed

+736
-45
lines changed

easilyai/enhanced_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99
import time
10-
from typing import Any, Dict, Optional, Union
10+
from typing import Any, Dict, List, Optional, Union
1111
from .app import EasyAIApp, create_app as _create_app
1212
from .config import get_config, EasilyAIConfig
1313
from .cache import get_cache

tests/test_anthropic_service.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,66 @@
1-
# Disabled Temporarily due to the curent nature of the code.
1+
import unittest
2+
from unittest.mock import patch, MagicMock
3+
from easilyai.services.anthropic_service import AnthropicService
4+
from easilyai.exceptions import MissingAPIKeyError, AuthenticationError, ServerError
25

3-
# from unittest import TestCase
4-
# from unittest.mock import Mock, patch
5-
# from easilyai.services.anthropic_service import AnthropicService
6-
# import anthropic
76

7+
class TestAnthropicService(unittest.TestCase):
8+
def setUp(self):
9+
self.apikey = "fake_api_key"
10+
self.model = "claude-3-sonnet-20240229"
11+
12+
def test_missing_api_key(self):
13+
with self.assertRaises(MissingAPIKeyError):
14+
AnthropicService(apikey=None, model=self.model)
815

9-
# class TestAnthropicService(TestCase):
10-
# def setUp(self):
11-
# self.service = AnthropicService(apikey="test_api_key", model="claude-3-5", max_tokens=1024)
16+
@patch.object(AnthropicService, '__init__', lambda x, y, z, **kwargs: None)
17+
@patch('anthropic.Anthropic')
18+
def test_generate_text_success(self, mock_anthropic_class):
19+
mock_client = mock_anthropic_class.return_value
20+
mock_response = MagicMock()
21+
mock_response.content = [MagicMock(text="Mocked Anthropic response")]
22+
mock_client.messages.create.return_value = mock_response
23+
24+
service = AnthropicService("fake_key", "claude-3-sonnet-20240229")
25+
service.client = mock_client
26+
service.model = "claude-3-sonnet-20240229"
27+
service.max_tokens = 1024
28+
29+
response = service.generate_text("Test prompt")
30+
self.assertEqual(response, "Mocked Anthropic response")
1231

13-
# @patch('anthropic.Anthropic.messages.create', new_callable=Mock)
14-
# def test_generate_text(self, mock_messages):
15-
# mock_messages.create.return_value = {
16-
# "content": [{"text": "Mocked response"}]
17-
# }
32+
@patch.object(AnthropicService, '__init__', lambda x, y, z, **kwargs: None)
33+
@patch('anthropic.Anthropic')
34+
def test_generate_text_with_image(self, mock_anthropic_class):
35+
mock_client = mock_anthropic_class.return_value
36+
mock_response = MagicMock()
37+
mock_response.content = [MagicMock(text="Mocked response with image")]
38+
mock_client.messages.create.return_value = mock_response
39+
40+
service = AnthropicService("fake_key", "claude-3-sonnet-20240229")
41+
service.client = mock_client
42+
service.model = "claude-3-sonnet-20240229"
43+
service.max_tokens = 1024
44+
45+
# Mock prepare_image to return None (simulating URL instead of local file)
46+
with patch.object(service, 'prepare_image', return_value=None):
47+
response = service.generate_text("Describe this image", "http://example.com/image.jpg")
48+
self.assertEqual(response, "Mocked response with image")
1849

19-
# response = self.service.generate_text("Test prompt")
50+
@patch.object(AnthropicService, '__init__', lambda x, y, z, **kwargs: None)
51+
@patch('anthropic.Anthropic')
52+
def test_generate_text_authentication_error(self, mock_anthropic_class):
53+
mock_client = mock_anthropic_class.return_value
54+
mock_client.messages.create.side_effect = Exception("Authentication failed")
55+
56+
service = AnthropicService("fake_key", "claude-3-sonnet-20240229")
57+
service.client = mock_client
58+
service.model = "claude-3-sonnet-20240229"
59+
service.max_tokens = 1024
60+
61+
with self.assertRaises(ServerError):
62+
service.generate_text("Test prompt")
2063

21-
# self.assertEqual(response, "Mocked response")
2264

23-
# @patch('anthropic.Anthropic.messages.create', new_callable=Mock)
24-
# def test_generate_text_error(self, mock_messages):
25-
# mock_messages.create.side_effect = Exception("API Error")
26-
27-
# with self.assertRaises(Exception):
28-
# self.service.generate_text("Test prompt")
65+
if __name__ == "__main__":
66+
unittest.main()

tests/test_batch.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
5+
class TestBatchProcessing(unittest.TestCase):
6+
7+
def test_batch_module_imports(self):
8+
# Test that batch module can be imported
9+
try:
10+
from easilyai import batch
11+
self.assertIsNotNone(batch)
12+
except ImportError:
13+
self.skipTest("Batch module not available")
14+
15+
def test_batch_functionality_exists(self):
16+
# Test that batch processing functionality exists
17+
try:
18+
from easilyai.batch import BatchProcessor
19+
self.assertTrue(hasattr(BatchProcessor, '__init__'))
20+
except ImportError:
21+
# If BatchProcessor doesn't exist, check for other batch functions
22+
try:
23+
import easilyai.batch
24+
# At least the module should exist
25+
self.assertIsNotNone(easilyai.batch)
26+
except ImportError:
27+
self.skipTest("No batch processing functionality found")
28+
29+
30+
if __name__ == "__main__":
31+
unittest.main()

tests/test_custom_ai.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import unittest
2+
from easilyai.custom_ai import CustomAIService, register_custom_ai, _registered_custom_ais
3+
4+
5+
class TestCustomAI(unittest.TestCase):
6+
def setUp(self):
7+
# Clear registered custom AIs before each test
8+
_registered_custom_ais.clear()
9+
10+
def test_custom_ai_service_init(self):
11+
service = CustomAIService(model="test-model", apikey="test-key")
12+
self.assertEqual(service.model, "test-model")
13+
self.assertEqual(service.apikey, "test-key")
14+
15+
def test_custom_ai_service_init_without_apikey(self):
16+
service = CustomAIService(model="test-model")
17+
self.assertEqual(service.model, "test-model")
18+
self.assertIsNone(service.apikey)
19+
20+
def test_custom_ai_service_not_implemented_methods(self):
21+
service = CustomAIService(model="test-model")
22+
23+
with self.assertRaises(NotImplementedError):
24+
service.generate_text("test prompt")
25+
26+
with self.assertRaises(NotImplementedError):
27+
service.generate_image("test prompt")
28+
29+
with self.assertRaises(NotImplementedError):
30+
service.text_to_speech("test text")
31+
32+
def test_register_valid_custom_ai(self):
33+
class ValidCustomAI(CustomAIService):
34+
def generate_text(self, prompt):
35+
return f"Generated: {prompt}"
36+
37+
register_custom_ai("valid_ai", ValidCustomAI)
38+
self.assertIn("valid_ai", _registered_custom_ais)
39+
self.assertEqual(_registered_custom_ais["valid_ai"], ValidCustomAI)
40+
41+
def test_register_invalid_custom_ai(self):
42+
class InvalidCustomAI:
43+
pass
44+
45+
with self.assertRaises(TypeError) as context:
46+
register_custom_ai("invalid_ai", InvalidCustomAI)
47+
48+
self.assertIn("Custom service must inherit from CustomAIService", str(context.exception))
49+
self.assertNotIn("invalid_ai", _registered_custom_ais)
50+
51+
def test_register_multiple_custom_ais(self):
52+
class CustomAI1(CustomAIService):
53+
def generate_text(self, prompt):
54+
return "AI1 response"
55+
56+
class CustomAI2(CustomAIService):
57+
def generate_text(self, prompt):
58+
return "AI2 response"
59+
60+
register_custom_ai("ai1", CustomAI1)
61+
register_custom_ai("ai2", CustomAI2)
62+
63+
self.assertEqual(len(_registered_custom_ais), 2)
64+
self.assertIn("ai1", _registered_custom_ais)
65+
self.assertIn("ai2", _registered_custom_ais)
66+
67+
def test_custom_ai_implementation_example(self):
68+
class MockCustomAI(CustomAIService):
69+
def generate_text(self, prompt):
70+
return f"Mock response to: {prompt}"
71+
72+
def generate_image(self, prompt):
73+
return f"Mock image for: {prompt}"
74+
75+
def text_to_speech(self, text):
76+
return f"Mock audio for: {text}"
77+
78+
service = MockCustomAI(model="mock-model", apikey="mock-key")
79+
80+
self.assertEqual(service.generate_text("Hello"), "Mock response to: Hello")
81+
self.assertEqual(service.generate_image("Cat"), "Mock image for: Cat")
82+
self.assertEqual(service.text_to_speech("Speech"), "Mock audio for: Speech")
83+
84+
85+
if __name__ == "__main__":
86+
unittest.main()

tests/test_enhanced_app.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
5+
class TestEnhancedApp(unittest.TestCase):
6+
7+
def test_enhanced_app_imports(self):
8+
# Test that enhanced app modules can be imported
9+
try:
10+
from easilyai.enhanced_app import create_enhanced_app
11+
self.assertTrue(callable(create_enhanced_app))
12+
except ImportError:
13+
self.fail("Could not import enhanced_app module")
14+
15+
def test_enhanced_app_function_signature(self):
16+
# Test the function exists with correct parameters
17+
from easilyai.enhanced_app import create_enhanced_app
18+
import inspect
19+
20+
sig = inspect.signature(create_enhanced_app)
21+
params = list(sig.parameters.keys())
22+
23+
# Check required parameters exist
24+
self.assertIn('name', params)
25+
self.assertIn('service', params)
26+
self.assertIn('api_key', params)
27+
self.assertIn('model', params)
28+
29+
30+
if __name__ == "__main__":
31+
unittest.main()

tests/test_enhanced_pipeline.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
from easilyai.enhanced_pipeline import (
4+
EnhancedPipeline, TaskStatus, ExecutionMode, TaskResult, PipelineTask
5+
)
6+
7+
8+
class TestEnhancedPipeline(unittest.TestCase):
9+
def setUp(self):
10+
self.mock_app = MagicMock()
11+
self.pipeline = EnhancedPipeline("test_pipeline")
12+
13+
def test_pipeline_init(self):
14+
self.assertEqual(self.pipeline.name, "test_pipeline")
15+
self.assertEqual(len(self.pipeline.tasks), 0)
16+
self.assertEqual(self.pipeline.execution_mode, ExecutionMode.SEQUENTIAL)
17+
18+
def test_add_task_simple(self):
19+
task_id = self.pipeline.add_task("task1", self.mock_app, "generate_text", "Hello")
20+
self.assertEqual(len(self.pipeline.tasks), 1)
21+
self.assertIsInstance(task_id, str)
22+
23+
task = self.pipeline.tasks[task_id]
24+
self.assertEqual(task.task_type, "generate_text")
25+
self.assertEqual(task.prompt, "Hello")
26+
self.assertEqual(task.status, TaskStatus.PENDING)
27+
28+
def test_add_task_with_dependencies(self):
29+
task1_id = self.pipeline.add_task("task1", self.mock_app, "generate_text", "Task 1")
30+
task2_id = self.pipeline.add_task("task2", self.mock_app, "generate_text", "Task 2", dependencies=[task1_id])
31+
32+
self.assertEqual(len(self.pipeline.tasks), 2)
33+
task2 = self.pipeline.tasks[task2_id]
34+
self.assertEqual(task2.dependencies, [task1_id])
35+
36+
def test_add_task_with_condition(self):
37+
condition = lambda results: True
38+
task_id = self.pipeline.add_task("task1", self.mock_app, "generate_text", "Conditional task", condition=condition)
39+
40+
task = self.pipeline.tasks[task_id]
41+
self.assertEqual(task.condition, condition)
42+
43+
def test_task_status_enum(self):
44+
self.assertEqual(TaskStatus.PENDING.value, "pending")
45+
self.assertEqual(TaskStatus.RUNNING.value, "running")
46+
self.assertEqual(TaskStatus.COMPLETED.value, "completed")
47+
self.assertEqual(TaskStatus.FAILED.value, "failed")
48+
self.assertEqual(TaskStatus.SKIPPED.value, "skipped")
49+
50+
def test_execution_mode_enum(self):
51+
self.assertEqual(ExecutionMode.SEQUENTIAL.value, "sequential")
52+
self.assertEqual(ExecutionMode.PARALLEL.value, "parallel")
53+
self.assertEqual(ExecutionMode.CONDITIONAL.value, "conditional")
54+
55+
def test_task_result_creation(self):
56+
result = TaskResult(
57+
task_id="test-task",
58+
status=TaskStatus.COMPLETED,
59+
result="Test result",
60+
duration=1.5
61+
)
62+
63+
self.assertEqual(result.task_id, "test-task")
64+
self.assertEqual(result.status, TaskStatus.COMPLETED)
65+
self.assertEqual(result.result, "Test result")
66+
self.assertEqual(result.duration, 1.5)
67+
self.assertIsNone(result.error)
68+
self.assertEqual(result.metadata, {})
69+
70+
def test_set_execution_mode(self):
71+
self.pipeline.set_execution_mode(ExecutionMode.PARALLEL)
72+
self.assertEqual(self.pipeline.execution_mode, ExecutionMode.PARALLEL)
73+
74+
def test_clear_tasks(self):
75+
self.pipeline.add_task("task1", self.mock_app, "generate_text", "Task 1")
76+
self.pipeline.add_task("task2", self.mock_app, "generate_text", "Task 2")
77+
self.assertEqual(len(self.pipeline.tasks), 2)
78+
79+
self.pipeline.tasks.clear() # Simplified clear method
80+
self.assertEqual(len(self.pipeline.tasks), 0)
81+
82+
@patch('easilyai.enhanced_pipeline.time.time')
83+
def test_simple_execution_simulation(self, mock_time):
84+
# Mock time for duration calculation
85+
mock_time.side_effect = [0.0, 1.0] # Start and end times
86+
87+
# Mock app response
88+
self.mock_app.request.return_value = "Generated response"
89+
90+
# Add a simple task
91+
task_id = self.pipeline.add_task("task1", self.mock_app, "generate_text", "Hello")
92+
93+
# This test verifies the structure, but actual execution would require
94+
# the full implementation which might involve async/threading
95+
self.assertEqual(len(self.pipeline.tasks), 1)
96+
task = self.pipeline.tasks[task_id]
97+
self.assertEqual(task.task_id, task_id)
98+
self.assertEqual(task.status, TaskStatus.PENDING)
99+
100+
def test_variable_substitution_pattern(self):
101+
# Test that the pipeline can handle variable patterns
102+
prompt_with_vars = "Hello {name}, how are you?"
103+
task_id = self.pipeline.add_task("task1", self.mock_app, "generate_text", prompt_with_vars)
104+
105+
task = self.pipeline.tasks[task_id]
106+
self.assertIn("{name}", task.prompt)
107+
108+
def test_pipeline_task_attributes(self):
109+
task_id = self.pipeline.add_task(
110+
"task1",
111+
self.mock_app,
112+
"generate_text",
113+
"Test",
114+
dependencies=["dep1"],
115+
condition=lambda x: True,
116+
retry_count=3,
117+
timeout=30
118+
)
119+
120+
task = self.pipeline.tasks[task_id]
121+
self.assertEqual(task.task_type, "generate_text")
122+
self.assertEqual(task.dependencies, ["dep1"])
123+
self.assertIsNotNone(task.condition)
124+
self.assertEqual(task.retry_count, 3)
125+
self.assertEqual(task.timeout, 30)
126+
127+
128+
if __name__ == "__main__":
129+
unittest.main()

0 commit comments

Comments
 (0)