Skip to content

Commit f9ba10f

Browse files
Break out PostResponse plugin into 3 constituent plugins for request recieved, streaming, and complete
1 parent f806677 commit f9ba10f

File tree

7 files changed

+277
-44
lines changed

7 files changed

+277
-44
lines changed

pkg/epp/handlers/response.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,26 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
6161
reqCtx.ResponseComplete = true
6262

6363
reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true)
64+
if s.director != nil {
65+
s.director.HandleResponseBodyComplete(ctx, reqCtx)
66+
}
6467
return reqCtx, nil
6568
}
6669

6770
// The function is to handle streaming response if the modelServer is streaming.
6871
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
72+
if s.director != nil {
73+
s.director.HandleResponseBodyStreaming(ctx, reqCtx)
74+
}
6975
if strings.Contains(responseText, streamingEndMsg) {
76+
reqCtx.ResponseComplete = true
7077
resp := parseRespForUsage(ctx, responseText)
7178
reqCtx.Usage = resp.Usage
7279
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
7380
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
81+
if s.director != nil {
82+
s.director.HandleResponseBodyComplete(ctx, reqCtx)
83+
}
7484
}
7585
}
7686

@@ -83,7 +93,7 @@ func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *Req
8393
}
8494
}
8595

86-
reqCtx, err := s.director.HandleResponse(ctx, reqCtx)
96+
reqCtx, err := s.director.HandleResponseRecieved(ctx, reqCtx)
8797

8898
return reqCtx, err
8999
}

pkg/epp/handlers/server.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer
5454

5555
type Director interface {
5656
HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
57-
HandleResponse(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
57+
HandleResponseRecieved(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
58+
HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
59+
HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
5860
GetRandomPod() *backend.Pod
5961
}
6062

pkg/epp/requestcontrol/director.go

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,26 @@ type SaturationDetector interface {
6262
// NewDirectorWithConfig creates a new Director instance with all dependencies.
6363
func NewDirectorWithConfig(datastore Datastore, scheduler Scheduler, saturationDetector SaturationDetector, config *Config) *Director {
6464
return &Director{
65-
datastore: datastore,
66-
scheduler: scheduler,
67-
saturationDetector: saturationDetector,
68-
preRequestPlugins: config.preRequestPlugins,
69-
postResponsePlugins: config.postResponsePlugins,
70-
defaultPriority: 0, // define default priority explicitly
65+
datastore: datastore,
66+
scheduler: scheduler,
67+
saturationDetector: saturationDetector,
68+
preRequestPlugins: config.preRequestPlugins,
69+
postResponseRecievedPlugins: config.postResponseRecievedPlugins,
70+
postResponseStreamingPlugins: config.postResponseStreamingPlugins,
71+
postResponseCompletePlugins: config.postResponseCompletePlugins,
72+
defaultPriority: 0, // define default priority explicitly
7173
}
7274
}
7375

7476
// Director orchestrates the request handling flow, including scheduling.
7577
type Director struct {
76-
datastore Datastore
77-
scheduler Scheduler
78-
saturationDetector SaturationDetector
79-
preRequestPlugins []PreRequest
80-
postResponsePlugins []PostResponse
78+
datastore Datastore
79+
scheduler Scheduler
80+
saturationDetector SaturationDetector
81+
preRequestPlugins []PreRequest
82+
postResponseRecievedPlugins []PostResponseRecieved
83+
postResponseStreamingPlugins []PostResponseStreaming
84+
postResponseCompletePlugins []PostResponseComplete
8185
// we just need a pointer to an int variable since priority is a pointer in InferenceObjective
8286
// no need to set this in the constructor, since the value we want is the default int val
8387
// and value types cannot be nil
@@ -278,19 +282,49 @@ func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []sch
278282
return pm
279283
}
280284

281-
func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
285+
// HandleResponseRecieved is called when the first chunk of the response arrives.
286+
func (d *Director) HandleResponseRecieved(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
282287
response := &Response{
283288
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
284289
Headers: reqCtx.Response.Headers,
285290
}
286291

287292
// TODO: to extend fallback functionality, handle cases where target pod is unavailable
288293
// https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1224
289-
d.runPostResponsePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
294+
d.runPostResponseRecievedPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
290295

291296
return reqCtx, nil
292297
}
293298

299+
// HandleResponseBodyStreaming is called every time a chunk of the response body is received.
300+
func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
301+
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
302+
logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk")
303+
response := &Response{
304+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
305+
Headers: reqCtx.Response.Headers,
306+
}
307+
308+
d.runPostResponseStreamingPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
309+
logger.V(logutil.TRACE).Info("Exiting HandleResponseBodyChunk")
310+
return reqCtx, nil
311+
}
312+
313+
// HandleResponseBodyComplete is called when the response body is fully received.
314+
func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
315+
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
316+
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
317+
response := &Response{
318+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
319+
Headers: reqCtx.Response.Headers,
320+
}
321+
322+
d.runPostResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
323+
324+
logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
325+
return reqCtx, nil
326+
}
327+
294328
func (d *Director) GetRandomPod() *backend.Pod {
295329
pods := d.datastore.PodList(backendmetrics.AllPodsPredicate)
296330
if len(pods) == 0 {
@@ -313,13 +347,34 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
313347
}
314348
}
315349

316-
func (d *Director) runPostResponsePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
350+
func (d *Director) runPostResponseRecievedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
317351
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
318-
for _, plugin := range d.postResponsePlugins {
352+
for _, plugin := range d.postResponseRecievedPlugins {
319353
loggerDebug.Info("Running post-response plugin", "plugin", plugin.TypedName())
320354
before := time.Now()
321-
plugin.PostResponse(ctx, request, response, targetPod)
322-
metrics.RecordPluginProcessingLatency(PostResponseExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
355+
plugin.PostResponseRecieved(ctx, request, response, targetPod)
356+
metrics.RecordPluginProcessingLatency(PostResponseRecievedExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
323357
loggerDebug.Info("Completed running post-response plugin successfully", "plugin", plugin.TypedName())
324358
}
325359
}
360+
361+
func (d *Director) runPostResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
362+
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
363+
for _, plugin := range d.postResponseStreamingPlugins {
364+
loggerTrace.Info("Running post-response chunk plugin", "plugin", plugin.TypedName().Type)
365+
before := time.Now()
366+
plugin.PostResponseStreaming(ctx, request, response, targetPod)
367+
metrics.RecordPluginProcessingLatency(PostResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
368+
}
369+
}
370+
371+
func (d *Director) runPostResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
372+
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
373+
for _, plugin := range d.postResponseCompletePlugins {
374+
loggerDebug.Info("Running post-response complete plugin", "plugin", plugin.TypedName().Type)
375+
before := time.Now()
376+
plugin.PostResponseComplete(ctx, request, response, targetPod)
377+
metrics.RecordPluginProcessingLatency(PostResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
378+
loggerDebug.Info("Completed running post-response complete plugin successfully", "plugin", plugin.TypedName())
379+
}
380+
}

pkg/epp/requestcontrol/director_test.go

Lines changed: 127 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -592,13 +592,13 @@ func TestGetRandomPod(t *testing.T) {
592592
}
593593
}
594594

595-
func TestDirector_HandleResponse(t *testing.T) {
596-
pr1 := newTestPostResponse("pr1")
595+
func TestDirector_HandleResponseRecieved(t *testing.T) {
596+
pr1 := newTestPostResponseRecieved("pr1")
597597

598598
ctx := logutil.NewTestLoggerIntoContext(context.Background())
599599
ds := datastore.NewDatastore(t.Context(), nil)
600600
mockSched := &mockScheduler{}
601-
director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithPostResponsePlugins(pr1))
601+
director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithPostResponseRecievedPlugins(pr1))
602602

603603
reqCtx := &handlers.RequestContext{
604604
Request: &handlers.Request{
@@ -613,7 +613,7 @@ func TestDirector_HandleResponse(t *testing.T) {
613613
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
614614
}
615615

616-
_, err := director.HandleResponse(ctx, reqCtx)
616+
_, err := director.HandleResponseRecieved(ctx, reqCtx)
617617
if err != nil {
618618
t.Fatalf("HandleResponse() returned unexpected error: %v", err)
619619
}
@@ -629,27 +629,143 @@ func TestDirector_HandleResponse(t *testing.T) {
629629
}
630630
}
631631

632+
func TestDirector_HandleResponseStreaming(t *testing.T) {
633+
ps1 := newTestPostResponseStreaming("ps1")
634+
635+
ctx := logutil.NewTestLoggerIntoContext(context.Background())
636+
ds := datastore.NewDatastore(t.Context(), nil)
637+
mockSched := &mockScheduler{}
638+
director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithPostResponseStreamingPlugins(ps1))
639+
640+
reqCtx := &handlers.RequestContext{
641+
Request: &handlers.Request{
642+
Headers: map[string]string{
643+
requtil.RequestIdHeaderKey: "test-req-id-for-streaming",
644+
},
645+
},
646+
Response: &handlers.Response{
647+
Headers: map[string]string{"X-Test-Streaming-Header": "StreamValue"},
648+
},
649+
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
650+
}
651+
652+
_, err := director.HandleResponseBodyStreaming(ctx, reqCtx)
653+
if err != nil {
654+
t.Fatalf("HandleResponseBodyStreaming() returned unexpected error: %v", err)
655+
}
656+
657+
if diff := cmp.Diff("test-req-id-for-streaming", ps1.lastRespOnStreaming.RequestId); diff != "" {
658+
t.Errorf("Scheduler.OnStreaming RequestId mismatch (-want +got):\n%s", diff)
659+
}
660+
if diff := cmp.Diff(reqCtx.Response.Headers, ps1.lastRespOnStreaming.Headers); diff != "" {
661+
t.Errorf("Scheduler.OnStreaming Headers mismatch (-want +got):\n%s", diff)
662+
}
663+
if diff := cmp.Diff("namespace1/test-pod-name", ps1.lastTargetPodOnStreaming); diff != "" {
664+
t.Errorf("Scheduler.OnStreaming TargetPodName mismatch (-want +got):\n%s", diff)
665+
}
666+
}
667+
668+
func TestDirector_HandleResponseComplete(t *testing.T) {
669+
pc1 := newTestPostResponseComplete("pc1")
670+
671+
ctx := logutil.NewTestLoggerIntoContext(context.Background())
672+
ds := datastore.NewDatastore(t.Context(), nil)
673+
mockSched := &mockScheduler{}
674+
director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithPostResponseCompletePlugins(pc1))
675+
676+
reqCtx := &handlers.RequestContext{
677+
Request: &handlers.Request{
678+
Headers: map[string]string{
679+
requtil.RequestIdHeaderKey: "test-req-id-for-complete",
680+
},
681+
},
682+
Response: &handlers.Response{
683+
Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"},
684+
},
685+
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
686+
}
687+
688+
_, err := director.HandleResponseBodyComplete(ctx, reqCtx)
689+
if err != nil {
690+
t.Fatalf("HandleResponseBodyComplete() returned unexpected error: %v", err)
691+
}
692+
693+
if diff := cmp.Diff("test-req-id-for-complete", pc1.lastRespOnComplete.RequestId); diff != "" {
694+
t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff)
695+
}
696+
if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" {
697+
t.Errorf("Scheduler.OnComplete Headers mismatch (-want +got):\n%s", diff)
698+
}
699+
if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" {
700+
t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff)
701+
}
702+
}
703+
632704
const (
633-
testPostResponseType = "test-post-response"
705+
testPostResponseRecievedType = "test-post-response"
706+
testPostStreamingType = "test-post-streaming"
707+
testPostCompleteType = "test-post-complete"
634708
)
635709

636-
type testPostResponse struct {
710+
type testPostResponseRecieved struct {
637711
tn plugins.TypedName
638712
lastRespOnResponse *Response
639713
lastTargetPodOnResponse string
640714
}
641715

642-
func newTestPostResponse(name string) *testPostResponse {
643-
return &testPostResponse{
644-
tn: plugins.TypedName{Type: testPostResponseType, Name: name},
716+
type testPostResponseStreaming struct {
717+
tn plugins.TypedName
718+
lastRespOnStreaming *Response
719+
lastTargetPodOnStreaming string
720+
}
721+
722+
type testPostResponseComplete struct {
723+
tn plugins.TypedName
724+
lastRespOnComplete *Response
725+
lastTargetPodOnComplete string
726+
}
727+
728+
func newTestPostResponseRecieved(name string) *testPostResponseRecieved {
729+
return &testPostResponseRecieved{
730+
tn: plugins.TypedName{Type: testPostResponseRecievedType, Name: name},
645731
}
646732
}
647733

648-
func (p *testPostResponse) TypedName() plugins.TypedName {
734+
func newTestPostResponseStreaming(name string) *testPostResponseStreaming {
735+
return &testPostResponseStreaming{
736+
tn: plugins.TypedName{Type: testPostStreamingType, Name: name},
737+
}
738+
}
739+
740+
func newTestPostResponseComplete(name string) *testPostResponseComplete {
741+
return &testPostResponseComplete{
742+
tn: plugins.TypedName{Type: testPostCompleteType, Name: name},
743+
}
744+
}
745+
746+
func (p *testPostResponseRecieved) TypedName() plugins.TypedName {
747+
return p.tn
748+
}
749+
750+
func (p *testPostResponseStreaming) TypedName() plugins.TypedName {
751+
return p.tn
752+
}
753+
754+
func (p *testPostResponseComplete) TypedName() plugins.TypedName {
649755
return p.tn
650756
}
651757

652-
func (p *testPostResponse) PostResponse(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
758+
func (p *testPostResponseRecieved) PostResponseRecieved(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
653759
p.lastRespOnResponse = response
654760
p.lastTargetPodOnResponse = targetPod.NamespacedName.String()
655761
}
762+
763+
func (p *testPostResponseStreaming) PostResponseStreaming(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
764+
p.lastRespOnStreaming = response
765+
p.lastTargetPodOnStreaming = targetPod.NamespacedName.String()
766+
}
767+
768+
func (p *testPostResponseComplete) PostResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
769+
p.lastRespOnComplete = response
770+
p.lastTargetPodOnComplete = targetPod.NamespacedName.String()
771+
}

pkg/epp/requestcontrol/plugins.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ import (
2525
)
2626

2727
const (
28-
PreRequestExtensionPoint = "PreRequest"
29-
PostResponseExtensionPoint = "PostResponse"
28+
PreRequestExtensionPoint = "PreRequest"
29+
PostResponseRecievedExtensionPoint = "PostResponseRecieved"
30+
PostResponseStreamingExtensionPoint = "PostResponseStreaming"
31+
PostResponseCompleteExtensionPoint = "PostResponseComplete"
3032
)
3133

3234
// PreRequest is called by the director after a getting result from scheduling layer and
@@ -36,9 +38,21 @@ type PreRequest interface {
3638
PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, targetPort int)
3739
}
3840

39-
// PostResponse is called by the director after a successful response was sent.
41+
// PostResponseRecieved is called by the director after a successful response is sent.
4042
// The given pod argument is the pod that served the request.
41-
type PostResponse interface {
43+
type PostResponseRecieved interface {
4244
plugins.Plugin
43-
PostResponse(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod)
45+
PostResponseRecieved(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod)
46+
}
47+
48+
// PostResponseStreaming is called by the director after each chunk of streaming response is sent.
49+
type PostResponseStreaming interface {
50+
plugins.Plugin
51+
PostResponseStreaming(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod)
52+
}
53+
54+
// PostResponseComplete is called by the director after the complete response is sent.
55+
type PostResponseComplete interface {
56+
plugins.Plugin
57+
PostResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod)
4458
}

0 commit comments

Comments
 (0)