Skip to content

Commit 6583a70

Browse files
committed
fix: recover inferOpenAI function and adjust inferMoss2 code logic
1 parent 8d7d274 commit 6583a70

File tree

2 files changed

+159
-8
lines changed

2 files changed

+159
-8
lines changed

apis/record/infer.go

Lines changed: 158 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ func InferMoss2(
144144
if len(response.Choices) == 0 {
145145
return unknownError
146146
}
147-
147+
148148
resultBuilder.WriteString(response.Choices[0].Delta.Content)
149149
nowOutput = resultBuilder.String()
150150

@@ -156,16 +156,17 @@ func InferMoss2(
156156

157157
if slices.Contains(model.EndDelimiter, FuncCallEnd) && strings.Contains(nowOutput, FuncCallEnd) {
158158
// if FuncCallEnd is found, call tool apis
159-
var result string
160-
err = GetFuncCallResult(nowOutput, &result)
159+
var funcCallResult string
160+
161+
err = GetFuncCallResult(nowOutput, &funcCallResult)
161162
if err != nil {
162163
return err
163164
}
164165

165-
// TODO: Do we need this? or ignore <im_start>func_ret and simply send the json content
166+
// TODO: Do we need this? or ignore <im_start>func_ret and simply send the funcCallResult
166167
message := openai.ChatCompletionMessage{
167168
Role: "func_ret",
168-
Content: result,
169+
Content: funcCallResult,
169170
}
170171

171172
request = openai.ChatCompletionRequest{
@@ -182,9 +183,10 @@ func InferMoss2(
182183
if err != nil {
183184
return err
184185
}
185-
// erase the content of funcall
186+
// erase the content of fun_call
186187
nowOutput = strings.Split(nowOutput, FuncCallStart)[0]
187-
continue
188+
resultBuilder.Reset()
189+
resultBuilder.WriteString(nowOutput)
188190
}
189191

190192
before, _, found := CutLastAny(nowOutput, ",.?!\n,。?!")
@@ -351,6 +353,152 @@ func GetFuncCallResult(
351353

352354

353355

356+
func InferOpenAI(
357+
record *Record,
358+
postRecord RecordModels,
359+
model *ModelConfig,
360+
user *User,
361+
ctx *InferWsContext,
362+
) (
363+
err error,
364+
) {
365+
defer func() {
366+
if v := recover(); v != nil {
367+
Logger.Error("infer openai panicked", zap.Any("error", v))
368+
err = unknownError
369+
}
370+
}()
371+
372+
openaiConfig := openai.DefaultConfig("")
373+
openaiConfig.BaseURL = model.Url
374+
client := openai.NewClientWithConfig(openaiConfig)
375+
376+
var messages = make([]openai.ChatCompletionMessage, 0, len(postRecord)+2)
377+
messages = append(messages, openai.ChatCompletionMessage{
378+
Role: "system",
379+
Content: model.OpenAISystemPrompt,
380+
})
381+
messages = append(messages, postRecord.ToOpenAIMessages()...)
382+
messages = append(messages, openai.ChatCompletionMessage{
383+
Role: "user",
384+
Content: record.Request,
385+
})
386+
request := openai.ChatCompletionRequest{
387+
Model: model.OpenAIModelName,
388+
Messages: messages,
389+
Stop: model.EndDelimiter,
390+
}
391+
392+
if ctx == nil {
393+
// openai client may panic when status code is 400
394+
response, err := client.CreateChatCompletion(
395+
context.Background(),
396+
request,
397+
)
398+
if err != nil {
399+
return err
400+
}
401+
402+
if len(response.Choices) == 0 {
403+
return unknownError
404+
}
405+
406+
record.Response = response.Choices[0].Message.Content
407+
} else {
408+
// streaming
409+
if config.Config.Debug {
410+
Logger.Info("openai streaming",
411+
zap.String("model", model.OpenAIModelName),
412+
zap.String("url", model.Url),
413+
)
414+
}
415+
416+
stream, err := client.CreateChatCompletionStream(
417+
context.Background(),
418+
request,
419+
)
420+
if err != nil {
421+
return err
422+
}
423+
defer stream.Close()
424+
425+
startTime := time.Now()
426+
427+
var resultBuilder strings.Builder
428+
var nowOutput string
429+
var detectedOutput string
430+
431+
for {
432+
if ctx.connectionClosed.Load() {
433+
return interruptError
434+
}
435+
response, err := stream.Recv()
436+
if errors.Is(err, io.EOF) {
437+
break
438+
}
439+
if err != nil {
440+
return err
441+
}
442+
443+
if len(response.Choices) == 0 {
444+
return unknownError
445+
}
446+
447+
resultBuilder.WriteString(response.Choices[0].Delta.Content)
448+
nowOutput = resultBuilder.String()
449+
450+
if slices.Contains(model.EndDelimiter, MossEnd) && strings.Contains(nowOutput, MossEnd) {
451+
// if MossEnd is found, break the loop
452+
nowOutput = strings.Split(nowOutput, MossEnd)[0]
453+
break
454+
}
455+
456+
before, _, found := CutLastAny(nowOutput, ",.?!\n,。?!")
457+
if !found || before == detectedOutput {
458+
continue
459+
}
460+
detectedOutput = before
461+
if model.EnableSensitiveCheck {
462+
err = sensitiveCheck(ctx.c, record, detectedOutput, startTime, user)
463+
if err != nil {
464+
return err
465+
}
466+
}
467+
468+
_ = ctx.c.WriteJSON(InferResponseModel{
469+
Status: 1,
470+
Output: detectedOutput,
471+
Stage: "MOSS",
472+
})
473+
}
474+
if nowOutput != detectedOutput {
475+
if model.EnableSensitiveCheck {
476+
err = sensitiveCheck(ctx.c, record, nowOutput, startTime, user)
477+
if err != nil {
478+
return err
479+
}
480+
}
481+
482+
_ = ctx.c.WriteJSON(InferResponseModel{
483+
Status: 1,
484+
Output: nowOutput,
485+
Stage: "MOSS",
486+
})
487+
}
488+
489+
record.Response = nowOutput
490+
record.Duration = float64(time.Since(startTime)) / 1000_000_000
491+
_ = ctx.c.WriteJSON(InferResponseModel{
492+
Status: 0,
493+
Output: nowOutput,
494+
Stage: "MOSS",
495+
})
496+
}
497+
498+
return nil
499+
}
500+
501+
354502
func InferCommon(
355503
record *Record,
356504
prefix string,
@@ -381,7 +529,9 @@ func InferCommon(
381529

382530
// dispatch
383531
if model.APIType == APITypeMOSS2 {
384-
return InferMoss2(record, postRecords, model, user, ctx);
532+
return InferMoss2(record, postRecords, model, user, ctx)
533+
} else if model.APIType == APITypeMOSS{
534+
return InferOpenAI(record, postRecords, model, user, ctx)
385535
} else {
386536
return errors.New("unknown API type")
387537
}

models/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ type APIType string
1313

1414

1515
const (
16+
APITypeMOSS APIType = "moss"
1617
APITypeMOSS2 APIType = "moss2"
1718
)
1819

0 commit comments

Comments
 (0)