Skip to content

Commit 8f7d1ea

Browse files
committed
token count update
1 parent b31bc66 commit 8f7d1ea

File tree

3 files changed

+74
-43
lines changed

3 files changed

+74
-43
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@
132132
<dependency>
133133
<groupId>com.knuddels</groupId>
134134
<artifactId>jtokkit</artifactId>
135-
<version>0.4.0</version>
135+
<version>1.0.0</version>
136136
</dependency>
137137
</dependencies>
138138

src/main/java/com/plexpt/chatgpt/ChatGPTStream.java

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,15 @@
11
package com.plexpt.chatgpt;
22

3-
import com.fasterxml.jackson.databind.ObjectMapper;
3+
import cn.hutool.core.util.RandomUtil;
4+
import cn.hutool.http.ContentType;
45
import com.plexpt.chatgpt.api.Api;
56
import com.plexpt.chatgpt.entity.chat.ChatCompletion;
67
import com.plexpt.chatgpt.entity.chat.Message;
7-
8-
import java.net.Proxy;
9-
import java.util.List;
10-
import java.util.Objects;
11-
import java.util.concurrent.TimeUnit;
12-
13-
import cn.hutool.core.util.RandomUtil;
14-
import cn.hutool.http.ContentType;
8+
import com.plexpt.chatgpt.util.fastjson.JSON;
159
import lombok.AllArgsConstructor;
1610
import lombok.Builder;
1711
import lombok.Data;
1812
import lombok.NoArgsConstructor;
19-
import lombok.NonNull;
2013
import lombok.extern.slf4j.Slf4j;
2114
import okhttp3.MediaType;
2215
import okhttp3.OkHttpClient;
@@ -26,6 +19,11 @@
2619
import okhttp3.sse.EventSourceListener;
2720
import okhttp3.sse.EventSources;
2821

22+
import java.net.Proxy;
23+
import java.util.List;
24+
import java.util.Objects;
25+
import java.util.concurrent.TimeUnit;
26+
2927

3028
/**
3129
* open ai 客户端
@@ -89,8 +87,8 @@ public void streamChatCompletion(ChatCompletion chatCompletion,
8987

9088
try {
9189
EventSource.Factory factory = EventSources.createFactory(okHttpClient);
92-
ObjectMapper mapper = new ObjectMapper();
93-
String requestBody = mapper.writeValueAsString(chatCompletion);
90+
91+
String requestBody = JSON.toJSONString(chatCompletion);
9492
String key = apiKey;
9593
if (apiKeyList != null && !apiKeyList.isEmpty()) {
9694
key = RandomUtil.randomEle(apiKeyList);
@@ -99,8 +97,7 @@ public void streamChatCompletion(ChatCompletion chatCompletion,
9997

10098
Request request = new Request.Builder()
10199
.url(apiHost + "v1/chat/completions")
102-
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()),
103-
requestBody))
100+
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
104101
.header("Authorization", "Bearer " + key)
105102
.build();
106103
factory.newEventSource(request, eventSourceListener);
Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,98 @@
11
package com.plexpt.chatgpt.util;
22

3-
import cn.hutool.core.util.StrUtil;
43
import com.knuddels.jtokkit.Encodings;
54
import com.knuddels.jtokkit.api.Encoding;
65
import com.knuddels.jtokkit.api.EncodingRegistry;
7-
import com.plexpt.chatgpt.entity.chat.ChatCompletion;
6+
import com.knuddels.jtokkit.api.EncodingType;
7+
import com.knuddels.jtokkit.api.ModelType;
88
import com.plexpt.chatgpt.entity.chat.Message;
99
import lombok.experimental.UtilityClass;
10+
import org.springframework.util.CollectionUtils;
11+
import org.springframework.util.StringUtils;
1012

11-
import java.util.HashMap;
1213
import java.util.List;
13-
import java.util.Map;
1414
import java.util.Optional;
1515

1616
@UtilityClass
1717
public class TokensUtil {
1818

19-
private static final Map<String, Encoding> modelEncodingMap = new HashMap<>();
20-
private static final EncodingRegistry encodingRegistry = Encodings.newDefaultEncodingRegistry();
19+
public static EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
20+
public static Encoding encoding = registry.getEncoding(EncodingType.CL100K_BASE);
2121

22-
static {
23-
for (ChatCompletion.Model model : ChatCompletion.Model.values()) {
24-
Optional<Encoding> encodingForModel = encodingRegistry.getEncodingForModel(model.getName());
25-
encodingForModel.ifPresent(encoding -> modelEncodingMap.put(model.getName(), encoding));
26-
}
22+
23+
/**
24+
* 计算text信息的tokens
25+
*
26+
* @param text
27+
* @return
28+
*/
29+
public static int countTextTokens(String text) {
30+
return encoding.countTokens(text);
2731
}
2832

33+
2934
/**
30-
* 计算tokens
31-
* @param modelName 模型名称
32-
* @param messages 消息列表
33-
* @return 计算出的tokens数量
35+
* 获取modelType
36+
*
37+
* @param name
38+
* @return
3439
*/
40+
private static ModelType getModelTypeByName(String name) {
41+
Optional<ModelType> optional = ModelType.fromName(name);
3542

36-
public static int tokens(String modelName, List<Message> messages) {
37-
Encoding encoding = modelEncodingMap.get(modelName);
38-
if (encoding == null) {
39-
throw new IllegalArgumentException("Unsupported model: " + modelName);
43+
return optional.orElse(ModelType.GPT_3_5_TURBO);
44+
}
45+
46+
/**
47+
* 通过模型名称计算messages获取编码数组
48+
* 参考官方的处理逻辑:
49+
* <a href=https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb>https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb</a>
50+
*
51+
* @param messages 消息体
52+
* @return
53+
*/
54+
public static int tokens(List<Message> messages, String model) {
55+
if (CollectionUtils.isEmpty(messages)) {
56+
return 0;
4057
}
4158

42-
int tokensPerMessage = 0;
43-
int tokensPerName = 0;
44-
if (modelName.startsWith("gpt-4")) {
59+
//"gpt-3.5-turbo"
60+
// every message follows <|start|>{role/name}\n{content}<|end|>\n
61+
int tokensPerMessage = 4;
62+
// if there's a name, the role is omitted
63+
int tokensPerName = -1;
64+
65+
if (StringUtils.startsWithIgnoreCase(model, ModelType.GPT_4.getName())) {
4566
tokensPerMessage = 3;
4667
tokensPerName = 1;
47-
} else if (modelName.startsWith("gpt-3.5-turbo")) {
48-
tokensPerMessage = 4; // every message follows <|start|>{role/name}\n{content}<|end|>\n
49-
tokensPerName = -1; // if there's a name, the role is omitted
5068
}
69+
5170
int sum = 0;
52-
for (Message message : messages) {
71+
for (final Message message : messages) {
5372
sum += tokensPerMessage;
5473
sum += encoding.countTokens(message.getContent());
5574
sum += encoding.countTokens(message.getRole());
56-
if (StrUtil.isNotBlank(message.getName())) {
75+
if (!StringUtils.isEmpty(message.getName())) {
5776
sum += encoding.countTokens(message.getName());
5877
sum += tokensPerName;
5978
}
6079
}
80+
81+
// every reply is primed with <|start|>assistant<|message|>
6182
sum += 3;
83+
6284
return sum;
6385
}
64-
}
86+
87+
/**
88+
* 计算tokens
89+
*
90+
* @param modelName 模型名称
91+
* @param messages 消息列表
92+
* @return 计算出的tokens数量
93+
*/
94+
95+
public static int tokens(String modelName, List<Message> messages) {
96+
return tokens(messages, modelName);
97+
}
98+
}

0 commit comments

Comments
 (0)