Skip to content

Commit b2a7d45

Browse files
Fix streamed request being called one final time after request complete, add predictor check to the beginning of each requestcontrol hook
1 parent 37fe013 commit b2a7d45

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,9 @@ func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *hand
269269
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
270270
logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk")
271271
response := &Response{
272-
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
273-
Headers: reqCtx.Response.Headers,
272+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
273+
Headers: reqCtx.Response.Headers,
274+
EndOfStream: reqCtx.ResponseComplete,
274275
}
275276

276277
d.runResponseStreamingPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype
113113
}
114114

115115
targetPod := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName].TargetPods[0].GetPod()
116+
if !t.CheckPredictor(logger, targetPod) {
117+
return
118+
}
116119

117120
podName := types.NamespacedName{
118121
Name: targetPod.NamespacedName.Name,
@@ -153,6 +156,10 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype
153156

154157
func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) {
155158
logger := log.FromContext(ctx)
159+
if !t.CheckPredictor(logger, targetPod) {
160+
return
161+
}
162+
156163
id := request.Headers[requtil.RequestIdHeaderKey]
157164

158165
sloCtx, err := t.getSLOContextForRequest(request)
@@ -161,10 +168,6 @@ func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *scheduli
161168
return
162169
}
163170

164-
if !t.CheckPredictor(logger, targetPod) {
165-
return
166-
}
167-
168171
if err := ProcessHeaderForLatencyPrediction(ctx, t.latencypredictor, sloCtx); err != nil {
169172
logger.V(logutil.DEBUG).Error(err, "ProcessHeader in latencypredictor failed")
170173
}
@@ -173,7 +176,7 @@ func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *scheduli
173176

174177
func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) {
175178
logger := log.FromContext(ctx)
176-
if !t.CheckPredictor(logger, pod) {
179+
if !t.CheckPredictor(logger, pod) || response.EndOfStream {
177180
return
178181
}
179182

@@ -248,11 +251,11 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli
248251

249252
func (t *SLOAwareRouter) CheckPredictor(logger logr.Logger, targetPod *backend.Pod) bool {
250253
if targetPod == nil {
251-
logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PostResponse because no target pod was provided.")
254+
logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because no target pod was provided.")
252255
return false
253256
}
254257
if t.latencypredictor == nil {
255-
logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PostResponse because predictor missing")
258+
logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because predictor missing")
256259
return false
257260
}
258261
return true

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ func TestSLOAwareRouter_PreRequest_EmptySchedulingResult(t *testing.T) {
150150

151151
func TestSLOAwareRouter_PreRequest_Success(t *testing.T) {
152152
router := createTestRouter()
153+
mockPredictor := new(mockPredictor)
154+
router.latencypredictor = mockPredictor
155+
153156
ctx := context.Background()
154157
pod := createTestPod("test-pod", 1, 1, 1)
155158
request := createTestLLMRequest("test", 100, 50, true)
@@ -180,6 +183,9 @@ func TestSLOAwareRouter_PreRequest_Success(t *testing.T) {
180183

181184
func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) {
182185
router := createTestRouter()
186+
mockPredictor := new(mockPredictor)
187+
router.latencypredictor = mockPredictor
188+
183189
ctx := context.Background()
184190
pod := createTestPod("test-pod", 1, 1, 1)
185191
request := createTestLLMRequest("test", 100, 50, true)
@@ -201,6 +207,9 @@ func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) {
201207

202208
func TestSLOAwareRouter_PreRequest_QueueAlreadyExists(t *testing.T) {
203209
router := createTestRouter()
210+
mockPredictor := new(mockPredictor)
211+
router.latencypredictor = mockPredictor
212+
204213
ctx := context.Background()
205214
pod := createTestPod("test-pod", 1, 1, 1)
206215
request1 := createTestLLMRequest("test-id-1", 100, 50, true)
@@ -729,6 +738,9 @@ func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) {
729738

730739
func TestSLOAwareRouter_MultipleRequests_SamePod(t *testing.T) {
731740
router := createTestRouter()
741+
mockPredictor := new(mockPredictor)
742+
router.latencypredictor = mockPredictor
743+
732744
ctx := context.Background()
733745
pod := createTestPod("test-pod", 1, 1, 1)
734746

@@ -807,6 +819,9 @@ func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) {
807819

808820
func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) {
809821
router := createTestRouter()
822+
mockPredictor := new(mockPredictor)
823+
router.latencypredictor = mockPredictor
824+
810825
ctx := context.Background()
811826

812827
pod1 := createTestPod("test-pod-1", 1, 1, 1)

0 commit comments

Comments
 (0)