diff --git a/pkg/common/config.go b/pkg/common/config.go index 5472407..7929b2b 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -220,6 +220,9 @@ type Configuration struct { DatasetURL string `yaml:"dataset-url" json:"dataset-url"` // DatasetInMemory defines whether to load the entire dataset into memory for faster access. DatasetInMemory bool `yaml:"dataset-in-memory" json:"dataset-in-memory"` + + // EnableSleepMode enables sleep mode + EnableSleepMode bool `yaml:"enable-sleep-mode" json:"enable-sleep-mode"` } type Metrics struct { @@ -741,6 +744,8 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.DatasetURL, "dataset-url", config.DatasetURL, "URL to download the sqlite db file for response generation from a dataset") f.BoolVar(&config.DatasetInMemory, "dataset-in-memory", config.DatasetInMemory, "Load the entire dataset into memory for faster access") + f.BoolVar(&config.EnableSleepMode, "enable-sleep-mode", config.EnableSleepMode, "Enable sleep mode") + f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures") failureTypes := getParamValueFromArgs("failure-types") var dummyFailureTypes multiString diff --git a/pkg/common/test_utils.go b/pkg/common/test_utils.go new file mode 100644 index 0000000..10d89ab --- /dev/null +++ b/pkg/common/test_utils.go @@ -0,0 +1,40 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "github.com/onsi/gomega" + zmq "github.com/pebbe/zmq4" +) + +// CreateSub creates a ZMQ sub, subscribes to the provided topic, and returns the +// sub and the endpoint to publish events on +func CreateSub(topic string) (*zmq.Socket, string) { + wildcardEndpoint := "tcp://*:*" + zctx, err := zmq.NewContext() + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + sub, err := zctx.NewSocket(zmq.SUB) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + err = sub.Bind(wildcardEndpoint) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + // get the actual port + endpoint, err := sub.GetLastEndpoint() + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + err = sub.SetSubscribe(topic) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + return sub, endpoint +} diff --git a/pkg/kv-cache/block_cache.go b/pkg/kv-cache/block_cache.go index 833b6c3..d9d3f5d 100644 --- a/pkg/kv-cache/block_cache.go +++ b/pkg/kv-cache/block_cache.go @@ -24,6 +24,7 @@ import ( "github.com/go-logr/logr" "github.com/llm-d/llm-d-inference-sim/pkg/common" + "github.com/llm-d/llm-d-inference-sim/pkg/common/logging" ) const ( @@ -42,6 +43,7 @@ type blockCache struct { eventChan chan EventData // channel for asynchronous event processing usageChan chan float64 // channel for usage reporting logger logr.Logger + disabled bool // indicated whether the cache is disabled } // newBlockCache creates a new blockCache with the specified maximum number of blocks @@ -58,6 +60,9 @@ func newBlockCache(config *common.Configuration, logger logr.Logger, usageChan c } } + eventSender := NewKVEventSender(publisher, CreateKVEventsTopic(config.Port, config.Model), + eChan, config.EventBatchSize, delay, logger) + return &blockCache{ requestToBlocks: make(map[string][]uint64), usedBlocks: make(map[uint64]int), @@ -65,24 +70,56 @@ func newBlockCache(config *common.Configuration, logger logr.Logger, usageChan c maxBlocks: config.KVCacheSize, eventChan: eChan, usageChan: usageChan, - eventSender: NewKVEventSender(publisher, createTopic(config), eChan, config.EventBatchSize, delay, logger), + eventSender: eventSender, logger: logger, }, nil } func (bc *blockCache) start(ctx context.Context) { + bc.logger.V(logging.INFO).Info("Starting KV cache") err := bc.eventSender.Run(ctx) if err != nil { bc.logger.Error(err, "Sender stopped with error") } } +func (bc *blockCache) discard() { + bc.logger.V(logging.INFO).Info("Discarding KV cache") + + bc.mu.Lock() + defer bc.mu.Unlock() + + bc.disabled = true + + bc.requestToBlocks = make(map[string][]uint64) + bc.usedBlocks = make(map[uint64]int) + bc.unusedBlocks = make(map[uint64]time.Time) + + common.WriteToChannel(bc.eventChan, + EventData{action: eventActionAllBlocksCleared}, + bc.logger, "block cache eventChan") +} + +func (bc *blockCache) activate() { + bc.logger.V(logging.INFO).Info("Activating KV cache") + + bc.mu.Lock() + defer bc.mu.Unlock() + + bc.disabled = false +} + // startRequest adds a request with its associated block hashes to the cache // and returns the number of blocks that were already in the cache func (bc *blockCache) startRequest(requestID string, blocks []uint64) (int, error) { bc.mu.Lock() defer bc.mu.Unlock() + if bc.disabled { + bc.logger.V(logging.TRACE).Info("KV cache is disabled, request is not added to the kv cache") + return 0, nil + } + if _, exists := bc.requestToBlocks[requestID]; exists { // request with the same id already exists return 0, fmt.Errorf("request already exists for id %s", requestID) @@ -167,6 +204,11 @@ func (bc *blockCache) finishRequest(requestID string) error { bc.mu.Lock() defer bc.mu.Unlock() + if bc.disabled { + bc.logger.V(logging.TRACE).Info("KV cache is disabled, request completion is not processed by the kv cache") + return nil + } + // Get blocks associated with this request blockHashes, exists := bc.requestToBlocks[requestID] if !exists { @@ -239,6 +281,6 @@ func (bc *blockCache) getBlockInfo(blockHash uint64) (int, bool) { return 0, false } -func createTopic(config *common.Configuration) string { - return fmt.Sprintf("kv@$localhost:%d@%s", config.Port, config.Model) +func CreateKVEventsTopic(port int, model string) string { + return fmt.Sprintf("kv@$localhost:%d@%s", port, model) } diff --git a/pkg/kv-cache/kv_cache.go b/pkg/kv-cache/kv_cache.go index cf69b0f..a87c7b7 100644 --- a/pkg/kv-cache/kv_cache.go +++ b/pkg/kv-cache/kv_cache.go @@ -63,6 +63,14 @@ func (h *KVCacheHelper) Run(ctx context.Context) { h.blockCache.start(ctx) } +func (h *KVCacheHelper) Discard() { + h.blockCache.discard() +} + +func (h *KVCacheHelper) Activate() { + h.blockCache.activate() +} + func (h *KVCacheHelper) OnRequestStart(vllmReq openaiserverapi.CompletionRequest) error { h.logger.V(logging.TRACE).Info("KV cache - process request") diff --git a/pkg/kv-cache/kv_cache_sender.go b/pkg/kv-cache/kv_cache_sender.go index c09b30a..51d0c38 100644 --- a/pkg/kv-cache/kv_cache_sender.go +++ b/pkg/kv-cache/kv_cache_sender.go @@ -32,6 +32,7 @@ type EventAction int const ( eventActionStore EventAction = iota eventActionRemove + eventActionAllBlocksCleared ) type EventData struct { @@ -97,6 +98,8 @@ func (s *KVEventSender) Run(ctx context.Context) error { payload, err = msgpack.Marshal(kvevents.BlockStored{BlockHashes: eventData.hashValues}.ToTaggedUnion()) case eventActionRemove: payload, err = msgpack.Marshal(kvevents.BlockRemoved{BlockHashes: eventData.hashValues}.ToTaggedUnion()) + case eventActionAllBlocksCleared: + payload, err = msgpack.Marshal(kvevents.AllBlocksCleared{}.ToTaggedUnion()) default: return fmt.Errorf("invalid event action %d", eventData.action) } diff --git a/pkg/kv-cache/kv_cache_test.go b/pkg/kv-cache/kv_cache_test.go index 029bbfe..826eada 100644 --- a/pkg/kv-cache/kv_cache_test.go +++ b/pkg/kv-cache/kv_cache_test.go @@ -18,16 +18,11 @@ package kvcache import ( "context" - "encoding/binary" "fmt" "sync" "time" - zmq "github.com/pebbe/zmq4" - "github.com/vmihailenco/msgpack/v5" - "github.com/llm-d/llm-d-inference-sim/pkg/common" - "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -207,7 +202,9 @@ var _ = Describe("KV cache", Ordered, func() { EventBatchSize: 1, } - sub, topic := createSub(config) + topic := CreateKVEventsTopic(config.Port, config.Model) + sub, endpoint := common.CreateSub(topic) + config.ZMQEndpoint = endpoint //nolint defer sub.Close() @@ -289,7 +286,7 @@ var _ = Describe("KV cache", Ordered, func() { for i := range test.expectedRemovedBlocks + test.expectedStoredBlocks { parts, err := sub.RecvMessageBytes(0) Expect(err).NotTo(HaveOccurred()) - stored, removed := parseEvent(parts, topic, uint64(i+1)) + stored, removed, _ := ParseKVEvent(parts, topic, uint64(i+1)) storedCount += len(stored) removedCount += len(removed) } @@ -309,7 +306,9 @@ var _ = Describe("KV cache", Ordered, func() { ZMQMaxConnectAttempts: 3, } - sub, topic := createSub(config) + topic := CreateKVEventsTopic(config.Port, config.Model) + sub, endpoint := common.CreateSub(topic) + config.ZMQEndpoint = endpoint //nolint defer sub.Close() @@ -378,7 +377,7 @@ var _ = Describe("KV cache", Ordered, func() { for { parts, err := sub.RecvMessageBytes(0) Expect(err).NotTo(HaveOccurred()) - stored, removed := parseEvent(parts, topic, count) + stored, removed, _ := ParseKVEvent(parts, topic, count) storedBlocks = append(storedBlocks, stored...) removedBlocks = append(removedBlocks, removed...) count++ @@ -484,67 +483,3 @@ func createRandomArray(minArrLen, maxArrLen int, maxValue uint64, random *common return arr } - -func parseEvent(parts [][]byte, expectedTopic string, expectedSeq uint64) ([]uint64, []uint64) { - // The message should be [topic, seq, payload] - Expect(parts).To(HaveLen(3)) - - Expect(string(parts[0])).To(Equal(expectedTopic)) - - seq := binary.BigEndian.Uint64(parts[1]) - Expect(seq).To(Equal(expectedSeq)) - - removed := make([]uint64, 0) - stored := make([]uint64, 0) - - var eventBatch kvevents.EventBatch - err := msgpack.Unmarshal(parts[2], &eventBatch) - Expect(err).NotTo(HaveOccurred()) - for _, rawEvent := range eventBatch.Events { - var taggedUnion []msgpack.RawMessage - err := msgpack.Unmarshal(rawEvent, &taggedUnion) - Expect(err).NotTo(HaveOccurred()) - Expect(len(taggedUnion)).To(BeNumerically(">", 1)) - - payloadBytes, err := msgpack.Marshal(taggedUnion[1:]) - Expect(err).NotTo(HaveOccurred()) - - var tag string - err = msgpack.Unmarshal(taggedUnion[0], &tag) - Expect(err).NotTo(HaveOccurred()) - - switch tag { - case kvevents.BlockStoredEventTag: - var bs kvevents.BlockStored - err = msgpack.Unmarshal(payloadBytes, &bs) - stored = append(stored, bs.BlockHashes...) - case kvevents.BlockRemovedEventTag: - var br kvevents.BlockRemoved - err = msgpack.Unmarshal(payloadBytes, &br) - removed = append(removed, br.BlockHashes...) - - default: - Fail("unexpected tag " + tag) - continue - } - Expect(err).NotTo(HaveOccurred()) - } - return stored, removed -} - -func createSub(config *common.Configuration) (*zmq.Socket, string) { - zctx, err := zmq.NewContext() - Expect(err).NotTo(HaveOccurred()) - sub, err := zctx.NewSocket(zmq.SUB) - Expect(err).NotTo(HaveOccurred()) - err = sub.Bind(wildcardEndpoint) - Expect(err).NotTo(HaveOccurred()) - // get the actual port - endpoint, err := sub.GetLastEndpoint() - Expect(err).NotTo(HaveOccurred()) - config.ZMQEndpoint = endpoint - topic := createTopic(config) - err = sub.SetSubscribe(topic) - Expect(err).NotTo(HaveOccurred()) - return sub, topic -} diff --git a/pkg/kv-cache/kv_test_helper.go b/pkg/kv-cache/kv_test_helper.go new file mode 100644 index 0000000..c124d0f --- /dev/null +++ b/pkg/kv-cache/kv_test_helper.go @@ -0,0 +1,78 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kvcache + +import ( + "encoding/binary" + + "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents" + "github.com/onsi/ginkgo/v2" + gomega "github.com/onsi/gomega" + "github.com/vmihailenco/msgpack/v5" +) + +func ParseKVEvent(parts [][]byte, expectedTopic string, expectedSeq uint64) ([]uint64, []uint64, bool) { + // The message should be [topic, seq, payload] + gomega.Expect(parts).To(gomega.HaveLen(3)) + + gomega.Expect(string(parts[0])).To(gomega.Equal(expectedTopic)) + + seq := binary.BigEndian.Uint64(parts[1]) + gomega.Expect(seq).To(gomega.Equal(expectedSeq)) + + removed := make([]uint64, 0) + stored := make([]uint64, 0) + allCleared := false + + var eventBatch kvevents.EventBatch + err := msgpack.Unmarshal(parts[2], &eventBatch) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + for _, rawEvent := range eventBatch.Events { + var taggedUnion []msgpack.RawMessage + err := msgpack.Unmarshal(rawEvent, &taggedUnion) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(taggedUnion).ToNot(gomega.BeEmpty()) + + payloadBytes, err := msgpack.Marshal(taggedUnion[1:]) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + + var tag string + err = msgpack.Unmarshal(taggedUnion[0], &tag) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + + switch tag { + case kvevents.BlockStoredEventTag: + var bs kvevents.BlockStored + err = msgpack.Unmarshal(payloadBytes, &bs) + stored = append(stored, bs.BlockHashes...) + case kvevents.BlockRemovedEventTag: + var br kvevents.BlockRemoved + err = msgpack.Unmarshal(payloadBytes, &br) + removed = append(removed, br.BlockHashes...) + case kvevents.AllBlocksClearedEventTag: + var ac kvevents.AllBlocksCleared + err = msgpack.Unmarshal(payloadBytes, &ac) + allCleared = true + + default: + ginkgo.Fail("unexpected tag " + tag) + continue + } + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + } + return stored, removed, allCleared +} diff --git a/pkg/llm-d-inference-sim/server.go b/pkg/llm-d-inference-sim/server.go index 44ce37b..0faf6a3 100644 --- a/pkg/llm-d-inference-sim/server.go +++ b/pkg/llm-d-inference-sim/server.go @@ -61,6 +61,9 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) r.GET("/health", s.HandleHealth) r.GET("/ready", s.HandleReady) r.POST("/tokenize", s.HandleTokenize) + r.POST("/sleep", s.HandleSleep) + r.POST("/wake_up", s.HandleWakeUp) + r.GET("/is_sleeping", s.HandleIsSleeping) server := &fasthttp.Server{ ErrorHandler: s.HandleError, @@ -326,3 +329,65 @@ func (s *VllmSimulator) HandleReady(ctx *fasthttp.RequestCtx) { ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) ctx.Response.SetBody([]byte("{}")) } + +// HandleIsSleeping handles /is_sleeping request according +func (s *VllmSimulator) HandleIsSleeping(ctx *fasthttp.RequestCtx) { + s.logger.V(logging.TRACE).Info("/is_sleeping request received") + + s.sleepMutex.RLock() + defer s.sleepMutex.RUnlock() + data, err := json.Marshal(map[string]bool{"is_sleeping": s.isSleeping}) + if err != nil { + s.logger.Error(err, "failed to marshal isSleeping response") + ctx.Error("Failed to marshal isSleeping response, "+err.Error(), fasthttp.StatusInternalServerError) + return + } + + ctx.Response.Header.SetContentType("application/json") + ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) + ctx.Response.SetBody(data) +} + +// HandleSleep http handler for /sleep +func (s *VllmSimulator) HandleSleep(ctx *fasthttp.RequestCtx) { + s.logger.V(logging.INFO).Info("Sleep request received") + + if s.config.EnableSleepMode && s.isInDevMode { + s.sleepMutex.Lock() + defer s.sleepMutex.Unlock() + + s.isSleeping = true + if s.config.EnableKVCache { + s.kvcacheHelper.Discard() + } + } + + ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) +} + +// HandleWakeUp http handler for /wake_up +func (s *VllmSimulator) HandleWakeUp(ctx *fasthttp.RequestCtx) { + s.logger.V(logging.INFO).Info("Wake up request received") + + var wakeUpKVCache bool + tags := ctx.QueryArgs().Peek("tags") + if tags != nil { + if string(tags) == "kv_cache" { + wakeUpKVCache = true + } + } else { + wakeUpKVCache = true + } + + s.sleepMutex.Lock() + defer s.sleepMutex.Unlock() + + // Activate the kv cache if either the tags are "kv_cache" or there are no tags + if s.config.EnableKVCache && wakeUpKVCache { + s.kvcacheHelper.Activate() + } + + s.isSleeping = false + + ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) +} diff --git a/pkg/llm-d-inference-sim/server_test.go b/pkg/llm-d-inference-sim/server_test.go index 0f64868..631b1ce 100644 --- a/pkg/llm-d-inference-sim/server_test.go +++ b/pkg/llm-d-inference-sim/server_test.go @@ -23,13 +23,17 @@ import ( "net/http" "os" "strings" + "time" "github.com/llm-d/llm-d-inference-sim/pkg/common" + kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) +const tmpDir = "./tests-tmp/" + var _ = Describe("Server", func() { It("Should respond to /health", func() { @@ -53,7 +57,6 @@ var _ = Describe("Server", func() { }) Context("tokenize", Ordered, func() { - tmpDir := "./tests-tmp/" AfterAll(func() { err := os.RemoveAll(tmpDir) Expect(err).NotTo(HaveOccurred()) @@ -208,4 +211,149 @@ var _ = Describe("Server", func() { }) }) + + Context("sleep mode", Ordered, func() { + AfterAll(func() { + err := os.RemoveAll(tmpDir) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Should respond to /is_sleeping", func() { + ctx := context.TODO() + client, err := startServer(ctx, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) + + checkSimSleeping(client, false) + }) + + It("Should not enter sleep mode without the flag", func() { + ctx := context.TODO() + client, err := startServerWithEnv(ctx, common.ModeRandom, map[string]string{"VLLM_SERVER_DEV_MODE": "1"}) + Expect(err).NotTo(HaveOccurred()) + + resp, err := client.Post("http://localhost/sleep", "", nil) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + checkSimSleeping(client, false) + }) + + It("Should not enter sleep mode without the env var", func() { + ctx := context.TODO() + client, err := startServerWithArgs(ctx, + []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, "--enable-sleep-mode"}) + Expect(err).NotTo(HaveOccurred()) + + resp, err := client.Post("http://localhost/sleep", "", nil) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + checkSimSleeping(client, false) + }) + + It("Should enter sleep mode and wake up", func() { + topic := kvcache.CreateKVEventsTopic(8000, qwenModelName) + sub, endpoint := common.CreateSub(topic) + + ctx := context.TODO() + client, err := startServerWithArgsAndEnv(ctx, common.ModeRandom, + []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, "--enable-sleep-mode", + "--enable-kvcache", "--v", "5", "--port", "8000", "--zmq-endpoint", endpoint, + "--tokenizers-cache-dir", tmpDir}, + map[string]string{"VLLM_SERVER_DEV_MODE": "1"}) + Expect(err).NotTo(HaveOccurred()) + + //nolint + defer sub.Close() + + // Send a request, check that a kv event BlockStored was sent + go func() { + time.Sleep(200 * time.Millisecond) + sendTextCompletionRequest(ctx, client) + }() + parts, err := sub.RecvMessageBytes(0) + Expect(err).NotTo(HaveOccurred()) + stored, _, _ := kvcache.ParseKVEvent(parts, topic, uint64(1)) + Expect(stored).To(HaveLen(1)) + + // Sleep and check that AllBlocksCleared event was sent + go func() { + time.Sleep(200 * time.Millisecond) + resp, err := client.Post("http://localhost/sleep", "", nil) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }() + parts, err = sub.RecvMessageBytes(0) + Expect(err).NotTo(HaveOccurred()) + _, _, allCleared := kvcache.ParseKVEvent(parts, topic, uint64(2)) + Expect(allCleared).To(BeTrue()) + + checkSimSleeping(client, true) + + // Send a request + go sendTextCompletionRequest(ctx, client) + + resp, err := client.Post("http://localhost/wake_up", "", nil) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + checkSimSleeping(client, false) + + // Send a request, check that a kv event BlockStored was sent, + // this checks that in sleep mode the kv cache was disabled. + // The sequence number of the event is an addition check. + go func() { + time.Sleep(200 * time.Millisecond) + sendTextCompletionRequest(ctx, client) + }() + parts, err = sub.RecvMessageBytes(0) + Expect(err).NotTo(HaveOccurred()) + stored, _, _ = kvcache.ParseKVEvent(parts, topic, uint64(3)) + Expect(stored).To(HaveLen(1)) + + // Sleep again and wait for AllBlocksCleared + go func() { + time.Sleep(200 * time.Millisecond) + resp, err := client.Post("http://localhost/sleep", "", nil) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }() + + parts, err = sub.RecvMessageBytes(0) + Expect(err).NotTo(HaveOccurred()) + _, _, allCleared = kvcache.ParseKVEvent(parts, topic, uint64(4)) + Expect(allCleared).To(BeTrue()) + + checkSimSleeping(client, true) + + // Wake up the weghts only, kv cache shouldn't wake up yet + resp, err = client.Post("http://localhost/wake_up?tags=weights", "", nil) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + checkSimSleeping(client, false) + + // Send a request + go sendTextCompletionRequest(ctx, client) + + // Now wake up the cache + resp, err = client.Post("http://localhost/wake_up?tags=kv_cache", "", nil) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + checkSimSleeping(client, false) + + // Send a request, check that a kv event BlockStored was sent, + // this checks that the kv cache was disabled after waking up with weights. + // The sequence number of the event is an addition check. + go func() { + time.Sleep(200 * time.Millisecond) + sendTextCompletionRequest(ctx, client) + }() + parts, err = sub.RecvMessageBytes(0) + Expect(err).NotTo(HaveOccurred()) + stored, _, _ = kvcache.ParseKVEvent(parts, topic, uint64(5)) + Expect(stored).To(HaveLen(1)) + }) + }) }) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 289a6f5..cb1ab54 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -190,6 +190,14 @@ type VllmSimulator struct { // rand with a configurable seed to generate reproducible random responses random *common.Random + // indication whether the simulator is sleeping + isSleeping bool + // indication whether the simulator is in development mode, set by environment + // variable VLLM_SERVER_DEV_MODE + isInDevMode bool + // a mutex for sleep-wake up + sleepMutex sync.RWMutex + // a channel for free workers freeWorkers chan *worker // a channel to indicate that a worker finished working on a request @@ -217,6 +225,7 @@ func New(logger logr.Logger) (*VllmSimulator, error) { kvcacheHelper: nil, // kvcache helper will be created only if required after reading configuration namespace: os.Getenv(podNsEnv), pod: os.Getenv(podNameEnv), + isInDevMode: os.Getenv("VLLM_SERVER_DEV_MODE") == "1", loras: &lorasUsageInfo{ loadedLoras: make(map[string]int), }, diff --git a/pkg/llm-d-inference-sim/test_utils.go b/pkg/llm-d-inference-sim/test_utils.go index 57ce422..36b2cee 100644 --- a/pkg/llm-d-inference-sim/test_utils.go +++ b/pkg/llm-d-inference-sim/test_utils.go @@ -58,6 +58,7 @@ func startServer(ctx context.Context, mode string) (*http.Client, error) { } // Starts server in the given mode and environment variables +// nolint func startServerWithEnv(ctx context.Context, mode string, envs map[string]string) (*http.Client, error) { return startServerWithArgsAndEnv(ctx, mode, nil, envs) } @@ -213,6 +214,15 @@ func sendSimpleChatRequest(envs map[string]string, streaming bool) *http.Respons return httpResp } +// sendTextCompletionRequest sends one text completions request +func sendTextCompletionRequest(ctx context.Context, client *http.Client) { + message := "aa bb cc dd ee ff gg hh ii jj aa bb cc dd ee ff gg hh ii jj" + openaiclient, params := getOpenAIClientAndTextParams(client, qwenModelName, message, false) + resp, err := openaiclient.Completions.New(ctx, params) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp).NotTo(gomega.BeNil()) +} + // getOpenAIClientAndChatParams - creates an openai client and params for /chat/completions call based on the given parameters func getOpenAIClientAndChatParams(client option.HTTPClient, model string, message string, streaming bool) (openai.Client, openai.ChatCompletionNewParams) { @@ -513,3 +523,18 @@ func checkLatencyMetrics(client *http.Client, modelName string, numOfInputTokens checkBucketBoundary(metrics, modelName, decodeTimeMetricName, math.Inf(1), lastBoundary, expectedDecodeTimeInSecs) checkBucketBoundary(metrics, modelName, e2eReqLatencyMetricName, math.Inf(1), lastBoundary, expectedE2ELatency) } + +func checkSimSleeping(client *http.Client, expectedToSleep bool) { + resp, err := client.Get("http://localhost/is_sleeping") + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode).To(gomega.Equal(http.StatusOK)) + defer func() { + err := resp.Body.Close() + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + }() + + body, err := io.ReadAll(resp.Body) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + expect := fmt.Sprintf("{\"is_sleeping\":%t}", expectedToSleep) + gomega.Expect(string(body)).To(gomega.Equal(expect)) +}