@@ -1100,14 +1100,14 @@ func CalculateLLMLatencyScore(took time.Duration, tokensUsed int) float64 {
11001100 return took .Seconds () / float64 (tokensUsed )
11011101}
11021102
1103- func processLLM (ctx context.Context , params aiRequestParams , req worker.GenLLMFormdataRequestBody ) (interface {}, error ) {
1103+ func processLLM (ctx context.Context , params aiRequestParams , req worker.GenLLMJSONRequestBody ) (interface {}, error ) {
11041104 resp , err := processAIRequest (ctx , params , req )
11051105 if err != nil {
11061106 return nil , err
11071107 }
11081108
11091109 if req .Stream != nil && * req .Stream {
1110- streamChan , ok := resp .(chan worker.LlmStreamChunk )
1110+ streamChan , ok := resp .(chan * worker.LLMResponse )
11111111 if ! ok {
11121112 return nil , errors .New ("unexpected response type for streaming request" )
11131113 }
@@ -1122,20 +1122,12 @@ func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMFo
11221122 return llmResp , nil
11231123}
11241124
1125- func submitLLM (ctx context.Context , params aiRequestParams , sess * AISession , req worker.GenLLMFormdataRequestBody ) (interface {}, error ) {
1126- var buf bytes.Buffer
1127- mw , err := worker .NewLLMMultipartWriter (& buf , req )
1128- if err != nil {
1129- if monitor .Enabled {
1130- monitor .AIRequestError (err .Error (), "llm" , * req .ModelId , nil )
1131- }
1132- return nil , err
1133- }
1125+ func submitLLM (ctx context.Context , params aiRequestParams , sess * AISession , req worker.GenLLMJSONRequestBody ) (interface {}, error ) {
11341126
11351127 client , err := worker .NewClientWithResponses (sess .Transcoder (), worker .WithHTTPClient (httpClient ))
11361128 if err != nil {
11371129 if monitor .Enabled {
1138- monitor .AIRequestError (err .Error (), "llm" , * req .ModelId , sess .OrchestratorInfo )
1130+ monitor .AIRequestError (err .Error (), "llm" , * req .Model , sess .OrchestratorInfo )
11391131 }
11401132 return nil , err
11411133 }
@@ -1148,17 +1140,17 @@ func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req
11481140 setHeaders , balUpdate , err := prepareAIPayment (ctx , sess , int64 (* req .MaxTokens ))
11491141 if err != nil {
11501142 if monitor .Enabled {
1151- monitor .AIRequestError (err .Error (), "llm" , * req .ModelId , sess .OrchestratorInfo )
1143+ monitor .AIRequestError (err .Error (), "llm" , * req .Model , sess .OrchestratorInfo )
11521144 }
11531145 return nil , err
11541146 }
11551147 defer completeBalanceUpdate (sess .BroadcastSession , balUpdate )
11561148
11571149 start := time .Now ()
1158- resp , err := client .GenLLMWithBody (ctx , mw . FormDataContentType (), & buf , setHeaders )
1150+ resp , err := client .GenLLM (ctx , req , setHeaders )
11591151 if err != nil {
11601152 if monitor .Enabled {
1161- monitor .AIRequestError (err .Error (), "llm" , * req .ModelId , sess .OrchestratorInfo )
1153+ monitor .AIRequestError (err .Error (), "llm" , * req .Model , sess .OrchestratorInfo )
11621154 }
11631155 return nil , err
11641156 }
@@ -1168,83 +1160,90 @@ func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req
11681160 return nil , fmt .Errorf ("unexpected status code: %d, body: %s" , resp .StatusCode , string (body ))
11691161 }
11701162
1163+ // We treat a response as "receiving change" where the change is the difference between the credit and debit for the update
1164+ // TODO: move to after receive stream response in handleSSEStream and handleNonStreamingResponse to count input tokens
1165+ if balUpdate != nil {
1166+ balUpdate .Status = ReceivedChange
1167+ }
1168+
11711169 if req .Stream != nil && * req .Stream {
11721170 return handleSSEStream (ctx , resp .Body , sess , req , start )
11731171 }
11741172
11751173 return handleNonStreamingResponse (ctx , resp .Body , sess , req , start )
11761174}
11771175
1178- func handleSSEStream (ctx context.Context , body io.ReadCloser , sess * AISession , req worker.GenLLMFormdataRequestBody , start time.Time ) (chan worker.LlmStreamChunk , error ) {
1179- streamChan := make (chan worker.LlmStreamChunk , 100 )
1176+ func handleSSEStream (ctx context.Context , body io.ReadCloser , sess * AISession , req worker.GenLLMJSONRequestBody , start time.Time ) (chan * worker.LLMResponse , error ) {
1177+ streamChan := make (chan * worker.LLMResponse , 100 )
11801178 go func () {
11811179 defer close (streamChan )
11821180 defer body .Close ()
11831181 scanner := bufio .NewScanner (body )
1184- var totalTokens int
1182+ var totalTokens worker. LLMTokenUsage
11851183 for scanner .Scan () {
11861184 line := scanner .Text ()
11871185 if strings .HasPrefix (line , "data: " ) {
11881186 data := strings .TrimPrefix (line , "data: " )
1189- if data == "[DONE]" {
1190- streamChan <- worker.LlmStreamChunk {Done : true , TokensUsed : totalTokens }
1191- break
1192- }
1193- var chunk worker.LlmStreamChunk
1187+
1188+ var chunk worker.LLMResponse
11941189 if err := json .Unmarshal ([]byte (data ), & chunk ); err != nil {
11951190 clog .Errorf (ctx , "Error unmarshaling SSE data: %v" , err )
11961191 continue
11971192 }
1198- totalTokens += chunk .TokensUsed
1199- streamChan <- chunk
1193+ totalTokens = chunk .TokensUsed
1194+ streamChan <- & chunk
1195+ //check if stream is finished
1196+ if chunk .Choices [0 ].FinishReason != nil && * chunk .Choices [0 ].FinishReason != "" {
1197+ break
1198+ }
12001199 }
12011200 }
12021201 if err := scanner .Err (); err != nil {
12031202 clog .Errorf (ctx , "Error reading SSE stream: %v" , err )
12041203 }
12051204
12061205 took := time .Since (start )
1207- sess .LatencyScore = CalculateLLMLatencyScore (took , totalTokens )
1206+ sess .LatencyScore = CalculateLLMLatencyScore (took , totalTokens . TotalTokens )
12081207
12091208 if monitor .Enabled {
12101209 var pricePerAIUnit float64
12111210 if priceInfo := sess .OrchestratorInfo .GetPriceInfo (); priceInfo != nil && priceInfo .PixelsPerUnit != 0 {
12121211 pricePerAIUnit = float64 (priceInfo .PricePerUnit ) / float64 (priceInfo .PixelsPerUnit )
12131212 }
1214- monitor .AIRequestFinished (ctx , "llm" , * req .ModelId , monitor.AIJobInfo {LatencyScore : sess .LatencyScore , PricePerUnit : pricePerAIUnit }, sess .OrchestratorInfo )
1213+ monitor .AIRequestFinished (ctx , "llm" , * req .Model , monitor.AIJobInfo {LatencyScore : sess .LatencyScore , PricePerUnit : pricePerAIUnit }, sess .OrchestratorInfo )
12151214 }
12161215 }()
12171216
12181217 return streamChan , nil
12191218}
12201219
1221- func handleNonStreamingResponse (ctx context.Context , body io.ReadCloser , sess * AISession , req worker.GenLLMFormdataRequestBody , start time.Time ) (* worker.LLMResponse , error ) {
1220+ func handleNonStreamingResponse (ctx context.Context , body io.ReadCloser , sess * AISession , req worker.GenLLMJSONRequestBody , start time.Time ) (* worker.LLMResponse , error ) {
12221221 data , err := io .ReadAll (body )
12231222 defer body .Close ()
12241223 if err != nil {
12251224 if monitor .Enabled {
1226- monitor .AIRequestError (err .Error (), "llm" , * req .ModelId , sess .OrchestratorInfo )
1225+ monitor .AIRequestError (err .Error (), "llm" , * req .Model , sess .OrchestratorInfo )
12271226 }
12281227 return nil , err
12291228 }
12301229
12311230 var res worker.LLMResponse
12321231 if err := json .Unmarshal (data , & res ); err != nil {
12331232 if monitor .Enabled {
1234- monitor .AIRequestError (err .Error (), "llm" , * req .ModelId , sess .OrchestratorInfo )
1233+ monitor .AIRequestError (err .Error (), "llm" , * req .Model , sess .OrchestratorInfo )
12351234 }
12361235 return nil , err
12371236 }
12381237
12391238 took := time .Since (start )
1240- sess .LatencyScore = CalculateLLMLatencyScore (took , res .TokensUsed )
1239+ sess .LatencyScore = CalculateLLMLatencyScore (took , res .TokensUsed . TotalTokens )
12411240
12421241 if monitor .Enabled {
12431242 var pricePerAIUnit float64
12441243 if priceInfo := sess .OrchestratorInfo .GetPriceInfo (); priceInfo != nil && priceInfo .PixelsPerUnit != 0 {
12451244 pricePerAIUnit = float64 (priceInfo .PricePerUnit ) / float64 (priceInfo .PixelsPerUnit )
12461245 }
1247- monitor .AIRequestFinished (ctx , "llm" , * req .ModelId , monitor.AIJobInfo {LatencyScore : sess .LatencyScore , PricePerUnit : pricePerAIUnit }, sess .OrchestratorInfo )
1246+ monitor .AIRequestFinished (ctx , "llm" , * req .Model , monitor.AIJobInfo {LatencyScore : sess .LatencyScore , PricePerUnit : pricePerAIUnit }, sess .OrchestratorInfo )
12481247 }
12491248
12501249 return & res , nil
@@ -1403,16 +1402,16 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
14031402 submitFn = func (ctx context.Context , params aiRequestParams , sess * AISession ) (interface {}, error ) {
14041403 return submitAudioToText (ctx , params , sess , v )
14051404 }
1406- case worker.GenLLMFormdataRequestBody :
1405+ case worker.GenLLMJSONRequestBody :
14071406 cap = core .Capability_LLM
14081407 modelID = defaultLLMModelID
1409- if v .ModelId != nil {
1410- modelID = * v .ModelId
1408+ if v .Model != nil {
1409+ modelID = * v .Model
14111410 }
14121411 submitFn = func (ctx context.Context , params aiRequestParams , sess * AISession ) (interface {}, error ) {
14131412 return submitLLM (ctx , params , sess , v )
14141413 }
1415- ctx = clog . AddVal ( ctx , "prompt" , v . Prompt )
1414+
14161415 case worker.GenSegmentAnything2MultipartRequestBody :
14171416 cap = core .Capability_SegmentAnything2
14181417 modelID = defaultSegmentAnything2ModelID
0 commit comments