Skip to content

Commit 83828d0

Browse files
authored
Added support for the vLLM --data-parallel-rank command line argument (#279)
* Added the --data-parallel-rank parameter Signed-off-by: Shmuel Kallner <[email protected]> * Use rank in logger is specified Signed-off-by: Shmuel Kallner <[email protected]> * Corrected test and correcting error message text Signed-off-by: Shmuel Kallner <[email protected]> * Improved error message text Signed-off-by: Shmuel Kallner <[email protected]> * Added a test for the new --data-parallel-rank command line argument Signed-off-by: Shmuel Kallner <[email protected]> --------- Signed-off-by: Shmuel Kallner <[email protected]>
1 parent a15923f commit 83828d0

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

pkg/common/config.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ type Configuration struct {
202202
// DPSize is data parallel size - a number of ranks to run, minimum is 1, maximum is 8, default is 1
203203
DPSize int `yaml:"data-parallel-size" json:"data-parallel-size"`
204204

205+
// Rank is the vLLM parameter used to specify the rank of this instance. Here only
206+
// used when running Data Parallel ranks as separate processes
207+
Rank int `yaml:"data-parallel-rank" json:"data-parallel-rank"`
208+
205209
// SSLCertFile is the path to the SSL certificate file for HTTPS
206210
SSLCertFile string `yaml:"ssl-certfile" json:"ssl-certfile"`
207211
// SSLKeyFile is the path to the SSL private key file for HTTPS
@@ -375,6 +379,7 @@ func newConfig() *Configuration {
375379
ZMQEndpoint: "tcp://localhost:5557",
376380
EventBatchSize: 16,
377381
DPSize: 1,
382+
Rank: -1,
378383
}
379384
}
380385

@@ -653,7 +658,11 @@ func (c *Configuration) validate() error {
653658
}
654659

655660
if c.DPSize < 1 || c.DPSize > 8 {
656-
return errors.New("data parallel size must be between 1 ans 8")
661+
return errors.New("data parallel size must be between 1 and 8")
662+
}
663+
664+
if c.Rank > 7 {
665+
return errors.New("data parallel rank must be between 0 and 7")
657666
}
658667

659668
if (c.SSLCertFile == "") != (c.SSLKeyFile == "") {
@@ -751,6 +760,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
751760
f.IntVar(&config.ZMQMaxConnectAttempts, "zmq-max-connect-attempts", config.ZMQMaxConnectAttempts, "Maximum number of times to try ZMQ connect")
752761
f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together")
753762
f.IntVar(&config.DPSize, "data-parallel-size", config.DPSize, "Number of ranks to run")
763+
f.IntVar(&config.Rank, "data-parallel-rank", config.Rank, "The rank when running each rank in a process")
754764

755765
f.StringVar(&config.DatasetPath, "dataset-path", config.DatasetPath, "Local path to the sqlite db file for response generation from a dataset")
756766
f.StringVar(&config.DatasetURL, "dataset-url", config.DatasetURL, "URL to download the sqlite db file for response generation from a dataset")

pkg/common/config_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,13 @@ var _ = Describe("Simulator configuration", func() {
476476
name: "invalid data-parallel-size",
477477
args: []string{"cmd", "--data-parallel-size", "15",
478478
"--config", "../../manifests/config.yaml"},
479-
expectedError: "data parallel size must be between 1 ans 8",
479+
expectedError: "data parallel size must be between 1 and 8",
480+
},
481+
{
482+
name: "invalid data-parallel-rank",
483+
args: []string{"cmd", "--data-parallel-rank", "15",
484+
"--config", "../../manifests/config.yaml"},
485+
expectedError: "data parallel rank must be between 0 and 7",
480486
},
481487
{
482488
name: "invalid max-num-seqs",

pkg/llm-d-inference-sim/simulator.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ func (s *VllmSimulator) Start(ctx context.Context) error {
277277
})
278278
}
279279
s.logger = klog.LoggerWithValues(s.logger, "rank", 0)
280+
} else if s.config.Rank >= 0 {
281+
s.logger = klog.LoggerWithValues(s.logger, "rank", s.config.Rank)
280282
}
281283
g.Go(func() error {
282284
return s.startSim(ctx)

0 commit comments

Comments
 (0)