11package com .plexpt .chatgpt .util ;
22
3- import cn .hutool .core .util .StrUtil ;
43import com .knuddels .jtokkit .Encodings ;
54import com .knuddels .jtokkit .api .Encoding ;
65import com .knuddels .jtokkit .api .EncodingRegistry ;
7- import com .plexpt .chatgpt .entity .chat .ChatCompletion ;
6+ import com .knuddels .jtokkit .api .EncodingType ;
7+ import com .knuddels .jtokkit .api .ModelType ;
88import com .plexpt .chatgpt .entity .chat .Message ;
99import lombok .experimental .UtilityClass ;
10+ import org .springframework .util .CollectionUtils ;
11+ import org .springframework .util .StringUtils ;
1012
11- import java .util .HashMap ;
1213import java .util .List ;
13- import java .util .Map ;
1414import java .util .Optional ;
1515
1616@ UtilityClass
1717public class TokensUtil {
1818
19- private static final Map < String , Encoding > modelEncodingMap = new HashMap <> ();
20- private static final EncodingRegistry encodingRegistry = Encodings . newDefaultEncodingRegistry ( );
19+ public static EncodingRegistry registry = Encodings . newDefaultEncodingRegistry ();
20+ public static Encoding encoding = registry . getEncoding ( EncodingType . CL100K_BASE );
2121
22- static {
23- for (ChatCompletion .Model model : ChatCompletion .Model .values ()) {
24- Optional <Encoding > encodingForModel = encodingRegistry .getEncodingForModel (model .getName ());
25- encodingForModel .ifPresent (encoding -> modelEncodingMap .put (model .getName (), encoding ));
26- }
22+
23+ /**
24+ * 计算text信息的tokens
25+ *
26+ * @param text
27+ * @return
28+ */
29+ public static int countTextTokens (String text ) {
30+ return encoding .countTokens (text );
2731 }
2832
33+
2934 /**
30- * 计算tokens
31- * @param modelName 模型名称
32- * @param messages 消息列表
33- * @return 计算出的tokens数量
35+ * 获取modelType
36+ *
37+ * @param name
38+ * @return
3439 */
40+ private static ModelType getModelTypeByName (String name ) {
41+ Optional <ModelType > optional = ModelType .fromName (name );
3542
36- public static int tokens (String modelName , List <Message > messages ) {
37- Encoding encoding = modelEncodingMap .get (modelName );
38- if (encoding == null ) {
39- throw new IllegalArgumentException ("Unsupported model: " + modelName );
43+ return optional .orElse (ModelType .GPT_3_5_TURBO );
44+ }
45+
46+ /**
47+ * 通过模型名称计算messages获取编码数组
48+ * 参考官方的处理逻辑:
49+ * <a href=https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb>https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb</a>
50+ *
51+ * @param messages 消息体
52+ * @return
53+ */
54+ public static int tokens (List <Message > messages , String model ) {
55+ if (CollectionUtils .isEmpty (messages )) {
56+ return 0 ;
4057 }
4158
42- int tokensPerMessage = 0 ;
43- int tokensPerName = 0 ;
44- if (modelName .startsWith ("gpt-4" )) {
59+ //"gpt-3.5-turbo"
60+ // every message follows <|start|>{role/name}\n{content}<|end|>\n
61+ int tokensPerMessage = 4 ;
62+ // if there's a name, the role is omitted
63+ int tokensPerName = -1 ;
64+
65+ if (StringUtils .startsWithIgnoreCase (model , ModelType .GPT_4 .getName ())) {
4566 tokensPerMessage = 3 ;
4667 tokensPerName = 1 ;
47- } else if (modelName .startsWith ("gpt-3.5-turbo" )) {
48- tokensPerMessage = 4 ; // every message follows <|start|>{role/name}\n{content}<|end|>\n
49- tokensPerName = -1 ; // if there's a name, the role is omitted
5068 }
69+
5170 int sum = 0 ;
52- for (Message message : messages ) {
71+ for (final Message message : messages ) {
5372 sum += tokensPerMessage ;
5473 sum += encoding .countTokens (message .getContent ());
5574 sum += encoding .countTokens (message .getRole ());
56- if (StrUtil . isNotBlank (message .getName ())) {
75+ if (! StringUtils . isEmpty (message .getName ())) {
5776 sum += encoding .countTokens (message .getName ());
5877 sum += tokensPerName ;
5978 }
6079 }
80+
81+ // every reply is primed with <|start|>assistant<|message|>
6182 sum += 3 ;
83+
6284 return sum ;
6385 }
64- }
86+
87+ /**
88+ * 计算tokens
89+ *
90+ * @param modelName 模型名称
91+ * @param messages 消息列表
92+ * @return 计算出的tokens数量
93+ */
94+
95+ public static int tokens (String modelName , List <Message > messages ) {
96+ return tokens (messages , modelName );
97+ }
98+ }
0 commit comments