@@ -144,7 +144,7 @@ func InferMoss2(
144
144
if len (response .Choices ) == 0 {
145
145
return unknownError
146
146
}
147
-
147
+
148
148
resultBuilder .WriteString (response .Choices [0 ].Delta .Content )
149
149
nowOutput = resultBuilder .String ()
150
150
@@ -156,16 +156,17 @@ func InferMoss2(
156
156
157
157
if slices .Contains (model .EndDelimiter , FuncCallEnd ) && strings .Contains (nowOutput , FuncCallEnd ) {
158
158
// if FuncCallEnd is found, call tool apis
159
- var result string
160
- err = GetFuncCallResult (nowOutput , & result )
159
+ var funcCallResult string
160
+
161
+ err = GetFuncCallResult (nowOutput , & funcCallResult )
161
162
if err != nil {
162
163
return err
163
164
}
164
165
165
- // TODO: Do we need this? or ignore <im_start>func_ret and simply send the json content
166
+ // TODO: Do we need this? or ignore <im_start>func_ret and simply send the funcCallResult
166
167
message := openai.ChatCompletionMessage {
167
168
Role : "func_ret" ,
168
- Content : result ,
169
+ Content : funcCallResult ,
169
170
}
170
171
171
172
request = openai.ChatCompletionRequest {
@@ -182,9 +183,10 @@ func InferMoss2(
182
183
if err != nil {
183
184
return err
184
185
}
185
- // erase the content of funcall
186
+ // erase the content of fun_call
186
187
nowOutput = strings .Split (nowOutput , FuncCallStart )[0 ]
187
- continue
188
+ resultBuilder .Reset ()
189
+ resultBuilder .WriteString (nowOutput )
188
190
}
189
191
190
192
before , _ , found := CutLastAny (nowOutput , ",.?!\n ,。?!" )
@@ -351,6 +353,152 @@ func GetFuncCallResult(
351
353
352
354
353
355
356
+ func InferOpenAI (
357
+ record * Record ,
358
+ postRecord RecordModels ,
359
+ model * ModelConfig ,
360
+ user * User ,
361
+ ctx * InferWsContext ,
362
+ ) (
363
+ err error ,
364
+ ) {
365
+ defer func () {
366
+ if v := recover (); v != nil {
367
+ Logger .Error ("infer openai panicked" , zap .Any ("error" , v ))
368
+ err = unknownError
369
+ }
370
+ }()
371
+
372
+ openaiConfig := openai .DefaultConfig ("" )
373
+ openaiConfig .BaseURL = model .Url
374
+ client := openai .NewClientWithConfig (openaiConfig )
375
+
376
+ var messages = make ([]openai.ChatCompletionMessage , 0 , len (postRecord )+ 2 )
377
+ messages = append (messages , openai.ChatCompletionMessage {
378
+ Role : "system" ,
379
+ Content : model .OpenAISystemPrompt ,
380
+ })
381
+ messages = append (messages , postRecord .ToOpenAIMessages ()... )
382
+ messages = append (messages , openai.ChatCompletionMessage {
383
+ Role : "user" ,
384
+ Content : record .Request ,
385
+ })
386
+ request := openai.ChatCompletionRequest {
387
+ Model : model .OpenAIModelName ,
388
+ Messages : messages ,
389
+ Stop : model .EndDelimiter ,
390
+ }
391
+
392
+ if ctx == nil {
393
+ // openai client may panic when status code is 400
394
+ response , err := client .CreateChatCompletion (
395
+ context .Background (),
396
+ request ,
397
+ )
398
+ if err != nil {
399
+ return err
400
+ }
401
+
402
+ if len (response .Choices ) == 0 {
403
+ return unknownError
404
+ }
405
+
406
+ record .Response = response .Choices [0 ].Message .Content
407
+ } else {
408
+ // streaming
409
+ if config .Config .Debug {
410
+ Logger .Info ("openai streaming" ,
411
+ zap .String ("model" , model .OpenAIModelName ),
412
+ zap .String ("url" , model .Url ),
413
+ )
414
+ }
415
+
416
+ stream , err := client .CreateChatCompletionStream (
417
+ context .Background (),
418
+ request ,
419
+ )
420
+ if err != nil {
421
+ return err
422
+ }
423
+ defer stream .Close ()
424
+
425
+ startTime := time .Now ()
426
+
427
+ var resultBuilder strings.Builder
428
+ var nowOutput string
429
+ var detectedOutput string
430
+
431
+ for {
432
+ if ctx .connectionClosed .Load () {
433
+ return interruptError
434
+ }
435
+ response , err := stream .Recv ()
436
+ if errors .Is (err , io .EOF ) {
437
+ break
438
+ }
439
+ if err != nil {
440
+ return err
441
+ }
442
+
443
+ if len (response .Choices ) == 0 {
444
+ return unknownError
445
+ }
446
+
447
+ resultBuilder .WriteString (response .Choices [0 ].Delta .Content )
448
+ nowOutput = resultBuilder .String ()
449
+
450
+ if slices .Contains (model .EndDelimiter , MossEnd ) && strings .Contains (nowOutput , MossEnd ) {
451
+ // if MossEnd is found, break the loop
452
+ nowOutput = strings .Split (nowOutput , MossEnd )[0 ]
453
+ break
454
+ }
455
+
456
+ before , _ , found := CutLastAny (nowOutput , ",.?!\n ,。?!" )
457
+ if ! found || before == detectedOutput {
458
+ continue
459
+ }
460
+ detectedOutput = before
461
+ if model .EnableSensitiveCheck {
462
+ err = sensitiveCheck (ctx .c , record , detectedOutput , startTime , user )
463
+ if err != nil {
464
+ return err
465
+ }
466
+ }
467
+
468
+ _ = ctx .c .WriteJSON (InferResponseModel {
469
+ Status : 1 ,
470
+ Output : detectedOutput ,
471
+ Stage : "MOSS" ,
472
+ })
473
+ }
474
+ if nowOutput != detectedOutput {
475
+ if model .EnableSensitiveCheck {
476
+ err = sensitiveCheck (ctx .c , record , nowOutput , startTime , user )
477
+ if err != nil {
478
+ return err
479
+ }
480
+ }
481
+
482
+ _ = ctx .c .WriteJSON (InferResponseModel {
483
+ Status : 1 ,
484
+ Output : nowOutput ,
485
+ Stage : "MOSS" ,
486
+ })
487
+ }
488
+
489
+ record .Response = nowOutput
490
+ record .Duration = float64 (time .Since (startTime )) / 1000_000_000
491
+ _ = ctx .c .WriteJSON (InferResponseModel {
492
+ Status : 0 ,
493
+ Output : nowOutput ,
494
+ Stage : "MOSS" ,
495
+ })
496
+ }
497
+
498
+ return nil
499
+ }
500
+
501
+
354
502
func InferCommon (
355
503
record * Record ,
356
504
prefix string ,
@@ -381,7 +529,9 @@ func InferCommon(
381
529
382
530
// dispatch
383
531
if model .APIType == APITypeMOSS2 {
384
- return InferMoss2 (record , postRecords , model , user , ctx );
532
+ return InferMoss2 (record , postRecords , model , user , ctx )
533
+ } else if model .APIType == APITypeMOSS {
534
+ return InferOpenAI (record , postRecords , model , user , ctx )
385
535
} else {
386
536
return errors .New ("unknown API type" )
387
537
}
0 commit comments