@@ -23,7 +23,9 @@ import (
2323 "testing"
2424
2525 "github.com/stretchr/testify/assert"
26+ "k8s.io/apimachinery/pkg/types"
2627 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
28+ schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2729 latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
2830)
2931
@@ -39,11 +41,24 @@ func TestBulkPredictWithMetrics(t *testing.T) {
3941 {KVCacheUsagePercent : 0.5 },
4042 {KVCacheUsagePercent : 0.6 },
4143 }
44+ requestBuilder := & DefaultPredictionRequestBuilder {}
45+ pods := []schedulingtypes.Endpoint {
46+ & schedulingtypes.PodMetrics {
47+ EndpointMetadata : & datalayer.EndpointMetadata {
48+ NamespacedName : types.NamespacedName {Namespace : "default" , Name : "pod1" },
49+ },
50+ },
51+ & schedulingtypes.PodMetrics {
52+ EndpointMetadata : & datalayer.EndpointMetadata {
53+ NamespacedName : types.NamespacedName {Namespace : "default" , Name : "pod2" },
54+ },
55+ },
56+ }
4257 prompts := []string {"prompt1" , "prompt2" }
4358 generatedTokenCounts := []int {1 , 1 }
4459 prefixCacheScores := []float64 {0.0 , 0.0 }
4560
46- results , err := bulkPredictWithMetrics (context .Background (), mockPredictor , metricsStates , prompts , generatedTokenCounts , prefixCacheScores )
61+ results , err := bulkPredictWithMetrics (context .Background (), mockPredictor , metricsStates , requestBuilder , pods , prompts , generatedTokenCounts , prefixCacheScores )
4762
4863 assert .NoError (t , err )
4964 assert .Len (t , results , 2 )
@@ -61,11 +76,19 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
6176 metricsStates := []* datalayer.Metrics {
6277 {KVCacheUsagePercent : 0.5 },
6378 }
79+ requestBuilder := & DefaultPredictionRequestBuilder {}
80+ pods := []schedulingtypes.Endpoint {
81+ & schedulingtypes.PodMetrics {
82+ EndpointMetadata : & datalayer.EndpointMetadata {
83+ NamespacedName : types.NamespacedName {Namespace : "default" , Name : "pod1" },
84+ },
85+ },
86+ }
6487 prompts := []string {"prompt1" }
6588 generatedTokenCounts := []int {1 }
6689 prefixCacheScores := []float64 {0.0 }
6790
68- results , err := bulkPredictWithMetrics (context .Background (), mockPredictor , metricsStates , prompts , generatedTokenCounts , prefixCacheScores )
91+ results , err := bulkPredictWithMetrics (context .Background (), mockPredictor , metricsStates , requestBuilder , pods , prompts , generatedTokenCounts , prefixCacheScores )
6992
7093 assert .Error (t , err )
7194 assert .Nil (t , results )
@@ -74,11 +97,19 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
7497func TestBulkPredictWithMetrics_InputMismatch (t * testing.T ) {
7598 mockPredictor := & mockPredictor {}
7699 metricsStates := []* datalayer.Metrics {{}}
100+ requestBuilder := & DefaultPredictionRequestBuilder {}
101+ pods := []schedulingtypes.Endpoint {
102+ & schedulingtypes.PodMetrics {
103+ EndpointMetadata : & datalayer.EndpointMetadata {
104+ NamespacedName : types.NamespacedName {Namespace : "default" , Name : "pod1" },
105+ },
106+ },
107+ }
77108 prompts := []string {"prompt1" , "prompt2" } // Mismatch length
78109 generatedTokenCounts := []int {1 }
79110 prefixCacheScores := []float64 {0.0 }
80111
81- results , err := bulkPredictWithMetrics (context .Background (), mockPredictor , metricsStates , prompts , generatedTokenCounts , prefixCacheScores )
112+ results , err := bulkPredictWithMetrics (context .Background (), mockPredictor , metricsStates , requestBuilder , pods , prompts , generatedTokenCounts , prefixCacheScores )
82113
83114 assert .Error (t , err )
84115 assert .Nil (t , results )
@@ -88,11 +119,19 @@ func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
88119func TestBulkPredictWithMetrics_NilMetricsState (t * testing.T ) {
89120 mockPredictor := & mockPredictor {}
90121 metricsStates := []* datalayer.Metrics {nil } // Nil metrics state
122+ requestBuilder := & DefaultPredictionRequestBuilder {}
123+ pods := []schedulingtypes.Endpoint {
124+ & schedulingtypes.PodMetrics {
125+ EndpointMetadata : & datalayer.EndpointMetadata {
126+ NamespacedName : types.NamespacedName {Namespace : "default" , Name : "pod1" },
127+ },
128+ },
129+ }
91130 prompts := []string {"prompt1" }
92131 generatedTokenCounts := []int {1 }
93132 prefixCacheScores := []float64 {0.0 }
94133
95- results , err := bulkPredictWithMetrics (context .Background (), mockPredictor , metricsStates , prompts , generatedTokenCounts , prefixCacheScores )
134+ results , err := bulkPredictWithMetrics (context .Background (), mockPredictor , metricsStates , requestBuilder , pods , prompts , generatedTokenCounts , prefixCacheScores )
96135
97136 assert .Error (t , err )
98137 assert .Nil (t , results )
0 commit comments