Skip to content

Commit bbeb5b6

Browse files
Log typed name in director.go and remove redundant director nil check in response.go
1 parent 0174f0b commit bbeb5b6

File tree

3 files changed

+35
-10
lines changed

3 files changed

+35
-10
lines changed

pkg/epp/handlers/response.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,25 +61,26 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
6161
reqCtx.ResponseComplete = true
6262

6363
reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true)
64-
if s.director != nil {
65-
s.director.HandleResponseBodyComplete(ctx, reqCtx)
66-
}
67-
return reqCtx, nil
64+
65+
return s.director.HandleResponseBodyComplete(ctx, reqCtx)
6866
}
6967

7068
// The function is to handle streaming response if the modelServer is streaming.
7169
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
72-
if s.director != nil {
73-
s.director.HandleResponseBodyStreaming(ctx, reqCtx)
70+
logger := log.FromContext(ctx)
71+
_, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx)
72+
if err != nil {
73+
logger.Error(err, "error in HandleResponseBodyStreaming")
7474
}
7575
if strings.Contains(responseText, streamingEndMsg) {
7676
reqCtx.ResponseComplete = true
7777
resp := parseRespForUsage(ctx, responseText)
7878
reqCtx.Usage = resp.Usage
7979
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
8080
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
81-
if s.director != nil {
82-
s.director.HandleResponseBodyComplete(ctx, reqCtx)
81+
_, err := s.director.HandleResponseBodyComplete(ctx, reqCtx)
82+
if err != nil {
83+
logger.Error(err, "error in HandleResponseBodyComplete")
8384
}
8485
}
8586
}

pkg/epp/handlers/response_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"github.com/google/go-cmp/cmp"
2525

26+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2627
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2728
)
2829

@@ -59,6 +60,27 @@ data: [DONE]
5960
`
6061
)
6162

63+
type mockDirector struct{}
64+
65+
func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
66+
return reqCtx, nil
67+
}
68+
func (m *mockDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
69+
return reqCtx, nil
70+
}
71+
func (m *mockDirector) HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
72+
return reqCtx, nil
73+
}
74+
func (m *mockDirector) HandlePreRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
75+
return reqCtx, nil
76+
}
77+
func (m *mockDirector) GetRandomPod() *backend.Pod {
78+
return &backend.Pod{}
79+
}
80+
func (m *mockDirector) HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
81+
return reqCtx, nil
82+
}
83+
6284
func TestHandleResponseBody(t *testing.T) {
6385
ctx := logutil.NewTestLoggerIntoContext(context.Background())
6486

@@ -83,6 +105,7 @@ func TestHandleResponseBody(t *testing.T) {
83105
for _, test := range tests {
84106
t.Run(test.name, func(t *testing.T) {
85107
server := &StreamingServer{}
108+
server.director = &mockDirector{}
86109
reqCtx := test.reqCtx
87110
if reqCtx == nil {
88111
reqCtx = &RequestContext{}
@@ -143,6 +166,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
143166
for _, test := range tests {
144167
t.Run(test.name, func(t *testing.T) {
145168
server := &StreamingServer{}
169+
server.director = &mockDirector{}
146170
reqCtx := test.reqCtx
147171
if reqCtx == nil {
148172
reqCtx = &RequestContext{}

pkg/epp/requestcontrol/director.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ func (d *Director) runPostResponseReceivedPlugins(ctx context.Context, request *
361361
func (d *Director) runPostResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
362362
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
363363
for _, plugin := range d.postResponseStreamingPlugins {
364-
loggerTrace.Info("Running post-response chunk plugin", "plugin", plugin.TypedName().Type)
364+
loggerTrace.Info("Running post-response chunk plugin", "plugin", plugin.TypedName())
365365
before := time.Now()
366366
plugin.PostResponseStreaming(ctx, request, response, targetPod)
367367
metrics.RecordPluginProcessingLatency(PostResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
@@ -371,7 +371,7 @@ func (d *Director) runPostResponseStreamingPlugins(ctx context.Context, request
371371
func (d *Director) runPostResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
372372
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
373373
for _, plugin := range d.postResponseCompletePlugins {
374-
loggerDebug.Info("Running post-response complete plugin", "plugin", plugin.TypedName().Type)
374+
loggerDebug.Info("Running post-response complete plugin", "plugin", plugin.TypedName())
375375
before := time.Now()
376376
plugin.PostResponseComplete(ctx, request, response, targetPod)
377377
metrics.RecordPluginProcessingLatency(PostResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))

0 commit comments

Comments
 (0)