|
1 | 1 | package tokenizer |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "strings" |
| 5 | + |
4 | 6 | "github.com/pkoukk/tiktoken-go" |
5 | 7 | "github.com/sashabaranov/go-openai" |
6 | 8 | ) |
7 | 9 |
|
8 | 10 | var encoders = map[string]*tiktoken.Tiktoken{} |
9 | 11 |
|
10 | | -func CountTokens(model, text string) int { |
| 12 | +func getEncoding(model string) (*tiktoken.Tiktoken, error) { |
11 | 13 | enc, ok := encoders[model] |
| 14 | + var err error |
12 | 15 | if !ok { |
13 | | - enc, _ = tiktoken.EncodingForModel(model) |
| 16 | + enc, err = tiktoken.EncodingForModel(model) |
| 17 | + if err != nil { |
| 18 | + return nil, err |
| 19 | + } |
14 | 20 | encoders[model] = enc |
15 | 21 | } |
| 22 | + return enc, nil |
| 23 | +} |
| 24 | + |
| 25 | +func CheckModel(model string) error { |
| 26 | + _, err := getEncoding(model) |
| 27 | + return err |
| 28 | +} |
| 29 | + |
| 30 | +func CountTokens(model, text string) int { |
| 31 | + enc, err := getEncoding(model) |
| 32 | + if err != nil { |
| 33 | + panic(err) |
| 34 | + } |
16 | 35 | return len(enc.Encode(text, nil, nil)) |
17 | 36 | } |
18 | 37 |
|
19 | 38 | // CountMessagesTokens based on https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb |
20 | 39 | func CountMessagesTokens(model string, messages []openai.ChatCompletionMessage) int { |
21 | | - var tokens int |
22 | | - var tokensPerMessage int |
23 | | - var tokensPerName int |
| 40 | + var ( |
| 41 | + tokens int |
| 42 | + tokensPerMessage int |
| 43 | + tokensPerName int |
| 44 | + ) |
24 | 45 |
|
25 | | - switch model { |
26 | | - case openai.GPT3Dot5Turbo, openai.GPT3Dot5Turbo0301: |
| 46 | + if strings.HasPrefix(model, "gpt-3.5") { |
27 | 47 | tokensPerMessage = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n |
28 | 48 | tokensPerName = -1 // if there's a name, the role is omitted |
29 | | - case openai.GPT4, openai.GPT40314, openai.GPT432K, openai.GPT432K0314: |
| 49 | + } else if strings.HasPrefix(model, "gpt-4") { |
30 | 50 | tokensPerMessage = 3 |
31 | 51 | tokensPerName = 1 |
32 | 52 | } |
|
0 commit comments