Skip to content

Commit a3e3409

Browse files
committed
fix test errors
1 parent 3968411 commit a3e3409

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper_test.go

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ import (
2323
"testing"
2424

2525
"github.com/stretchr/testify/assert"
26+
"k8s.io/apimachinery/pkg/types"
2627
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
28+
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2729
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
2830
)
2931

@@ -39,11 +41,24 @@ func TestBulkPredictWithMetrics(t *testing.T) {
3941
{KVCacheUsagePercent: 0.5},
4042
{KVCacheUsagePercent: 0.6},
4143
}
44+
requestBuilder := &DefaultPredictionRequestBuilder{}
45+
pods := []schedulingtypes.Endpoint{
46+
&schedulingtypes.PodMetrics{
47+
EndpointMetadata: &datalayer.EndpointMetadata{
48+
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
49+
},
50+
},
51+
&schedulingtypes.PodMetrics{
52+
EndpointMetadata: &datalayer.EndpointMetadata{
53+
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod2"},
54+
},
55+
},
56+
}
4257
prompts := []string{"prompt1", "prompt2"}
4358
generatedTokenCounts := []int{1, 1}
4459
prefixCacheScores := []float64{0.0, 0.0}
4560

46-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
61+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
4762

4863
assert.NoError(t, err)
4964
assert.Len(t, results, 2)
@@ -61,11 +76,19 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
6176
metricsStates := []*datalayer.Metrics{
6277
{KVCacheUsagePercent: 0.5},
6378
}
79+
requestBuilder := &DefaultPredictionRequestBuilder{}
80+
pods := []schedulingtypes.Endpoint{
81+
&schedulingtypes.PodMetrics{
82+
EndpointMetadata: &datalayer.EndpointMetadata{
83+
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
84+
},
85+
},
86+
}
6487
prompts := []string{"prompt1"}
6588
generatedTokenCounts := []int{1}
6689
prefixCacheScores := []float64{0.0}
6790

68-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
91+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
6992

7093
assert.Error(t, err)
7194
assert.Nil(t, results)
@@ -74,11 +97,19 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
7497
func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
7598
mockPredictor := &mockPredictor{}
7699
metricsStates := []*datalayer.Metrics{{}}
100+
requestBuilder := &DefaultPredictionRequestBuilder{}
101+
pods := []schedulingtypes.Endpoint{
102+
&schedulingtypes.PodMetrics{
103+
EndpointMetadata: &datalayer.EndpointMetadata{
104+
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
105+
},
106+
},
107+
}
77108
prompts := []string{"prompt1", "prompt2"} // Mismatch length
78109
generatedTokenCounts := []int{1}
79110
prefixCacheScores := []float64{0.0}
80111

81-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
112+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
82113

83114
assert.Error(t, err)
84115
assert.Nil(t, results)
@@ -88,11 +119,19 @@ func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
88119
func TestBulkPredictWithMetrics_NilMetricsState(t *testing.T) {
89120
mockPredictor := &mockPredictor{}
90121
metricsStates := []*datalayer.Metrics{nil} // Nil metrics state
122+
requestBuilder := &DefaultPredictionRequestBuilder{}
123+
pods := []schedulingtypes.Endpoint{
124+
&schedulingtypes.PodMetrics{
125+
EndpointMetadata: &datalayer.EndpointMetadata{
126+
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
127+
},
128+
},
129+
}
91130
prompts := []string{"prompt1"}
92131
generatedTokenCounts := []int{1}
93132
prefixCacheScores := []float64{0.0}
94133

95-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
134+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
96135

97136
assert.Error(t, err)
98137
assert.Nil(t, results)

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ func createTestRouter() *SLOAwareRouter {
6262
sloContextStore: sync.Map{},
6363
runningRequestLists: make(map[types.NamespacedName]*requestPriorityQueue),
6464
latencypredictor: nil,
65+
requestBuilder: &DefaultPredictionRequestBuilder{},
6566
config: DefaultConfig,
6667
}
6768
}

0 commit comments

Comments
 (0)