diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index 7dfaf3b2e..5760cbfc6 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -61,16 +61,27 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques reqCtx.ResponseComplete = true reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true) - return reqCtx, nil + + return s.director.HandleResponseBodyComplete(ctx, reqCtx) } // The function is to handle streaming response if the modelServer is streaming. func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) { + logger := log.FromContext(ctx) + _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx, logger) + if err != nil { + logger.Error(err, "error in HandleResponseBodyStreaming") + } if strings.Contains(responseText, streamingEndMsg) { + reqCtx.ResponseComplete = true resp := parseRespForUsage(ctx, responseText) reqCtx.Usage = resp.Usage metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens) metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens) + _, err := s.director.HandleResponseBodyComplete(ctx, reqCtx) + if err != nil { + logger.Error(err, "error in HandleResponseBodyComplete") + } } } @@ -83,7 +94,7 @@ func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *Req } } - reqCtx, err := s.director.HandleResponse(ctx, reqCtx) + reqCtx, err := s.director.HandleResponseReceived(ctx, reqCtx) return reqCtx, err } diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 6eb7734e4..290161167 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -21,8 +21,10 @@ import ( "encoding/json" "testing" + "github.com/go-logr/logr" "github.com/google/go-cmp/cmp" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -59,6 +61,27 @@ data: [DONE] ` ) +type mockDirector struct{} + +func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error) { + return reqCtx, nil +} +func (m *mockDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { + return reqCtx, nil +} +func (m *mockDirector) HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { + return reqCtx, nil +} +func (m *mockDirector) HandlePreRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { + return reqCtx, nil +} +func (m *mockDirector) GetRandomPod() *backend.Pod { + return &backend.Pod{} +} +func (m *mockDirector) HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { + return reqCtx, nil +} + func TestHandleResponseBody(t *testing.T) { ctx := logutil.NewTestLoggerIntoContext(context.Background()) @@ -83,6 +106,7 @@ func TestHandleResponseBody(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { server := &StreamingServer{} + server.director = &mockDirector{} reqCtx := test.reqCtx if reqCtx == nil { reqCtx = &RequestContext{} @@ -143,6 +167,7 @@ func TestHandleStreamedResponseBody(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { server := &StreamingServer{} + server.director = &mockDirector{} reqCtx := test.reqCtx if reqCtx == nil { reqCtx = &RequestContext{} diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index ddfb3316c..59cde8949 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -54,7 +54,9 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer type Director interface { HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) - HandleResponse(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) + HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) + HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error) + HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) GetRandomPod() *backend.Pod } @@ -121,7 +123,7 @@ const ( HeaderRequestResponseComplete StreamRequestState = 1 BodyRequestResponsesComplete StreamRequestState = 2 TrailerRequestResponsesComplete StreamRequestState = 3 - ResponseRecieved StreamRequestState = 4 + ResponseReceived StreamRequestState = 4 HeaderResponseResponseComplete StreamRequestState = 5 BodyResponseResponsesComplete StreamRequestState = 6 TrailerResponseResponsesComplete StreamRequestState = 7 @@ -251,7 +253,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) loggerTrace.Info("model server is streaming response") } } - reqCtx.RequestState = ResponseRecieved + reqCtx.RequestState = ResponseReceived var responseErr error reqCtx, responseErr = s.HandleResponseHeaders(ctx, reqCtx, v) @@ -377,7 +379,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) } } - if r.RequestState == ResponseRecieved && r.respHeaderResp != nil { + if r.RequestState == ResponseReceived && r.respHeaderResp != nil { loggerTrace.Info("Sending response header response", "obj", r.respHeaderResp) if err := srv.Send(r.respHeaderResp); err != nil { return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index a3e2d6d13..56b4b1870 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -27,6 +27,7 @@ import ( "strings" "time" + "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" @@ -62,22 +63,20 @@ type SaturationDetector interface { // NewDirectorWithConfig creates a new Director instance with all dependencies. func NewDirectorWithConfig(datastore Datastore, scheduler Scheduler, saturationDetector SaturationDetector, config *Config) *Director { return &Director{ - datastore: datastore, - scheduler: scheduler, - saturationDetector: saturationDetector, - preRequestPlugins: config.preRequestPlugins, - postResponsePlugins: config.postResponsePlugins, - defaultPriority: 0, // define default priority explicitly + datastore: datastore, + scheduler: scheduler, + saturationDetector: saturationDetector, + requestControlPlugins: *config, + defaultPriority: 0, // define default priority explicitly } } // Director orchestrates the request handling flow, including scheduling. type Director struct { - datastore Datastore - scheduler Scheduler - saturationDetector SaturationDetector - preRequestPlugins []PreRequest - postResponsePlugins []PostResponse + datastore Datastore + scheduler Scheduler + saturationDetector SaturationDetector + requestControlPlugins Config // we just need a pointer to an int variable since priority is a pointer in InferenceObjective // no need to set this in the constructor, since the value we want is the default int val // and value types cannot be nil @@ -278,7 +277,8 @@ func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []sch return pm } -func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { +// HandleResponseReceived is called when the response headers are received. +func (d *Director) HandleResponseReceived(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { response := &Response{ RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], Headers: reqCtx.Response.Headers, @@ -286,11 +286,39 @@ func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestC // TODO: to extend fallback functionality, handle cases where target pod is unavailable // https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1224 - d.runPostResponsePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) + d.runResponseReceivedPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) return reqCtx, nil } +// HandleResponseBodyStreaming is called every time a chunk of the response body is received. +func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext, logger logr.Logger) (*handlers.RequestContext, error) { + logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk") + response := &Response{ + RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], + Headers: reqCtx.Response.Headers, + } + + d.runResponseStreamingPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) + logger.V(logutil.TRACE).Info("Exiting HandleResponseBodyChunk") + return reqCtx, nil +} + +// HandleResponseBodyComplete is called when the response body is fully received. +func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { + logger := log.FromContext(ctx).WithValues("stage", "bodyChunk") + logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete") + response := &Response{ + RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], + Headers: reqCtx.Response.Headers, + } + + d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) + + logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete") + return reqCtx, nil +} + func (d *Director) GetRandomPod() *backend.Pod { pods := d.datastore.PodList(backendmetrics.AllPodsPredicate) if len(pods) == 0 { @@ -304,22 +332,44 @@ func (d *Director) GetRandomPod() *backend.Pod { func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult, targetPort int) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) - for _, plugin := range d.preRequestPlugins { - loggerDebug.Info("Running pre-request plugin", "plugin", plugin.TypedName()) + for _, plugin := range d.requestControlPlugins.preRequestPlugins { + loggerDebug.Info("Running PreRequest plugin", "plugin", plugin.TypedName()) before := time.Now() plugin.PreRequest(ctx, request, schedulingResult, targetPort) metrics.RecordPluginProcessingLatency(PreRequestExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before)) - loggerDebug.Info("Completed running pre-request plugin successfully", "plugin", plugin.TypedName()) + loggerDebug.Info("Completed running PreRequest plugin successfully", "plugin", plugin.TypedName()) + } +} + +func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { + loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) + for _, plugin := range d.requestControlPlugins.responseReceivedPlugins { + loggerDebug.Info("Running ResponseReceived plugin", "plugin", plugin.TypedName()) + before := time.Now() + plugin.ResponseReceived(ctx, request, response, targetPod) + metrics.RecordPluginProcessingLatency(ResponseReceivedExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before)) + loggerDebug.Info("Completed running ResponseReceived plugin successfully", "plugin", plugin.TypedName()) + } +} + +func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { + loggerTrace := log.FromContext(ctx).V(logutil.TRACE) + for _, plugin := range d.requestControlPlugins.responseStreamingPlugins { + loggerTrace.Info("Running ResponseStreaming plugin", "plugin", plugin.TypedName()) + before := time.Now() + plugin.ResponseStreaming(ctx, request, response, targetPod) + metrics.RecordPluginProcessingLatency(ResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before)) + loggerTrace.Info("Completed running ResponseStreaming plugin successfully", "plugin", plugin.TypedName()) } } -func (d *Director) runPostResponsePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) - for _, plugin := range d.postResponsePlugins { - loggerDebug.Info("Running post-response plugin", "plugin", plugin.TypedName()) + for _, plugin := range d.requestControlPlugins.responseCompletePlugins { + loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName()) before := time.Now() - plugin.PostResponse(ctx, request, response, targetPod) - metrics.RecordPluginProcessingLatency(PostResponseExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before)) - loggerDebug.Info("Completed running post-response plugin successfully", "plugin", plugin.TypedName()) + plugin.ResponseComplete(ctx, request, response, targetPod) + metrics.RecordPluginProcessingLatency(ResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before)) + loggerDebug.Info("Completed running ResponseComplete plugin successfully", "plugin", plugin.TypedName()) } } diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index a0cb7c325..ea486d3c4 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -32,6 +32,7 @@ import ( "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/log" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" @@ -592,13 +593,13 @@ func TestGetRandomPod(t *testing.T) { } } -func TestDirector_HandleResponse(t *testing.T) { - pr1 := newTestPostResponse("pr1") +func TestDirector_HandleResponseReceived(t *testing.T) { + pr1 := newTestResponseReceived("pr1") ctx := logutil.NewTestLoggerIntoContext(context.Background()) ds := datastore.NewDatastore(t.Context(), nil) mockSched := &mockScheduler{} - director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithPostResponsePlugins(pr1)) + director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseReceivedPlugins(pr1)) reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ @@ -613,7 +614,7 @@ func TestDirector_HandleResponse(t *testing.T) { TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } - _, err := director.HandleResponse(ctx, reqCtx) + _, err := director.HandleResponseReceived(ctx, reqCtx) if err != nil { t.Fatalf("HandleResponse() returned unexpected error: %v", err) } @@ -629,27 +630,144 @@ func TestDirector_HandleResponse(t *testing.T) { } } +func TestDirector_HandleResponseStreaming(t *testing.T) { + ps1 := newTestResponseStreaming("ps1") + + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + ds := datastore.NewDatastore(t.Context(), nil) + mockSched := &mockScheduler{} + director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseStreamingPlugins(ps1)) + logger := log.FromContext(ctx) + + reqCtx := &handlers.RequestContext{ + Request: &handlers.Request{ + Headers: map[string]string{ + requtil.RequestIdHeaderKey: "test-req-id-for-streaming", + }, + }, + Response: &handlers.Response{ + Headers: map[string]string{"X-Test-Streaming-Header": "StreamValue"}, + }, + TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, + } + + _, err := director.HandleResponseBodyStreaming(ctx, reqCtx, logger) + if err != nil { + t.Fatalf("HandleResponseBodyStreaming() returned unexpected error: %v", err) + } + + if diff := cmp.Diff("test-req-id-for-streaming", ps1.lastRespOnStreaming.RequestId); diff != "" { + t.Errorf("Scheduler.OnStreaming RequestId mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(reqCtx.Response.Headers, ps1.lastRespOnStreaming.Headers); diff != "" { + t.Errorf("Scheduler.OnStreaming Headers mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff("namespace1/test-pod-name", ps1.lastTargetPodOnStreaming); diff != "" { + t.Errorf("Scheduler.OnStreaming TargetPodName mismatch (-want +got):\n%s", diff) + } +} + +func TestDirector_HandleResponseComplete(t *testing.T) { + pc1 := newTestResponseComplete("pc1") + + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + ds := datastore.NewDatastore(t.Context(), nil) + mockSched := &mockScheduler{} + director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseCompletePlugins(pc1)) + + reqCtx := &handlers.RequestContext{ + Request: &handlers.Request{ + Headers: map[string]string{ + requtil.RequestIdHeaderKey: "test-req-id-for-complete", + }, + }, + Response: &handlers.Response{ + Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"}, + }, + TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, + } + + _, err := director.HandleResponseBodyComplete(ctx, reqCtx) + if err != nil { + t.Fatalf("HandleResponseBodyComplete() returned unexpected error: %v", err) + } + + if diff := cmp.Diff("test-req-id-for-complete", pc1.lastRespOnComplete.RequestId); diff != "" { + t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" { + t.Errorf("Scheduler.OnComplete Headers mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" { + t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff) + } +} + const ( - testPostResponseType = "test-post-response" + testResponseReceivedType = "test-response-received" + testPostStreamingType = "test-response-streaming" + testPostCompleteType = "test-response-complete" ) -type testPostResponse struct { +type testResponseReceived struct { tn plugins.TypedName lastRespOnResponse *Response lastTargetPodOnResponse string } -func newTestPostResponse(name string) *testPostResponse { - return &testPostResponse{ - tn: plugins.TypedName{Type: testPostResponseType, Name: name}, +type testResponseStreaming struct { + tn plugins.TypedName + lastRespOnStreaming *Response + lastTargetPodOnStreaming string +} + +type testResponseComplete struct { + tn plugins.TypedName + lastRespOnComplete *Response + lastTargetPodOnComplete string +} + +func newTestResponseReceived(name string) *testResponseReceived { + return &testResponseReceived{ + tn: plugins.TypedName{Type: testResponseReceivedType, Name: name}, } } -func (p *testPostResponse) TypedName() plugins.TypedName { +func newTestResponseStreaming(name string) *testResponseStreaming { + return &testResponseStreaming{ + tn: plugins.TypedName{Type: testPostStreamingType, Name: name}, + } +} + +func newTestResponseComplete(name string) *testResponseComplete { + return &testResponseComplete{ + tn: plugins.TypedName{Type: testPostCompleteType, Name: name}, + } +} + +func (p *testResponseReceived) TypedName() plugins.TypedName { + return p.tn +} + +func (p *testResponseStreaming) TypedName() plugins.TypedName { + return p.tn +} + +func (p *testResponseComplete) TypedName() plugins.TypedName { return p.tn } -func (p *testPostResponse) PostResponse(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (p *testResponseReceived) ResponseReceived(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { p.lastRespOnResponse = response p.lastTargetPodOnResponse = targetPod.NamespacedName.String() } + +func (p *testResponseStreaming) ResponseStreaming(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { + p.lastRespOnStreaming = response + p.lastTargetPodOnStreaming = targetPod.NamespacedName.String() +} + +func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { + p.lastRespOnComplete = response + p.lastTargetPodOnComplete = targetPod.NamespacedName.String() +} diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go index ca823a670..44334c68f 100644 --- a/pkg/epp/requestcontrol/plugins.go +++ b/pkg/epp/requestcontrol/plugins.go @@ -25,8 +25,10 @@ import ( ) const ( - PreRequestExtensionPoint = "PreRequest" - PostResponseExtensionPoint = "PostResponse" + PreRequestExtensionPoint = "PreRequest" + ResponseReceivedExtensionPoint = "ResponseReceived" + ResponseStreamingExtensionPoint = "ResponseStreaming" + ResponseCompleteExtensionPoint = "ResponseComplete" ) // PreRequest is called by the director after a getting result from scheduling layer and @@ -36,9 +38,22 @@ type PreRequest interface { PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, targetPort int) } -// PostResponse is called by the director after a successful response was sent. +// ResponseReceived is called by the director after the response headers are successfully received +// which indicates the beginning of the response handling by the model server. // The given pod argument is the pod that served the request. -type PostResponse interface { +type ResponseReceived interface { plugins.Plugin - PostResponse(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) + ResponseReceived(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) +} + +// ResponseStreaming is called by the director after each chunk of streaming response is sent. +type ResponseStreaming interface { + plugins.Plugin + ResponseStreaming(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) +} + +// ResponseComplete is called by the director after the complete response is sent. +type ResponseComplete interface { + plugins.Plugin + ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) } diff --git a/pkg/epp/requestcontrol/request_control_config.go b/pkg/epp/requestcontrol/request_control_config.go index 2d6dc95e7..ffa6c6609 100644 --- a/pkg/epp/requestcontrol/request_control_config.go +++ b/pkg/epp/requestcontrol/request_control_config.go @@ -23,15 +23,19 @@ import ( // NewConfig creates a new Config object and returns its pointer. func NewConfig() *Config { return &Config{ - preRequestPlugins: []PreRequest{}, - postResponsePlugins: []PostResponse{}, + preRequestPlugins: []PreRequest{}, + responseReceivedPlugins: []ResponseReceived{}, + responseStreamingPlugins: []ResponseStreaming{}, + responseCompletePlugins: []ResponseComplete{}, } } // Config provides a configuration for the requestcontrol plugins. type Config struct { - preRequestPlugins []PreRequest - postResponsePlugins []PostResponse + preRequestPlugins []PreRequest + responseReceivedPlugins []ResponseReceived + responseStreamingPlugins []ResponseStreaming + responseCompletePlugins []ResponseComplete } // WithPreRequestPlugins sets the given plugins as the PreRequest plugins. @@ -41,20 +45,44 @@ func (c *Config) WithPreRequestPlugins(plugins ...PreRequest) *Config { return c } -// WithPostResponsePlugins sets the given plugins as the PostResponse plugins. -// If the Config has PostResponse plugins already, this call replaces the existing plugins with the given ones. -func (c *Config) WithPostResponsePlugins(plugins ...PostResponse) *Config { - c.postResponsePlugins = plugins +// WithResponseReceivedPlugins sets the given plugins as the ResponseReceived plugins. +// If the Config has ResponseReceived plugins already, this call replaces the existing plugins with the given ones. +func (c *Config) WithResponseReceivedPlugins(plugins ...ResponseReceived) *Config { + c.responseReceivedPlugins = plugins return c } +// WithResponseStreamingPlugins sets the given plugins as the ResponseStreaming plugins. +// If the Config has ResponseStreaming plugins already, this call replaces the existing plugins with the given ones. +func (c *Config) WithResponseStreamingPlugins(plugins ...ResponseStreaming) *Config { + c.responseStreamingPlugins = plugins + return c +} + +// WithResponseCompletePlugins sets the given plugins as the ResponseComplete plugins. +// If the Config has ResponseComplete plugins already, this call replaces the existing plugins with the given ones. +func (c *Config) WithResponseCompletePlugins(plugins ...ResponseComplete) *Config { + c.responseCompletePlugins = plugins + return c +} + +// AddPlugins adds the given plugins to the Config. +// The type of each plugin is checked and added to the corresponding list of plugins in the Config. +// If a plugin implements multiple plugin interfaces, it will be added to each corresponding list. + func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) { for _, plugin := range pluginObjects { if preRequestPlugin, ok := plugin.(PreRequest); ok { c.preRequestPlugins = append(c.preRequestPlugins, preRequestPlugin) } - if postResponsePlugin, ok := plugin.(PostResponse); ok { - c.postResponsePlugins = append(c.postResponsePlugins, postResponsePlugin) + if responseReceivedPlugin, ok := plugin.(ResponseReceived); ok { + c.responseReceivedPlugins = append(c.responseReceivedPlugins, responseReceivedPlugin) + } + if responseStreamingPlugin, ok := plugin.(ResponseStreaming); ok { + c.responseStreamingPlugins = append(c.responseStreamingPlugins, responseStreamingPlugin) + } + if responseCompletePlugin, ok := plugin.(ResponseComplete); ok { + c.responseCompletePlugins = append(c.responseCompletePlugins, responseCompletePlugin) } } } diff --git a/pkg/epp/requestcontrol/types.go b/pkg/epp/requestcontrol/types.go index 8604e1dda..c881ed713 100644 --- a/pkg/epp/requestcontrol/types.go +++ b/pkg/epp/requestcontrol/types.go @@ -16,7 +16,7 @@ limitations under the License. package requestcontrol -// Response contains information from the response received to be passed to PostResponse plugins +// Response contains information from the response received to be passed to the Response requestcontrol plugins type Response struct { // RequestId is the Envoy generated Id for the request being processed RequestId string diff --git a/pkg/epp/server/server_test.go b/pkg/epp/server/server_test.go index aff6d4644..f3fa16bfd 100644 --- a/pkg/epp/server/server_test.go +++ b/pkg/epp/server/server_test.go @@ -22,6 +22,7 @@ import ( "testing" pb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/go-logr/logr" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -181,7 +182,15 @@ func (ts *testDirector) HandleRequest(ctx context.Context, reqCtx *handlers.Requ return reqCtx, nil } -func (ts *testDirector) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { +func (ts *testDirector) HandleResponseReceived(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { + return reqCtx, nil +} + +func (ts *testDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext, logger logr.Logger) (*handlers.RequestContext, error) { + return reqCtx, nil +} + +func (ts *testDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { return reqCtx, nil }