Skip to content

Commit b3f93d6

Browse files
authored
Sleep mode (#252)
* Sleep mode Signed-off-by: irar2 <[email protected]> * lint Signed-off-by: irar2 <[email protected]> * Disable the new sleep test to check if it causes the problem Signed-off-by: irar2 <[email protected]> * Wait before sending request Signed-off-by: irar2 <[email protected]> * Review comments and test fix Signed-off-by: irar2 <[email protected]> * Log levels and lint Signed-off-by: irar2 <[email protected]> --------- Signed-off-by: irar2 <[email protected]>
1 parent e1e27ea commit b3f93d6

File tree

11 files changed

+435
-77
lines changed

11 files changed

+435
-77
lines changed

pkg/common/config.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ type Configuration struct {
220220
DatasetURL string `yaml:"dataset-url" json:"dataset-url"`
221221
// DatasetInMemory defines whether to load the entire dataset into memory for faster access.
222222
DatasetInMemory bool `yaml:"dataset-in-memory" json:"dataset-in-memory"`
223+
224+
// EnableSleepMode enables sleep mode
225+
EnableSleepMode bool `yaml:"enable-sleep-mode" json:"enable-sleep-mode"`
223226
}
224227

225228
type Metrics struct {
@@ -741,6 +744,8 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
741744
f.StringVar(&config.DatasetURL, "dataset-url", config.DatasetURL, "URL to download the sqlite db file for response generation from a dataset")
742745
f.BoolVar(&config.DatasetInMemory, "dataset-in-memory", config.DatasetInMemory, "Load the entire dataset into memory for faster access")
743746

747+
f.BoolVar(&config.EnableSleepMode, "enable-sleep-mode", config.EnableSleepMode, "Enable sleep mode")
748+
744749
f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures")
745750
failureTypes := getParamValueFromArgs("failure-types")
746751
var dummyFailureTypes multiString

pkg/common/test_utils.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
Copyright 2025 The llm-d-inference-sim Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package common
18+
19+
import (
20+
"github.com/onsi/gomega"
21+
zmq "github.com/pebbe/zmq4"
22+
)
23+
24+
// CreateSub creates a ZMQ sub, subscribes to the provided topic, and returns the
25+
// sub and the endpoint to publish events on
26+
func CreateSub(topic string) (*zmq.Socket, string) {
27+
wildcardEndpoint := "tcp://*:*"
28+
zctx, err := zmq.NewContext()
29+
gomega.Expect(err).NotTo(gomega.HaveOccurred())
30+
sub, err := zctx.NewSocket(zmq.SUB)
31+
gomega.Expect(err).NotTo(gomega.HaveOccurred())
32+
err = sub.Bind(wildcardEndpoint)
33+
gomega.Expect(err).NotTo(gomega.HaveOccurred())
34+
// get the actual port
35+
endpoint, err := sub.GetLastEndpoint()
36+
gomega.Expect(err).NotTo(gomega.HaveOccurred())
37+
err = sub.SetSubscribe(topic)
38+
gomega.Expect(err).NotTo(gomega.HaveOccurred())
39+
return sub, endpoint
40+
}

pkg/kv-cache/block_cache.go

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424

2525
"github.com/go-logr/logr"
2626
"github.com/llm-d/llm-d-inference-sim/pkg/common"
27+
"github.com/llm-d/llm-d-inference-sim/pkg/common/logging"
2728
)
2829

2930
const (
@@ -42,6 +43,7 @@ type blockCache struct {
4243
eventChan chan EventData // channel for asynchronous event processing
4344
usageChan chan float64 // channel for usage reporting
4445
logger logr.Logger
46+
disabled bool // indicated whether the cache is disabled
4547
}
4648

4749
// newBlockCache creates a new blockCache with the specified maximum number of blocks
@@ -58,31 +60,66 @@ func newBlockCache(config *common.Configuration, logger logr.Logger, usageChan c
5860
}
5961
}
6062

63+
eventSender := NewKVEventSender(publisher, CreateKVEventsTopic(config.Port, config.Model),
64+
eChan, config.EventBatchSize, delay, logger)
65+
6166
return &blockCache{
6267
requestToBlocks: make(map[string][]uint64),
6368
usedBlocks: make(map[uint64]int),
6469
unusedBlocks: make(map[uint64]time.Time),
6570
maxBlocks: config.KVCacheSize,
6671
eventChan: eChan,
6772
usageChan: usageChan,
68-
eventSender: NewKVEventSender(publisher, createTopic(config), eChan, config.EventBatchSize, delay, logger),
73+
eventSender: eventSender,
6974
logger: logger,
7075
}, nil
7176
}
7277

7378
func (bc *blockCache) start(ctx context.Context) {
79+
bc.logger.V(logging.INFO).Info("Starting KV cache")
7480
err := bc.eventSender.Run(ctx)
7581
if err != nil {
7682
bc.logger.Error(err, "Sender stopped with error")
7783
}
7884
}
7985

86+
func (bc *blockCache) discard() {
87+
bc.logger.V(logging.INFO).Info("Discarding KV cache")
88+
89+
bc.mu.Lock()
90+
defer bc.mu.Unlock()
91+
92+
bc.disabled = true
93+
94+
bc.requestToBlocks = make(map[string][]uint64)
95+
bc.usedBlocks = make(map[uint64]int)
96+
bc.unusedBlocks = make(map[uint64]time.Time)
97+
98+
common.WriteToChannel(bc.eventChan,
99+
EventData{action: eventActionAllBlocksCleared},
100+
bc.logger, "block cache eventChan")
101+
}
102+
103+
func (bc *blockCache) activate() {
104+
bc.logger.V(logging.INFO).Info("Activating KV cache")
105+
106+
bc.mu.Lock()
107+
defer bc.mu.Unlock()
108+
109+
bc.disabled = false
110+
}
111+
80112
// startRequest adds a request with its associated block hashes to the cache
81113
// and returns the number of blocks that were already in the cache
82114
func (bc *blockCache) startRequest(requestID string, blocks []uint64) (int, error) {
83115
bc.mu.Lock()
84116
defer bc.mu.Unlock()
85117

118+
if bc.disabled {
119+
bc.logger.V(logging.TRACE).Info("KV cache is disabled, request is not added to the kv cache")
120+
return 0, nil
121+
}
122+
86123
if _, exists := bc.requestToBlocks[requestID]; exists {
87124
// request with the same id already exists
88125
return 0, fmt.Errorf("request already exists for id %s", requestID)
@@ -167,6 +204,11 @@ func (bc *blockCache) finishRequest(requestID string) error {
167204
bc.mu.Lock()
168205
defer bc.mu.Unlock()
169206

207+
if bc.disabled {
208+
bc.logger.V(logging.TRACE).Info("KV cache is disabled, request completion is not processed by the kv cache")
209+
return nil
210+
}
211+
170212
// Get blocks associated with this request
171213
blockHashes, exists := bc.requestToBlocks[requestID]
172214
if !exists {
@@ -239,6 +281,6 @@ func (bc *blockCache) getBlockInfo(blockHash uint64) (int, bool) {
239281
return 0, false
240282
}
241283

242-
func createTopic(config *common.Configuration) string {
243-
return fmt.Sprintf("kv@$localhost:%d@%s", config.Port, config.Model)
284+
func CreateKVEventsTopic(port int, model string) string {
285+
return fmt.Sprintf("kv@$localhost:%d@%s", port, model)
244286
}

pkg/kv-cache/kv_cache.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ func (h *KVCacheHelper) Run(ctx context.Context) {
6363
h.blockCache.start(ctx)
6464
}
6565

66+
func (h *KVCacheHelper) Discard() {
67+
h.blockCache.discard()
68+
}
69+
70+
func (h *KVCacheHelper) Activate() {
71+
h.blockCache.activate()
72+
}
73+
6674
func (h *KVCacheHelper) OnRequestStart(vllmReq openaiserverapi.CompletionRequest) error {
6775
h.logger.V(logging.TRACE).Info("KV cache - process request")
6876

pkg/kv-cache/kv_cache_sender.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type EventAction int
3232
const (
3333
eventActionStore EventAction = iota
3434
eventActionRemove
35+
eventActionAllBlocksCleared
3536
)
3637

3738
type EventData struct {
@@ -97,6 +98,8 @@ func (s *KVEventSender) Run(ctx context.Context) error {
9798
payload, err = msgpack.Marshal(kvevents.BlockStored{BlockHashes: eventData.hashValues}.ToTaggedUnion())
9899
case eventActionRemove:
99100
payload, err = msgpack.Marshal(kvevents.BlockRemoved{BlockHashes: eventData.hashValues}.ToTaggedUnion())
101+
case eventActionAllBlocksCleared:
102+
payload, err = msgpack.Marshal(kvevents.AllBlocksCleared{}.ToTaggedUnion())
100103
default:
101104
return fmt.Errorf("invalid event action %d", eventData.action)
102105
}

pkg/kv-cache/kv_cache_test.go

Lines changed: 8 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,11 @@ package kvcache
1818

1919
import (
2020
"context"
21-
"encoding/binary"
2221
"fmt"
2322
"sync"
2423
"time"
2524

26-
zmq "github.com/pebbe/zmq4"
27-
"github.com/vmihailenco/msgpack/v5"
28-
2925
"github.com/llm-d/llm-d-inference-sim/pkg/common"
30-
"github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents"
3126
. "github.com/onsi/ginkgo/v2"
3227
. "github.com/onsi/gomega"
3328
)
@@ -207,7 +202,9 @@ var _ = Describe("KV cache", Ordered, func() {
207202
EventBatchSize: 1,
208203
}
209204

210-
sub, topic := createSub(config)
205+
topic := CreateKVEventsTopic(config.Port, config.Model)
206+
sub, endpoint := common.CreateSub(topic)
207+
config.ZMQEndpoint = endpoint
211208
//nolint
212209
defer sub.Close()
213210

@@ -289,7 +286,7 @@ var _ = Describe("KV cache", Ordered, func() {
289286
for i := range test.expectedRemovedBlocks + test.expectedStoredBlocks {
290287
parts, err := sub.RecvMessageBytes(0)
291288
Expect(err).NotTo(HaveOccurred())
292-
stored, removed := parseEvent(parts, topic, uint64(i+1))
289+
stored, removed, _ := ParseKVEvent(parts, topic, uint64(i+1))
293290
storedCount += len(stored)
294291
removedCount += len(removed)
295292
}
@@ -309,7 +306,9 @@ var _ = Describe("KV cache", Ordered, func() {
309306
ZMQMaxConnectAttempts: 3,
310307
}
311308

312-
sub, topic := createSub(config)
309+
topic := CreateKVEventsTopic(config.Port, config.Model)
310+
sub, endpoint := common.CreateSub(topic)
311+
config.ZMQEndpoint = endpoint
313312
//nolint
314313
defer sub.Close()
315314

@@ -378,7 +377,7 @@ var _ = Describe("KV cache", Ordered, func() {
378377
for {
379378
parts, err := sub.RecvMessageBytes(0)
380379
Expect(err).NotTo(HaveOccurred())
381-
stored, removed := parseEvent(parts, topic, count)
380+
stored, removed, _ := ParseKVEvent(parts, topic, count)
382381
storedBlocks = append(storedBlocks, stored...)
383382
removedBlocks = append(removedBlocks, removed...)
384383
count++
@@ -484,67 +483,3 @@ func createRandomArray(minArrLen, maxArrLen int, maxValue uint64, random *common
484483

485484
return arr
486485
}
487-
488-
func parseEvent(parts [][]byte, expectedTopic string, expectedSeq uint64) ([]uint64, []uint64) {
489-
// The message should be [topic, seq, payload]
490-
Expect(parts).To(HaveLen(3))
491-
492-
Expect(string(parts[0])).To(Equal(expectedTopic))
493-
494-
seq := binary.BigEndian.Uint64(parts[1])
495-
Expect(seq).To(Equal(expectedSeq))
496-
497-
removed := make([]uint64, 0)
498-
stored := make([]uint64, 0)
499-
500-
var eventBatch kvevents.EventBatch
501-
err := msgpack.Unmarshal(parts[2], &eventBatch)
502-
Expect(err).NotTo(HaveOccurred())
503-
for _, rawEvent := range eventBatch.Events {
504-
var taggedUnion []msgpack.RawMessage
505-
err := msgpack.Unmarshal(rawEvent, &taggedUnion)
506-
Expect(err).NotTo(HaveOccurred())
507-
Expect(len(taggedUnion)).To(BeNumerically(">", 1))
508-
509-
payloadBytes, err := msgpack.Marshal(taggedUnion[1:])
510-
Expect(err).NotTo(HaveOccurred())
511-
512-
var tag string
513-
err = msgpack.Unmarshal(taggedUnion[0], &tag)
514-
Expect(err).NotTo(HaveOccurred())
515-
516-
switch tag {
517-
case kvevents.BlockStoredEventTag:
518-
var bs kvevents.BlockStored
519-
err = msgpack.Unmarshal(payloadBytes, &bs)
520-
stored = append(stored, bs.BlockHashes...)
521-
case kvevents.BlockRemovedEventTag:
522-
var br kvevents.BlockRemoved
523-
err = msgpack.Unmarshal(payloadBytes, &br)
524-
removed = append(removed, br.BlockHashes...)
525-
526-
default:
527-
Fail("unexpected tag " + tag)
528-
continue
529-
}
530-
Expect(err).NotTo(HaveOccurred())
531-
}
532-
return stored, removed
533-
}
534-
535-
func createSub(config *common.Configuration) (*zmq.Socket, string) {
536-
zctx, err := zmq.NewContext()
537-
Expect(err).NotTo(HaveOccurred())
538-
sub, err := zctx.NewSocket(zmq.SUB)
539-
Expect(err).NotTo(HaveOccurred())
540-
err = sub.Bind(wildcardEndpoint)
541-
Expect(err).NotTo(HaveOccurred())
542-
// get the actual port
543-
endpoint, err := sub.GetLastEndpoint()
544-
Expect(err).NotTo(HaveOccurred())
545-
config.ZMQEndpoint = endpoint
546-
topic := createTopic(config)
547-
err = sub.SetSubscribe(topic)
548-
Expect(err).NotTo(HaveOccurred())
549-
return sub, topic
550-
}

0 commit comments

Comments
 (0)