Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pkg/epp/handlers/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,26 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
reqCtx.ResponseComplete = true

reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true)
if s.director != nil {
s.director.HandleResponseBodyComplete(ctx, reqCtx)
}
return reqCtx, nil
}

// The function is to handle streaming response if the modelServer is streaming.
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
if s.director != nil {
s.director.HandleResponseBodyStreaming(ctx, reqCtx)
}
if strings.Contains(responseText, streamingEndMsg) {
reqCtx.ResponseComplete = true
resp := parseRespForUsage(ctx, responseText)
reqCtx.Usage = resp.Usage
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
if s.director != nil {
s.director.HandleResponseBodyComplete(ctx, reqCtx)
}
}
}

Expand All @@ -83,7 +93,7 @@ func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *Req
}
}

reqCtx, err := s.director.HandleResponse(ctx, reqCtx)
reqCtx, err := s.director.HandleResponseReceived(ctx, reqCtx)

return reqCtx, err
}
Expand Down
10 changes: 6 additions & 4 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer

type Director interface {
HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponse(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
GetRandomPod() *backend.Pod
}

Expand Down Expand Up @@ -121,7 +123,7 @@ const (
HeaderRequestResponseComplete StreamRequestState = 1
BodyRequestResponsesComplete StreamRequestState = 2
TrailerRequestResponsesComplete StreamRequestState = 3
ResponseRecieved StreamRequestState = 4
ResponseReceived StreamRequestState = 4
HeaderResponseResponseComplete StreamRequestState = 5
BodyResponseResponsesComplete StreamRequestState = 6
TrailerResponseResponsesComplete StreamRequestState = 7
Expand Down Expand Up @@ -251,7 +253,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
loggerTrace.Info("model server is streaming response")
}
}
reqCtx.RequestState = ResponseRecieved
reqCtx.RequestState = ResponseReceived

var responseErr error
reqCtx, responseErr = s.HandleResponseHeaders(ctx, reqCtx, v)
Expand Down Expand Up @@ -377,7 +379,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
}
}
if r.RequestState == ResponseRecieved && r.respHeaderResp != nil {
if r.RequestState == ResponseReceived && r.respHeaderResp != nil {
loggerTrace.Info("Sending response header response", "obj", r.respHeaderResp)
if err := srv.Send(r.respHeaderResp); err != nil {
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
Expand Down
89 changes: 72 additions & 17 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,26 @@ type SaturationDetector interface {
// NewDirectorWithConfig creates a new Director instance with all dependencies.
func NewDirectorWithConfig(datastore Datastore, scheduler Scheduler, saturationDetector SaturationDetector, config *Config) *Director {
return &Director{
datastore: datastore,
scheduler: scheduler,
saturationDetector: saturationDetector,
preRequestPlugins: config.preRequestPlugins,
postResponsePlugins: config.postResponsePlugins,
defaultPriority: 0, // define default priority explicitly
datastore: datastore,
scheduler: scheduler,
saturationDetector: saturationDetector,
preRequestPlugins: config.preRequestPlugins,
postResponseReceivedPlugins: config.postResponseReceivedPlugins,
postResponseStreamingPlugins: config.postResponseStreamingPlugins,
postResponseCompletePlugins: config.postResponseCompletePlugins,
defaultPriority: 0, // define default priority explicitly
}
}

// Director orchestrates the request handling flow, including scheduling.
type Director struct {
datastore Datastore
scheduler Scheduler
saturationDetector SaturationDetector
preRequestPlugins []PreRequest
postResponsePlugins []PostResponse
datastore Datastore
scheduler Scheduler
saturationDetector SaturationDetector
preRequestPlugins []PreRequest
postResponseReceivedPlugins []PostResponseReceived
postResponseStreamingPlugins []PostResponseStreaming
postResponseCompletePlugins []PostResponseComplete
// we just need a pointer to an int variable since priority is a pointer in InferenceObjective
// no need to set this in the constructor, since the value we want is the default int val
// and value types cannot be nil
Expand Down Expand Up @@ -278,19 +282,49 @@ func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []sch
return pm
}

func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
// HandleResponseReceived is called when the first chunk of the response arrives.
func (d *Director) HandleResponseReceived(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
response := &Response{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
}

// TODO: to extend fallback functionality, handle cases where target pod is unavailable
// https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1224
d.runPostResponsePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
d.runPostResponseReceivedPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)

return reqCtx, nil
}

// HandleResponseBodyStreaming is called every time a chunk of the response body is received.
func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk")
response := &Response{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
}

d.runPostResponseStreamingPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
logger.V(logutil.TRACE).Info("Exiting HandleResponseBodyChunk")
return reqCtx, nil
}

// HandleResponseBodyComplete is called when the response body is fully received.
func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
response := &Response{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
}

d.runPostResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)

logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
return reqCtx, nil
}

func (d *Director) GetRandomPod() *backend.Pod {
pods := d.datastore.PodList(backendmetrics.AllPodsPredicate)
if len(pods) == 0 {
Expand All @@ -313,13 +347,34 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
}
}

func (d *Director) runPostResponsePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
func (d *Director) runPostResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.postResponsePlugins {
for _, plugin := range d.postResponseReceivedPlugins {
loggerDebug.Info("Running post-response plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.PostResponse(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(PostResponseExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
plugin.PostResponseReceived(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(PostResponseReceivedExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerDebug.Info("Completed running post-response plugin successfully", "plugin", plugin.TypedName())
}
}

func (d *Director) runPostResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
for _, plugin := range d.postResponseStreamingPlugins {
loggerTrace.Info("Running post-response chunk plugin", "plugin", plugin.TypedName().Type)
before := time.Now()
plugin.PostResponseStreaming(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(PostResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
}
}

func (d *Director) runPostResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.postResponseCompletePlugins {
loggerDebug.Info("Running post-response complete plugin", "plugin", plugin.TypedName().Type)
before := time.Now()
plugin.PostResponseComplete(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(PostResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerDebug.Info("Completed running post-response complete plugin successfully", "plugin", plugin.TypedName())
}
}
138 changes: 127 additions & 11 deletions pkg/epp/requestcontrol/director_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,13 +592,13 @@ func TestGetRandomPod(t *testing.T) {
}
}

func TestDirector_HandleResponse(t *testing.T) {
pr1 := newTestPostResponse("pr1")
func TestDirector_HandleResponseReceived(t *testing.T) {
pr1 := newTestPostResponseReceived("pr1")

ctx := logutil.NewTestLoggerIntoContext(context.Background())
ds := datastore.NewDatastore(t.Context(), nil)
mockSched := &mockScheduler{}
director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithPostResponsePlugins(pr1))
director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithPostResponseReceivedPlugins(pr1))

reqCtx := &handlers.RequestContext{
Request: &handlers.Request{
Expand All @@ -613,7 +613,7 @@ func TestDirector_HandleResponse(t *testing.T) {
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
}

_, err := director.HandleResponse(ctx, reqCtx)
_, err := director.HandleResponseReceived(ctx, reqCtx)
if err != nil {
t.Fatalf("HandleResponse() returned unexpected error: %v", err)
}
Expand All @@ -629,27 +629,143 @@ func TestDirector_HandleResponse(t *testing.T) {
}
}

func TestDirector_HandleResponseStreaming(t *testing.T) {
ps1 := newTestPostResponseStreaming("ps1")

ctx := logutil.NewTestLoggerIntoContext(context.Background())
ds := datastore.NewDatastore(t.Context(), nil)
mockSched := &mockScheduler{}
director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithPostResponseStreamingPlugins(ps1))

reqCtx := &handlers.RequestContext{
Request: &handlers.Request{
Headers: map[string]string{
requtil.RequestIdHeaderKey: "test-req-id-for-streaming",
},
},
Response: &handlers.Response{
Headers: map[string]string{"X-Test-Streaming-Header": "StreamValue"},
},
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
}

_, err := director.HandleResponseBodyStreaming(ctx, reqCtx)
if err != nil {
t.Fatalf("HandleResponseBodyStreaming() returned unexpected error: %v", err)
}

if diff := cmp.Diff("test-req-id-for-streaming", ps1.lastRespOnStreaming.RequestId); diff != "" {
t.Errorf("Scheduler.OnStreaming RequestId mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(reqCtx.Response.Headers, ps1.lastRespOnStreaming.Headers); diff != "" {
t.Errorf("Scheduler.OnStreaming Headers mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff("namespace1/test-pod-name", ps1.lastTargetPodOnStreaming); diff != "" {
t.Errorf("Scheduler.OnStreaming TargetPodName mismatch (-want +got):\n%s", diff)
}
}

func TestDirector_HandleResponseComplete(t *testing.T) {
pc1 := newTestPostResponseComplete("pc1")

ctx := logutil.NewTestLoggerIntoContext(context.Background())
ds := datastore.NewDatastore(t.Context(), nil)
mockSched := &mockScheduler{}
director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithPostResponseCompletePlugins(pc1))

reqCtx := &handlers.RequestContext{
Request: &handlers.Request{
Headers: map[string]string{
requtil.RequestIdHeaderKey: "test-req-id-for-complete",
},
},
Response: &handlers.Response{
Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"},
},
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
}

_, err := director.HandleResponseBodyComplete(ctx, reqCtx)
if err != nil {
t.Fatalf("HandleResponseBodyComplete() returned unexpected error: %v", err)
}

if diff := cmp.Diff("test-req-id-for-complete", pc1.lastRespOnComplete.RequestId); diff != "" {
t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" {
t.Errorf("Scheduler.OnComplete Headers mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" {
t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff)
}
}

const (
testPostResponseType = "test-post-response"
testPostResponseReceivedType = "test-post-response"
testPostStreamingType = "test-post-streaming"
testPostCompleteType = "test-post-complete"
)

type testPostResponse struct {
type testPostResponseReceived struct {
tn plugins.TypedName
lastRespOnResponse *Response
lastTargetPodOnResponse string
}

func newTestPostResponse(name string) *testPostResponse {
return &testPostResponse{
tn: plugins.TypedName{Type: testPostResponseType, Name: name},
type testPostResponseStreaming struct {
tn plugins.TypedName
lastRespOnStreaming *Response
lastTargetPodOnStreaming string
}

type testPostResponseComplete struct {
tn plugins.TypedName
lastRespOnComplete *Response
lastTargetPodOnComplete string
}

func newTestPostResponseReceived(name string) *testPostResponseReceived {
return &testPostResponseReceived{
tn: plugins.TypedName{Type: testPostResponseReceivedType, Name: name},
}
}

func (p *testPostResponse) TypedName() plugins.TypedName {
func newTestPostResponseStreaming(name string) *testPostResponseStreaming {
return &testPostResponseStreaming{
tn: plugins.TypedName{Type: testPostStreamingType, Name: name},
}
}

func newTestPostResponseComplete(name string) *testPostResponseComplete {
return &testPostResponseComplete{
tn: plugins.TypedName{Type: testPostCompleteType, Name: name},
}
}

func (p *testPostResponseReceived) TypedName() plugins.TypedName {
return p.tn
}

func (p *testPostResponseStreaming) TypedName() plugins.TypedName {
return p.tn
}

func (p *testPostResponseComplete) TypedName() plugins.TypedName {
return p.tn
}

func (p *testPostResponse) PostResponse(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
func (p *testPostResponseReceived) PostResponseReceived(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
p.lastRespOnResponse = response
p.lastTargetPodOnResponse = targetPod.NamespacedName.String()
}

func (p *testPostResponseStreaming) PostResponseStreaming(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
p.lastRespOnStreaming = response
p.lastTargetPodOnStreaming = targetPod.NamespacedName.String()
}

func (p *testPostResponseComplete) PostResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
p.lastRespOnComplete = response
p.lastTargetPodOnComplete = targetPod.NamespacedName.String()
}
Loading