Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions pkg/epp/handlers/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,27 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
reqCtx.ResponseComplete = true

reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true)
return reqCtx, nil

return s.director.HandleResponseBodyComplete(ctx, reqCtx)
}

// The function is to handle streaming response if the modelServer is streaming.
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
logger := log.FromContext(ctx)
_, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx, logger)
if err != nil {
logger.Error(err, "error in HandleResponseBodyStreaming")
}
if strings.Contains(responseText, streamingEndMsg) {
reqCtx.ResponseComplete = true
resp := parseRespForUsage(ctx, responseText)
reqCtx.Usage = resp.Usage
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
_, err := s.director.HandleResponseBodyComplete(ctx, reqCtx)
if err != nil {
logger.Error(err, "error in HandleResponseBodyComplete")
}
}
}

Expand All @@ -83,7 +94,7 @@ func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *Req
}
}

reqCtx, err := s.director.HandleResponse(ctx, reqCtx)
reqCtx, err := s.director.HandleResponseReceived(ctx, reqCtx)

return reqCtx, err
}
Expand Down
25 changes: 25 additions & 0 deletions pkg/epp/handlers/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import (
"encoding/json"
"testing"

"github.com/go-logr/logr"
"github.com/google/go-cmp/cmp"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

Expand Down Expand Up @@ -59,6 +61,27 @@ data: [DONE]
`
)

type mockDirector struct{}

func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error) {
return reqCtx, nil
}
func (m *mockDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
return reqCtx, nil
}
func (m *mockDirector) HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
return reqCtx, nil
}
func (m *mockDirector) HandlePreRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
return reqCtx, nil
}
func (m *mockDirector) GetRandomPod() *backend.Pod {
return &backend.Pod{}
}
func (m *mockDirector) HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
return reqCtx, nil
}

func TestHandleResponseBody(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())

Expand All @@ -83,6 +106,7 @@ func TestHandleResponseBody(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := &StreamingServer{}
server.director = &mockDirector{}
reqCtx := test.reqCtx
if reqCtx == nil {
reqCtx = &RequestContext{}
Expand Down Expand Up @@ -143,6 +167,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := &StreamingServer{}
server.director = &mockDirector{}
reqCtx := test.reqCtx
if reqCtx == nil {
reqCtx = &RequestContext{}
Expand Down
10 changes: 6 additions & 4 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer

type Director interface {
HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponse(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error)
HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
GetRandomPod() *backend.Pod
}

Expand Down Expand Up @@ -121,7 +123,7 @@ const (
HeaderRequestResponseComplete StreamRequestState = 1
BodyRequestResponsesComplete StreamRequestState = 2
TrailerRequestResponsesComplete StreamRequestState = 3
ResponseRecieved StreamRequestState = 4
ResponseReceived StreamRequestState = 4
HeaderResponseResponseComplete StreamRequestState = 5
BodyResponseResponsesComplete StreamRequestState = 6
TrailerResponseResponsesComplete StreamRequestState = 7
Expand Down Expand Up @@ -251,7 +253,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
loggerTrace.Info("model server is streaming response")
}
}
reqCtx.RequestState = ResponseRecieved
reqCtx.RequestState = ResponseReceived

var responseErr error
reqCtx, responseErr = s.HandleResponseHeaders(ctx, reqCtx, v)
Expand Down Expand Up @@ -377,7 +379,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
}
}
if r.RequestState == ResponseRecieved && r.respHeaderResp != nil {
if r.RequestState == ResponseReceived && r.respHeaderResp != nil {
loggerTrace.Info("Sending response header response", "obj", r.respHeaderResp)
if err := srv.Send(r.respHeaderResp); err != nil {
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
Expand Down
94 changes: 72 additions & 22 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"strings"
"time"

"github.com/go-logr/logr"
"sigs.k8s.io/controller-runtime/pkg/log"

v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
Expand Down Expand Up @@ -62,22 +63,20 @@ type SaturationDetector interface {
// NewDirectorWithConfig creates a new Director instance with all dependencies.
func NewDirectorWithConfig(datastore Datastore, scheduler Scheduler, saturationDetector SaturationDetector, config *Config) *Director {
return &Director{
datastore: datastore,
scheduler: scheduler,
saturationDetector: saturationDetector,
preRequestPlugins: config.preRequestPlugins,
postResponsePlugins: config.postResponsePlugins,
defaultPriority: 0, // define default priority explicitly
datastore: datastore,
scheduler: scheduler,
saturationDetector: saturationDetector,
requestControlPlugins: *config,
defaultPriority: 0, // define default priority explicitly
}
}

// Director orchestrates the request handling flow, including scheduling.
type Director struct {
datastore Datastore
scheduler Scheduler
saturationDetector SaturationDetector
preRequestPlugins []PreRequest
postResponsePlugins []PostResponse
datastore Datastore
scheduler Scheduler
saturationDetector SaturationDetector
requestControlPlugins Config
// we just need a pointer to an int variable since priority is a pointer in InferenceObjective
// no need to set this in the constructor, since the value we want is the default int val
// and value types cannot be nil
Expand Down Expand Up @@ -278,19 +277,48 @@ func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []sch
return pm
}

func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
// HandleResponseReceived is called when the response headers are received.
func (d *Director) HandleResponseReceived(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
response := &Response{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
}

// TODO: to extend fallback functionality, handle cases where target pod is unavailable
// https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1224
d.runPostResponsePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
d.runResponseReceivedPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)

return reqCtx, nil
}

// HandleResponseBodyStreaming is called every time a chunk of the response body is received.
func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext, logger logr.Logger) (*handlers.RequestContext, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove logger from the args, we don't use logr as argument to functions, we use contextual logging.
for example

log.FromContext(ctx).V(logutil.TRACE).Info(....)

see for example here:

loggerTrace := log.FromContext(ctx).V(logutil.TRACE)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked him to remove that, in the streaming case the high volume of instantiations of the logger will cause allocation/gc pressure. Luke mentioned this as an optimization during flow control benchmarking, and i think the same applies here b/c there will be multiple streaming calls per request.

Copy link
Contributor Author

@BenjaminBraunDev BenjaminBraunDev Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does log.FromContext(ctx).V(logutil.TRACE).Info(....) actually make another logger? I would think it takes the logger "from the context" given ctx has a logger. Does it actually allocate more memory? I would think that pretty inefficient.

From the FromContext() in go-logr package:

// FromContext returns a Logger from ctx or an error if no Logger is found.

logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk")
response := &Response{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
}

d.runResponseStreamingPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
logger.V(logutil.TRACE).Info("Exiting HandleResponseBodyChunk")
return reqCtx, nil
}

// HandleResponseBodyComplete is called when the response body is fully received.
func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
response := &Response{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
}

d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)

logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
return reqCtx, nil
}

func (d *Director) GetRandomPod() *backend.Pod {
pods := d.datastore.PodList(backendmetrics.AllPodsPredicate)
if len(pods) == 0 {
Expand All @@ -304,22 +332,44 @@ func (d *Director) GetRandomPod() *backend.Pod {
func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest,
schedulingResult *schedulingtypes.SchedulingResult, targetPort int) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.preRequestPlugins {
loggerDebug.Info("Running pre-request plugin", "plugin", plugin.TypedName())
for _, plugin := range d.requestControlPlugins.preRequestPlugins {
loggerDebug.Info("Running PreRequest plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.PreRequest(ctx, request, schedulingResult, targetPort)
metrics.RecordPluginProcessingLatency(PreRequestExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerDebug.Info("Completed running pre-request plugin successfully", "plugin", plugin.TypedName())
loggerDebug.Info("Completed running PreRequest plugin successfully", "plugin", plugin.TypedName())
}
}

func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.requestControlPlugins.responseReceivedPlugins {
loggerDebug.Info("Running ResponseReceived plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.ResponseReceived(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(ResponseReceivedExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerDebug.Info("Completed running ResponseReceived plugin successfully", "plugin", plugin.TypedName())
}
}

func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
for _, plugin := range d.requestControlPlugins.responseStreamingPlugins {
loggerTrace.Info("Running ResponseStreaming plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.ResponseStreaming(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(ResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerTrace.Info("Completed running ResponseStreaming plugin successfully", "plugin", plugin.TypedName())
}
}

func (d *Director) runPostResponsePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.postResponsePlugins {
loggerDebug.Info("Running post-response plugin", "plugin", plugin.TypedName())
for _, plugin := range d.requestControlPlugins.responseCompletePlugins {
loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.PostResponse(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(PostResponseExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerDebug.Info("Completed running post-response plugin successfully", "plugin", plugin.TypedName())
plugin.ResponseComplete(ctx, request, response, targetPod)
metrics.RecordPluginProcessingLatency(ResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerDebug.Info("Completed running ResponseComplete plugin successfully", "plugin", plugin.TypedName())
}
}
Loading