Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ For more details see the <a href="https://docs.vllm.ai/en/stable/getting_started

Example:
{"running-requests":10,"waiting-requests":30,"kv-cache-usage":0.4,"loras":[{"running":"lora4,lora2","waiting":"lora3","timestamp":1257894567},{"running":"lora4,lora3","waiting":"","timestamp":1257894569}]}
---
- `data-parallel-size`: number of ranks to run in Data Parallel deployment, from 1 to 8, default is 1. The ports will be assigned as follows: rank 0 will run on the configured `port`, rank 1 on `port`+1, etc.

In addition, as we are using klog, the following parameters are available:
- `add_dir_header`: if true, adds the file directory to the header of the log messages
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
github.com/spf13/pflag v1.0.6
github.com/valyala/fasthttp v1.59.0
github.com/vmihailenco/msgpack/v5 v5.4.1
golang.org/x/sync v0.12.0
gopkg.in/yaml.v3 v3.0.1
k8s.io/klog/v2 v2.130.1
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
28 changes: 25 additions & 3 deletions pkg/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ const (
vLLMDefaultPort = 8000
ModeRandom = "random"
ModeEcho = "echo"
dummy = "dummy"

// Failure type constants
FailureTypeRateLimit = "rate_limit"
FailureTypeInvalidAPIKey = "invalid_api_key"
FailureTypeContextLength = "context_length"
FailureTypeServerError = "server_error"
FailureTypeInvalidRequest = "invalid_request"
FailureTypeModelNotFound = "model_not_found"
dummy = "dummy"
)

type Configuration struct {
Expand Down Expand Up @@ -162,6 +163,9 @@ type Configuration struct {
FailureInjectionRate int `yaml:"failure-injection-rate" json:"failure-injection-rate"`
// FailureTypes is a list of specific failure types to inject (empty means all types)
FailureTypes []string `yaml:"failure-types" json:"failure-types"`

// DPSize is data parallel size - a number of ranks to run, minimum is 1, maximum is 8, default is 1
DPSize int `yaml:"data-parallel-size" json:"data-parallel-size"`
}

type Metrics struct {
Expand Down Expand Up @@ -265,6 +269,7 @@ func newConfig() *Configuration {
TokenBlockSize: 16,
ZMQEndpoint: "tcp://localhost:5557",
EventBatchSize: 16,
DPSize: 1,
}
}

Expand Down Expand Up @@ -440,9 +445,23 @@ func (c *Configuration) validate() error {
return errors.New("fake metrics KV cache usage must be between 0 ans 1")
}
}

if c.DPSize < 1 || c.DPSize > 8 {
return errors.New("data parallel size must be between 1 ans 8")
}
return nil
}

func (c *Configuration) Copy() (*Configuration, error) {
var dst Configuration
data, err := json.Marshal(c)
if err != nil {
return nil, err
}
err = json.Unmarshal(data, &dst)
return &dst, err
}

// ParseCommandParamsAndLoadConfig loads configuration, parses command line parameters, merges the values
// (command line values overwrite the config file ones), and validates the configuration
func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
Expand Down Expand Up @@ -501,12 +520,15 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events")
f.UintVar(&config.ZMQMaxConnectAttempts, "zmq-max-connect-attempts", config.ZMQMaxConnectAttempts, "Maximum number of times to try ZMQ connect")
f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together")
f.IntVar(&config.DPSize, "data-parallel-size", config.DPSize, "Number of ranks to run")

f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures")

failureTypes := getParamValueFromArgs("failure-types")
var dummyFailureTypes multiString
f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)")
failureTypesDescription := fmt.Sprintf("List of specific failure types to inject (%s, %s, %s, %s, %s, %s)",
FailureTypeRateLimit, FailureTypeInvalidAPIKey, FailureTypeContextLength, FailureTypeServerError, FailureTypeInvalidRequest,
FailureTypeModelNotFound)
f.Var(&dummyFailureTypes, "failure-types", failureTypesDescription)
f.Lookup("failure-types").NoOptDefVal = dummy

// These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help
Expand Down
5 changes: 5 additions & 0 deletions pkg/common/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,11 @@ var _ = Describe("Simulator configuration", func() {
args: []string{"cmd", "--kv-cache-transfer-time-std-dev", "-1",
"--config", "../../manifests/config.yaml"},
},
{
name: "invalid data-parallel-size",
args: []string{"cmd", "--data-parallel-size", "15",
"--config", "../../manifests/config.yaml"},
},
}

for _, test := range invalidTests {
Expand Down
17 changes: 6 additions & 11 deletions pkg/llm-d-inference-sim/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import (
// Metrics reported:
// - lora_requests_info
func (s *VllmSimulator) createAndRegisterPrometheus() error {
s.registry = prometheus.NewRegistry()

s.loraInfo = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Subsystem: "",
Expand All @@ -43,7 +45,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error {
[]string{vllmapi.PromLabelMaxLora, vllmapi.PromLabelRunningLoraAdapters, vllmapi.PromLabelWaitingLoraAdapters},
)

if err := prometheus.Register(s.loraInfo); err != nil {
if err := s.registry.Register(s.loraInfo); err != nil {
s.logger.Error(err, "Prometheus lora info gauge register failed")
return err
}
Expand All @@ -57,7 +59,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error {
[]string{vllmapi.PromLabelModelName},
)

if err := prometheus.Register(s.runningRequests); err != nil {
if err := s.registry.Register(s.runningRequests); err != nil {
s.logger.Error(err, "Prometheus number of running requests gauge register failed")
return err
}
Expand All @@ -72,7 +74,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error {
[]string{vllmapi.PromLabelModelName},
)

if err := prometheus.Register(s.waitingRequests); err != nil {
if err := s.registry.Register(s.waitingRequests); err != nil {
s.logger.Error(err, "Prometheus number of requests in queue gauge register failed")
return err
}
Expand All @@ -87,7 +89,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error {
[]string{vllmapi.PromLabelModelName},
)

if err := prometheus.Register(s.kvCacheUsagePercentage); err != nil {
if err := s.registry.Register(s.kvCacheUsagePercentage); err != nil {
s.logger.Error(err, "Prometheus kv cache usage percentage gauge register failed")
return err
}
Expand Down Expand Up @@ -179,13 +181,6 @@ func (s *VllmSimulator) reportWaitingRequests() {
}
}

func (s *VllmSimulator) unregisterPrometheus() {
prometheus.Unregister(s.loraInfo)
prometheus.Unregister(s.runningRequests)
prometheus.Unregister(s.waitingRequests)
prometheus.Unregister(s.kvCacheUsagePercentage)
}

// startMetricsUpdaters starts the various metrics updaters
func (s *VllmSimulator) startMetricsUpdaters(ctx context.Context) {
go s.waitingRequestsUpdater(ctx)
Expand Down
18 changes: 5 additions & 13 deletions pkg/llm-d-inference-sim/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ var _ = Describe("Simulator metrics", Ordered, func() {
args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom,
"--time-to-first-token", "3000", "--max-num-seqs", "2"}

s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true)
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
Expect(err).NotTo(HaveOccurred())
defer s.unregisterPrometheus()

openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
Expand Down Expand Up @@ -121,9 +120,8 @@ var _ = Describe("Simulator metrics", Ordered, func() {
"--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}",
"{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"}

s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true)
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
Expect(err).NotTo(HaveOccurred())
defer s.unregisterPrometheus()

openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
Expand Down Expand Up @@ -175,11 +173,9 @@ var _ = Describe("Simulator metrics", Ordered, func() {
"--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}",
"{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"}

s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true)
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
Expect(err).NotTo(HaveOccurred())

defer s.unregisterPrometheus()

openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client))
Expand Down Expand Up @@ -253,11 +249,9 @@ var _ = Describe("Simulator metrics", Ordered, func() {
"--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}",
"{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"}

s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true)
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
Expect(err).NotTo(HaveOccurred())

defer s.unregisterPrometheus()

openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client))
Expand Down Expand Up @@ -328,11 +322,9 @@ var _ = Describe("Simulator metrics", Ordered, func() {
"{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":0.4,\"loras\":[{\"running\":\"lora4,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567},{\"running\":\"lora4,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}]}",
}

s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true)
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
Expect(err).NotTo(HaveOccurred())

defer s.unregisterPrometheus()

resp, err := client.Get(metricsUrl)
Expect(err).NotTo(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
Expand Down
59 changes: 45 additions & 14 deletions pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package llmdinferencesim
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"os"
Expand All @@ -34,6 +33,8 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttpadaptor"
"golang.org/x/sync/errgroup"
"k8s.io/klog/v2"

"github.com/llm-d/llm-d-inference-sim/pkg/common"
kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
Expand Down Expand Up @@ -94,6 +95,8 @@ type VllmSimulator struct {
nWaitingReqs int64
// waitingReqChan is a channel to update nWaitingReqs
waitingReqChan chan int64
// registry is a Prometheus registry
registry *prometheus.Registry
// loraInfo is prometheus gauge
loraInfo *prometheus.GaugeVec
// runningRequests is prometheus gauge
Expand Down Expand Up @@ -136,27 +139,54 @@ func New(logger logr.Logger) (*VllmSimulator, error) {

// Start starts the simulator
func (s *VllmSimulator) Start(ctx context.Context) error {
var err error
// parse command line parameters
config, err := common.ParseCommandParamsAndLoadConfig()
s.config, err = common.ParseCommandParamsAndLoadConfig()
if err != nil {
return err
}

s.config = config

err = s.showConfig(s.logger)
err = s.showConfig(s.config.DPSize > 1)
if err != nil {
return err
}

for _, lora := range config.LoraModules {
// For Data Parallel, start data-parallel-size - 1 additional simulators
if s.config.DPSize > 1 {
g, ctx := errgroup.WithContext(context.Background())
for i := 2; i <= s.config.DPSize; i++ {
newConfig, err := s.config.Copy()
if err != nil {
return err
}
dpRank := i - 1
newConfig.Port = s.config.Port + dpRank
newSim, err := New(klog.LoggerWithValues(s.logger, "rank", dpRank))
if err != nil {
return err
}
newSim.config = newConfig
g.Go(func() error {
return newSim.startSim(ctx)
})
}
if err := g.Wait(); err != nil {
return err
}
s.logger = klog.LoggerWithValues(s.logger, "rank", 0)
}
return s.startSim(ctx)
}

func (s *VllmSimulator) startSim(ctx context.Context) error {
for _, lora := range s.config.LoraModules {
s.loraAdaptors.Store(lora.Name, "")
}

common.InitRandom(s.config.Seed)

// initialize prometheus metrics
err = s.createAndRegisterPrometheus()
err := s.createAndRegisterPrometheus()
if err != nil {
return err
}
Expand Down Expand Up @@ -208,7 +238,7 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener)
r.POST("/v1/load_lora_adapter", s.HandleLoadLora)
r.POST("/v1/unload_lora_adapter", s.HandleUnloadLora)
// supports /metrics prometheus API
r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.Handler()))
r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.HandlerFor(s.registry, promhttp.HandlerOpts{})))
// supports standard Kubernetes health and readiness checks
r.GET("/health", s.HandleHealth)
r.GET("/ready", s.HandleReady)
Expand Down Expand Up @@ -755,21 +785,22 @@ func (s *VllmSimulator) getDisplayedModelName(reqModel string) string {
return s.config.ServedModelNames[0]
}

func (s *VllmSimulator) showConfig(tgtLgr logr.Logger) error {
if tgtLgr == logr.Discard() {
return errors.New("target logger is nil, cannot show configuration")
}
func (s *VllmSimulator) showConfig(dp bool) error {
cfgJSON, err := json.Marshal(s.config)
if err != nil {
return fmt.Errorf("failed to marshal configuration to JSON: %w", err)
}

// clean LoraModulesString field
var m map[string]interface{}
err = json.Unmarshal(cfgJSON, &m)
if err != nil {
return fmt.Errorf("failed to unmarshal JSON to map: %w", err)
}
if dp {
// remove the port
delete(m, "port")
}
// clean LoraModulesString field
m["lora-modules"] = m["LoraModules"]
delete(m, "LoraModules")
delete(m, "LoraModulesString")
Expand All @@ -784,6 +815,6 @@ func (s *VllmSimulator) showConfig(tgtLgr logr.Logger) error {
if err != nil {
return fmt.Errorf("failed to marshal configuration to JSON: %w", err)
}
tgtLgr.Info("Configuration:", "", string(cfgJSON))
s.logger.Info("Configuration:", "", string(cfgJSON))
return nil
}
Loading