Skip to content

Commit 5905dea

Browse files
committed
fix: remove ollama library since ollama can compatible to use openai api format
1 parent d6be309 commit 5905dea

File tree

4 files changed

+80
-206
lines changed

4 files changed

+80
-206
lines changed

.env.example

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,6 @@ OPENAI_HOST=https://api.openai.com
2121
OPENAI_HOST_PATH=/v1
2222
OPENAI_API_KEY=
2323

24-
# -----------------------------------
25-
# Ollama Configuration
26-
# -----------------------------------
27-
OLLAMA_HOST=
28-
OLLAMA_HOST_PATH=/
29-
3024
# -----------------------------------
3125
# GPT Configuration
3226
# -----------------------------------

go.mod

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ toolchain go1.22.6
77
require (
88
github.com/joho/godotenv v1.5.1
99
github.com/lib/pq v1.10.9
10-
github.com/ollama/ollama v0.3.14
1110
github.com/sashabaranov/go-openai v1.36.1
1211
github.com/sirupsen/logrus v1.9.3
1312
github.com/spf13/cobra v1.8.1
@@ -23,7 +22,7 @@ require (
2322
github.com/inconshreveable/mousetrap v1.1.0 // indirect
2423
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
2524
github.com/mattn/go-colorable v0.1.13 // indirect
26-
github.com/mattn/go-isatty v0.0.20 // indirect
25+
github.com/mattn/go-isatty v0.0.19 // indirect
2726
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect
2827
github.com/rs/zerolog v1.33.0 // indirect
2928
github.com/spf13/pflag v1.0.5 // indirect

go.sum

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4
99
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
1010
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
1111
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
12-
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
13-
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
12+
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
13+
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
1414
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
1515
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
1616
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -28,14 +28,11 @@ github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxec
2828
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
2929
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
3030
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
31+
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
3132
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
32-
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
33-
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
3433
github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
3534
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
3635
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
37-
github.com/ollama/ollama v0.3.14 h1:e94+Fb1PDqmD3O90g5cqUSkSxfNm9U3fHMIyaKQ8aSc=
38-
github.com/ollama/ollama v0.3.14/go.mod h1:YrWoNkFnPOYsnDvsf/Ztb1wxU9/IXrNsQHqcxbY2r94=
3936
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
4037
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
4138
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=

pkg/gpt/gpt.go

Lines changed: 76 additions & 192 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,19 @@ package gpt
22

33
import (
44
"context"
5-
"encoding/json"
6-
"net"
7-
"net/http"
8-
"net/url"
5+
"errors"
6+
"io"
97
"os"
108
"regexp"
119
"strings"
1210

13-
Ollama "github.com/ollama/ollama/api"
1411
OpenAI "github.com/sashabaranov/go-openai"
1512

1613
"github.com/dimaskiddo/go-whatsapp-multidevice-gpt/pkg/env"
1714
"github.com/dimaskiddo/go-whatsapp-multidevice-gpt/pkg/log"
1815
)
1916

2017
var OAIClient *OpenAI.Client
21-
var OClient *Ollama.Client
2218

2319
var (
2420
WAGPTEngine,
@@ -32,11 +28,6 @@ var (
3228
OAIAPIKey string
3329
)
3430

35-
var (
36-
OHost,
37-
OHostPath string
38-
)
39-
4031
var (
4132
GPTModelName,
4233
GPTModelPrompt string
@@ -58,51 +49,29 @@ func init() {
5849
// -----------------------------------------------------------------------
5950
// WhatsApp GPT Configuration Environment
6051
// -----------------------------------------------------------------------
61-
WAGPTEngine, err = env.GetEnvString("WHATSAPP_GPT_ENGINE")
62-
if err != nil {
63-
log.Println(log.LogLevelFatal, "Error Parse Environment Variable for WhatsApp GPT Engine")
64-
}
65-
6652
WAGPTBlockedWord = strings.TrimSpace(os.Getenv("WHATSAPP_GPT_BLOCKED_WORD"))
6753
if len(WAGPTBlockedWord) > 0 {
6854
WAGPTBlockedWordRegex = regexp.MustCompile("\\b(?i)(" + listBlockedWord + "|" + WAGPTBlockedWord + ")")
6955
} else {
7056
WAGPTBlockedWordRegex = regexp.MustCompile("\\b(?i)(" + listBlockedWord + ")")
7157
}
7258

73-
switch strings.ToLower(WAGPTEngine) {
74-
case "openai":
75-
// -----------------------------------------------------------------------
76-
// OpenAI Configuration Environment
77-
// -----------------------------------------------------------------------
78-
OAIHost, err = env.GetEnvString("OPENAI_HOST")
79-
if err != nil {
80-
OAIHost = "https://api.openai.com"
81-
}
82-
83-
OAIHostPath, err = env.GetEnvString("OPENAI_HOST_PATH")
84-
if err != nil {
85-
OAIHostPath = "/v1"
86-
}
87-
88-
OAIAPIKey, err = env.GetEnvString("OPENAI_API_KEY")
89-
if err != nil {
90-
log.Println(log.LogLevelFatal, "Error Parse Environment Variable for OpenAI API Key")
91-
}
59+
// -----------------------------------------------------------------------
60+
// OpenAI Configuration Environment
61+
// -----------------------------------------------------------------------
62+
OAIHost, err = env.GetEnvString("OPENAI_HOST")
63+
if err != nil {
64+
OAIHost = "https://api.openai.com"
65+
}
9266

93-
default:
94-
// -----------------------------------------------------------------------
95-
// Ollama Configuration Environment
96-
// -----------------------------------------------------------------------
97-
OHost, err = env.GetEnvString("OLLAMA_HOST")
98-
if err != nil {
99-
log.Println(log.LogLevelFatal, "Error Parse Environment Variable for Ollama Host")
100-
}
67+
OAIHostPath, err = env.GetEnvString("OPENAI_HOST_PATH")
68+
if err != nil {
69+
OAIHostPath = "/v1"
70+
}
10171

102-
OHostPath, err = env.GetEnvString("OLLAMA_HOST_PATH")
103-
if err != nil {
104-
OHostPath = "/"
105-
}
72+
OAIAPIKey, err = env.GetEnvString("OPENAI_API_KEY")
73+
if err != nil {
74+
log.Println(log.LogLevelFatal, "Error Parse Environment Variable for OpenAI API Key")
10675
}
10776

10877
// -----------------------------------------------------------------------
@@ -146,38 +115,10 @@ func init() {
146115
// -----------------------------------------------------------------------
147116
// GPT Engine Initialization
148117
// -----------------------------------------------------------------------
149-
switch strings.ToLower(WAGPTEngine) {
150-
case "openai":
151-
OAIConfig := OpenAI.DefaultConfig(OAIAPIKey)
152-
OAIConfig.BaseURL = OAIHost + OAIHostPath
153-
154-
OAIClient = OpenAI.NewClientWithConfig(OAIConfig)
155-
156-
default:
157-
var OHostPort string
158-
159-
OHostSchema, OHostURL, isOK := strings.Cut(OHost, "://")
118+
OAIConfig := OpenAI.DefaultConfig(OAIAPIKey)
119+
OAIConfig.BaseURL = OAIHost + OAIHostPath
160120

161-
if !isOK {
162-
OHostSchema = "http"
163-
OHostURL = OHost
164-
OHostPort = "11434"
165-
}
166-
167-
switch OHostSchema {
168-
case "http":
169-
OHostPort = "80"
170-
171-
case "https":
172-
OHostPort = "443"
173-
}
174-
175-
OClient = Ollama.NewClient(&url.URL{
176-
Scheme: OHostSchema,
177-
Host: net.JoinHostPort(OHostURL, OHostPort),
178-
Path: OHostPath,
179-
}, http.DefaultClient)
180-
}
121+
OAIClient = OpenAI.NewClientWithConfig(OAIConfig)
181122
}
182123

183124
func GPTResponse(question string) (response string, err error) {
@@ -186,131 +127,74 @@ func GPTResponse(question string) (response string, err error) {
186127
}
187128

188129
isStream := new(bool)
189-
*isStream = false
190-
191-
switch strings.ToLower(WAGPTEngine) {
192-
case "openai":
193-
var OAIGPTResponseText string
194-
var OAIGPTChatCompletion []OpenAI.ChatCompletionMessage
195-
196-
if len(strings.TrimSpace(GPTModelPrompt)) != 0 {
197-
OAIGPTChatCompletion = []OpenAI.ChatCompletionMessage{
198-
{
199-
Role: OpenAI.ChatMessageRoleSystem,
200-
Content: GPTModelPrompt,
201-
},
202-
{
203-
Role: OpenAI.ChatMessageRoleUser,
204-
Content: question,
205-
},
206-
}
207-
} else {
208-
OAIGPTChatCompletion = []OpenAI.ChatCompletionMessage{
209-
{
210-
Role: OpenAI.ChatMessageRoleUser,
211-
Content: question,
212-
},
213-
}
214-
}
215-
216-
OAIGPTPrompt := OpenAI.ChatCompletionRequest{
217-
Model: GPTModelName,
218-
MaxTokens: GPTModelToken,
219-
Temperature: GPTModelTemperature,
220-
TopP: GPTModelTopP,
221-
PresencePenalty: GPTModelPenaltyPresence,
222-
FrequencyPenalty: GPTModelPenaltyFreq,
223-
Messages: OAIGPTChatCompletion,
224-
Stream: *isStream,
130+
*isStream = true
131+
132+
var OAIGPTResponseText string
133+
var OAIGPTChatCompletion []OpenAI.ChatCompletionMessage
134+
135+
if len(strings.TrimSpace(GPTModelPrompt)) != 0 {
136+
OAIGPTChatCompletion = []OpenAI.ChatCompletionMessage{
137+
{
138+
Role: OpenAI.ChatMessageRoleSystem,
139+
Content: GPTModelPrompt,
140+
},
141+
{
142+
Role: OpenAI.ChatMessageRoleUser,
143+
Content: question,
144+
},
225145
}
226-
227-
OAIGPTResponse, err := OAIClient.CreateChatCompletion(
228-
context.Background(),
229-
OAIGPTPrompt,
230-
)
231-
232-
if err != nil {
233-
return "", err
146+
} else {
147+
OAIGPTChatCompletion = []OpenAI.ChatCompletionMessage{
148+
{
149+
Role: OpenAI.ChatMessageRoleUser,
150+
Content: question,
151+
},
234152
}
153+
}
235154

236-
if len(OAIGPTResponse.Choices) > 0 {
237-
OAIGPTResponseText = OAIGPTResponse.Choices[0].Message.Content
238-
}
155+
OAIGPTPrompt := OpenAI.ChatCompletionRequest{
156+
Model: GPTModelName,
157+
MaxTokens: GPTModelToken,
158+
Temperature: GPTModelTemperature,
159+
TopP: GPTModelTopP,
160+
PresencePenalty: GPTModelPenaltyPresence,
161+
FrequencyPenalty: GPTModelPenaltyFreq,
162+
Messages: OAIGPTChatCompletion,
163+
Stream: *isStream,
164+
}
239165

240-
OAIGPTResponseBuffer := strings.TrimSpace(OAIGPTResponseText)
241-
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, "?\n")
242-
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, "!\n")
243-
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, ":\n")
244-
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, "'\n")
245-
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, ".\n")
246-
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, "\n")
247-
248-
return OAIGPTResponseBuffer, nil
249-
250-
default:
251-
var OGPTResponseText string
252-
var OGPTChatCompletion []Ollama.Message
253-
254-
if len(strings.TrimSpace(GPTModelPrompt)) != 0 {
255-
OGPTChatCompletion = []Ollama.Message{
256-
{
257-
Role: "system",
258-
Content: GPTModelPrompt,
259-
},
260-
{
261-
Role: "user",
262-
Content: question,
263-
},
264-
}
265-
} else {
266-
OGPTChatCompletion = []Ollama.Message{
267-
{
268-
Role: "user",
269-
Content: question,
270-
},
271-
}
272-
}
166+
OAIGPTStream, err := OAIClient.CreateChatCompletionStream(
167+
context.Background(),
168+
OAIGPTPrompt,
169+
)
273170

274-
OGPTOptions := map[string]interface{}{}
275-
OGPTOptionsMarshal, _ := json.Marshal(Ollama.Options{
276-
Temperature: GPTModelTemperature,
277-
TopP: GPTModelTopP,
278-
PresencePenalty: GPTModelPenaltyPresence,
279-
FrequencyPenalty: GPTModelPenaltyFreq,
280-
})
281-
282-
json.Unmarshal(OGPTOptionsMarshal, &OGPTOptions)
283-
284-
OGPTPrompt := &Ollama.ChatRequest{
285-
Model: GPTModelName,
286-
Options: OGPTOptions,
287-
Messages: OGPTChatCompletion,
288-
Stream: isStream,
289-
}
171+
if err != nil {
172+
return "", err
173+
}
174+
defer OAIGPTStream.Close()
290175

291-
OGTPResponseFunc := func(OGPTResponse Ollama.ChatResponse) error {
292-
OGPTResponseText = OGPTResponse.Message.Content
293-
return nil
176+
for {
177+
OAIGPTResponse, err := OAIGPTStream.Recv()
178+
if errors.Is(err, io.EOF) {
179+
break
294180
}
295181

296-
err := OClient.Chat(
297-
context.Background(),
298-
OGPTPrompt,
299-
OGTPResponseFunc,
300-
)
301-
302182
if err != nil {
303183
return "", err
304184
}
305185

306-
OGPTResponseBuffer := strings.TrimSpace(OGPTResponseText)
307-
OGPTResponseBuffer = strings.TrimLeft(OGPTResponseBuffer, "?\n")
308-
OGPTResponseBuffer = strings.TrimLeft(OGPTResponseBuffer, "!\n")
309-
OGPTResponseBuffer = strings.TrimLeft(OGPTResponseBuffer, ":\n")
310-
OGPTResponseBuffer = strings.TrimLeft(OGPTResponseBuffer, "'\n")
311-
OGPTResponseBuffer = strings.TrimLeft(OGPTResponseBuffer, ".\n")
312-
OGPTResponseBuffer = strings.TrimLeft(OGPTResponseBuffer, "\n")
313-
314-
return OGPTResponseBuffer, nil
186+
if len(OAIGPTResponse.Choices) > 0 {
187+
OAIGPTResponseText = OAIGPTResponseText + OAIGPTResponse.Choices[0].Delta.Content
188+
}
315189
}
190+
191+
OAIGPTResponseBuffer := strings.TrimSpace(OAIGPTResponseText)
192+
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, "?\n")
193+
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, "!\n")
194+
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, ":\n")
195+
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, "'\n")
196+
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, ".\n")
197+
OAIGPTResponseBuffer = strings.TrimLeft(OAIGPTResponseBuffer, "\n")
198+
199+
return OAIGPTResponseBuffer, nil
316200
}

0 commit comments

Comments
 (0)