Skip to content

Commit 0bb0213

Browse files
authored
fix: add check to model name (#83)
1 parent b732ab5 commit 0bb0213

File tree

5 files changed

+53
-17
lines changed

5 files changed

+53
-17
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ Notes:
290290

291291
- `api_type` should be "AZURE" or "AZURE_AD".
292292
- `api_version` defaults to "2023-05-15" if not specified.
293-
- Configure `model_mapping` to map model names to your deployment names. If not specified, the model name will be used as the deployment name with `.` or `:` removed (e.g. "gpt-3.5-turbo" -> "gpt-35-turbo").
293+
- Configure `model_mapping` to map model names to your deployment names. The key must be a valid OpenAI model name. If not specified, the model name will be used as the deployment name with `.` or `:` removed (e.g. "gpt-3.5-turbo" -> "gpt-35-turbo").
294294

295295
Find more details about Azure OpenAI service here: https://learn.microsoft.com/en-US/azure/ai-services/openai/reference.
296296

cmd/chatgpt/main.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,10 @@ func main() {
9393
defer func() { _ = lockFile.Unlock() }()
9494
}
9595

96-
conversations := chatgpt.NewConversationManager(conf, chatgpt.ConversationHistoryFile())
96+
conversations, err := chatgpt.NewConversationManager(conf, chatgpt.ConversationHistoryFile())
97+
if err != nil {
98+
exit(err)
99+
}
97100

98101
if *startNewConversation {
99102
conversations.New(conf.Conversation)

config.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010

1111
"github.com/mitchellh/go-homedir"
1212
"github.com/sashabaranov/go-openai"
13+
14+
"github.com/j178/chatgpt/tokenizer"
1315
)
1416

1517
type ConversationConfig struct {
@@ -140,7 +142,7 @@ func InitConfig() (GlobalConfig, error) {
140142
Prompt: "default",
141143
ContextLength: 6,
142144
Stream: true,
143-
Temperature: 0,
145+
Temperature: 1.0,
144146
MaxTokens: 1024,
145147
},
146148
KeyMap: defaultKeyMapConfig(),
@@ -157,14 +159,21 @@ func InitConfig() (GlobalConfig, error) {
157159
if endpoint != "" {
158160
conf.Endpoint = endpoint
159161
}
162+
160163
if conf.APIKey == "" {
161164
return GlobalConfig{}, errors.New("Missing API key. Set it in `~/.config/chatgpt/config.json` or by setting the `OPENAI_API_KEY` environment variable. You can find or create your API key at https://platform.openai.com/account/api-keys.")
162165
}
166+
163167
conf.APIType = openai.APIType(strings.ToUpper(string(conf.APIType)))
164168
switch conf.APIType {
169+
case openai.APITypeOpenAI, openai.APITypeAzure, openai.APITypeAzureAD:
165170
default:
166171
return GlobalConfig{}, fmt.Errorf("unknown API type: %s", conf.APIType)
167-
case openai.APITypeOpenAI, openai.APITypeAzure, openai.APITypeAzureAD:
172+
}
173+
174+
err = tokenizer.CheckModel(conf.Conversation.Model)
175+
if err != nil {
176+
return GlobalConfig{}, fmt.Errorf("invalid model %s", conf.Conversation.Model)
168177
}
169178
return conf, nil
170179
}

conversation.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package chatgpt
33
import (
44
"encoding/json"
55
"errors"
6-
"log"
6+
"fmt"
77
"os"
88

99
"github.com/sashabaranov/go-openai"
@@ -18,7 +18,7 @@ type ConversationManager struct {
1818
Idx int `json:"last_idx"`
1919
}
2020

21-
func NewConversationManager(conf GlobalConfig, historyFile string) *ConversationManager {
21+
func NewConversationManager(conf GlobalConfig, historyFile string) (*ConversationManager, error) {
2222
h := &ConversationManager{
2323
file: historyFile,
2424
globalConf: conf,
@@ -27,9 +27,9 @@ func NewConversationManager(conf GlobalConfig, historyFile string) *Conversation
2727

2828
err := h.Load()
2929
if err != nil {
30-
log.Println("Failed to load history:", err)
30+
return nil, fmt.Errorf("Failed to load conversation history: %w", err)
3131
}
32-
return h
32+
return h, nil
3333
}
3434

3535
func (m *ConversationManager) Dump() error {
@@ -65,7 +65,11 @@ func (m *ConversationManager) Load() error {
6565
if err != nil {
6666
return err
6767
}
68-
for _, c := range m.Conversations {
68+
for i, c := range m.Conversations {
69+
err = tokenizer.CheckModel(c.Config.Model)
70+
if err != nil {
71+
return fmt.Errorf("invalid model %s in conversation %d", c.Config.Model, i+1)
72+
}
6973
c.manager = m
7074
}
7175
return nil

tokenizer/tokenize.go

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,52 @@
11
package tokenizer
22

33
import (
4+
"strings"
5+
46
"github.com/pkoukk/tiktoken-go"
57
"github.com/sashabaranov/go-openai"
68
)
79

810
var encoders = map[string]*tiktoken.Tiktoken{}
911

10-
func CountTokens(model, text string) int {
12+
func getEncoding(model string) (*tiktoken.Tiktoken, error) {
1113
enc, ok := encoders[model]
14+
var err error
1215
if !ok {
13-
enc, _ = tiktoken.EncodingForModel(model)
16+
enc, err = tiktoken.EncodingForModel(model)
17+
if err != nil {
18+
return nil, err
19+
}
1420
encoders[model] = enc
1521
}
22+
return enc, nil
23+
}
24+
25+
func CheckModel(model string) error {
26+
_, err := getEncoding(model)
27+
return err
28+
}
29+
30+
func CountTokens(model, text string) int {
31+
enc, err := getEncoding(model)
32+
if err != nil {
33+
panic(err)
34+
}
1635
return len(enc.Encode(text, nil, nil))
1736
}
1837

1938
// CountMessagesTokens based on https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
2039
func CountMessagesTokens(model string, messages []openai.ChatCompletionMessage) int {
21-
var tokens int
22-
var tokensPerMessage int
23-
var tokensPerName int
40+
var (
41+
tokens int
42+
tokensPerMessage int
43+
tokensPerName int
44+
)
2445

25-
switch model {
26-
case openai.GPT3Dot5Turbo, openai.GPT3Dot5Turbo0301:
46+
if strings.HasPrefix(model, "gpt-3.5") {
2747
tokensPerMessage = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
2848
tokensPerName = -1 // if there's a name, the role is omitted
29-
case openai.GPT4, openai.GPT40314, openai.GPT432K, openai.GPT432K0314:
49+
} else if strings.HasPrefix(model, "gpt-4") {
3050
tokensPerMessage = 3
3151
tokensPerName = 1
3252
}

0 commit comments

Comments
 (0)