Skip to content

Commit f5225a1

Browse files
committed
换用廉价模型 (#271)
1 parent b629ce3 commit f5225a1

File tree

15 files changed

+241
-56
lines changed

15 files changed

+241
-56
lines changed

internal/ai/handlers.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,14 @@ func InitHandlerFacade(common []handler.Builder,
3737

3838
func InitZhipu() *zhipu.Handler {
3939
type Config struct {
40-
APIKey string `yaml:"apikey"`
41-
Price float64 `yaml:"price"`
40+
APIKey string `yaml:"apikey"`
4241
}
4342
var cfg Config
4443
err := econf.UnmarshalKey("zhipu", &cfg)
4544
if err != nil {
4645
panic(err)
4746
}
48-
h, err := zhipu.NewHandler(cfg.APIKey, cfg.Price)
47+
h, err := zhipu.NewHandler(cfg.APIKey)
4948
if err != nil {
5049
panic(err)
5150
}

internal/ai/internal/domain/llm.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ type LLMResponse struct {
2626
}
2727

2828
type BizConfig struct {
29+
// 使用的模型
30+
Model string
31+
// 多少分钱/1000 token
32+
Price int64
33+
34+
Temperature float64
35+
TopP float64
36+
37+
// 系统 Prompt
38+
SystemPrompt string
2939
// 允许的最长输入
3040
// 这里我们不用计算 token,只需要简单约束一下字符串长度就可以
3141
MaxInput int

internal/ai/internal/integration/llm_service_test.go

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func (s *LLMServiceSuite) SetupSuite() {
5151
err = s.db.Create(&dao.BizConfig{
5252
Biz: domain.BizQuestionExamine,
5353
MaxInput: 100,
54-
PromptTemplate: "这是问题 %s,这是用户输入 %s",
54+
PromptTemplate: "这是问题 %s,这是问题内容 %s,这是用户输入 %s",
5555
KnowledgeId: knowledgeId,
5656
Ctime: now,
5757
Utime: now,
@@ -60,7 +60,7 @@ func (s *LLMServiceSuite) SetupSuite() {
6060
err = s.db.Create(&dao.BizConfig{
6161
Biz: domain.BizCaseExamine,
6262
MaxInput: 100,
63-
PromptTemplate: "这是案例 %s,这是用户输入 %s",
63+
PromptTemplate: "这是案例 %s,这是案例内容 %s,这是用户输入 %s",
6464
KnowledgeId: knowledgeId,
6565
Ctime: now,
6666
Utime: now,
@@ -97,6 +97,7 @@ func (s *LLMServiceSuite) TestService() {
9797
Tid: "11",
9898
Input: []string{
9999
"问题1",
100+
"问题1内容",
100101
"用户输入1",
101102
},
102103
},
@@ -142,11 +143,12 @@ func (s *LLMServiceSuite) TestService() {
142143
Valid: true,
143144
Val: []string{
144145
"问题1",
146+
"问题1内容",
145147
"用户输入1",
146148
},
147149
},
148150
Status: 1,
149-
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"),
151+
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是问题内容 %s,这是用户输入 %s"),
150152
Answer: sqlx.NewNullString("aians"),
151153
}, logModel)
152154
// 校验credit写入的内容是否正确
@@ -171,6 +173,7 @@ func (s *LLMServiceSuite) TestService() {
171173
Tid: "13",
172174
Input: []string{
173175
"案例1",
176+
"案例1内容",
174177
"用户输入1",
175178
},
176179
},
@@ -217,11 +220,12 @@ func (s *LLMServiceSuite) TestService() {
217220
Valid: true,
218221
Val: []string{
219222
"案例1",
223+
"案例1内容",
220224
"用户输入1",
221225
},
222226
},
223227
Status: 1,
224-
PromptTemplate: sqlx.NewNullString("这是案例 %s,这是用户输入 %s"),
228+
PromptTemplate: sqlx.NewNullString("这是案例 %s,这是案例内容 %s,这是用户输入 %s"),
225229
Answer: sqlx.NewNullString("aians"),
226230
}, logModel)
227231
// 校验credit写入的内容是否正确
@@ -274,11 +278,12 @@ func (s *LLMServiceSuite) TestService() {
274278
Valid: true,
275279
Val: []string{
276280
"问题1",
281+
"问题1内容",
277282
"用户输入1",
278283
},
279284
},
280285
Status: domain.RecordStatusFailed.ToUint8(),
281-
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"),
286+
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是问题内容 %s,这是用户输入 %s"),
282287
}, logModel)
283288
},
284289
assertFunc: assert.Error,
@@ -291,6 +296,7 @@ func (s *LLMServiceSuite) TestService() {
291296
Tid: "11",
292297
Input: []string{
293298
"问题1",
299+
"问题1内容",
294300
"用户输入1",
295301
},
296302
},
@@ -323,11 +329,12 @@ func (s *LLMServiceSuite) TestService() {
323329
Valid: true,
324330
Val: []string{
325331
"问题1",
332+
"问题1内容",
326333
"用户输入1",
327334
},
328335
},
329336
Status: domain.CreditStatusFailed.ToUint8(),
330-
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"),
337+
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是问题内容 %s,这是用户输入 %s"),
331338
Answer: sqlx.NewNullString("aians"),
332339
}, logModel)
333340
// 校验credit写入的内容是否正确
@@ -353,6 +360,7 @@ func (s *LLMServiceSuite) TestService() {
353360
Tid: "11",
354361
Input: []string{
355362
"问题1",
363+
"问题1内容",
356364
"用户输入1",
357365
},
358366
},
@@ -412,11 +420,12 @@ func (s *LLMServiceSuite) TestService() {
412420
Valid: true,
413421
Val: []string{
414422
"问题1",
423+
"问题1内容",
415424
"用户输入1",
416425
},
417426
},
418427
Status: domain.RecordStatusFailed.ToUint8(),
419-
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"),
428+
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是问题内容 %s,这是用户输入 %s"),
420429
Answer: sqlx.NewNullString("aians"),
421430
}, logModel)
422431
// 校验credit写入的内容是否正确

internal/ai/internal/repository/config.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ func (repo *CachedConfigRepository) GetConfig(ctx context.Context, biz string) (
4141
return domain.BizConfig{}, err
4242
}
4343
return domain.BizConfig{
44+
Model: res.Model,
45+
Price: res.Price,
46+
Temperature: res.Temperature,
47+
TopP: res.TopP,
48+
SystemPrompt: res.SystemPrompt,
4449
MaxInput: res.MaxInput,
4550
PromptTemplate: res.PromptTemplate,
4651
KnowledgeId: res.KnowledgeId,

internal/ai/internal/repository/dao/config.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,15 @@ func (dao *GORMConfigDAO) GetConfig(ctx context.Context, biz string) (BizConfig,
3939
}
4040

4141
type BizConfig struct {
42-
Id int64 `gorm:"primaryKey;autoIncrement;comment:AI biz 配置表ID"`
43-
Biz string `gorm:"type:varchar(256);uniqueIndex;not null;comment:业务类型名"`
44-
MaxInput int `gorm:"comment:最大输入长度"`
42+
Id int64 `gorm:"primaryKey;autoIncrement;comment:AI biz 配置表ID"`
43+
Biz string `gorm:"type:varchar(256);uniqueIndex;not null;comment:业务类型名"`
44+
MaxInput int `gorm:"comment:最大输入长度"`
45+
Model string `gorm:"type:varchar(256)"`
46+
Price int64
47+
Temperature float64
48+
TopP float64
49+
// 系统 prompt
50+
SystemPrompt string
4551
PromptTemplate string
4652
KnowledgeId string `gorm:"type:varchar(256);not null;comment:使用的知识库 ID"`
4753
// 其它字段按需添加

internal/ai/internal/service/llm/handler/biz/case_examine.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@ func NewCaseExamineBizHandlerBuilder() *CaseExamineBizHandlerBuilder {
1919
func (h *CaseExamineBizHandlerBuilder) Next(next handler.Handler) handler.Handler {
2020
return handler.HandleFunc(func(ctx context.Context, req domain.LLMRequest) (domain.LLMResponse, error) {
2121
title := req.Input[0]
22-
userInput := req.Input[1]
22+
refCase := req.Input[1]
23+
userInput := req.Input[2]
2324
userInputLen := utf8.RuneCount([]byte(userInput))
2425

2526
if userInputLen > req.Config.MaxInput {
2627
return domain.LLMResponse{}, fmt.Errorf("输入太长,最常不超过 %d,现有长度 %d", req.Config.MaxInput, userInputLen)
2728
}
2829
// 把 input 和 prompt 结合起来
29-
prompt := fmt.Sprintf(req.Config.PromptTemplate, title, userInput)
30+
prompt := fmt.Sprintf(req.Config.PromptTemplate, title, refCase, userInput)
3031
req.Prompt = prompt
3132
return next.Handle(ctx, req)
3233
})

internal/ai/internal/service/llm/handler/biz/question_examine.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@ func NewQuestionExamineBizHandlerBuilder() *QuestionExamineBizHandlerBuilder {
3333
func (h *QuestionExamineBizHandlerBuilder) Next(next handler.Handler) handler.Handler {
3434
return handler.HandleFunc(func(ctx context.Context, req domain.LLMRequest) (domain.LLMResponse, error) {
3535
title := req.Input[0]
36-
userInput := req.Input[1]
36+
answer := req.Input[1]
37+
userInput := req.Input[2]
3738
userInputLen := utf8.RuneCount([]byte(userInput))
3839

3940
if userInputLen > req.Config.MaxInput {
4041
return domain.LLMResponse{}, fmt.Errorf("输入太长,最常不超过 %d,现有长度 %d", req.Config.MaxInput, userInputLen)
4142
}
4243
// 把 input 和 prompt 结合起来
43-
prompt := fmt.Sprintf(req.Config.PromptTemplate, title, userInput)
44+
prompt := fmt.Sprintf(req.Config.PromptTemplate, title, answer, userInput)
4445
req.Prompt = prompt
4546
return next.Handle(ctx, req)
4647
})

internal/ai/internal/service/llm/handler/platform/zhipu/handler.go

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,16 @@ import (
1111
// Handler 如果后续有不同的实现,就提供不同的实现
1212
type Handler struct {
1313
client *zhipu.Client
14-
svc *zhipu.ChatCompletionService
15-
// 价格和 model 进行绑定的
16-
price float64
1714
}
1815

19-
func NewHandler(apikey string,
20-
price float64) (*Handler, error) {
16+
func NewHandler(apikey string) (*Handler, error) {
2117
client, err := zhipu.NewClient(zhipu.WithAPIKey(apikey))
2218
if err != nil {
2319
return nil, err
2420
}
25-
const model = "glm-4-0520"
26-
svc := client.ChatCompletion(model)
2721
return &Handler{
2822
client: client,
2923
// 后续可以做成可配置的
30-
svc: svc,
31-
price: price,
3224
}, err
3325
}
3426

@@ -38,19 +30,15 @@ func (h *Handler) Name() string {
3830

3931
func (h *Handler) Handle(ctx context.Context, req domain.LLMRequest) (domain.LLMResponse, error) {
4032
// 这边它不会调用 next,因为它是最终的出口
41-
completion, err := h.svc.AddTool(zhipu.ChatCompletionToolRetrieval{
42-
KnowledgeID: req.Config.KnowledgeId,
43-
}).AddMessage(zhipu.ChatCompletionMessage{
44-
Role: "user",
45-
Content: req.Prompt,
46-
}).Do(ctx)
33+
chatReq := h.buildReq(req)
34+
completion, err := chatReq.Do(ctx)
4735
if err != nil {
4836
return domain.LLMResponse{}, err
4937
}
5038
tokens := completion.Usage.TotalTokens
5139
// 现在的报价都是 N/1k token
5240
// 而后向上取整
53-
amt := math.Ceil(float64(tokens) * h.price / 1000)
41+
amt := math.Ceil(float64(tokens*req.Config.Price) / float64(1000))
5442
// 金额只有具体的模型才知道怎么算
5543
resp := domain.LLMResponse{
5644
Tokens: tokens,
@@ -62,3 +50,33 @@ func (h *Handler) Handle(ctx context.Context, req domain.LLMRequest) (domain.LLM
6250
}
6351
return resp, nil
6452
}
53+
54+
func (h *Handler) buildReq(req domain.LLMRequest) *zhipu.ChatCompletionService {
55+
svc := h.client.ChatCompletion(req.Config.Model)
56+
chatReq := svc.AddMessage(zhipu.ChatCompletionMessage{
57+
Role: zhipu.RoleUser,
58+
Content: req.Prompt,
59+
})
60+
61+
if req.Config.Temperature > 0 {
62+
chatReq = chatReq.SetTemperature(req.Config.Temperature)
63+
}
64+
65+
if req.Config.TopP > 0 {
66+
chatReq = chatReq.SetTopP(req.Config.TopP)
67+
}
68+
69+
if req.Config.SystemPrompt != "" {
70+
chatReq = chatReq.AddMessage(zhipu.ChatCompletionMessage{
71+
Role: zhipu.RoleSystem,
72+
Content: req.Config.SystemPrompt,
73+
})
74+
}
75+
76+
if req.Config.KnowledgeId != "" {
77+
chatReq = chatReq.AddTool(zhipu.ChatCompletionToolRetrieval{
78+
KnowledgeID: req.Config.KnowledgeId,
79+
})
80+
}
81+
return chatReq
82+
}

internal/cases/internal/service/examine.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func (svc *LLMExamineService) Examine(ctx context.Context,
7272
Uid: uid,
7373
Tid: tid,
7474
Biz: biz,
75-
Input: []string{ca.Title, input},
75+
Input: []string{ca.Title, ca.Content, input},
7676
}
7777
aiResp, err := svc.aiSvc.Invoke(ctx, aiReq)
7878
if err != nil {

internal/question/internal/domain/question.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
package domain
1616

17-
import "time"
17+
import (
18+
"strings"
19+
"time"
20+
)
1821

1922
// Question 和 QuestionSet 是一个多对多的关系
2023
type Question struct {
@@ -47,6 +50,20 @@ type Answer struct {
4750
Utime time.Time
4851
}
4952

53+
func (a Answer) String() string {
54+
var sb strings.Builder
55+
sb.WriteString("15K: ")
56+
sb.WriteString(a.Basic.Content)
57+
sb.WriteString("\n")
58+
sb.WriteString("25K: ")
59+
sb.WriteString(a.Intermediate.Content)
60+
sb.WriteString("\n")
61+
sb.WriteString("35K: ")
62+
sb.WriteString(a.Advanced.Content)
63+
sb.WriteString("\n")
64+
return sb.String()
65+
}
66+
5067
type AnswerElement struct {
5168
Id int64
5269
Content string

0 commit comments

Comments
 (0)