Skip to content

Commit f37ed5b

Browse files
authored
Support DP (#188)
Signed-off-by: Ira <[email protected]>
1 parent e442062 commit f37ed5b

File tree

9 files changed

+96
-55
lines changed

9 files changed

+96
-55
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ For more details see the <a href="https://docs.vllm.ai/en/stable/getting_started
146146

147147
Example:
148148
{"running-requests":10,"waiting-requests":30,"kv-cache-usage":0.4,"loras":[{"running":"lora4,lora2","waiting":"lora3","timestamp":1257894567},{"running":"lora4,lora3","waiting":"","timestamp":1257894569}]}
149+
---
150+
- `data-parallel-size`: number of ranks to run in Data Parallel deployment, from 1 to 8, default is 1. The ports will be assigned as follows: rank 0 will run on the configured `port`, rank 1 on `port`+1, etc.
149151

150152
In addition, as we are using klog, the following parameters are available:
151153
- `add_dir_header`: if true, adds the file directory to the header of the log messages

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ require (
1818
github.com/spf13/pflag v1.0.6
1919
github.com/valyala/fasthttp v1.59.0
2020
github.com/vmihailenco/msgpack/v5 v5.4.1
21+
golang.org/x/sync v0.12.0
2122
gopkg.in/yaml.v3 v3.0.1
2223
k8s.io/klog/v2 v2.130.1
2324
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT
163163
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
164164
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
165165
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
166+
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
167+
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
166168
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
167169
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
168170
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

pkg/common/config.go

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ const (
3434
vLLMDefaultPort = 8000
3535
ModeRandom = "random"
3636
ModeEcho = "echo"
37+
dummy = "dummy"
38+
3739
// Failure type constants
3840
FailureTypeRateLimit = "rate_limit"
3941
FailureTypeInvalidAPIKey = "invalid_api_key"
4042
FailureTypeContextLength = "context_length"
4143
FailureTypeServerError = "server_error"
4244
FailureTypeInvalidRequest = "invalid_request"
4345
FailureTypeModelNotFound = "model_not_found"
44-
dummy = "dummy"
4546
)
4647

4748
type Configuration struct {
@@ -162,6 +163,9 @@ type Configuration struct {
162163
FailureInjectionRate int `yaml:"failure-injection-rate" json:"failure-injection-rate"`
163164
// FailureTypes is a list of specific failure types to inject (empty means all types)
164165
FailureTypes []string `yaml:"failure-types" json:"failure-types"`
166+
167+
// DPSize is data parallel size - a number of ranks to run, minimum is 1, maximum is 8, default is 1
168+
DPSize int `yaml:"data-parallel-size" json:"data-parallel-size"`
165169
}
166170

167171
type Metrics struct {
@@ -265,6 +269,7 @@ func newConfig() *Configuration {
265269
TokenBlockSize: 16,
266270
ZMQEndpoint: "tcp://localhost:5557",
267271
EventBatchSize: 16,
272+
DPSize: 1,
268273
}
269274
}
270275

@@ -440,9 +445,23 @@ func (c *Configuration) validate() error {
440445
return errors.New("fake metrics KV cache usage must be between 0 ans 1")
441446
}
442447
}
448+
449+
if c.DPSize < 1 || c.DPSize > 8 {
450+
return errors.New("data parallel size must be between 1 ans 8")
451+
}
443452
return nil
444453
}
445454

455+
func (c *Configuration) Copy() (*Configuration, error) {
456+
var dst Configuration
457+
data, err := json.Marshal(c)
458+
if err != nil {
459+
return nil, err
460+
}
461+
err = json.Unmarshal(data, &dst)
462+
return &dst, err
463+
}
464+
446465
// ParseCommandParamsAndLoadConfig loads configuration, parses command line parameters, merges the values
447466
// (command line values overwrite the config file ones), and validates the configuration
448467
func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
@@ -501,12 +520,15 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
501520
f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events")
502521
f.UintVar(&config.ZMQMaxConnectAttempts, "zmq-max-connect-attempts", config.ZMQMaxConnectAttempts, "Maximum number of times to try ZMQ connect")
503522
f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together")
523+
f.IntVar(&config.DPSize, "data-parallel-size", config.DPSize, "Number of ranks to run")
504524

505525
f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures")
506-
507526
failureTypes := getParamValueFromArgs("failure-types")
508527
var dummyFailureTypes multiString
509-
f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)")
528+
failureTypesDescription := fmt.Sprintf("List of specific failure types to inject (%s, %s, %s, %s, %s, %s)",
529+
FailureTypeRateLimit, FailureTypeInvalidAPIKey, FailureTypeContextLength, FailureTypeServerError, FailureTypeInvalidRequest,
530+
FailureTypeModelNotFound)
531+
f.Var(&dummyFailureTypes, "failure-types", failureTypesDescription)
510532
f.Lookup("failure-types").NoOptDefVal = dummy
511533

512534
// These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help

pkg/common/config_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,11 @@ var _ = Describe("Simulator configuration", func() {
426426
args: []string{"cmd", "--kv-cache-transfer-time-std-dev", "-1",
427427
"--config", "../../manifests/config.yaml"},
428428
},
429+
{
430+
name: "invalid data-parallel-size",
431+
args: []string{"cmd", "--data-parallel-size", "15",
432+
"--config", "../../manifests/config.yaml"},
433+
},
429434
}
430435

431436
for _, test := range invalidTests {

pkg/llm-d-inference-sim/metrics.go

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ import (
3434
// Metrics reported:
3535
// - lora_requests_info
3636
func (s *VllmSimulator) createAndRegisterPrometheus() error {
37+
s.registry = prometheus.NewRegistry()
38+
3739
s.loraInfo = prometheus.NewGaugeVec(
3840
prometheus.GaugeOpts{
3941
Subsystem: "",
@@ -43,7 +45,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error {
4345
[]string{vllmapi.PromLabelMaxLora, vllmapi.PromLabelRunningLoraAdapters, vllmapi.PromLabelWaitingLoraAdapters},
4446
)
4547

46-
if err := prometheus.Register(s.loraInfo); err != nil {
48+
if err := s.registry.Register(s.loraInfo); err != nil {
4749
s.logger.Error(err, "Prometheus lora info gauge register failed")
4850
return err
4951
}
@@ -57,7 +59,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error {
5759
[]string{vllmapi.PromLabelModelName},
5860
)
5961

60-
if err := prometheus.Register(s.runningRequests); err != nil {
62+
if err := s.registry.Register(s.runningRequests); err != nil {
6163
s.logger.Error(err, "Prometheus number of running requests gauge register failed")
6264
return err
6365
}
@@ -72,7 +74,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error {
7274
[]string{vllmapi.PromLabelModelName},
7375
)
7476

75-
if err := prometheus.Register(s.waitingRequests); err != nil {
77+
if err := s.registry.Register(s.waitingRequests); err != nil {
7678
s.logger.Error(err, "Prometheus number of requests in queue gauge register failed")
7779
return err
7880
}
@@ -87,7 +89,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error {
8789
[]string{vllmapi.PromLabelModelName},
8890
)
8991

90-
if err := prometheus.Register(s.kvCacheUsagePercentage); err != nil {
92+
if err := s.registry.Register(s.kvCacheUsagePercentage); err != nil {
9193
s.logger.Error(err, "Prometheus kv cache usage percentage gauge register failed")
9294
return err
9395
}
@@ -179,13 +181,6 @@ func (s *VllmSimulator) reportWaitingRequests() {
179181
}
180182
}
181183

182-
func (s *VllmSimulator) unregisterPrometheus() {
183-
prometheus.Unregister(s.loraInfo)
184-
prometheus.Unregister(s.runningRequests)
185-
prometheus.Unregister(s.waitingRequests)
186-
prometheus.Unregister(s.kvCacheUsagePercentage)
187-
}
188-
189184
// startMetricsUpdaters starts the various metrics updaters
190185
func (s *VllmSimulator) startMetricsUpdaters(ctx context.Context) {
191186
go s.waitingRequestsUpdater(ctx)

pkg/llm-d-inference-sim/metrics_test.go

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ var _ = Describe("Simulator metrics", Ordered, func() {
6969
args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom,
7070
"--time-to-first-token", "3000", "--max-num-seqs", "2"}
7171

72-
s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true)
72+
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
7373
Expect(err).NotTo(HaveOccurred())
74-
defer s.unregisterPrometheus()
7574

7675
openaiclient := openai.NewClient(
7776
option.WithBaseURL(baseURL),
@@ -121,9 +120,8 @@ var _ = Describe("Simulator metrics", Ordered, func() {
121120
"--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}",
122121
"{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"}
123122

124-
s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true)
123+
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
125124
Expect(err).NotTo(HaveOccurred())
126-
defer s.unregisterPrometheus()
127125

128126
openaiclient := openai.NewClient(
129127
option.WithBaseURL(baseURL),
@@ -175,11 +173,9 @@ var _ = Describe("Simulator metrics", Ordered, func() {
175173
"--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}",
176174
"{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"}
177175

178-
s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true)
176+
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
179177
Expect(err).NotTo(HaveOccurred())
180178

181-
defer s.unregisterPrometheus()
182-
183179
openaiclient := openai.NewClient(
184180
option.WithBaseURL(baseURL),
185181
option.WithHTTPClient(client))
@@ -253,11 +249,9 @@ var _ = Describe("Simulator metrics", Ordered, func() {
253249
"--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}",
254250
"{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"}
255251

256-
s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true)
252+
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
257253
Expect(err).NotTo(HaveOccurred())
258254

259-
defer s.unregisterPrometheus()
260-
261255
openaiclient := openai.NewClient(
262256
option.WithBaseURL(baseURL),
263257
option.WithHTTPClient(client))
@@ -328,11 +322,9 @@ var _ = Describe("Simulator metrics", Ordered, func() {
328322
"{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":0.4,\"loras\":[{\"running\":\"lora4,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567},{\"running\":\"lora4,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}]}",
329323
}
330324

331-
s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true)
325+
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
332326
Expect(err).NotTo(HaveOccurred())
333327

334-
defer s.unregisterPrometheus()
335-
336328
resp, err := client.Get(metricsUrl)
337329
Expect(err).NotTo(HaveOccurred())
338330
Expect(resp.StatusCode).To(Equal(http.StatusOK))

pkg/llm-d-inference-sim/simulator.go

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package llmdinferencesim
2020
import (
2121
"context"
2222
"encoding/json"
23-
"errors"
2423
"fmt"
2524
"net"
2625
"os"
@@ -34,6 +33,8 @@ import (
3433
"github.com/prometheus/client_golang/prometheus/promhttp"
3534
"github.com/valyala/fasthttp"
3635
"github.com/valyala/fasthttp/fasthttpadaptor"
36+
"golang.org/x/sync/errgroup"
37+
"k8s.io/klog/v2"
3738

3839
"github.com/llm-d/llm-d-inference-sim/pkg/common"
3940
kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
@@ -94,6 +95,8 @@ type VllmSimulator struct {
9495
nWaitingReqs int64
9596
// waitingReqChan is a channel to update nWaitingReqs
9697
waitingReqChan chan int64
98+
// registry is a Prometheus registry
99+
registry *prometheus.Registry
97100
// loraInfo is prometheus gauge
98101
loraInfo *prometheus.GaugeVec
99102
// runningRequests is prometheus gauge
@@ -136,27 +139,54 @@ func New(logger logr.Logger) (*VllmSimulator, error) {
136139

137140
// Start starts the simulator
138141
func (s *VllmSimulator) Start(ctx context.Context) error {
142+
var err error
139143
// parse command line parameters
140-
config, err := common.ParseCommandParamsAndLoadConfig()
144+
s.config, err = common.ParseCommandParamsAndLoadConfig()
141145
if err != nil {
142146
return err
143147
}
144148

145-
s.config = config
146-
147-
err = s.showConfig(s.logger)
149+
err = s.showConfig(s.config.DPSize > 1)
148150
if err != nil {
149151
return err
150152
}
151153

152-
for _, lora := range config.LoraModules {
154+
// For Data Parallel, start data-parallel-size - 1 additional simulators
155+
if s.config.DPSize > 1 {
156+
g, ctx := errgroup.WithContext(context.Background())
157+
for i := 2; i <= s.config.DPSize; i++ {
158+
newConfig, err := s.config.Copy()
159+
if err != nil {
160+
return err
161+
}
162+
dpRank := i - 1
163+
newConfig.Port = s.config.Port + dpRank
164+
newSim, err := New(klog.LoggerWithValues(s.logger, "rank", dpRank))
165+
if err != nil {
166+
return err
167+
}
168+
newSim.config = newConfig
169+
g.Go(func() error {
170+
return newSim.startSim(ctx)
171+
})
172+
}
173+
if err := g.Wait(); err != nil {
174+
return err
175+
}
176+
s.logger = klog.LoggerWithValues(s.logger, "rank", 0)
177+
}
178+
return s.startSim(ctx)
179+
}
180+
181+
func (s *VllmSimulator) startSim(ctx context.Context) error {
182+
for _, lora := range s.config.LoraModules {
153183
s.loraAdaptors.Store(lora.Name, "")
154184
}
155185

156186
common.InitRandom(s.config.Seed)
157187

158188
// initialize prometheus metrics
159-
err = s.createAndRegisterPrometheus()
189+
err := s.createAndRegisterPrometheus()
160190
if err != nil {
161191
return err
162192
}
@@ -208,7 +238,7 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener)
208238
r.POST("/v1/load_lora_adapter", s.HandleLoadLora)
209239
r.POST("/v1/unload_lora_adapter", s.HandleUnloadLora)
210240
// supports /metrics prometheus API
211-
r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.Handler()))
241+
r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.HandlerFor(s.registry, promhttp.HandlerOpts{})))
212242
// supports standard Kubernetes health and readiness checks
213243
r.GET("/health", s.HandleHealth)
214244
r.GET("/ready", s.HandleReady)
@@ -759,21 +789,22 @@ func (s *VllmSimulator) getDisplayedModelName(reqModel string) string {
759789
return s.config.ServedModelNames[0]
760790
}
761791

762-
func (s *VllmSimulator) showConfig(tgtLgr logr.Logger) error {
763-
if tgtLgr == logr.Discard() {
764-
return errors.New("target logger is nil, cannot show configuration")
765-
}
792+
func (s *VllmSimulator) showConfig(dp bool) error {
766793
cfgJSON, err := json.Marshal(s.config)
767794
if err != nil {
768795
return fmt.Errorf("failed to marshal configuration to JSON: %w", err)
769796
}
770797

771-
// clean LoraModulesString field
772798
var m map[string]interface{}
773799
err = json.Unmarshal(cfgJSON, &m)
774800
if err != nil {
775801
return fmt.Errorf("failed to unmarshal JSON to map: %w", err)
776802
}
803+
if dp {
804+
// remove the port
805+
delete(m, "port")
806+
}
807+
// clean LoraModulesString field
777808
m["lora-modules"] = m["LoraModules"]
778809
delete(m, "LoraModules")
779810
delete(m, "LoraModulesString")
@@ -788,6 +819,6 @@ func (s *VllmSimulator) showConfig(tgtLgr logr.Logger) error {
788819
if err != nil {
789820
return fmt.Errorf("failed to marshal configuration to JSON: %w", err)
790821
}
791-
tgtLgr.Info("Configuration:", "", string(cfgJSON))
822+
s.logger.Info("Configuration:", "", string(cfgJSON))
792823
return nil
793824
}

0 commit comments

Comments
 (0)