@@ -2,23 +2,19 @@ package gpt
22
33import (
44 "context"
5- "encoding/json"
6- "net"
7- "net/http"
8- "net/url"
5+ "errors"
6+ "io"
97 "os"
108 "regexp"
119 "strings"
1210
13- Ollama "github.com/ollama/ollama/api"
1411 OpenAI "github.com/sashabaranov/go-openai"
1512
1613 "github.com/dimaskiddo/go-whatsapp-multidevice-gpt/pkg/env"
1714 "github.com/dimaskiddo/go-whatsapp-multidevice-gpt/pkg/log"
1815)
1916
2017var OAIClient * OpenAI.Client
21- var OClient * Ollama.Client
2218
2319var (
2420 WAGPTEngine ,
3228 OAIAPIKey string
3329)
3430
35- var (
36- OHost ,
37- OHostPath string
38- )
39-
4031var (
4132 GPTModelName ,
4233 GPTModelPrompt string
@@ -58,51 +49,29 @@ func init() {
5849 // -----------------------------------------------------------------------
5950 // WhatsApp GPT Configuration Environment
6051 // -----------------------------------------------------------------------
61- WAGPTEngine , err = env .GetEnvString ("WHATSAPP_GPT_ENGINE" )
62- if err != nil {
63- log .Println (log .LogLevelFatal , "Error Parse Environment Variable for WhatsApp GPT Engine" )
64- }
65-
6652 WAGPTBlockedWord = strings .TrimSpace (os .Getenv ("WHATSAPP_GPT_BLOCKED_WORD" ))
6753 if len (WAGPTBlockedWord ) > 0 {
6854 WAGPTBlockedWordRegex = regexp .MustCompile ("\\ b(?i)(" + listBlockedWord + "|" + WAGPTBlockedWord + ")" )
6955 } else {
7056 WAGPTBlockedWordRegex = regexp .MustCompile ("\\ b(?i)(" + listBlockedWord + ")" )
7157 }
7258
73- switch strings .ToLower (WAGPTEngine ) {
74- case "openai" :
75- // -----------------------------------------------------------------------
76- // OpenAI Configuration Environment
77- // -----------------------------------------------------------------------
78- OAIHost , err = env .GetEnvString ("OPENAI_HOST" )
79- if err != nil {
80- OAIHost = "https://api.openai.com"
81- }
82-
83- OAIHostPath , err = env .GetEnvString ("OPENAI_HOST_PATH" )
84- if err != nil {
85- OAIHostPath = "/v1"
86- }
87-
88- OAIAPIKey , err = env .GetEnvString ("OPENAI_API_KEY" )
89- if err != nil {
90- log .Println (log .LogLevelFatal , "Error Parse Environment Variable for OpenAI API Key" )
91- }
59+ // -----------------------------------------------------------------------
60+ // OpenAI Configuration Environment
61+ // -----------------------------------------------------------------------
62+ OAIHost , err = env .GetEnvString ("OPENAI_HOST" )
63+ if err != nil {
64+ OAIHost = "https://api.openai.com"
65+ }
9266
93- default :
94- // -----------------------------------------------------------------------
95- // Ollama Configuration Environment
96- // -----------------------------------------------------------------------
97- OHost , err = env .GetEnvString ("OLLAMA_HOST" )
98- if err != nil {
99- log .Println (log .LogLevelFatal , "Error Parse Environment Variable for Ollama Host" )
100- }
67+ OAIHostPath , err = env .GetEnvString ("OPENAI_HOST_PATH" )
68+ if err != nil {
69+ OAIHostPath = "/v1"
70+ }
10171
102- OHostPath , err = env .GetEnvString ("OLLAMA_HOST_PATH" )
103- if err != nil {
104- OHostPath = "/"
105- }
72+ OAIAPIKey , err = env .GetEnvString ("OPENAI_API_KEY" )
73+ if err != nil {
74+ log .Println (log .LogLevelFatal , "Error Parse Environment Variable for OpenAI API Key" )
10675 }
10776
10877 // -----------------------------------------------------------------------
@@ -146,38 +115,10 @@ func init() {
146115 // -----------------------------------------------------------------------
147116 // GPT Engine Initialization
148117 // -----------------------------------------------------------------------
149- switch strings .ToLower (WAGPTEngine ) {
150- case "openai" :
151- OAIConfig := OpenAI .DefaultConfig (OAIAPIKey )
152- OAIConfig .BaseURL = OAIHost + OAIHostPath
153-
154- OAIClient = OpenAI .NewClientWithConfig (OAIConfig )
155-
156- default :
157- var OHostPort string
158-
159- OHostSchema , OHostURL , isOK := strings .Cut (OHost , "://" )
118+ OAIConfig := OpenAI .DefaultConfig (OAIAPIKey )
119+ OAIConfig .BaseURL = OAIHost + OAIHostPath
160120
161- if ! isOK {
162- OHostSchema = "http"
163- OHostURL = OHost
164- OHostPort = "11434"
165- }
166-
167- switch OHostSchema {
168- case "http" :
169- OHostPort = "80"
170-
171- case "https" :
172- OHostPort = "443"
173- }
174-
175- OClient = Ollama .NewClient (& url.URL {
176- Scheme : OHostSchema ,
177- Host : net .JoinHostPort (OHostURL , OHostPort ),
178- Path : OHostPath ,
179- }, http .DefaultClient )
180- }
121+ OAIClient = OpenAI .NewClientWithConfig (OAIConfig )
181122}
182123
183124func GPTResponse (question string ) (response string , err error ) {
@@ -186,131 +127,74 @@ func GPTResponse(question string) (response string, err error) {
186127 }
187128
188129 isStream := new (bool )
189- * isStream = false
190-
191- switch strings .ToLower (WAGPTEngine ) {
192- case "openai" :
193- var OAIGPTResponseText string
194- var OAIGPTChatCompletion []OpenAI.ChatCompletionMessage
195-
196- if len (strings .TrimSpace (GPTModelPrompt )) != 0 {
197- OAIGPTChatCompletion = []OpenAI.ChatCompletionMessage {
198- {
199- Role : OpenAI .ChatMessageRoleSystem ,
200- Content : GPTModelPrompt ,
201- },
202- {
203- Role : OpenAI .ChatMessageRoleUser ,
204- Content : question ,
205- },
206- }
207- } else {
208- OAIGPTChatCompletion = []OpenAI.ChatCompletionMessage {
209- {
210- Role : OpenAI .ChatMessageRoleUser ,
211- Content : question ,
212- },
213- }
214- }
215-
216- OAIGPTPrompt := OpenAI.ChatCompletionRequest {
217- Model : GPTModelName ,
218- MaxTokens : GPTModelToken ,
219- Temperature : GPTModelTemperature ,
220- TopP : GPTModelTopP ,
221- PresencePenalty : GPTModelPenaltyPresence ,
222- FrequencyPenalty : GPTModelPenaltyFreq ,
223- Messages : OAIGPTChatCompletion ,
224- Stream : * isStream ,
130+ * isStream = true
131+
132+ var OAIGPTResponseText string
133+ var OAIGPTChatCompletion []OpenAI.ChatCompletionMessage
134+
135+ if len (strings .TrimSpace (GPTModelPrompt )) != 0 {
136+ OAIGPTChatCompletion = []OpenAI.ChatCompletionMessage {
137+ {
138+ Role : OpenAI .ChatMessageRoleSystem ,
139+ Content : GPTModelPrompt ,
140+ },
141+ {
142+ Role : OpenAI .ChatMessageRoleUser ,
143+ Content : question ,
144+ },
225145 }
226-
227- OAIGPTResponse , err := OAIClient .CreateChatCompletion (
228- context .Background (),
229- OAIGPTPrompt ,
230- )
231-
232- if err != nil {
233- return "" , err
146+ } else {
147+ OAIGPTChatCompletion = []OpenAI.ChatCompletionMessage {
148+ {
149+ Role : OpenAI .ChatMessageRoleUser ,
150+ Content : question ,
151+ },
234152 }
153+ }
235154
236- if len (OAIGPTResponse .Choices ) > 0 {
237- OAIGPTResponseText = OAIGPTResponse .Choices [0 ].Message .Content
238- }
155+ OAIGPTPrompt := OpenAI.ChatCompletionRequest {
156+ Model : GPTModelName ,
157+ MaxTokens : GPTModelToken ,
158+ Temperature : GPTModelTemperature ,
159+ TopP : GPTModelTopP ,
160+ PresencePenalty : GPTModelPenaltyPresence ,
161+ FrequencyPenalty : GPTModelPenaltyFreq ,
162+ Messages : OAIGPTChatCompletion ,
163+ Stream : * isStream ,
164+ }
239165
240- OAIGPTResponseBuffer := strings .TrimSpace (OAIGPTResponseText )
241- OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , "?\n " )
242- OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , "!\n " )
243- OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , ":\n " )
244- OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , "'\n " )
245- OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , ".\n " )
246- OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , "\n " )
247-
248- return OAIGPTResponseBuffer , nil
249-
250- default :
251- var OGPTResponseText string
252- var OGPTChatCompletion []Ollama.Message
253-
254- if len (strings .TrimSpace (GPTModelPrompt )) != 0 {
255- OGPTChatCompletion = []Ollama.Message {
256- {
257- Role : "system" ,
258- Content : GPTModelPrompt ,
259- },
260- {
261- Role : "user" ,
262- Content : question ,
263- },
264- }
265- } else {
266- OGPTChatCompletion = []Ollama.Message {
267- {
268- Role : "user" ,
269- Content : question ,
270- },
271- }
272- }
166+ OAIGPTStream , err := OAIClient .CreateChatCompletionStream (
167+ context .Background (),
168+ OAIGPTPrompt ,
169+ )
273170
274- OGPTOptions := map [string ]interface {}{}
275- OGPTOptionsMarshal , _ := json .Marshal (Ollama.Options {
276- Temperature : GPTModelTemperature ,
277- TopP : GPTModelTopP ,
278- PresencePenalty : GPTModelPenaltyPresence ,
279- FrequencyPenalty : GPTModelPenaltyFreq ,
280- })
281-
282- json .Unmarshal (OGPTOptionsMarshal , & OGPTOptions )
283-
284- OGPTPrompt := & Ollama.ChatRequest {
285- Model : GPTModelName ,
286- Options : OGPTOptions ,
287- Messages : OGPTChatCompletion ,
288- Stream : isStream ,
289- }
171+ if err != nil {
172+ return "" , err
173+ }
174+ defer OAIGPTStream .Close ()
290175
291- OGTPResponseFunc := func (OGPTResponse Ollama.ChatResponse ) error {
292- OGPTResponseText = OGPTResponse .Message .Content
293- return nil
176+ for {
177+ OAIGPTResponse , err := OAIGPTStream .Recv ()
178+ if errors .Is (err , io .EOF ) {
179+ break
294180 }
295181
296- err := OClient .Chat (
297- context .Background (),
298- OGPTPrompt ,
299- OGTPResponseFunc ,
300- )
301-
302182 if err != nil {
303183 return "" , err
304184 }
305185
306- OGPTResponseBuffer := strings .TrimSpace (OGPTResponseText )
307- OGPTResponseBuffer = strings .TrimLeft (OGPTResponseBuffer , "?\n " )
308- OGPTResponseBuffer = strings .TrimLeft (OGPTResponseBuffer , "!\n " )
309- OGPTResponseBuffer = strings .TrimLeft (OGPTResponseBuffer , ":\n " )
310- OGPTResponseBuffer = strings .TrimLeft (OGPTResponseBuffer , "'\n " )
311- OGPTResponseBuffer = strings .TrimLeft (OGPTResponseBuffer , ".\n " )
312- OGPTResponseBuffer = strings .TrimLeft (OGPTResponseBuffer , "\n " )
313-
314- return OGPTResponseBuffer , nil
186+ if len (OAIGPTResponse .Choices ) > 0 {
187+ OAIGPTResponseText = OAIGPTResponseText + OAIGPTResponse .Choices [0 ].Delta .Content
188+ }
315189 }
190+
191+ OAIGPTResponseBuffer := strings .TrimSpace (OAIGPTResponseText )
192+ OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , "?\n " )
193+ OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , "!\n " )
194+ OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , ":\n " )
195+ OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , "'\n " )
196+ OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , ".\n " )
197+ OAIGPTResponseBuffer = strings .TrimLeft (OAIGPTResponseBuffer , "\n " )
198+
199+ return OAIGPTResponseBuffer , nil
316200}
0 commit comments