Skip to content

Commit b1c3817

Browse files
authored
[feat] add distribute-dp api server least_request route (vllm-project#1866)
Signed-off-by: yangyouchuan <1184540833@qq.com>
1 parent 4dc43dd commit b1c3817

File tree

10 files changed

+295
-6
lines changed

10 files changed

+295
-6
lines changed

pkg/cache/cache_trace.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package cache
1818
import (
1919
"context"
2020
"fmt"
21+
"strconv"
2122
"sync/atomic"
2223
"time"
2324

@@ -49,6 +50,8 @@ func (c *Store) addPodStats(ctx *types.RoutingContext, requestID string) {
4950
return
5051
}
5152
pod := ctx.TargetPod()
53+
port := ctx.TargetPort()
54+
5255
metaPod, ok := c.metaPods.Load(utils.GeneratePodKey(pod.Namespace, pod.Name))
5356
if !ok {
5457
klog.Warningf("can't find routing pod: %s, requestID: %s", pod.Name, requestID)
@@ -57,7 +60,11 @@ func (c *Store) addPodStats(ctx *types.RoutingContext, requestID string) {
5760

5861
// Update running requests
5962
requests := atomic.AddInt32(&metaPod.runningRequests, 1)
60-
if err := c.updatePodRecord(metaPod, "", metrics.RealtimeNumRequestsRunning, metrics.PodMetricScope, &metrics.SimpleMetricValue{Value: float64(requests)}); err != nil {
63+
metricName := metrics.RealtimeNumRequestsRunning
64+
if port > 0 {
65+
metricName = metricName + "/" + strconv.Itoa(port)
66+
}
67+
if err := c.updatePodRecord(metaPod, "", metricName, metrics.PodMetricScope, &metrics.SimpleMetricValue{Value: float64(requests)}); err != nil {
6168
klog.Warningf("can't update realtime metric: %s, pod: %s, requestID: %s, err: %v", metrics.RealtimeNumRequestsRunning, metaPod.Name, requestID, err)
6269
}
6370

@@ -86,6 +93,7 @@ func (c *Store) donePodStats(ctx *types.RoutingContext, requestID string) {
8693
return
8794
}
8895
pod := ctx.TargetPod()
96+
port := ctx.TargetPort()
8997

9098
// Now that pendingLoadProvider must be set.
9199
metaPod, ok := c.metaPods.Load(utils.GeneratePodKey(pod.Namespace, pod.Name))
@@ -96,7 +104,11 @@ func (c *Store) donePodStats(ctx *types.RoutingContext, requestID string) {
96104

97105
// Update running requests
98106
requests := atomic.AddInt32(&metaPod.runningRequests, -1)
99-
if err := c.updatePodRecord(metaPod, ctx.Model, metrics.RealtimeNumRequestsRunning, metrics.PodMetricScope, &metrics.SimpleMetricValue{Value: float64(requests)}); err != nil {
107+
metricName := metrics.RealtimeNumRequestsRunning
108+
if port > 0 {
109+
metricName = metricName + "/" + strconv.Itoa(port)
110+
}
111+
if err := c.updatePodRecord(metaPod, ctx.Model, metricName, metrics.PodMetricScope, &metrics.SimpleMetricValue{Value: float64(requests)}); err != nil {
100112
klog.Warningf("can't update realtime metric: %s, pod: %s, requestID: %s", metrics.RealtimeNumRequestsRunning, pod.Name, requestID)
101113
}
102114

pkg/plugins/gateway/algorithms/least_load_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ func (m *mockPodList) ListByIndex(index string) []*v1.Pod {
9898
return nil
9999
}
100100

101+
func (m *mockPodList) ListPortsForPod() map[string][]int {
102+
return nil
103+
}
104+
101105
func newMockPodList(pods []*v1.Pod, indexes map[string][]*v1.Pod) *mockPodList {
102106
if indexes == nil {
103107
indexes = make(map[string][]*v1.Pod)

pkg/plugins/gateway/algorithms/least_request.go

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ limitations under the License.
1717
package routingalgorithms
1818

1919
import (
20+
"fmt"
2021
"math"
2122
"math/rand"
23+
"strconv"
24+
"strings"
2225

2326
"github.com/vllm-project/aibrix/pkg/cache"
2427
"github.com/vllm-project/aibrix/pkg/metrics"
@@ -45,14 +48,19 @@ func NewLeastRequestRouter() (types.Router, error) {
4548
return nil, err
4649
}
4750

48-
return leastRequestRouter{
51+
return &leastRequestRouter{
4952
cache: c,
5053
}, nil
5154
}
5255

5356
// Route request based of least active request among input ready pods
54-
func (r leastRequestRouter) Route(ctx *types.RoutingContext, readyPodList types.PodList) (string, error) {
57+
func (r *leastRequestRouter) Route(ctx *types.RoutingContext, readyPodList types.PodList) (string, error) {
5558
readyPods := readyPodList.All()
59+
// Use distributed DP-level API server routing when pods have multiple ports
60+
if isMultiPortPods(readyPods) {
61+
return r.apiServerRoute(ctx, readyPods, readyPodList.ListPortsForPod())
62+
}
63+
// Use default Pod-level routing
5664
targetPod := selectTargetPodWithLeastRequestCount(r.cache, readyPods)
5765

5866
// Use fallback if no valid metrics
@@ -68,6 +76,20 @@ func (r leastRequestRouter) Route(ctx *types.RoutingContext, readyPodList types.
6876
return ctx.TargetAddress(), nil
6977
}
7078

79+
func (r *leastRequestRouter) apiServerRoute(ctx *types.RoutingContext, readyPods []*v1.Pod, portsMap map[string][]int) (string, error) {
80+
targetPod, targetPort := selectTargetPodAndPortWithLeastRequestCount(r.cache, readyPods, portsMap)
81+
if targetPod == nil {
82+
return "", fmt.Errorf("no target pod selected")
83+
}
84+
85+
if targetPort == 0 {
86+
return "", fmt.Errorf("target pod does not have a port")
87+
}
88+
ctx.SetTargetPod(targetPod)
89+
ctx.SetTargetPort(targetPort)
90+
return ctx.TargetAddress(), nil
91+
}
92+
7193
func (r *leastRequestRouter) SubscribedMetrics() []string {
7294
return []string{
7395
metrics.RealtimeNumRequestsRunning,
@@ -95,13 +117,67 @@ func selectTargetPodWithLeastRequestCount(cache cache.Cache, readyPods []*v1.Pod
95117
return targetPod
96118
}
97119

120+
func selectTargetPodAndPortWithLeastRequestCount(cache cache.Cache, readyPods []*v1.Pod, portsMap map[string][]int) (*v1.Pod, int) {
121+
readyPodsMap := make(map[string]*v1.Pod, len(readyPods))
122+
for _, pod := range readyPods {
123+
readyPodsMap[pod.Name] = pod
124+
}
125+
126+
minCount := math.MaxInt32
127+
128+
var targetApiServers []string
129+
podRequestCount := getRequestCountsWithPort(cache, readyPods, portsMap)
130+
if len(podRequestCount) == 0 {
131+
return nil, 0
132+
}
133+
134+
klog.V(4).InfoS("selectTargetPodAndPortWithLeastRequestCount", "podRequestCount", podRequestCount)
135+
for servername, totalReq := range podRequestCount {
136+
if totalReq < minCount {
137+
minCount = totalReq
138+
targetApiServers = []string{servername}
139+
} else if totalReq == minCount {
140+
targetApiServers = append(targetApiServers, servername)
141+
}
142+
}
143+
144+
if len(targetApiServers) == 0 {
145+
return nil, 0
146+
}
147+
148+
// Random selection among candidates
149+
selectedServer := targetApiServers[rand.Intn(len(targetApiServers))]
150+
parts := strings.Split(selectedServer, "/")
151+
if len(parts) != 2 {
152+
klog.ErrorS(nil, "Invalid server name format", "serverName", selectedServer)
153+
return nil, 0
154+
}
155+
156+
podName := parts[0]
157+
portStr := parts[1]
158+
159+
targetPod, found := readyPodsMap[podName]
160+
if !found {
161+
klog.ErrorS(nil, "Selected pod not found in ready pods list", "podName", podName)
162+
return nil, 0
163+
}
164+
165+
targetPort, err := strconv.Atoi(portStr)
166+
if err != nil {
167+
klog.ErrorS(err, "Failed to parse port", "port", portStr)
168+
return targetPod, 0
169+
}
170+
171+
return targetPod, targetPort
172+
}
173+
98174
// getRequestCounts returns running request count for each pod tracked by gateway.
99175
// Note: Currently, gateway instance tracks active running request counts for each pod locally,
100176
// if multiple gateway instances are active then state is not shared across them.
101177
// It is advised to run on leader gateway instance.
102178
// TODO: Support stateful information sync across gateway instances: https://github.com/vllm-project/aibrix/issues/761
103179
func getRequestCounts(cache cache.Cache, readyPods []*v1.Pod) map[string]int {
104-
podRequestCount := map[string]int{}
180+
podRequestCount := make(map[string]int, len(readyPods))
105181
for _, pod := range readyPods {
106182
runningReq, err := cache.GetMetricValueByPod(pod.Name, pod.Namespace, metrics.RealtimeNumRequestsRunning)
107183
if err != nil {
@@ -112,3 +188,45 @@ func getRequestCounts(cache cache.Cache, readyPods []*v1.Pod) map[string]int {
112188

113189
return podRequestCount
114190
}
191+
192+
// getRequestCountsWithPort returns running request count for each pod with port tracked by gateway
193+
func getRequestCountsWithPort(cache cache.Cache, readyPods []*v1.Pod, portsMap map[string][]int) map[string]int {
194+
podRequestCount := make(map[string]int)
195+
for _, pod := range readyPods {
196+
podPorts, exists := portsMap[pod.Name]
197+
if !exists || len(podPorts) == 0 {
198+
continue
199+
}
200+
201+
for _, port := range podPorts {
202+
var metricName string
203+
var keyName string
204+
205+
if len(podPorts) == 1 {
206+
metricName = metrics.RealtimeNumRequestsRunning
207+
keyName = pod.Name
208+
} else {
209+
metricName = metrics.RealtimeNumRequestsRunning + "/" + strconv.Itoa(port)
210+
keyName = pod.Name + "/" + strconv.Itoa(port)
211+
}
212+
213+
var count int
214+
if val, err := cache.GetMetricValueByPod(pod.Name, pod.Namespace, metricName); err == nil && val != nil {
215+
count = int(val.GetSimpleValue())
216+
}
217+
podRequestCount[keyName] = count
218+
}
219+
}
220+
221+
return podRequestCount
222+
}
223+
224+
func isMultiPortPods(pods []*v1.Pod) bool {
225+
for _, pod := range pods {
226+
if utils.IsDataParallelPod(pod) {
227+
return true
228+
}
229+
}
230+
231+
return false
232+
}

pkg/plugins/gateway/algorithms/prefix_cache_preble_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ func (m *MockPodList) ListByIndex(index string) []*v1.Pod {
4848
return m.pods
4949
}
5050

51+
func (m *MockPodList) ListPortsForPod() map[string][]int {
52+
return nil
53+
}
54+
5155
func createTestRoutingContext(model, message, requestID string) *types.RoutingContext {
5256
ctx := context.Background()
5357
return types.NewRoutingContext(ctx, RouterPrefixCachePreble, model, message, requestID, "")

pkg/plugins/gateway/algorithms/vtc/vtc_basic_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ func (p *SimplePodList) ListByIndex(index string) []*v1.Pod {
106106
return p.pods
107107
}
108108

109+
func (p *SimplePodList) ListPortsForPod() map[string][]int {
110+
return nil
111+
}
112+
109113
func TestVTCRouterSimple(t *testing.T) {
110114
trackerConfig := &VTCConfig{
111115
InputTokenWeight: 1.0,

pkg/plugins/gateway/gateway.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func (s *Server) selectTargetPod(ctx *types.RoutingContext, pods types.PodList,
180180
if len(readyPods) == 0 {
181181
return "", fmt.Errorf("no ready pods for routing")
182182
}
183-
if len(readyPods) == 1 {
183+
if len(readyPods) == 1 && len(utils.GetPortsForPod(readyPods[0])) <= 1 {
184184
ctx.SetTargetPod(readyPods[0])
185185
return ctx.TargetAddress(), nil
186186
}

pkg/types/pod_list.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,7 @@ type PodList interface {
3030

3131
// ListByIndex returns a slice of pods that match the given index.
3232
ListByIndex(index string) []*v1.Pod
33+
34+
// ListPortsForPod returns a map of portList that bind with pod, key podname
35+
ListPortsForPod() map[string][]int
3336
}

pkg/types/router_context.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ type RoutingContext struct {
7171

7272
targetPodSet chan struct{}
7373
targetPod atomic.Pointer[v1.Pod]
74+
targetPort atomic.Int32
7475
lastError atomic.Pointer[error]
7576
tokens []int // Cache of tokenized prompts
7677
predictor OutputPredictor // OutputPredictor gained from cache
@@ -204,6 +205,14 @@ func (r *RoutingContext) TargetPod() *v1.Pod {
204205
return targetPod
205206
}
206207

208+
func (r *RoutingContext) TargetPort() int {
209+
return int(r.targetPort.Load())
210+
}
211+
212+
func (r *RoutingContext) SetTargetPort(port int) {
213+
r.targetPort.Store(int32(port))
214+
}
215+
207216
// GetError returns the error of the routing context.
208217
func (r *RoutingContext) GetError() error {
209218
if r.TargetPod() == nil {
@@ -218,6 +227,11 @@ func (r *RoutingContext) TargetAddress() string {
218227
if pod == nil {
219228
return ""
220229
}
230+
231+
port := r.TargetPort()
232+
if port != 0 {
233+
return r.targetAddressWithPort(pod.Status.PodIP, port)
234+
}
221235
return r.targetAddress(r.TargetPod())
222236
}
223237

@@ -256,6 +270,10 @@ func (r *RoutingContext) targetAddress(pod *v1.Pod) string {
256270
return fmt.Sprintf("%v:%v", pod.Status.PodIP, utils.GetModelPortForPod(r.RequestID, pod))
257271
}
258272

273+
func (r *RoutingContext) targetAddressWithPort(podIP string, port int) string {
274+
return fmt.Sprintf("%v:%v", podIP, port)
275+
}
276+
259277
func (r *RoutingContext) getError() (err error) {
260278
errAddr := r.lastError.Load()
261279
if errAddr != nil {

pkg/utils/pod_array.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,22 @@ func (arr *PodArray) initDeployments() {
118118
arr.deployments = deployments
119119
arr.podsByDeployment = podsByDeployment
120120
}
121+
122+
func (arr *PodArray) ListPortsForPod() map[string][]int {
123+
pods := arr.All()
124+
if len(pods) == 0 {
125+
return nil
126+
}
127+
128+
podWithPort := make(map[string][]int, len(pods))
129+
for _, pod := range pods {
130+
ports := GetPortsForPod(pod)
131+
if len(ports) > 0 {
132+
podWithPort[pod.Name] = append(podWithPort[pod.Name], ports...)
133+
} else {
134+
podWithPort[pod.Name] = []int{}
135+
}
136+
}
137+
138+
return podWithPort
139+
}

0 commit comments

Comments
 (0)