Skip to content

Commit 58544fe

Browse files
committed
Calc kv cache transfer overhead based on prompt length
Signed-off-by: Qifan Deng <[email protected]>
1 parent 65d27c1 commit 58544fe

File tree

5 files changed

+142
-19
lines changed

5 files changed

+142
-19
lines changed

pkg/common/config.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ type Configuration struct {
6767
TimeToFirstTokenStdDev int `yaml:"time-to-first-token-std-dev" json:"time-to-first-token-std-dev"`
6868

6969
// PrefillOverhead time taken to prefill the context, in milliseconds
70-
PrefillOverhead int `yaml:"prefill-overhead" json:"prefill-overhead"`
70+
// PrefillOverhead along with PrefillComplexity defines the time taken to prefill the context
71+
PrefillOverhead int `yaml:"prefill-overhead" json:"prefill-overhead"`
72+
// options are "n^2" and "nlog(n)"
7173
PrefillComplexity string `yaml:"prefill-complexity" json:"prefill-complexity"`
7274

7375
// InterTokenLatency time between generated tokens, in milliseconds
@@ -85,6 +87,13 @@ type Configuration struct {
8587
// KVCacheTransferLatency
8688
KVCacheTransferLatencyStdDev int `yaml:"kv-cache-transfer-latency-std-dev" json:"kv-cache-transfer-latency-std-dev"`
8789

90+
// KVCacheTransfer overhead time taken to transfer kv-cache from another vLLM instance in case P/D is activated,
91+
// in milliseconds.
92+
// KVCacheTransferOverhead along with KVCacheTransferComplexity defines the time taken to transfer kv-cache.
93+
KVCacheTransferOverhead int `yaml:"kv-cache-transfer-overhead" json:"kv-cache-transfer-overhead"`
94+
// options are "linear" and "in-place", default is "linear"
95+
KVCacheTransferComplexity string `yaml:"kv-cache-transfer-complexity" json:"kv-cache-transfer-complexity"`
96+
8897
// Mode defines the simulator response generation mode, valid values: echo, random
8998
Mode string `yaml:"mode" json:"mode"`
9099
// Seed defines random seed for operations
@@ -319,6 +328,17 @@ func (c *Configuration) validate() error {
319328
if float32(c.KVCacheTransferLatencyStdDev) > 0.3*float32(c.KVCacheTransferLatency) {
320329
return errors.New("kv-cache tranfer standard deviation cannot be more than 30% of kv-cache tranfer")
321330
}
331+
if c.KVCacheTransferOverhead < 0 {
332+
return errors.New("kv-cache transfer overhead cannot be negative")
333+
} else if c.KVCacheTransferOverhead == 0 {
334+
if c.KVCacheTransferComplexity != "" {
335+
return errors.New("kv-cache transfer complexity is set, but kv-cache transfer overhead is 0")
336+
}
337+
}
338+
if c.KVCacheTransferComplexity != "" && c.KVCacheTransferComplexity != "linear" && c.KVCacheTransferComplexity != "in-place" {
339+
return errors.New("kv-cache transfer complexity should be either \"linear\" or \"in-place\"")
340+
}
341+
322342
if c.MaxLoras < 1 {
323343
return errors.New("max LoRAs cannot be less than 1")
324344
}
@@ -422,6 +442,8 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
422442
f.IntVar(&config.TimeToFirstTokenStdDev, "time-to-first-token-std-dev", config.TimeToFirstTokenStdDev, "Standard deviation for time before the first token will be returned (in milliseconds)")
423443
f.IntVar(&config.KVCacheTransferLatencyStdDev, "kv-cache-transfer-latency-std-dev", config.KVCacheTransferLatencyStdDev, "Standard deviation for time for KV-cache transfer from a remote vLLM (in milliseconds)")
424444
f.Int64Var(&config.Seed, "seed", config.Seed, "Random seed for operations (if not set, current Unix time in nanoseconds is used)")
445+
f.IntVar(&config.KVCacheTransferOverhead, "kv-cache-transfer-overhead", config.KVCacheTransferOverhead, "Time to transfer kv-cache in milliseconds. This argument is ignored if <kv-cache-transfer-latency> is not set.")
446+
f.StringVar(&config.KVCacheTransferComplexity, "kv-cache-transfer-complexity", config.KVCacheTransferComplexity, "Complexity of kv-cache transfer based on token length. Options are \"linear\" and \"in-place\". Default is \"linear\".")
425447

426448
f.IntVar(&config.MaxToolCallIntegerParam, "max-tool-call-integer-param", config.MaxToolCallIntegerParam, "Maximum possible value of integer parameters in a tool call")
427449
f.IntVar(&config.MinToolCallIntegerParam, "min-tool-call-integer-param", config.MinToolCallIntegerParam, "Minimum possible value of integer parameters in a tool call")

pkg/common/config_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,18 @@ var _ = Describe("Simulator configuration", func() {
392392
name: "<prefill-overhead> must be set when <prefill-complexity> is set",
393393
args: []string{"cmd", "--prefill-complexity", "n^2", "--config", "../../manifests/config.yaml"},
394394
},
395+
{
396+
name: "<prefill-complexity> should not be 'xxx'",
397+
args: []string{"cmd", "--prefill-complexity", "xxx", "--config", "../../manifests/config.yaml"},
398+
},
399+
{
400+
name: "<kv-cache-transfer-overhead> must be set when <kv-cache-transfer-complexity> is set",
401+
args: []string{"cmd", "--kv-cache-transfer-complexity", "linear", "--config", "../../manifests/config.yaml"},
402+
},
403+
{
404+
name: "<kv-cache-transfer-complexity> should not be 'xxx'",
405+
args: []string{"cmd", "--kv-cache-transfer-complexity", "xxx", "--config", "../../manifests/config.yaml"},
406+
},
395407
}
396408

397409
for _, test := range invalidTests {

pkg/llm-d-inference-sim/simulator.go

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques
636636
// calculate how long to wait before returning the response, time is based on number of tokens
637637
nPromptTokens := usageData.PromptTokens
638638
nGenTokens := usageData.CompletionTokens
639-
totalMillisToWait := s.getTimeToFirstToken(doRemotePrefill, nPromptTokens) + s.getTotalInterTokenLatency(nGenTokens)
639+
totalMillisToWait := s.getTimeToFirstToken(nPromptTokens, doRemotePrefill) + s.getTotalInterTokenLatency(nGenTokens)
640640
time.Sleep(time.Duration(totalMillisToWait) * time.Millisecond)
641641

642642
ctx.Response.Header.SetContentType("application/json")
@@ -654,13 +654,17 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques
654654
}
655655

656656
// returns time to first token based on the current request's doRemotePrefill
657-
func (s *VllmSimulator) getTimeToFirstToken(doRemotePrefill bool, nPromptTokens int) int {
657+
func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, doRemotePrefill bool) int {
658658
if s.config.TimeToFirstToken == 0 && s.config.PrefillOverhead != 0 {
659659
if nPromptTokens <= 1 {
660-
return s.config.PrefillOverhead
660+
if !doRemotePrefill {
661+
return s.config.PrefillOverhead
662+
}
663+
return s.config.KVCacheTransferOverhead
661664
}
662-
return s.calcPrefillOverhead(nPromptTokens)
665+
return s.calcPrefillOverhead(nPromptTokens, doRemotePrefill)
663666
}
667+
fmt.Printf("get time to first token %d, nPromptTokens %d, doRemotePrefill %v\n", s.config.TimeToFirstToken, nPromptTokens, doRemotePrefill)
664668

665669
mean := float64(s.config.TimeToFirstToken)
666670
stddev := float64(s.config.TimeToFirstTokenStdDev)
@@ -688,7 +692,10 @@ func (s *VllmSimulator) getTotalInterTokenLatency(numOfTokens int) int {
688692
}
689693

690694
// calc the prefill overhead against number of tokens
691-
func (s *VllmSimulator) calcPrefillOverhead(nPromptTokens int) int {
695+
func (s *VllmSimulator) calcPrefillOverhead(nPromptTokens int, doRemotePrefill bool) int {
696+
if doRemotePrefill {
697+
return s.calcRemotePrefillOverhead(nPromptTokens)
698+
}
692699
pfOverhead := s.config.PrefillOverhead
693700
complexity := s.config.PrefillComplexity
694701
// policies of different complexities of prefill implementation
@@ -699,7 +706,24 @@ func (s *VllmSimulator) calcPrefillOverhead(nPromptTokens int) int {
699706
case "nlog(n)":
700707
return int(float64(pfOverhead) * (float64(nPromptTokens) * math.Log2(float64(nPromptTokens))))
701708
}
709+
// should never reach here
710+
return 0
711+
}
702712

713+
// calc the remote prefill overhead against number of tokens
714+
func (s *VllmSimulator) calcRemotePrefillOverhead(nPromptTokens int) int {
715+
overhead := s.config.KVCacheTransferOverhead
716+
complexity := s.config.KVCacheTransferComplexity
717+
switch complexity {
718+
case "linear", "":
719+
fmt.Printf("linear complexity, overhead %d, nPromptTokens %d\n", overhead, nPromptTokens)
720+
return overhead * nPromptTokens
721+
case "in-place":
722+
// when the context is already filled
723+
// this is a simple implementation which return a defined overhead
724+
return overhead
725+
}
726+
// should never reach here
703727
return 0
704728
}
705729

pkg/llm-d-inference-sim/simulator_test.go

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ var _ = Describe("Simulator", func() {
802802
simulator.config.TimeToFirstTokenStdDev = timeToFirstTokenStdDev
803803
simulator.config.KVCacheTransferLatency = kvCacheLatency
804804
simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev
805-
timeToFirst := simulator.getTimeToFirstToken(doREmotePrefill, 1)
805+
timeToFirst := simulator.getTimeToFirstToken(1, doREmotePrefill)
806806
if doREmotePrefill {
807807
Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3)))
808808
Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7)))
@@ -826,29 +826,30 @@ var _ = Describe("Simulator", func() {
826826

827827
It("when <time-to-first-token> is not 0, ignore <prefill-overhead>", func() {
828828
timeToFirstToken := 10000
829-
prefillOverhead := 100
830829
simulator.config.TimeToFirstToken = timeToFirstToken
831-
simulator.config.PrefillOverhead = prefillOverhead
832-
timeToFirst := simulator.getTimeToFirstToken(false, 1)
830+
simulator.config.PrefillOverhead = 100
831+
timeToFirst := simulator.getTimeToFirstToken(1, false)
833832
Expect(timeToFirst).To(BeNumerically(">=", int(float32(timeToFirstToken)*0.3)))
834833
Expect(timeToFirst).To(BeNumerically("<=", int(float32(timeToFirstToken)*1.7)))
835834
})
836835

837-
It("when <time-to-first-token> is 0, use <prefill-overhead>", func() {
836+
It("when <time-to-first-token> is 0, and <prefill-overhead> is not 0, use <prefill-overhead>", func() {
838837
simulator.config.TimeToFirstToken = 0
839838
simulator.config.PrefillOverhead = 100
840-
timeToFirst := simulator.getTimeToFirstToken(false, 1)
839+
timeToFirst := simulator.getTimeToFirstToken(1, false)
841840
Expect(timeToFirst).To(BeNumerically(">=", 100))
842841
})
843842

844843
DescribeTable("time to first token is super linear of prefill against number of prompt tokens",
845844
func(prefillOverhead int, tolerance float64, minNTokens int, maxNTokens int) {
845+
simulator.config.PrefillComplexity = "n^2"
846846
for nTokens := minNTokens; nTokens <= maxNTokens; nTokens++ {
847-
square := prefillOverhead * nTokens * nTokens
848847
simulator.config.PrefillOverhead = prefillOverhead
849-
timeToFirst := simulator.getTimeToFirstToken(false, nTokens)
848+
timeToFirst := simulator.getTimeToFirstToken(nTokens, false)
849+
850+
square := prefillOverhead * nTokens * nTokens
850851
diffRatio := math.Abs(float64(timeToFirst-square)) / float64(square)
851-
Expect(diffRatio).To(BeNumerically("<", tolerance))
852+
Expect(diffRatio).To(BeNumerically("<=", tolerance))
852853
}
853854
},
854855
func(prefillOverhead int, tolerance float64, minNTokens int, maxNTokens int) string {
@@ -865,11 +866,12 @@ var _ = Describe("Simulator", func() {
865866
simulator.config.PrefillComplexity = "nlog(n)"
866867

867868
for nTokens := minNTokens; nTokens <= maxNTokens; nTokens++ {
868-
nlogn := int(float64(prefillOverhead) * float64(nTokens) * math.Log2(float64(nTokens)))
869869
simulator.config.PrefillOverhead = prefillOverhead
870-
timeToFirst := simulator.getTimeToFirstToken(false, nTokens)
870+
timeToFirst := simulator.getTimeToFirstToken(nTokens, false)
871+
872+
nlogn := int(float64(prefillOverhead) * float64(nTokens) * math.Log2(float64(nTokens)))
871873
diffRatio := math.Abs(float64(timeToFirst-nlogn)) / float64(nlogn)
872-
Expect(diffRatio).To(BeNumerically("<", tolerance))
874+
Expect(diffRatio).To(BeNumerically("<=", tolerance))
873875
}
874876
},
875877
func(prefillOverhead int, tolerance float64, minNTokens int, maxNTokens int) string {
@@ -880,6 +882,69 @@ var _ = Describe("Simulator", func() {
880882
Entry("medium numbers, larger range", 200, 0.1, 50, 100),
881883
Entry("large numbers", 150, 0.05, 20000, 20010),
882884
)
885+
886+
It("when <kv-cache-transfer-latency> not 0, ignore <kv-cache-transfer-overhead>", func() {
887+
overhead := 100
888+
simulator.config.KVCacheTransferLatency = 1000
889+
simulator.config.KVCacheTransferOverhead = overhead
890+
timeToFirst := simulator.getTimeToFirstToken(1, false)
891+
Expect(timeToFirst).To(BeNumerically(">=", overhead))
892+
})
893+
894+
It("when <kv-cache-transfer-latency> is 0, and <kv-cache-transfer-overhead> is not 0, use <kv-cache-transfer-overhead>", func() {
895+
overhead := 100
896+
simulator.config.KVCacheTransferLatency = 0
897+
simulator.config.KVCacheTransferOverhead = overhead
898+
timeToFirst := simulator.getTimeToFirstToken(1, false)
899+
Expect(timeToFirst).To(BeNumerically(">", 0))
900+
})
901+
902+
DescribeTable("When remote kv cache transfer is enabled with \"linear\" policy, time to first token is linear of kv cache transfer against number of prompt tokens",
903+
func(kvCacheOverhead int, tolerance float64, minNTokens int, maxNTokens int) {
904+
simulator.config.TimeToFirstToken = 0
905+
simulator.config.PrefillOverhead = 1
906+
simulator.config.KVCacheTransferComplexity = "linear"
907+
908+
for nTokens := minNTokens; nTokens <= maxNTokens; nTokens++ {
909+
simulator.config.KVCacheTransferOverhead = kvCacheOverhead
910+
timeToFirst := simulator.getTimeToFirstToken(nTokens, true)
911+
912+
linear := kvCacheOverhead * nTokens
913+
diffRatio := math.Abs(float64(timeToFirst-linear)) / float64(linear)
914+
Expect(diffRatio).To(BeNumerically("<=", tolerance))
915+
}
916+
},
917+
func(kvCacheOverhead int, tolerance float64, minNTokens int, maxNTokens int) string {
918+
return fmt.Sprintf("kvCacheOverhead: %d tolerance: %f minNTokens: %d maxNTokens: %d",
919+
kvCacheOverhead, tolerance, minNTokens, maxNTokens)
920+
},
921+
Entry("small numbers", 100, 0.1, 1, 10),
922+
Entry("medium numbers, larger range", 200, 0.1, 50, 100),
923+
Entry("large numbers", 150, 0.05, 20000, 20010),
924+
)
925+
926+
DescribeTable("When remote kv cache transfer is enabled with \"in-place\" policy, time to first token should not be impacted by number of prompt tokens",
927+
func(kvCacheOverhead int, tolerance float64, minNTokens int, maxNTokens int) {
928+
simulator.config.TimeToFirstToken = 0
929+
simulator.config.PrefillOverhead = 1
930+
simulator.config.KVCacheTransferComplexity = "in-place"
931+
for nTokens := minNTokens; nTokens <= maxNTokens; nTokens++ {
932+
simulator.config.KVCacheTransferOverhead = kvCacheOverhead
933+
timeToFirst := simulator.getTimeToFirstToken(nTokens, true)
934+
935+
inPlace := kvCacheOverhead
936+
diffRatio := math.Abs(float64(timeToFirst-inPlace)) / float64(inPlace)
937+
Expect(diffRatio).To(BeNumerically("<=", tolerance))
938+
}
939+
},
940+
func(kvCacheOverhead int, tolerance float64, minNTokens int, maxNTokens int) string {
941+
return fmt.Sprintf("kvCacheOverhead: %d tolerance: %f minNTokens: %d maxNTokens: %d",
942+
kvCacheOverhead, tolerance, minNTokens, maxNTokens)
943+
},
944+
Entry("small numbers", 100, 0.1, 1, 10),
945+
Entry("medium numbers, larger range", 200, 0.1, 50, 100),
946+
Entry("large numbers", 150, 0.05, 20000, 20010),
947+
)
883948
})
884949

885950
Context("fake metrics", func() {

pkg/llm-d-inference-sim/streaming.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, nPrompt
9696
// sendTokenChunks creates and sends response chunks
9797
func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writer, nPromptTokens int, genTokens []string, tc *openaiserverapi.ToolCall, finishReason string) {
9898
// time to first token delay
99-
time.Sleep(time.Duration(s.getTimeToFirstToken(context.doRemotePrefill, nPromptTokens)) * time.Millisecond)
99+
time.Sleep(time.Duration(s.getTimeToFirstToken(nPromptTokens, context.doRemotePrefill)) * time.Millisecond)
100100

101101
for i, token := range genTokens {
102102
if i != 0 {

0 commit comments

Comments
 (0)