1
+ import unittest
2
+ from unittest .mock import patch , MagicMock
3
+ from Chat_GPT_Function import gpt , dalle3 , dalle2
4
+ from dotenv import load_dotenv
5
+ import os
6
+
7
+ load_dotenv (override = True )
8
+
9
+ gpt_api_key = os .getenv ("GPT_API_KEY" )
10
+ class TestGPT (unittest .TestCase ):
11
+ @patch ('Chat_GPT_Function.OpenAI' )
12
+ def test_gpt_success (self , mock_openai ):
13
+ mock_client = MagicMock ()
14
+ mock_openai .return_value = mock_client
15
+ mock_response = MagicMock ()
16
+ mock_response .choices = [MagicMock (message = MagicMock (content = "Test output" ))]
17
+ mock_client .chat .completions .create .return_value = mock_response
18
+
19
+ result = gpt ("gpt-3.5-turbo" , "prompt" , "sys_prompt" , 0.5 )
20
+
21
+ self .assertIsInstance (result , str )
22
+ self .assertTrue (result .strip ())
23
+ mock_openai .assert_called_once_with (api_key = gpt_api_key )
24
+ mock_client .chat .completions .create .assert_called_once_with (
25
+ model = "gpt-3.5-turbo" ,
26
+ messages = [
27
+ {"role" : "system" , "content" : "sys_prompt" },
28
+ {"role" : "user" , "content" : "prompt" }
29
+ ],
30
+ temperature = 0.5 ,
31
+ top_p = 1
32
+ )
33
+
34
+ @patch ('Chat_GPT_Function.OpenAI' )
35
+ def test_gpt_failure (self , mock_openai ):
36
+ mock_client = MagicMock ()
37
+ mock_openai .return_value = mock_client
38
+ mock_client .chat .completions .create .side_effect = Exception ("API error" )
39
+
40
+ with self .assertRaises (Exception ):
41
+ gpt ("model" , "prompt" , "sys_prompt" , 0.5 )
42
+
43
+ class TestDALLE3 (unittest .TestCase ):
44
+ @patch ('Chat_GPT_Function.OpenAI' )
45
+ def test_dalle3_success (self , mock_openai ):
46
+ mock_client = MagicMock ()
47
+ mock_openai .return_value = mock_client
48
+ mock_response = MagicMock ()
49
+ mock_response .data = [MagicMock (url = "https://example.com/image.png" )]
50
+ mock_client .images .generate .return_value = mock_response
51
+
52
+ result = dalle3 ("prompt" , "hd" , "1792x1024" , "vivid" )
53
+
54
+ self .assertIsInstance (result , str )
55
+ self .assertTrue (result .startswith ("https://" ))
56
+ mock_openai .assert_called_once_with (api_key = gpt_api_key )
57
+ mock_client .images .generate .assert_called_once_with (
58
+ model = "dall-e-3" ,
59
+ prompt = "prompt" ,
60
+ size = "1792x1024" ,
61
+ quality = "hd" ,
62
+ style = "vivid" ,
63
+ n = 1
64
+ )
65
+
66
+
67
+ @patch ('Chat_GPT_Function.OpenAI' )
68
+ def test_dalle3_failure (self , mock_openai ):
69
+ mock_client = MagicMock ()
70
+ mock_openai .return_value = mock_client
71
+ mock_client .images .generate .side_effect = Exception ("API error" )
72
+
73
+ with self .assertRaises (Exception ):
74
+ dalle3 ("prompt" , "quality" , "size" , "style" )
75
+
76
+ class TestDALLE2 (unittest .TestCase ):
77
+ @patch ('Chat_GPT_Function.OpenAI' )
78
+ def test_dalle2_success (self , mock_openai ):
79
+ mock_client = MagicMock ()
80
+ mock_openai .return_value = mock_client
81
+ mock_response = MagicMock ()
82
+ mock_response .data = [MagicMock (url = "https://example.com/image.png" )]
83
+ mock_client .images .generate .return_value = mock_response
84
+
85
+ result = dalle2 ("prompt" , "256x256" )
86
+
87
+ self .assertIsInstance (result , str )
88
+ self .assertTrue (result .startswith ("https://" ))
89
+ mock_openai .assert_called_once_with (api_key = gpt_api_key )
90
+ mock_client .images .generate .assert_called_once_with (
91
+ model = "dall-e-2" ,
92
+ prompt = "prompt" ,
93
+ size = "256x256" ,
94
+ n = 1
95
+ )
96
+
97
+ @patch ('Chat_GPT_Function.OpenAI' )
98
+ def test_dalle2_failure (self , mock_openai ):
99
+ mock_client = MagicMock ()
100
+ mock_openai .return_value = mock_client
101
+ mock_client .images .generate .side_effect = Exception ("API error" )
102
+
103
+ with self .assertRaises (Exception ):
104
+ dalle2 ("prompt" , "size" )
105
+
106
+ if __name__ == '__main__' :
107
+ unittest .main ()
0 commit comments