diff --git a/pkg/llm-d-inference-sim/metrics.go b/pkg/llm-d-inference-sim/metrics.go index 5b065648..fffd5824 100644 --- a/pkg/llm-d-inference-sim/metrics.go +++ b/pkg/llm-d-inference-sim/metrics.go @@ -22,6 +22,7 @@ import ( "context" "strconv" "strings" + "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -135,20 +136,25 @@ func (s *VllmSimulator) reportLoras() { return } - var loras []string - s.runningLoras.Range(func(key interface{}, _ interface{}) bool { + var runningLoras []string + s.runningLoras.Range(func(key any, _ any) bool { if lora, ok := key.(string); ok { - loras = append(loras, lora) + runningLoras = append(runningLoras, lora) + } + return true + }) + var waitingLoras []string + s.waitingLoras.Range(func(key any, _ any) bool { + if lora, ok := key.(string); ok { + waitingLoras = append(waitingLoras, lora) } return true }) - allLoras := strings.Join(loras, ",") s.loraInfo.WithLabelValues( strconv.Itoa(s.config.MaxLoras), - allLoras, - // TODO - add names of loras in queue - "").Set(float64(time.Now().Unix())) + strings.Join(runningLoras, ","), + strings.Join(waitingLoras, ",")).Set(float64(time.Now().Unix())) } // reportRunningRequests sets information about running completion requests @@ -184,6 +190,7 @@ func (s *VllmSimulator) unregisterPrometheus() { func (s *VllmSimulator) startMetricsUpdaters(ctx context.Context) { go s.waitingRequestsUpdater(ctx) go s.runningRequestsUpdater(ctx) + go s.lorasUpdater(ctx) } // waitingRequestsUpdater updates the waiting requests metric by listening on the relevant channel @@ -211,3 +218,48 @@ func (s *VllmSimulator) runningRequestsUpdater(ctx context.Context) { } } } + +// lorasUpdater updates the running loras metric by listening on the relevant channel +// one function updates both waiting and running loras since they a part of the same prometheus gauge +func (s *VllmSimulator) lorasUpdater(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case loraUpdate := <-s.lorasChan: + switch loraUpdate.state { + case waitingUsageState: + s.incrementLoraRefCount(loraUpdate.name, &s.waitingLoras) + case runningUsageState: + s.decrementLoraRefCount(loraUpdate.name, &s.waitingLoras) + s.incrementLoraRefCount(loraUpdate.name, &s.runningLoras) + case doneUsageState: + s.decrementLoraRefCount(loraUpdate.name, &s.runningLoras) + } + s.reportLoras() + } + } +} + +func (s *VllmSimulator) incrementLoraRefCount(lora string, theMap *sync.Map) { + count := 0 + if value, ok := theMap.Load(lora); ok { + // if lora is already in the map - increment its counter + count = value.(int) + } + theMap.Store(lora, count+1) +} + +func (s *VllmSimulator) decrementLoraRefCount(lora string, theMap *sync.Map) { + if value, ok := theMap.Load(lora); ok { + count := value.(int) + if count > 1 { + theMap.Store(lora, count-1) + } else { + // last lora instance stopped its execution - remove from the map + theMap.Delete(lora) + } + } else { + s.logger.Error(nil, "Zero model reference", "model", lora) + } +} diff --git a/pkg/llm-d-inference-sim/metrics_test.go b/pkg/llm-d-inference-sim/metrics_test.go index 0d4e1f3c..0d359e95 100644 --- a/pkg/llm-d-inference-sim/metrics_test.go +++ b/pkg/llm-d-inference-sim/metrics_test.go @@ -18,9 +18,11 @@ package llmdinferencesim import ( "context" + "errors" "io" "net/http" "regexp" + "sort" "strconv" "strings" "sync" @@ -33,6 +35,31 @@ import ( "github.com/openai/openai-go/option" ) +const ( + metricsUrl = "http://localhost/metrics" + + lora1 = "lora1" + lora2 = "lora2" +) + +var emptyArray = []string{} +var lora1Arr = []string{lora1} +var lora2Arr = []string{lora2} + +var paramsLora1 openai.ChatCompletionNewParams = openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora1", +} + +var paramsLora2 openai.ChatCompletionNewParams = openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora2", +} + var _ = Describe("Simulator metrics", Ordered, func() { It("Should send correct running and waiting requests metrics", func() { modelName := "testmodel" @@ -73,7 +100,7 @@ var _ = Describe("Simulator metrics", Ordered, func() { defer GinkgoRecover() time.Sleep(300 * time.Millisecond) - metricsResp, err := client.Get("http://localhost/metrics") + metricsResp, err := client.Get(metricsUrl) Expect(err).NotTo(HaveOccurred()) Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) @@ -102,59 +129,49 @@ var _ = Describe("Simulator metrics", Ordered, func() { option.WithBaseURL(baseURL), option.WithHTTPClient(client)) - params1 := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: "lora1", - } - - _, err = openaiclient.Chat.Completions.New(ctx, params1) + _, err = openaiclient.Chat.Completions.New(ctx, paramsLora1) Expect(err).NotTo(HaveOccurred()) - params2 := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: "lora2", - } - - _, err = openaiclient.Chat.Completions.New(ctx, params2) + _, err = openaiclient.Chat.Completions.New(ctx, paramsLora2) Expect(err).NotTo(HaveOccurred()) - metricsResp, err := client.Get("http://localhost/metrics") + metricsResp, err := client.Get(metricsUrl) Expect(err).NotTo(HaveOccurred()) Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) data, err := io.ReadAll(metricsResp.Body) Expect(err).NotTo(HaveOccurred()) - metrics := string(data) + metrics := strings.Split(string(data), "\n") // We sent two sequentual requests to two different LoRAs, we expect to see (in this order) - // 1. running_lora_adapter = lora1 - // 2. running_lora_adapter = lora2 - // 3. running_lora_adapter = {} - lora1 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora1\",waiting_lora_adapters=\"\"}" - lora2 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora2\",waiting_lora_adapters=\"\"}" - empty := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"\",waiting_lora_adapters=\"\"}" - - Expect(metrics).To(ContainSubstring(lora1)) - Expect(metrics).To(ContainSubstring(lora2)) - Expect(metrics).To(ContainSubstring(empty)) + // 1. running: empty, waiting: lora1 + // 2. running: lora1, waiting: empty + // 3. running: empty, waiting: lora2 + // 4. running: lora2, waiting: empty + // 5. running: empty, waiting: empty + Expect(isLoraMetricPresent(metrics, emptyArray, lora1Arr)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, lora1Arr, emptyArray)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, emptyArray, lora2Arr)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, lora2Arr, emptyArray)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, emptyArray, emptyArray)).To(BeTrue()) // Check the order - lora1Timestamp := extractTimestamp(metrics, lora1) - lora2Timestamp := extractTimestamp(metrics, lora2) - noLorasTimestamp := extractTimestamp(metrics, empty) - - Expect(lora1Timestamp < lora2Timestamp).To(BeTrue()) - Expect(lora2Timestamp < noLorasTimestamp).To(BeTrue()) + timestamp1 := getLoraValidTimestamp(metrics, emptyArray, lora1Arr) + timestamp2 := getLoraValidTimestamp(metrics, lora1Arr, emptyArray) + timestamp3 := getLoraValidTimestamp(metrics, emptyArray, lora2Arr) + timestamp4 := getLoraValidTimestamp(metrics, lora2Arr, emptyArray) + timestamp5 := getLoraValidTimestamp(metrics, emptyArray, emptyArray) + + Expect(timestamp1 <= timestamp2).To(BeTrue()) + Expect(timestamp2 <= timestamp3).To(BeTrue()) + Expect(timestamp3 <= timestamp4).To(BeTrue()) + Expect(timestamp4 <= timestamp5).To(BeTrue()) }) - It("Should send correct lora metrics for parallel requests", func() { + It("Should send correct lora metrics for parallel requests with delay", func() { ctx := context.TODO() args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, - "--time-to-first-token", "2000", + "--time-to-first-token", "3000", "--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} @@ -167,74 +184,140 @@ var _ = Describe("Simulator metrics", Ordered, func() { option.WithBaseURL(baseURL), option.WithHTTPClient(client)) - params1 := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: "lora1", - } - - params2 := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: "lora2", - } - var wg sync.WaitGroup wg.Add(1) + // sends three requests with a delay of 0.5 second between them + // request1 for lora1, request2 for lora2, and request 3 for lora1 + go func() { + time.Sleep(500 * time.Millisecond) + defer GinkgoRecover() + _, err := openaiclient.Chat.Completions.New(ctx, paramsLora2) + Expect(err).NotTo(HaveOccurred()) + }() go func() { time.Sleep(1 * time.Second) defer wg.Done() defer GinkgoRecover() - _, err := openaiclient.Chat.Completions.New(ctx, params2) + _, err := openaiclient.Chat.Completions.New(ctx, paramsLora1) Expect(err).NotTo(HaveOccurred()) }() - _, err = openaiclient.Chat.Completions.New(ctx, params1) + _, err = openaiclient.Chat.Completions.New(ctx, paramsLora1) Expect(err).NotTo(HaveOccurred()) wg.Wait() - metricsResp, err := client.Get("http://localhost/metrics") + metricsResp, err := client.Get(metricsUrl) Expect(err).NotTo(HaveOccurred()) Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) data, err := io.ReadAll(metricsResp.Body) Expect(err).NotTo(HaveOccurred()) - metrics := string(data) - - // We sent two parallel requests: first to lora1 and then to lora2 (with a delay), we expect - // to see (in this order) - // 1. running_lora_adapter = lora1 - // 2. running_lora_adapter = lora2,lora1 (the order of LoRAs doesn't matter here) - // 3. running_lora_adapter = lora2 - // 4. running_lora_adapter = {} - lora1 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora1\",waiting_lora_adapters=\"\"}" - lora12 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora1,lora2\",waiting_lora_adapters=\"\"}" - lora21 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora2,lora1\",waiting_lora_adapters=\"\"}" - lora2 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora2\",waiting_lora_adapters=\"\"}" - empty := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"\",waiting_lora_adapters=\"\"}" - - Expect(metrics).To(ContainSubstring(lora1)) - Expect(metrics).To(Or(ContainSubstring(lora12), ContainSubstring(lora21))) - Expect(metrics).To(ContainSubstring(lora2)) - Expect(metrics).To(ContainSubstring(empty)) + metrics := strings.Split(string(data), "\n") + + // We sent 3 requests, we expect to see (in this order) + // 1. running: empty, waiting: lora1 + // 2. running: lora1, waiting: lora2 + // 3. running: lora1, lora2 (in any order), waiting: lora1 + // 4. running: lora1, lora2 (in any order), waiting: empty + // 5. running: lora1, waiting: empty + // 6. running: empty, waiting: empty + Expect(isLoraMetricPresent(metrics, emptyArray, lora1Arr)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, lora1Arr, lora2Arr)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, []string{lora1, lora2}, lora1Arr)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, []string{lora1, lora2}, emptyArray)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, lora1Arr, emptyArray)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, emptyArray, emptyArray)).To(BeTrue()) // Check the order - lora1Timestamp := extractTimestamp(metrics, lora1) - lora2Timestamp := extractTimestamp(metrics, lora2) - noLorasTimestamp := extractTimestamp(metrics, empty) - var twoLorasTimestamp float64 - if strings.Contains(metrics, lora12) { - twoLorasTimestamp = extractTimestamp(metrics, lora12) + timestamp1 := getLoraValidTimestamp(metrics, emptyArray, lora1Arr) + timestamp2 := getLoraValidTimestamp(metrics, lora1Arr, lora2Arr) + timestamp3 := getLoraValidTimestamp(metrics, []string{lora1, lora2}, lora1Arr) + timestamp4 := getLoraValidTimestamp(metrics, []string{lora1, lora2}, emptyArray) + timestamp5 := getLoraValidTimestamp(metrics, lora1Arr, emptyArray) + timestamp6 := getLoraValidTimestamp(metrics, emptyArray, emptyArray) + + // in case of requests sent with delay the order is well-defined + Expect(timestamp1 <= timestamp2).To(BeTrue()) + Expect(timestamp2 <= timestamp3).To(BeTrue()) + Expect(timestamp3 <= timestamp4).To(BeTrue()) + Expect(timestamp4 <= timestamp5).To(BeTrue()) + Expect(timestamp5 <= timestamp6).To(BeTrue()) + }) + + It("Should send correct lora metrics for parallel requests without delay", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, + "--time-to-first-token", "3000", + "--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} + + s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + Expect(err).NotTo(HaveOccurred()) + + defer s.unregisterPrometheus() + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + var wg sync.WaitGroup + wg.Add(1) + + // send two requests with lora1 and lora2 in parallel + go func() { + defer wg.Done() + defer GinkgoRecover() + _, err := openaiclient.Chat.Completions.New(ctx, paramsLora2) + Expect(err).NotTo(HaveOccurred()) + }() + + _, err = openaiclient.Chat.Completions.New(ctx, paramsLora1) + Expect(err).NotTo(HaveOccurred()) + + wg.Wait() + + metricsResp, err := client.Get(metricsUrl) + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := strings.Split(string(data), "\n") + + // We sent two parallel requests: first to lora1 and then to lora2, + // we expect to see metrics in this order: + // 1. running: empty, waiting: lora1 or lora2 (depends which request received first) + // 2. running: one of the loras, waiting: another lora + // 3. running: both lora2 and lora1 (the order of LoRAs doesn't matter here), waiting: empty + // 4. running: empty, waiting: empty + Expect(isLoraMetricPresent(metrics, emptyArray, lora1Arr) || isLoraMetricPresent(metrics, emptyArray, lora2Arr)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, lora1Arr, lora2Arr) || isLoraMetricPresent(metrics, lora2Arr, lora1Arr)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, []string{lora1, lora2}, emptyArray)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, emptyArray, emptyArray)).To(BeTrue()) + + // Check the order: + // 1. one of the loras in the waiting list + // 2. both loras in the running list + // 3. empty + l1WaitingTimestamp, err := getLoraTimestamp(metrics, emptyArray, lora1Arr) + Expect(err).NotTo(HaveOccurred()) + l2WaitingTimestamp, err := getLoraTimestamp(metrics, emptyArray, lora2Arr) + Expect(err).NotTo(HaveOccurred()) + Expect((l1WaitingTimestamp != nil)).ToNot(Equal((l2WaitingTimestamp != nil))) + var singleWaitingTimestamp float64 + if l1WaitingTimestamp != nil { + singleWaitingTimestamp = *l1WaitingTimestamp } else { - twoLorasTimestamp = extractTimestamp(metrics, lora21) + singleWaitingTimestamp = *l2WaitingTimestamp } - Expect(lora1Timestamp < twoLorasTimestamp).To(BeTrue()) - Expect(twoLorasTimestamp < lora2Timestamp).To(BeTrue()) - Expect(lora2Timestamp < noLorasTimestamp).To(BeTrue()) + + bothRunningTimestamp := getLoraValidTimestamp(metrics, []string{lora1, lora2}, emptyArray) + emptyTimestamp := getLoraValidTimestamp(metrics, emptyArray, emptyArray) + + Expect(singleWaitingTimestamp <= bothRunningTimestamp).To(BeTrue()) + Expect(bothRunningTimestamp <= emptyTimestamp).To(BeTrue()) }) Context("fake metrics", func() { @@ -250,7 +333,7 @@ var _ = Describe("Simulator metrics", Ordered, func() { defer s.unregisterPrometheus() - resp, err := client.Get("http://localhost/metrics") + resp, err := client.Get(metricsUrl) Expect(err).NotTo(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -266,11 +349,80 @@ var _ = Describe("Simulator metrics", Ordered, func() { }) }) -func extractTimestamp(metrics string, key string) float64 { - re := regexp.MustCompile(key + ` (\S+)`) - result := re.FindStringSubmatch(metrics) - Expect(len(result)).To(BeNumerically(">", 1)) - f, err := strconv.ParseFloat(result[1], 64) +// isLoraMetricPresent checks if a matching metric exists +// metrics: the list of metrics +// running: list of loras in running_lora_adapters, the order does not matter +// waiting: list of loras in waiting_lora_adapters, the order does not matter +func isLoraMetricPresent(metrics []string, running, waiting []string) bool { + return findLoraMetric(metrics, running, waiting) != "" +} + +// getLoraTimestamp returns timestamp or nil, error +func getLoraTimestamp(metrics []string, running, waiting []string) (*float64, error) { + mertic := findLoraMetric(metrics, running, waiting) + if mertic == "" { + return nil, nil // not found + } + // Extract timestamp: last part after space + parts := strings.Split(mertic, " ") + if len(parts) < 2 { + return nil, errors.New("invalid metric format") + } + timestampStr := parts[len(parts)-1] + timestamp, err := strconv.ParseFloat(timestampStr, 64) Expect(err).NotTo(HaveOccurred()) - return f + + return ×tamp, nil +} + +func getLoraValidTimestamp(metrics []string, running, waiting []string) float64 { + timestamp, err := getLoraTimestamp(metrics, running, waiting) + Expect(err).NotTo(HaveOccurred()) + Expect(timestamp).ToNot(BeNil()) + return *timestamp +} + +// findLoraMetric finds the relevant metric by comparing with the given loras sets (ignoring order) +// metrics: lines of metrics +// running: list of running loras to find +// waiting: list of waiting loras to find +// Looks for a line with the given running and waiting loras sets, the comparison is order agnostic. +// Return metric should match in both running and waiting sets. +// E.g. for input running=["l1", "l2", "l3"] and waiting=[] will return metric +// with running_lora_adapters=["l3", "l1", "l2"] and waiting_lora_adapters=[] +func findLoraMetric(metrics []string, running, waiting []string) string { + // sort input arrays before compare, create string of all values, separated by comma + sort.Strings(running) + sort.Strings(waiting) + runStr := strings.Join(running, ",") + waitStr := strings.Join(waiting, ",") + + // regex to extract lora metrics and values + re := regexp.MustCompile(`vllm:lora_requests_info\{.*running_lora_adapters="([^"]*)".*waiting_lora_adapters="([^"]*)".*\}\s+([0-9.e\+\-]+)`) + for _, metric := range metrics { + matches := re.FindStringSubmatch(metric) + if len(matches) == 4 { + // this line contains loraInfo metric, check running and waiting loras lists + // split and sort metric's running and waiting loras lists for the comparison + metricRun := splitString(matches[1]) + metricWait := splitString(matches[2]) + sort.Strings(metricRun) + sort.Strings(metricWait) + // if both lists are the same - return the metric + if strings.Join(metricRun, ",") == runStr && strings.Join(metricWait, ",") == waitStr { + return metric + } + } // if the metric is not in the required format - skip it + } + + // required metric was not found + return "" +} + +// splits the given string to array of strings with separator = "," +func splitString(str string) []string { + if str == "" { + return []string{} + } + return strings.Split(str, ",") } diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 9f56f798..179b2ec9 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -55,6 +55,21 @@ const ( maxNumberOfRequests = 1000 ) +type loraUsageState int + +const ( + waitingUsageState loraUsageState = iota + runningUsageState + doneUsageState +) + +type loraUsage struct { + // the lora adapter name + name string + // state of the lora usage - waiting/running/done + state loraUsageState +} + // VllmSimulator simulates vLLM server supporting OpenAI API type VllmSimulator struct { // logger is used for information and errors logging @@ -63,11 +78,14 @@ type VllmSimulator struct { config *common.Configuration // loraAdaptors contains list of LoRA available adaptors loraAdaptors sync.Map - // runningLoras is a collection of running loras, key of lora's name, value is number of requests using this lora + // runningLoras is a collection of running loras, + // the key is lora's name, the value is the number of running requests using this lora runningLoras sync.Map - // waitingLoras will represent collection of loras defined in requests in the queue - Not implemented yet - // nolint:unused + // waitingLoras is a collection of waiting loras, + // the key is lora's name, the value is the number of waiting requests using this lora waitingLoras sync.Map + // lorasChan is a channel to update waitingLoras and runningLoras + lorasChan chan loraUsage // nRunningReqs is the number of inference requests that are currently being processed nRunningReqs int64 // runReqChan is a channel to update nRunningReqs @@ -112,6 +130,7 @@ func New(logger logr.Logger) (*VllmSimulator, error) { pod: os.Getenv(podNameEnv), runReqChan: make(chan int64, maxNumberOfRequests), waitingReqChan: make(chan int64, maxNumberOfRequests), + lorasChan: make(chan loraUsage, maxNumberOfRequests), }, nil } @@ -388,7 +407,13 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple IsChatCompletion: isChatCompletion, Wg: &wg, } + // increment the waiting requests metric s.waitingReqChan <- 1 + if s.isLora(reqCtx.CompletionReq.GetModel()) { + // update loraInfo metrics with the new waiting request + s.lorasChan <- loraUsage{reqCtx.CompletionReq.GetModel(), waitingUsageState} + } + // send the request to the waiting queue (channel) s.reqChan <- reqCtx wg.Wait() } @@ -405,32 +430,20 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { return } - s.waitingReqChan <- -1 - req := reqCtx.CompletionReq model := req.GetModel() displayModel := s.getDisplayedModelName(model) - if s.isLora(model) { - // if current request's model is LoRA, add it to the list of running loras - value, ok := s.runningLoras.Load(model) - intValue := 0 - - if !ok { - s.logger.Info("Create reference counter", "model", model) - intValue = 0 - } else { - intValue = value.(int) - } - s.runningLoras.Store(model, intValue+1) - s.logger.Info("Update LoRA reference counter", "model", model, "old value", intValue, "new value", intValue+1) + // decriment waiting and increment running requests count + s.waitingReqChan <- -1 + s.runReqChan <- 1 - // TODO - check if this request went to the waiting queue - add it to waiting map - s.reportLoras() + if s.isLora(model) { + // update loraInfo metric to reflect that + // the request has changed its status from waiting to running + s.lorasChan <- loraUsage{model, runningUsageState} } - s.runReqChan <- 1 - var responseTokens []string var finishReason string var err error @@ -500,31 +513,13 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { // decrease model usage reference number func (s *VllmSimulator) responseSentCallback(model string) { + // decriment running requests count s.runReqChan <- -1 - // Only LoRA models require reference-count handling. - if !s.isLora(model) { - return + if s.isLora(model) { + // update loraInfo metrics to reflect that the request processing has been finished + s.lorasChan <- loraUsage{model, doneUsageState} } - - value, ok := s.runningLoras.Load(model) - - if !ok { - s.logger.Info("Error: nil reference counter", "model", model) - s.logger.Error(nil, "Zero model reference", "model", model) - } else { - intValue := value.(int) - if intValue > 1 { - s.runningLoras.Store(model, intValue-1) - s.logger.Info("Update LoRA reference counter", "model", model, "prev value", intValue, "new value", intValue-1) - } else { - // last lora instance stopped its execution - remove from the map - s.runningLoras.Delete(model) - s.logger.Info("Remove LoRA from set of running loras", "model", model) - } - } - - s.reportLoras() } // sendCompletionError sends an error response for the current completion request