@@ -18,16 +18,11 @@ package kvcache
1818
1919import (
2020 "context"
21- "encoding/binary"
2221 "fmt"
2322 "sync"
2423 "time"
2524
26- zmq "github.com/pebbe/zmq4"
27- "github.com/vmihailenco/msgpack/v5"
28-
2925 "github.com/llm-d/llm-d-inference-sim/pkg/common"
30- "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents"
3126 . "github.com/onsi/ginkgo/v2"
3227 . "github.com/onsi/gomega"
3328)
@@ -207,7 +202,9 @@ var _ = Describe("KV cache", Ordered, func() {
207202 EventBatchSize : 1 ,
208203 }
209204
210- sub , topic := createSub (config )
205+ topic := CreateKVEventsTopic (config .Port , config .Model )
206+ sub , endpoint := common .CreateSub (topic )
207+ config .ZMQEndpoint = endpoint
211208 //nolint
212209 defer sub .Close ()
213210
@@ -289,7 +286,7 @@ var _ = Describe("KV cache", Ordered, func() {
289286 for i := range test .expectedRemovedBlocks + test .expectedStoredBlocks {
290287 parts , err := sub .RecvMessageBytes (0 )
291288 Expect (err ).NotTo (HaveOccurred ())
292- stored , removed := parseEvent (parts , topic , uint64 (i + 1 ))
289+ stored , removed , _ := ParseKVEvent (parts , topic , uint64 (i + 1 ))
293290 storedCount += len (stored )
294291 removedCount += len (removed )
295292 }
@@ -309,7 +306,9 @@ var _ = Describe("KV cache", Ordered, func() {
309306 ZMQMaxConnectAttempts : 3 ,
310307 }
311308
312- sub , topic := createSub (config )
309+ topic := CreateKVEventsTopic (config .Port , config .Model )
310+ sub , endpoint := common .CreateSub (topic )
311+ config .ZMQEndpoint = endpoint
313312 //nolint
314313 defer sub .Close ()
315314
@@ -378,7 +377,7 @@ var _ = Describe("KV cache", Ordered, func() {
378377 for {
379378 parts , err := sub .RecvMessageBytes (0 )
380379 Expect (err ).NotTo (HaveOccurred ())
381- stored , removed := parseEvent (parts , topic , count )
380+ stored , removed , _ := ParseKVEvent (parts , topic , count )
382381 storedBlocks = append (storedBlocks , stored ... )
383382 removedBlocks = append (removedBlocks , removed ... )
384383 count ++
@@ -484,67 +483,3 @@ func createRandomArray(minArrLen, maxArrLen int, maxValue uint64, random *common
484483
485484 return arr
486485}
487-
488- func parseEvent (parts [][]byte , expectedTopic string , expectedSeq uint64 ) ([]uint64 , []uint64 ) {
489- // The message should be [topic, seq, payload]
490- Expect (parts ).To (HaveLen (3 ))
491-
492- Expect (string (parts [0 ])).To (Equal (expectedTopic ))
493-
494- seq := binary .BigEndian .Uint64 (parts [1 ])
495- Expect (seq ).To (Equal (expectedSeq ))
496-
497- removed := make ([]uint64 , 0 )
498- stored := make ([]uint64 , 0 )
499-
500- var eventBatch kvevents.EventBatch
501- err := msgpack .Unmarshal (parts [2 ], & eventBatch )
502- Expect (err ).NotTo (HaveOccurred ())
503- for _ , rawEvent := range eventBatch .Events {
504- var taggedUnion []msgpack.RawMessage
505- err := msgpack .Unmarshal (rawEvent , & taggedUnion )
506- Expect (err ).NotTo (HaveOccurred ())
507- Expect (len (taggedUnion )).To (BeNumerically (">" , 1 ))
508-
509- payloadBytes , err := msgpack .Marshal (taggedUnion [1 :])
510- Expect (err ).NotTo (HaveOccurred ())
511-
512- var tag string
513- err = msgpack .Unmarshal (taggedUnion [0 ], & tag )
514- Expect (err ).NotTo (HaveOccurred ())
515-
516- switch tag {
517- case kvevents .BlockStoredEventTag :
518- var bs kvevents.BlockStored
519- err = msgpack .Unmarshal (payloadBytes , & bs )
520- stored = append (stored , bs .BlockHashes ... )
521- case kvevents .BlockRemovedEventTag :
522- var br kvevents.BlockRemoved
523- err = msgpack .Unmarshal (payloadBytes , & br )
524- removed = append (removed , br .BlockHashes ... )
525-
526- default :
527- Fail ("unexpected tag " + tag )
528- continue
529- }
530- Expect (err ).NotTo (HaveOccurred ())
531- }
532- return stored , removed
533- }
534-
535- func createSub (config * common.Configuration ) (* zmq.Socket , string ) {
536- zctx , err := zmq .NewContext ()
537- Expect (err ).NotTo (HaveOccurred ())
538- sub , err := zctx .NewSocket (zmq .SUB )
539- Expect (err ).NotTo (HaveOccurred ())
540- err = sub .Bind (wildcardEndpoint )
541- Expect (err ).NotTo (HaveOccurred ())
542- // get the actual port
543- endpoint , err := sub .GetLastEndpoint ()
544- Expect (err ).NotTo (HaveOccurred ())
545- config .ZMQEndpoint = endpoint
546- topic := createTopic (config )
547- err = sub .SetSubscribe (topic )
548- Expect (err ).NotTo (HaveOccurred ())
549- return sub , topic
550- }
0 commit comments