From bf18c3234317a0c83fc65fb470564404a743e23c Mon Sep 17 00:00:00 2001 From: Ira Date: Thu, 4 Sep 2025 11:13:13 +0300 Subject: [PATCH] Support DP Signed-off-by: Ira --- README.md | 2 + go.mod | 1 + go.sum | 2 + pkg/common/config.go | 28 +++++++++-- pkg/common/config_test.go | 5 ++ pkg/llm-d-inference-sim/metrics.go | 17 +++---- pkg/llm-d-inference-sim/metrics_test.go | 18 ++----- pkg/llm-d-inference-sim/simulator.go | 59 +++++++++++++++++------ pkg/llm-d-inference-sim/simulator_test.go | 19 ++------ 9 files changed, 96 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index adf7c48a..535e77fc 100644 --- a/README.md +++ b/README.md @@ -146,6 +146,8 @@ For more details see the 8 { + return errors.New("data parallel size must be between 1 ans 8") + } return nil } +func (c *Configuration) Copy() (*Configuration, error) { + var dst Configuration + data, err := json.Marshal(c) + if err != nil { + return nil, err + } + err = json.Unmarshal(data, &dst) + return &dst, err +} + // ParseCommandParamsAndLoadConfig loads configuration, parses command line parameters, merges the values // (command line values overwrite the config file ones), and validates the configuration func ParseCommandParamsAndLoadConfig() (*Configuration, error) { @@ -501,12 +520,15 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") f.UintVar(&config.ZMQMaxConnectAttempts, "zmq-max-connect-attempts", config.ZMQMaxConnectAttempts, "Maximum number of times to try ZMQ connect") f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") + f.IntVar(&config.DPSize, "data-parallel-size", config.DPSize, "Number of ranks to run") f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures") - failureTypes := getParamValueFromArgs("failure-types") var dummyFailureTypes multiString - f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)") + failureTypesDescription := fmt.Sprintf("List of specific failure types to inject (%s, %s, %s, %s, %s, %s)", + FailureTypeRateLimit, FailureTypeInvalidAPIKey, FailureTypeContextLength, FailureTypeServerError, FailureTypeInvalidRequest, + FailureTypeModelNotFound) + f.Var(&dummyFailureTypes, "failure-types", failureTypesDescription) f.Lookup("failure-types").NoOptDefVal = dummy // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index 7d5fae13..20aba9a4 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -426,6 +426,11 @@ var _ = Describe("Simulator configuration", func() { args: []string{"cmd", "--kv-cache-transfer-time-std-dev", "-1", "--config", "../../manifests/config.yaml"}, }, + { + name: "invalid data-parallel-size", + args: []string{"cmd", "--data-parallel-size", "15", + "--config", "../../manifests/config.yaml"}, + }, } for _, test := range invalidTests { diff --git a/pkg/llm-d-inference-sim/metrics.go b/pkg/llm-d-inference-sim/metrics.go index fffd5824..850db935 100644 --- a/pkg/llm-d-inference-sim/metrics.go +++ b/pkg/llm-d-inference-sim/metrics.go @@ -34,6 +34,8 @@ import ( // Metrics reported: // - lora_requests_info func (s *VllmSimulator) createAndRegisterPrometheus() error { + s.registry = prometheus.NewRegistry() + s.loraInfo = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Subsystem: "", @@ -43,7 +45,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { []string{vllmapi.PromLabelMaxLora, vllmapi.PromLabelRunningLoraAdapters, vllmapi.PromLabelWaitingLoraAdapters}, ) - if err := prometheus.Register(s.loraInfo); err != nil { + if err := s.registry.Register(s.loraInfo); err != nil { s.logger.Error(err, "Prometheus lora info gauge register failed") return err } @@ -57,7 +59,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { []string{vllmapi.PromLabelModelName}, ) - if err := prometheus.Register(s.runningRequests); err != nil { + if err := s.registry.Register(s.runningRequests); err != nil { s.logger.Error(err, "Prometheus number of running requests gauge register failed") return err } @@ -72,7 +74,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { []string{vllmapi.PromLabelModelName}, ) - if err := prometheus.Register(s.waitingRequests); err != nil { + if err := s.registry.Register(s.waitingRequests); err != nil { s.logger.Error(err, "Prometheus number of requests in queue gauge register failed") return err } @@ -87,7 +89,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { []string{vllmapi.PromLabelModelName}, ) - if err := prometheus.Register(s.kvCacheUsagePercentage); err != nil { + if err := s.registry.Register(s.kvCacheUsagePercentage); err != nil { s.logger.Error(err, "Prometheus kv cache usage percentage gauge register failed") return err } @@ -179,13 +181,6 @@ func (s *VllmSimulator) reportWaitingRequests() { } } -func (s *VllmSimulator) unregisterPrometheus() { - prometheus.Unregister(s.loraInfo) - prometheus.Unregister(s.runningRequests) - prometheus.Unregister(s.waitingRequests) - prometheus.Unregister(s.kvCacheUsagePercentage) -} - // startMetricsUpdaters starts the various metrics updaters func (s *VllmSimulator) startMetricsUpdaters(ctx context.Context) { go s.waitingRequestsUpdater(ctx) diff --git a/pkg/llm-d-inference-sim/metrics_test.go b/pkg/llm-d-inference-sim/metrics_test.go index 0d359e95..f721093e 100644 --- a/pkg/llm-d-inference-sim/metrics_test.go +++ b/pkg/llm-d-inference-sim/metrics_test.go @@ -69,9 +69,8 @@ var _ = Describe("Simulator metrics", Ordered, func() { args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom, "--time-to-first-token", "3000", "--max-num-seqs", "2"} - s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) Expect(err).NotTo(HaveOccurred()) - defer s.unregisterPrometheus() openaiclient := openai.NewClient( option.WithBaseURL(baseURL), @@ -121,9 +120,8 @@ var _ = Describe("Simulator metrics", Ordered, func() { "--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} - s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) Expect(err).NotTo(HaveOccurred()) - defer s.unregisterPrometheus() openaiclient := openai.NewClient( option.WithBaseURL(baseURL), @@ -175,11 +173,9 @@ var _ = Describe("Simulator metrics", Ordered, func() { "--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} - s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) Expect(err).NotTo(HaveOccurred()) - defer s.unregisterPrometheus() - openaiclient := openai.NewClient( option.WithBaseURL(baseURL), option.WithHTTPClient(client)) @@ -253,11 +249,9 @@ var _ = Describe("Simulator metrics", Ordered, func() { "--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} - s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) Expect(err).NotTo(HaveOccurred()) - defer s.unregisterPrometheus() - openaiclient := openai.NewClient( option.WithBaseURL(baseURL), option.WithHTTPClient(client)) @@ -328,11 +322,9 @@ var _ = Describe("Simulator metrics", Ordered, func() { "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":0.4,\"loras\":[{\"running\":\"lora4,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567},{\"running\":\"lora4,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}]}", } - s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) Expect(err).NotTo(HaveOccurred()) - defer s.unregisterPrometheus() - resp, err := client.Get(metricsUrl) Expect(err).NotTo(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 026a55c4..b452200b 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -20,7 +20,6 @@ package llmdinferencesim import ( "context" "encoding/json" - "errors" "fmt" "net" "os" @@ -34,6 +33,8 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpadaptor" + "golang.org/x/sync/errgroup" + "k8s.io/klog/v2" "github.com/llm-d/llm-d-inference-sim/pkg/common" kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" @@ -94,6 +95,8 @@ type VllmSimulator struct { nWaitingReqs int64 // waitingReqChan is a channel to update nWaitingReqs waitingReqChan chan int64 + // registry is a Prometheus registry + registry *prometheus.Registry // loraInfo is prometheus gauge loraInfo *prometheus.GaugeVec // runningRequests is prometheus gauge @@ -136,27 +139,54 @@ func New(logger logr.Logger) (*VllmSimulator, error) { // Start starts the simulator func (s *VllmSimulator) Start(ctx context.Context) error { + var err error // parse command line parameters - config, err := common.ParseCommandParamsAndLoadConfig() + s.config, err = common.ParseCommandParamsAndLoadConfig() if err != nil { return err } - s.config = config - - err = s.showConfig(s.logger) + err = s.showConfig(s.config.DPSize > 1) if err != nil { return err } - for _, lora := range config.LoraModules { + // For Data Parallel, start data-parallel-size - 1 additional simulators + if s.config.DPSize > 1 { + g, ctx := errgroup.WithContext(context.Background()) + for i := 2; i <= s.config.DPSize; i++ { + newConfig, err := s.config.Copy() + if err != nil { + return err + } + dpRank := i - 1 + newConfig.Port = s.config.Port + dpRank + newSim, err := New(klog.LoggerWithValues(s.logger, "rank", dpRank)) + if err != nil { + return err + } + newSim.config = newConfig + g.Go(func() error { + return newSim.startSim(ctx) + }) + } + if err := g.Wait(); err != nil { + return err + } + s.logger = klog.LoggerWithValues(s.logger, "rank", 0) + } + return s.startSim(ctx) +} + +func (s *VllmSimulator) startSim(ctx context.Context) error { + for _, lora := range s.config.LoraModules { s.loraAdaptors.Store(lora.Name, "") } common.InitRandom(s.config.Seed) // initialize prometheus metrics - err = s.createAndRegisterPrometheus() + err := s.createAndRegisterPrometheus() if err != nil { return err } @@ -208,7 +238,7 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) r.POST("/v1/load_lora_adapter", s.HandleLoadLora) r.POST("/v1/unload_lora_adapter", s.HandleUnloadLora) // supports /metrics prometheus API - r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.Handler())) + r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.HandlerFor(s.registry, promhttp.HandlerOpts{}))) // supports standard Kubernetes health and readiness checks r.GET("/health", s.HandleHealth) r.GET("/ready", s.HandleReady) @@ -755,21 +785,22 @@ func (s *VllmSimulator) getDisplayedModelName(reqModel string) string { return s.config.ServedModelNames[0] } -func (s *VllmSimulator) showConfig(tgtLgr logr.Logger) error { - if tgtLgr == logr.Discard() { - return errors.New("target logger is nil, cannot show configuration") - } +func (s *VllmSimulator) showConfig(dp bool) error { cfgJSON, err := json.Marshal(s.config) if err != nil { return fmt.Errorf("failed to marshal configuration to JSON: %w", err) } - // clean LoraModulesString field var m map[string]interface{} err = json.Unmarshal(cfgJSON, &m) if err != nil { return fmt.Errorf("failed to unmarshal JSON to map: %w", err) } + if dp { + // remove the port + delete(m, "port") + } + // clean LoraModulesString field m["lora-modules"] = m["LoraModules"] delete(m, "LoraModules") delete(m, "LoraModulesString") @@ -784,6 +815,6 @@ func (s *VllmSimulator) showConfig(tgtLgr logr.Logger) error { if err != nil { return fmt.Errorf("failed to marshal configuration to JSON: %w", err) } - tgtLgr.Info("Configuration:", "", string(cfgJSON)) + s.logger.Info("Configuration:", "", string(cfgJSON)) return nil } diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 1c9c8805..df43ff57 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -49,12 +49,6 @@ func startServer(ctx context.Context, mode string) (*http.Client, error) { } func startServerWithArgs(ctx context.Context, mode string, args []string, envs map[string]string) (*http.Client, error) { - _, client, err := startServerWithArgsAndMetrics(ctx, mode, args, envs, false) - return client, err -} - -func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []string, envs map[string]string, - setMetrics bool) (*VllmSimulator, *http.Client, error) { oldArgs := os.Args defer func() { os.Args = oldArgs @@ -84,11 +78,11 @@ func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []stri s, err := New(logger) if err != nil { - return nil, nil, err + return nil, err } config, err := common.ParseCommandParamsAndLoadConfig() if err != nil { - return nil, nil, err + return nil, err } s.config = config @@ -98,11 +92,8 @@ func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []stri common.InitRandom(s.config.Seed) - if setMetrics { - err = s.createAndRegisterPrometheus() - if err != nil { - return nil, nil, err - } + if err := s.createAndRegisterPrometheus(); err != nil { + return nil, err } // calculate number of tokens for user message, @@ -125,7 +116,7 @@ func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []stri } }() - return s, &http.Client{ + return &http.Client{ Transport: &http.Transport{ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return listener.Dial()