Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions pkg/epp/handlers/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,27 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
reqCtx.ResponseComplete = true

reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true)
return reqCtx, nil

return s.director.HandleResponseBodyComplete(ctx, reqCtx)
}

// The function is to handle streaming response if the modelServer is streaming.
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
logger := log.FromContext(ctx)
_, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx)
if err != nil {
logger.Error(err, "error in HandleResponseBodyStreaming")
}
if strings.Contains(responseText, streamingEndMsg) {
reqCtx.ResponseComplete = true
resp := parseRespForUsage(ctx, responseText)
reqCtx.Usage = resp.Usage
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
_, err := s.director.HandleResponseBodyComplete(ctx, reqCtx)
if err != nil {
logger.Error(err, "error in HandleResponseBodyComplete")
}
}
}

Expand All @@ -83,7 +94,7 @@ func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *Req
}
}

reqCtx, err := s.director.HandleResponse(ctx, reqCtx)
reqCtx, err := s.director.HandleResponseReceived(ctx, reqCtx)

return reqCtx, err
}
Expand Down
24 changes: 24 additions & 0 deletions pkg/epp/handlers/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/google/go-cmp/cmp"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

Expand Down Expand Up @@ -59,6 +60,27 @@ data: [DONE]
`
)

type mockDirector struct{}

func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
return reqCtx, nil
}
func (m *mockDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
return reqCtx, nil
}
func (m *mockDirector) HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
return reqCtx, nil
}
func (m *mockDirector) HandlePreRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
return reqCtx, nil
}
func (m *mockDirector) GetRandomPod() *backend.Pod {
return &backend.Pod{}
}
func (m *mockDirector) HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
return reqCtx, nil
}

func TestHandleResponseBody(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())

Expand All @@ -83,6 +105,7 @@ func TestHandleResponseBody(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := &StreamingServer{}
server.director = &mockDirector{}
reqCtx := test.reqCtx
if reqCtx == nil {
reqCtx = &RequestContext{}
Expand Down Expand Up @@ -143,6 +166,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := &StreamingServer{}
server.director = &mockDirector{}
reqCtx := test.reqCtx
if reqCtx == nil {
reqCtx = &RequestContext{}
Expand Down
10 changes: 6 additions & 4 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer

type Director interface {
HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponse(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
GetRandomPod() *backend.Pod
}

Expand Down Expand Up @@ -121,7 +123,7 @@ const (
HeaderRequestResponseComplete StreamRequestState = 1
BodyRequestResponsesComplete StreamRequestState = 2
TrailerRequestResponsesComplete StreamRequestState = 3
ResponseRecieved StreamRequestState = 4
ResponseReceived StreamRequestState = 4
HeaderResponseResponseComplete StreamRequestState = 5
BodyResponseResponsesComplete StreamRequestState = 6
TrailerResponseResponsesComplete StreamRequestState = 7
Expand Down Expand Up @@ -251,7 +253,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
loggerTrace.Info("model server is streaming response")
}
}
reqCtx.RequestState = ResponseRecieved
reqCtx.RequestState = ResponseReceived

var responseErr error
reqCtx, responseErr = s.HandleResponseHeaders(ctx, reqCtx, v)
Expand Down Expand Up @@ -377,7 +379,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
}
}
if r.RequestState == ResponseRecieved && r.respHeaderResp != nil {
if r.RequestState == ResponseReceived && r.respHeaderResp != nil {
loggerTrace.Info("Sending response header response", "obj", r.respHeaderResp)
if err := srv.Send(r.respHeaderResp); err != nil {
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
Expand Down
99 changes: 81 additions & 18 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,37 @@ type SaturationDetector interface {
IsSaturated(ctx context.Context, candidatePods []backendmetrics.PodMetrics) bool
}

type RequestControlPlugins struct {
preRequestPlugins []PreRequest
responseReceivedPlugins []ResponseReceived
responseStreamingPlugins []ResponseStreaming
responseCompletePlugins []ResponseComplete
}

// NewDirectorWithConfig creates a new Director instance with all dependencies.
func NewDirectorWithConfig(datastore Datastore, scheduler Scheduler, saturationDetector SaturationDetector, config *Config) *Director {
RCPlugins := RequestControlPlugins{
preRequestPlugins: config.preRequestPlugins,
responseReceivedPlugins: config.responseReceivedPlugins,
responseStreamingPlugins: config.responseStreamingPlugins,
responseCompletePlugins: config.responseCompletePlugins,
}

return &Director{
datastore: datastore,
scheduler: scheduler,
saturationDetector: saturationDetector,
preRequestPlugins: config.preRequestPlugins,
postResponsePlugins: config.postResponsePlugins,
defaultPriority: 0, // define default priority explicitly
datastore: datastore,
scheduler: scheduler,
saturationDetector: saturationDetector,
requestControlPlugins: RCPlugins,
defaultPriority: 0, // define default priority explicitly
}
}

// Director orchestrates the request handling flow, including scheduling.
type Director struct {
datastore Datastore
scheduler Scheduler
saturationDetector SaturationDetector
preRequestPlugins []PreRequest
postResponsePlugins []PostResponse
datastore Datastore
scheduler Scheduler
saturationDetector SaturationDetector
requestControlPlugins RequestControlPlugins
// we just need a pointer to an int variable since priority is a pointer in InferenceObjective
// no need to set this in the constructor, since the value we want is the default int val
// and value types cannot be nil
Expand Down Expand Up @@ -278,16 +290,46 @@ func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []sch
return pm
}

func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
// HandleResponseReceived is called when the first chunk of the response arrives.
func (d *Director) HandleResponseReceived(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
response := &Response{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
}

// TODO: to extend fallback functionality, handle cases where target pod is unavailable
// https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1224
d.runPostResponsePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
d.runResponseReceivedPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)

return reqCtx, nil
}

// HandleResponseBodyStreaming is called every time a chunk of the response body is received.
func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk")
response := &Response{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
}

d.runResponseStreamingPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
logger.V(logutil.TRACE).Info("Exiting HandleResponseBodyChunk")
return reqCtx, nil
}

// HandleResponseBodyComplete is called when the response body is fully received.
func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
response := &Response{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
}

d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)

logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
return reqCtx, nil
}

Expand All @@ -304,7 +346,7 @@ func (d *Director) GetRandomPod() *backend.Pod {
func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest,
schedulingResult *schedulingtypes.SchedulingResult, targetPort int) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.preRequestPlugins {
for _, plugin := range d.requestControlPlugins.preRequestPlugins {
loggerDebug.Info("Running pre-request plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.PreRequest(ctx, request, schedulingResult, targetPort)
Expand All @@ -313,13 +355,34 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
}
}

func (d *Director) runPostResponsePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.postResponsePlugins {
for _, plugin := range d.requestControlPlugins.responseReceivedPlugins {
loggerDebug.Info("Running post-response plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.PostResponse(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(PostResponseExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
plugin.ResponseReceived(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(ResponseReceivedExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerDebug.Info("Completed running post-response plugin successfully", "plugin", plugin.TypedName())
}
}

func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
for _, plugin := range d.requestControlPlugins.responseStreamingPlugins {
loggerTrace.Info("Running post-response chunk plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.ResponseStreaming(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(ResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
}
}

func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.requestControlPlugins.responseCompletePlugins {
loggerDebug.Info("Running post-response complete plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.ResponseComplete(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(ResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerDebug.Info("Completed running post-response complete plugin successfully", "plugin", plugin.TypedName())
}
}
Loading