Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 30 additions & 59 deletions pkg/epp/handlers/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@ limitations under the License.
package handlers

import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

Expand All @@ -36,49 +35,56 @@ const (
)

// HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling.
func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, response map[string]any) (*RequestContext, error) {
func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, body []byte) (*RequestContext, error) {
logger := log.FromContext(ctx)
responseBytes, err := json.Marshal(response)
llmResponse, err := types.NewLLMResponseFromBytes(body)
if err != nil {
return reqCtx, fmt.Errorf("error marshalling responseBody - %w", err)
}
if response["usage"] != nil {
usg := response["usage"].(map[string]any)
usage := Usage{
PromptTokens: int(usg["prompt_tokens"].(float64)),
CompletionTokens: int(usg["completion_tokens"].(float64)),
TotalTokens: int(usg["total_tokens"].(float64)),
logger.Error(err, "failed to create LLMResponse from bytes")
} else {
reqCtx.SchedulingResponse = llmResponse
if usage := reqCtx.SchedulingResponse.Usage(); usage != nil {
reqCtx.Usage = usage
logger.V(logutil.VERBOSE).Info("Response generated", "usage", usage)
}
reqCtx.Usage = usage
logger.V(logutil.VERBOSE).Info("Response generated", "usage", reqCtx.Usage)
}
reqCtx.ResponseSize = len(responseBytes)
reqCtx.ResponseSize = len(body)
// ResponseComplete is to indicate the response is complete. In non-streaming
// case, it will be set to be true once the response is processed; in
// streaming case, it will be set to be true once the last chunk is processed.
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/178)
// will add the processing for streaming case.
reqCtx.ResponseComplete = true

reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true)
reqCtx.respBodyResp = generateResponseBodyResponses(body, true)

return s.director.HandleResponseBodyComplete(ctx, reqCtx)
}

// The function is to handle streaming response if the modelServer is streaming.
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, streamBody []byte) {
logger := log.FromContext(ctx)
_, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx)
if err != nil {
logger.Error(err, "error in HandleResponseBodyStreaming")
}
if strings.Contains(responseText, streamingEndMsg) {
}

func (s *StreamingServer) HandleResponseBodyModelStreamingComplete(ctx context.Context, reqCtx *RequestContext, streamBody []byte) {
logger := log.FromContext(ctx)
if bytes.Contains(streamBody, []byte(streamingEndMsg)) {
reqCtx.ResponseComplete = true
resp := parseRespForUsage(ctx, responseText)
reqCtx.Usage = resp.Usage
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
_, err := s.director.HandleResponseBodyComplete(ctx, reqCtx)
resp, err := types.NewLLMResponseFromStream(streamBody)
if err != nil {
logger.Error(err, "error in converting stream response to LLMResponse.")
} else {
reqCtx.SchedulingResponse = resp
if usage := resp.Usage(); usage != nil {
reqCtx.Usage = usage
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.PromptTokens)
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.CompletionTokens)
}
}
_, err = s.director.HandleResponseBodyComplete(ctx, reqCtx)
if err != nil {
logger.Error(err, "error in HandleResponseBodyComplete")
}
Expand Down Expand Up @@ -153,41 +159,6 @@ func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*con
return headers
}

// Example message if "stream_options": {"include_usage": "true"} is included in the request:
// data: {"id":"...","object":"text_completion","created":1739400043,"model":"food-review-0","choices":[],
// "usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}
//
// data: [DONE]
//
// Noticed that vLLM returns two entries in one response.
// We need to strip the `data:` prefix and next Data: [DONE] from the message to fetch response data.
//
// If include_usage is not included in the request, `data: [DONE]` is returned separately, which
// indicates end of streaming.
func parseRespForUsage(ctx context.Context, responseText string) ResponseBody {
response := ResponseBody{}
logger := log.FromContext(ctx)

lines := strings.Split(responseText, "\n")
for _, line := range lines {
if !strings.HasPrefix(line, streamingRespPrefix) {
continue
}
content := strings.TrimPrefix(line, streamingRespPrefix)
if content == "[DONE]" {
continue
}

byteSlice := []byte(content)
if err := json.Unmarshal(byteSlice, &response); err != nil {
logger.Error(err, "unmarshaling response body")
continue
}
}

return response
}

type ResponseBody struct {
Usage Usage `json:"usage"`
}
Expand Down
57 changes: 37 additions & 20 deletions pkg/epp/handlers/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package handlers

import (
"context"
"encoding/json"
"testing"

"github.com/google/go-cmp/cmp"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

Expand Down Expand Up @@ -52,12 +52,33 @@ const (
}
`

streamingBodyWithoutUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":null}
`
streamingBodyWithoutUsage = `
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]}

streamingBodyWithUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}
data: [DONE]
`
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]}

data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]}

data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}

data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":null}

data: [DONE]
`

streamingBodyWithUsage = `
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]}

data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]}

data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]}

data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}

data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":7,"total_tokens":12}}

data: [DONE]
`
)

type mockDirector struct{}
Expand Down Expand Up @@ -88,13 +109,13 @@ func TestHandleResponseBody(t *testing.T) {
name string
body []byte
reqCtx *RequestContext
want Usage
want *types.Usage
wantErr bool
}{
{
name: "success",
body: []byte(body),
want: Usage{
want: &types.Usage{
PromptTokens: 11,
TotalTokens: 111,
CompletionTokens: 100,
Expand All @@ -110,12 +131,7 @@ func TestHandleResponseBody(t *testing.T) {
if reqCtx == nil {
reqCtx = &RequestContext{}
}
var responseMap map[string]any
marshalErr := json.Unmarshal(test.body, &responseMap)
if marshalErr != nil {
t.Error(marshalErr, "Error unmarshaling request body")
}
_, err := server.HandleResponseBody(ctx, reqCtx, responseMap)
_, err := server.HandleResponseBody(ctx, reqCtx, test.body)
if err != nil {
if !test.wantErr {
t.Fatalf("HandleResponseBody returned unexpected error: %v, want %v", err, test.wantErr)
Expand All @@ -136,7 +152,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
name string
body string
reqCtx *RequestContext
want Usage
want *types.Usage
wantErr bool
}{
{
Expand All @@ -155,10 +171,10 @@ func TestHandleStreamedResponseBody(t *testing.T) {
modelServerStreaming: true,
},
wantErr: false,
want: Usage{
PromptTokens: 7,
TotalTokens: 17,
CompletionTokens: 10,
want: &types.Usage{
PromptTokens: 5,
TotalTokens: 12,
CompletionTokens: 7,
},
},
}
Expand All @@ -171,7 +187,8 @@ func TestHandleStreamedResponseBody(t *testing.T) {
if reqCtx == nil {
reqCtx = &RequestContext{}
}
server.HandleResponseBodyModelStreaming(ctx, reqCtx, test.body)
server.HandleResponseBodyModelStreaming(ctx, reqCtx, []byte(test.body))
server.HandleResponseBodyModelStreamingComplete(ctx, reqCtx, []byte(test.body))

if diff := cmp.Diff(test.want, reqCtx.Usage); diff != "" {
t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff)
Expand Down
69 changes: 34 additions & 35 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,15 @@ type RequestContext struct {
RequestReceivedTimestamp time.Time
ResponseCompleteTimestamp time.Time
RequestSize int
Usage Usage
Usage *schedulingtypes.Usage
ResponseSize int
ResponseComplete bool
ResponseStatusCode string
RequestRunning bool
Request *Request

SchedulingRequest *schedulingtypes.LLMRequest
SchedulingRequest *schedulingtypes.LLMRequest
SchedulingResponse *schedulingtypes.LLMResponse

RequestState StreamRequestState
modelServerStreaming bool
Expand Down Expand Up @@ -267,52 +268,50 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
reqCtx.respHeaderResp = s.generateResponseHeaderResponse(reqCtx)

case *extProcPb.ProcessingRequest_ResponseBody:
body = append(body, v.ResponseBody.Body...)
if reqCtx.modelServerStreaming {
// Currently we punt on response parsing if the modelServer is streaming, and we just passthrough.

responseText := string(v.ResponseBody.Body)
s.HandleResponseBodyModelStreaming(ctx, reqCtx, responseText)
s.HandleResponseBodyModelStreaming(ctx, reqCtx, v.ResponseBody.Body)
if v.ResponseBody.EndOfStream {
loggerTrace.Info("stream completed")
s.HandleResponseBodyModelStreamingComplete(ctx, reqCtx, body)

reqCtx.ResponseCompleteTimestamp = time.Now()
metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
metrics.RecordResponseSizes(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.ResponseSize)
}

reqCtx.respBodyResp = generateResponseBodyResponses(v.ResponseBody.Body, v.ResponseBody.EndOfStream)
} else {
body = append(body, v.ResponseBody.Body...)

// Message is buffered, we can read and decode.
if v.ResponseBody.EndOfStream {
loggerTrace.Info("stream completed")
// Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes.
// We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message.
// Using the standard 'err' var will send an immediate error response back to the caller.
var responseErr error
responseErr = json.Unmarshal(body, &responseBody)
if responseErr != nil {
if logger.V(logutil.DEBUG).Enabled() {
logger.V(logutil.DEBUG).Error(responseErr, "Error unmarshalling request body", "body", string(body))
} else {
logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshalling request body", "body", string(body))
}
reqCtx.respBodyResp = generateResponseBodyResponses(body, true)
break
} else if v.ResponseBody.EndOfStream {
loggerTrace.Info("stream completed")
// Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes.
// We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message.
// Using the standard 'err' var will send an immediate error response back to the caller.
var responseErr error
responseErr = json.Unmarshal(body, &responseBody)
if responseErr != nil {
if logger.V(logutil.DEBUG).Enabled() {
logger.V(logutil.DEBUG).Error(responseErr, "Error unmarshalling request body", "body", string(body))
} else {
logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshalling request body", "body", string(body))
}
reqCtx.respBodyResp = generateResponseBodyResponses(body, true)
break
}

reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody)
if responseErr != nil {
if logger.V(logutil.DEBUG).Enabled() {
logger.V(logutil.DEBUG).Error(responseErr, "Failed to process response body", "request", req)
} else {
logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response body")
}
} else if reqCtx.ResponseComplete {
reqCtx.ResponseCompleteTimestamp = time.Now()
metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
metrics.RecordResponseSizes(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.ResponseSize)
reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, body)
if responseErr != nil {
if logger.V(logutil.DEBUG).Enabled() {
logger.V(logutil.DEBUG).Error(responseErr, "Failed to process response body", "request", req)
} else {
logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response body")
}
} else if reqCtx.ResponseComplete {
reqCtx.ResponseCompleteTimestamp = time.Now()
metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
metrics.RecordResponseSizes(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.ResponseSize)
if reqCtx.Usage != nil {
// Response complete does not guarantee the Usage is populated.
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.Usage.PromptTokens)
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.Usage.CompletionTokens)
}
Expand Down
15 changes: 8 additions & 7 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package requestcontrol

import (
"context"
"errors"
"fmt"
"math/rand"
"net"
Expand Down Expand Up @@ -289,14 +290,14 @@ func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *hand

// HandleResponseBodyComplete is called when the response body is fully received.
func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey]
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID)
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
response := &Response{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
if reqCtx.SchedulingResponse == nil {
err := errors.New("nil scheduling response from reqCtx")
return reqCtx, err
}

d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, reqCtx.SchedulingResponse, reqCtx.TargetPod)

logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
return reqCtx, nil
Expand Down Expand Up @@ -346,7 +347,7 @@ func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *sch
}
}

func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.requestControlPlugins.responseCompletePlugins {
loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName())
Expand Down
Loading