@@ -18,20 +18,25 @@ package filter
1818
1919import (
2020 "context"
21+ "strings"
2122
2223 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
2324 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2425)
2526
2627const (
28+ // headerTestEppEndPointSelectionKey is the header used for testing purposes to make EPP behavior controllable.
29+ // The header value should be a comma-separated list of endpoint IP addresses.
30+ // E.g., "test-epp-endpoint-selection": "10.0.0.7,10.0.0.8"
31+ // The returned order is the same as the order provided in the header.
2732 headerTestEppEndPointSelectionKey = "test-epp-endpoint-selection"
2833)
2934
3035// compile-time type assertion
3136var _ framework.Filter = & HeaderBasedTestingFilter {}
3237
33- // NewHeaderBasedTestingFilter initializes a new HeaderBasedTestingFilter and returns its pointer .
34- // This should be only used in testing purpose .
38+ // NewHeaderBasedTestingFilter initializes a new HeaderBasedTestingFilter.
39+ // This should only be used for testing purposes .
3540func NewHeaderBasedTestingFilter () * HeaderBasedTestingFilter {
3641 return & HeaderBasedTestingFilter {}
3742}
@@ -41,20 +46,26 @@ type HeaderBasedTestingFilter struct{}
4146
4247// Name returns the name of the filter.
4348func (f * HeaderBasedTestingFilter ) Name () string {
44- return "test- header-based"
49+ return "header-based-testing "
4550}
4651
47- // Filter filters out pods that doesn't meet the filter criteria .
52+ // Filter selects pods that match the IP addresses specified in the request header .
4853func (f * HeaderBasedTestingFilter ) Filter (_ context.Context , request * types.LLMRequest , _ * types.CycleState , pods []types.Pod ) []types.Pod {
49- filteredPods := []types.Pod {}
50-
51- endPointInReqeust , found := request .Headers [headerTestEppEndPointSelectionKey ]
52- if ! found {
53- return filteredPods
54+ headerValue , ok := request .Headers [headerTestEppEndPointSelectionKey ]
55+ if ! ok || headerValue == "" {
56+ return []types.Pod {}
5457 }
5558
59+ podAddressMap := make (map [string ]types.Pod , len (pods ))
5660 for _ , pod := range pods {
57- if pod .GetPod ().Address == endPointInReqeust {
61+ podAddressMap [pod .GetPod ().Address ] = pod
62+ }
63+
64+ endpoints := strings .Split (headerValue , "," )
65+ filteredPods := make ([]types.Pod , 0 , len (endpoints ))
66+ for _ , endpoint := range endpoints {
67+ trimmedEndpoint := strings .TrimSpace (endpoint )
68+ if pod , found := podAddressMap [trimmedEndpoint ]; found {
5869 filteredPods = append (filteredPods , pod )
5970 }
6071 }
0 commit comments