Skip to content

Commit 601d30c

Browse files
update llm pipeline PR (livepeer#3336)
1 parent 21c2f02 commit 601d30c

File tree

13 files changed

+86
-79
lines changed

13 files changed

+86
-79
lines changed

core/ai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type AI interface {
2323
ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
2424
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
2525
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
26-
LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
26+
LLM(context.Context, worker.GenLLMJSONRequestBody) (interface{}, error)
2727
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
2828
ImageToText(context.Context, worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error)
2929
TextToSpeech(context.Context, worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error)

core/ai_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,8 +690,11 @@ func (a *stubAIWorker) SegmentAnything2(ctx context.Context, req worker.GenSegme
690690
return &worker.MasksResponse{Logits: "logits", Masks: "masks", Scores: "scores"}, nil
691691
}
692692

693-
func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
694-
return &worker.LLMResponse{Response: "response tokens", TokensUsed: 10}, nil
693+
func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody) (interface{}, error) {
694+
var choices []worker.LLMChoice
695+
choices = append(choices, worker.LLMChoice{Delta: &worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0})
696+
tokensUsed := worker.LLMTokenUsage{PromptTokens: 40, CompletionTokens: 10, TotalTokens: 50}
697+
return &worker.LLMResponse{Choices: choices, Created: 1, Model: "llm_model", TokensUsed: tokensUsed}, nil
695698
}
696699

697700
func (a *stubAIWorker) ImageToText(ctx context.Context, req worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error) {

core/ai_worker.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -824,14 +824,14 @@ func (orch *orchestrator) SegmentAnything2(ctx context.Context, requestID string
824824
}
825825

826826
// Return type is LLMResponse, but a stream is available as well as chan(string)
827-
func (orch *orchestrator) LLM(ctx context.Context, requestID string, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
827+
func (orch *orchestrator) LLM(ctx context.Context, requestID string, req worker.GenLLMJSONRequestBody) (interface{}, error) {
828828
// local AIWorker processes job if combined orchestrator/ai worker
829829
if orch.node.AIWorker != nil {
830830
// no file response to save, response is text sent back to gateway
831831
return orch.node.AIWorker.LLM(ctx, req)
832832
}
833833

834-
res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "llm", *req.ModelId, "", AIJobRequestData{Request: req})
834+
res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "llm", *req.Model, "", AIJobRequestData{Request: req})
835835
if err != nil {
836836
return nil, err
837837
}
@@ -842,7 +842,7 @@ func (orch *orchestrator) LLM(ctx context.Context, requestID string, req worker.
842842
if err != nil {
843843
clog.Errorf(ctx, "Error saving remote ai result err=%q", err)
844844
if monitor.Enabled {
845-
monitor.AIResultSaveError(ctx, "llm", *req.ModelId, string(monitor.SegmentUploadErrorUnknown))
845+
monitor.AIResultSaveError(ctx, "llm", *req.Model, string(monitor.SegmentUploadErrorUnknown))
846846
}
847847
return nil, err
848848

@@ -1087,7 +1087,7 @@ func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.GenSegme
10871087
return n.AIWorker.SegmentAnything2(ctx, req)
10881088
}
10891089

1090-
func (n *LivepeerNode) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
1090+
func (n *LivepeerNode) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody) (interface{}, error) {
10911091
return n.AIWorker.LLM(ctx, req)
10921092
}
10931093

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ require (
1414
github.com/google/uuid v1.6.0
1515
github.com/jaypipes/ghw v0.10.0
1616
github.com/jaypipes/pcidb v1.0.0
17-
github.com/livepeer/ai-worker v0.12.7-0.20241219141308-c19289d128a3
17+
github.com/livepeer/ai-worker v0.13.1
1818
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
1919
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
2020
github.com/livepeer/lpms v0.0.0-20241203012405-fc96cadb6393

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
607607
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
608608
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
609609
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
610-
github.com/livepeer/ai-worker v0.12.7-0.20241219141308-c19289d128a3 h1:uutmGZq2YdIKnKhn6QGHtGnKfBGYAUMMOr44LXYs23w=
611-
github.com/livepeer/ai-worker v0.12.7-0.20241219141308-c19289d128a3/go.mod h1:ZibfmZQQh6jFvnPLHeIPInghfX5ln+JpN845nS3GuyM=
610+
github.com/livepeer/ai-worker v0.13.1 h1:BnqzmBD/E5gHM0P6UXt9M2/bZwU3ZryEfNpbW+NYJr0=
611+
github.com/livepeer/ai-worker v0.13.1/go.mod h1:ZibfmZQQh6jFvnPLHeIPInghfX5ln+JpN845nS3GuyM=
612612
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
613613
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
614614
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=

server/ai_http.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func startAIServer(lp *lphttp) error {
6666
lp.transRPC.Handle("/image-to-video", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenImageToVideoMultipartRequestBody])))
6767
lp.transRPC.Handle("/upscale", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenUpscaleMultipartRequestBody])))
6868
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenAudioToTextMultipartRequestBody])))
69-
lp.transRPC.Handle("/llm", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenLLMFormdataRequestBody])))
69+
lp.transRPC.Handle("/llm", oapiReqValidator(aiHttpHandle(lp, jsonDecoder[worker.GenLLMJSONRequestBody])))
7070
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody])))
7171
lp.transRPC.Handle("/image-to-text", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenImageToTextMultipartRequestBody])))
7272
lp.transRPC.Handle("/text-to-speech", oapiReqValidator(aiHttpHandle(lp, jsonDecoder[worker.GenTextToSpeechJSONRequestBody])))
@@ -405,10 +405,10 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
405405
return
406406
}
407407
outPixels *= 1000 // Convert to milliseconds
408-
case worker.GenLLMFormdataRequestBody:
408+
case worker.GenLLMJSONRequestBody:
409409
pipeline = "llm"
410410
cap = core.Capability_LLM
411-
modelID = *v.ModelId
411+
modelID = *v.Model
412412
submitFn = func(ctx context.Context) (interface{}, error) {
413413
return orch.LLM(ctx, requestID, v)
414414
}
@@ -586,7 +586,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
586586
}
587587

588588
// Check if the response is a streaming response
589-
if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok {
589+
if streamChan, ok := resp.(<-chan *worker.LLMResponse); ok {
590590
glog.Infof("Streaming response for request id=%v", requestID)
591591

592592
// Set headers for SSE
@@ -610,7 +610,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
610610
fmt.Fprintf(w, "data: %s\n\n", data)
611611
flusher.Flush()
612612

613-
if chunk.Done {
613+
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
614614
break
615615
}
616616
}
@@ -685,8 +685,8 @@ func (h *lphttp) AIResults() http.Handler {
685685
case "text/event-stream":
686686
resultType = "streaming"
687687
glog.Infof("Received %s response from remote worker=%s taskId=%d", resultType, r.RemoteAddr, tid)
688-
resChan := make(chan worker.LlmStreamChunk, 100)
689-
workerResult.Results = (<-chan worker.LlmStreamChunk)(resChan)
688+
resChan := make(chan *worker.LLMResponse, 100)
689+
workerResult.Results = (<-chan *worker.LLMResponse)(resChan)
690690

691691
defer r.Body.Close()
692692
defer close(resChan)
@@ -705,12 +705,12 @@ func (h *lphttp) AIResults() http.Handler {
705705
line := scanner.Text()
706706
if strings.HasPrefix(line, "data: ") {
707707
data := strings.TrimPrefix(line, "data: ")
708-
var chunk worker.LlmStreamChunk
708+
var chunk worker.LLMResponse
709709
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
710710
clog.Errorf(ctx, "Error unmarshaling stream data: %v", err)
711711
continue
712712
}
713-
resChan <- chunk
713+
resChan <- &chunk
714714
}
715715
}
716716
}

server/ai_mediaserver.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -259,20 +259,19 @@ func (ls *LivepeerServer) LLM() http.Handler {
259259
requestID := string(core.RandomManifestID())
260260
ctx = clog.AddVal(ctx, "request_id", requestID)
261261

262-
var req worker.GenLLMFormdataRequestBody
263-
264-
multiRdr, err := r.MultipartReader()
265-
if err != nil {
262+
var req worker.GenLLMJSONRequestBody
263+
if err := jsonDecoder(&req, r); err != nil {
266264
respondJsonError(ctx, w, err, http.StatusBadRequest)
267265
return
268266
}
269267

270-
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
271-
respondJsonError(ctx, w, err, http.StatusBadRequest)
268+
//check required fields
269+
if req.Model == nil || req.Messages == nil || req.Stream == nil || req.MaxTokens == nil || len(req.Messages) == 0 {
270+
respondJsonError(ctx, w, errors.New("missing required fields"), http.StatusBadRequest)
272271
return
273272
}
274273

275-
clog.V(common.VERBOSE).Infof(ctx, "Received LLM request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, *req.Stream)
274+
clog.V(common.VERBOSE).Infof(ctx, "Received LLM request model_id=%v stream=%v", *req.Model, *req.Stream)
276275

277276
orchAddr := r.Header.Get("OrchAddr")
278277
params := aiRequestParams{
@@ -295,9 +294,9 @@ func (ls *LivepeerServer) LLM() http.Handler {
295294
}
296295

297296
took := time.Since(start)
298-
clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took)
297+
clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request model_id=%v took=%v", *req.Model, took)
299298

300-
if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
299+
if streamChan, ok := resp.(chan *worker.LLMResponse); ok {
301300
// Handle streaming response (SSE)
302301
w.Header().Set("Content-Type", "text/event-stream")
303302
w.Header().Set("Cache-Control", "no-cache")
@@ -307,7 +306,7 @@ func (ls *LivepeerServer) LLM() http.Handler {
307306
data, _ := json.Marshal(chunk)
308307
fmt.Fprintf(w, "data: %s\n\n", data)
309308
w.(http.Flusher).Flush()
310-
if chunk.Done {
309+
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
311310
break
312311
}
313312
}

server/ai_process.go

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,14 +1100,14 @@ func CalculateLLMLatencyScore(took time.Duration, tokensUsed int) float64 {
11001100
return took.Seconds() / float64(tokensUsed)
11011101
}
11021102

1103-
func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
1103+
func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMJSONRequestBody) (interface{}, error) {
11041104
resp, err := processAIRequest(ctx, params, req)
11051105
if err != nil {
11061106
return nil, err
11071107
}
11081108

11091109
if req.Stream != nil && *req.Stream {
1110-
streamChan, ok := resp.(chan worker.LlmStreamChunk)
1110+
streamChan, ok := resp.(chan *worker.LLMResponse)
11111111
if !ok {
11121112
return nil, errors.New("unexpected response type for streaming request")
11131113
}
@@ -1122,20 +1122,12 @@ func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMFo
11221122
return llmResp, nil
11231123
}
11241124

1125-
func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
1126-
var buf bytes.Buffer
1127-
mw, err := worker.NewLLMMultipartWriter(&buf, req)
1128-
if err != nil {
1129-
if monitor.Enabled {
1130-
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, nil)
1131-
}
1132-
return nil, err
1133-
}
1125+
func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMJSONRequestBody) (interface{}, error) {
11341126

11351127
client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
11361128
if err != nil {
11371129
if monitor.Enabled {
1138-
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
1130+
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)
11391131
}
11401132
return nil, err
11411133
}
@@ -1148,17 +1140,17 @@ func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req
11481140
setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, int64(*req.MaxTokens))
11491141
if err != nil {
11501142
if monitor.Enabled {
1151-
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
1143+
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)
11521144
}
11531145
return nil, err
11541146
}
11551147
defer completeBalanceUpdate(sess.BroadcastSession, balUpdate)
11561148

11571149
start := time.Now()
1158-
resp, err := client.GenLLMWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders)
1150+
resp, err := client.GenLLM(ctx, req, setHeaders)
11591151
if err != nil {
11601152
if monitor.Enabled {
1161-
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
1153+
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)
11621154
}
11631155
return nil, err
11641156
}
@@ -1168,83 +1160,90 @@ func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req
11681160
return nil, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
11691161
}
11701162

1163+
// We treat a response as "receiving change" where the change is the difference between the credit and debit for the update
1164+
// TODO: move to after receive stream response in handleSSEStream and handleNonStreamingResponse to count input tokens
1165+
if balUpdate != nil {
1166+
balUpdate.Status = ReceivedChange
1167+
}
1168+
11711169
if req.Stream != nil && *req.Stream {
11721170
return handleSSEStream(ctx, resp.Body, sess, req, start)
11731171
}
11741172

11751173
return handleNonStreamingResponse(ctx, resp.Body, sess, req, start)
11761174
}
11771175

1178-
func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) {
1179-
streamChan := make(chan worker.LlmStreamChunk, 100)
1176+
func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (chan *worker.LLMResponse, error) {
1177+
streamChan := make(chan *worker.LLMResponse, 100)
11801178
go func() {
11811179
defer close(streamChan)
11821180
defer body.Close()
11831181
scanner := bufio.NewScanner(body)
1184-
var totalTokens int
1182+
var totalTokens worker.LLMTokenUsage
11851183
for scanner.Scan() {
11861184
line := scanner.Text()
11871185
if strings.HasPrefix(line, "data: ") {
11881186
data := strings.TrimPrefix(line, "data: ")
1189-
if data == "[DONE]" {
1190-
streamChan <- worker.LlmStreamChunk{Done: true, TokensUsed: totalTokens}
1191-
break
1192-
}
1193-
var chunk worker.LlmStreamChunk
1187+
1188+
var chunk worker.LLMResponse
11941189
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
11951190
clog.Errorf(ctx, "Error unmarshaling SSE data: %v", err)
11961191
continue
11971192
}
1198-
totalTokens += chunk.TokensUsed
1199-
streamChan <- chunk
1193+
totalTokens = chunk.TokensUsed
1194+
streamChan <- &chunk
1195+
//check if stream is finished
1196+
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
1197+
break
1198+
}
12001199
}
12011200
}
12021201
if err := scanner.Err(); err != nil {
12031202
clog.Errorf(ctx, "Error reading SSE stream: %v", err)
12041203
}
12051204

12061205
took := time.Since(start)
1207-
sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens)
1206+
sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens.TotalTokens)
12081207

12091208
if monitor.Enabled {
12101209
var pricePerAIUnit float64
12111210
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
12121211
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
12131212
}
1214-
monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
1213+
monitor.AIRequestFinished(ctx, "llm", *req.Model, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
12151214
}
12161215
}()
12171216

12181217
return streamChan, nil
12191218
}
12201219

1221-
func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMFormdataRequestBody, start time.Time) (*worker.LLMResponse, error) {
1220+
func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (*worker.LLMResponse, error) {
12221221
data, err := io.ReadAll(body)
12231222
defer body.Close()
12241223
if err != nil {
12251224
if monitor.Enabled {
1226-
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
1225+
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)
12271226
}
12281227
return nil, err
12291228
}
12301229

12311230
var res worker.LLMResponse
12321231
if err := json.Unmarshal(data, &res); err != nil {
12331232
if monitor.Enabled {
1234-
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
1233+
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)
12351234
}
12361235
return nil, err
12371236
}
12381237

12391238
took := time.Since(start)
1240-
sess.LatencyScore = CalculateLLMLatencyScore(took, res.TokensUsed)
1239+
sess.LatencyScore = CalculateLLMLatencyScore(took, res.TokensUsed.TotalTokens)
12411240

12421241
if monitor.Enabled {
12431242
var pricePerAIUnit float64
12441243
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
12451244
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
12461245
}
1247-
monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
1246+
monitor.AIRequestFinished(ctx, "llm", *req.Model, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
12481247
}
12491248

12501249
return &res, nil
@@ -1403,16 +1402,16 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
14031402
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
14041403
return submitAudioToText(ctx, params, sess, v)
14051404
}
1406-
case worker.GenLLMFormdataRequestBody:
1405+
case worker.GenLLMJSONRequestBody:
14071406
cap = core.Capability_LLM
14081407
modelID = defaultLLMModelID
1409-
if v.ModelId != nil {
1410-
modelID = *v.ModelId
1408+
if v.Model != nil {
1409+
modelID = *v.Model
14111410
}
14121411
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
14131412
return submitLLM(ctx, params, sess, v)
14141413
}
1415-
ctx = clog.AddVal(ctx, "prompt", v.Prompt)
1414+
14161415
case worker.GenSegmentAnything2MultipartRequestBody:
14171416
cap = core.Capability_SegmentAnything2
14181417
modelID = defaultSegmentAnything2ModelID

server/ai_process_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func Test_submitLLM(t *testing.T) {
1313
ctx context.Context
1414
params aiRequestParams
1515
sess *AISession
16-
req worker.GenLLMFormdataRequestBody
16+
req worker.GenLLMJSONRequestBody
1717
}
1818
tests := []struct {
1919
name string

0 commit comments

Comments
 (0)