Skip to content

Commit a9dd5c5

Browse files
authored
add metrics tests for latency metrics with remote prefill (#248)
* add metrics tests for latency metrics with remote prefill Signed-off-by: Maya Barnea <[email protected]> * fix pr comments Signed-off-by: Maya Barnea <[email protected]> --------- Signed-off-by: Maya Barnea <[email protected]>
1 parent c7fef65 commit a9dd5c5

File tree

2 files changed

+84
-50
lines changed

2 files changed

+84
-50
lines changed

pkg/llm-d-inference-sim/metrics_test.go

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -774,23 +774,27 @@ var _ = Describe("Simulator metrics", Ordered, func() {
774774
numOfTokens := len(common.Tokenize(testUserMessage))
775775

776776
DescribeTable("should calculate all latency related metrics correctly for a single request",
777-
func(testNamePrefix string, ttft int, prefillTimePerToken int, interTokenLatency int) {
777+
func(testNamePrefix string, ttft int, prefillTimePerToken int, interTokenLatency int,
778+
kvcacheTransferLatency int, kvCacheTransferTimePerToken int, doRemotePrefill bool) {
778779
// send a single request with a prompt of 4 tokens and echo mode, so output tokens number of 4 too
779-
client := startServerAndSendRequest(testModel, testUserMessage, false, ttft, prefillTimePerToken, interTokenLatency)
780-
checkLatencyMertics(client, testModel, numOfTokens, numOfTokens, ttft, prefillTimePerToken, interTokenLatency)
781-
782-
// same in streaming modeq
783-
client = startServerAndSendRequest(testModel, testUserMessage, true, ttft, prefillTimePerToken, interTokenLatency)
784-
checkLatencyMertics(client, testModel, numOfTokens, numOfTokens, ttft, prefillTimePerToken, interTokenLatency)
780+
singleRequestLatencyTest(ttft, prefillTimePerToken, interTokenLatency, kvcacheTransferLatency,
781+
kvCacheTransferTimePerToken, false, numOfTokens, doRemotePrefill)
782+
singleRequestLatencyTest(ttft, prefillTimePerToken, interTokenLatency, kvcacheTransferLatency,
783+
kvCacheTransferTimePerToken, true, numOfTokens, doRemotePrefill)
785784
},
786-
func(testNamePrefix string, ttft int, prefillTimePerToken int, interTokenLatency int) string {
787-
return fmt.Sprintf("%s\nttft: %d, prefillTimePerToken: %d, interTokenLatency: %d", testNamePrefix, ttft, prefillTimePerToken, interTokenLatency)
785+
func(testNamePrefix string, ttft int, prefillTimePerToken int, interTokenLatency int,
786+
kvcacheTransferLatency int, kvCacheTransferTimePerToken int, doRemotePrefill bool) string {
787+
return fmt.Sprintf("%s\nttft: %d, prefillTimePerToken: %d, interTokenLatency: %d, kvcacheTransferLatency: %d, kvCacheTransferTimePerToken: %d, doRemotePrefill: %t",
788+
testNamePrefix, ttft, prefillTimePerToken, interTokenLatency, kvcacheTransferLatency, kvCacheTransferTimePerToken, doRemotePrefill)
788789
},
789-
// Params order: testName, ttft, prefillTimePerToken, interTokenLatency
790-
Entry(nil, "constant prefill + inter token time", 0, 0, 100),
791-
Entry(nil, "constant prefill + inter token time", 900, 0, 100),
792-
Entry(nil, "constant prefill + inter token time", 1000, 0, 100),
793-
Entry(nil, "prefill per token + inter token time", 0, 100, 100),
790+
// Params order: testName, ttft, prefillTimePerToken, interTokenLatency, kvcacheTransferLatency, kvCacheTransferTimePerToken, doRemotePrefill)
791+
Entry(nil, "constant prefill + inter token time", 0, 0, 100, 0, 0, false),
792+
Entry(nil, "constant prefill + inter token time", 900, 0, 100, 0, 0, false),
793+
Entry(nil, "constant prefill + inter token time", 1000, 0, 100, 0, 0, false),
794+
Entry(nil, "prefill per token + inter token time", 0, 100, 100, 0, 0, false),
795+
Entry(nil, "remote prefill constant time", 0, 0, 0, 1000, 0, true),
796+
Entry(nil, "remote prefill constant time with non-remote times", 5000, 5000, 0, 1000, 0, true),
797+
Entry(nil, "remote prefill time per transferfed token", 0, 0, 0, 0, 100, true),
794798
)
795799
})
796800

pkg/llm-d-inference-sim/test_utils.go

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
package llmdinferencesim
1717

1818
import (
19+
"bufio"
1920
"context"
2021
"crypto/tls"
2122
"errors"
@@ -136,13 +137,12 @@ func startServerWithArgsAndEnv(ctx context.Context, mode string, args []string,
136137
}, nil
137138
}
138139

139-
// startServerAndSendRequest - starts server configured according the given latency parameters in echo mode,
140-
// sends a single request with the given prompt
141-
func startServerAndSendRequest(modelName string, prompt string, isStreaming bool, ttft int, prefillTimePerToken int, interTokenLatency int) *http.Client {
140+
// startServerForLatencyTest - starts server configured according the given latency parameters in echo modes
141+
func startServerForLatencyTest(modelName string, ttft int, prefillTimePerToken int, interTokenLatency int, kvcacheTransferLatency int, kvCacheTransferTimePerToken int) *http.Client {
142142
ctx := context.TODO()
143143
args := []string{"cmd", "--model", modelName, "--mode", common.ModeEcho,
144-
// "--kv-cache-transfer-latency", strconv.Itoa(kvcacheTransferLatency),
145-
// "--kv-cache-transfer-time-per-token", strconv.Itoa(kvCacheTransferTimePerToken),
144+
"--kv-cache-transfer-latency", strconv.Itoa(kvcacheTransferLatency),
145+
"--kv-cache-transfer-time-per-token", strconv.Itoa(kvCacheTransferTimePerToken),
146146
"--time-to-first-token", strconv.Itoa(ttft),
147147
"--prefill-time-per-token", strconv.Itoa(prefillTimePerToken),
148148
"--inter-token-latency", strconv.Itoa(interTokenLatency),
@@ -151,27 +151,47 @@ func startServerAndSendRequest(modelName string, prompt string, isStreaming bool
151151
client, err := startServerWithArgs(ctx, args)
152152
gomega.Expect(err).NotTo(gomega.HaveOccurred())
153153

154-
openaitextclient, params := getOpenAIClientAndTextParams(client, modelName, prompt, isStreaming)
154+
return client
155+
}
155156

156-
if isStreaming {
157-
// send a single request in a serial way
158-
stream := openaitextclient.Completions.NewStreaming(ctx, params)
159-
chunksCnt := 0
157+
func singleRequestLatencyTest(ttft int, prefillTimePerToken int, interTokenLatency int, kvcacheTransferLatency int,
158+
kvCacheTransferTimePerToken int, isStreaming bool, numOfTokens int, doRemotePrefill bool) {
159+
client := startServerForLatencyTest(testModel, ttft, prefillTimePerToken, interTokenLatency, kvcacheTransferLatency, kvCacheTransferTimePerToken)
160+
sendCompletionRequestForLatencyTest(client, testModel, testUserMessage, isStreaming, doRemotePrefill)
161+
checkLatencyMetrics(client, testModel, numOfTokens, numOfTokens, ttft, prefillTimePerToken, interTokenLatency, kvcacheTransferLatency,
162+
kvCacheTransferTimePerToken, doRemotePrefill)
160163

161-
for stream.Next() {
162-
chunksCnt++
163-
}
164-
if err := stream.Err(); err != nil {
164+
}
165+
166+
// sendCompletionRequestForLatencyTest sends completion request according the given parameters
167+
// uses http.Post and not openai-api function because vllm specific fields should be sent
168+
func sendCompletionRequestForLatencyTest(client *http.Client, modelName string, prompt string, isStreaming bool, doRemotePrefill bool) {
169+
// send completions request using http post because disagregated PD fields should be included
170+
// Test with raw HTTP to verify the error response format
171+
reqBody := fmt.Sprintf(`{
172+
"prompt": "%s",
173+
"model": "%s",
174+
"stream": %t,
175+
"do_remote_prefill": %t
176+
}`, prompt, modelName, isStreaming, doRemotePrefill)
177+
178+
resp, err := client.Post("http://localhost/v1/completions", "application/json", strings.NewReader(reqBody))
179+
gomega.Expect(err).NotTo(gomega.HaveOccurred())
180+
defer func() {
181+
err := resp.Body.Close()
182+
gomega.Expect(err).NotTo(gomega.HaveOccurred())
183+
}()
184+
185+
if isStreaming {
186+
reader := bufio.NewReader(resp.Body)
187+
for {
188+
_, err := reader.ReadString('\n')
189+
if err == io.EOF {
190+
break
191+
}
165192
gomega.Expect(err).NotTo(gomega.HaveOccurred())
166193
}
167-
// number of chunks is number of tokens + 2 (one chunk with usage info and one closing chunk)
168-
gomega.Expect(chunksCnt).To(gomega.BeNumerically("==", len(common.Tokenize(prompt))+2))
169-
} else {
170-
_, err = openaitextclient.Completions.New(ctx, params)
171-
gomega.Expect(err).NotTo(gomega.HaveOccurred())
172194
}
173-
174-
return client
175195
}
176196

177197
// sendSimpleChatRequest starts server using the given environment variables and sends one chat completions request
@@ -212,6 +232,7 @@ func getOpenAIClientAndChatParams(client option.HTTPClient, model string, messag
212232
return openaiclient, params
213233
}
214234

235+
// nolint
215236
// getOpenAIClientAndTextParams - creates an openai client and params for /completions call based on the given parameters
216237
func getOpenAIClientAndTextParams(client option.HTTPClient, model string, message string, streaming bool) (openai.Client, openai.CompletionNewParams) {
217238
openaiclient := openai.NewClient(
@@ -438,14 +459,15 @@ func checkBucketBoundary(metrics string, modelName string, metricName string, bu
438459
gomega.Expect(metrics).To(gomega.ContainSubstring(getFloatBucketMetricLine(modelName, metricName, bucketBoudary, expectedCount)))
439460
}
440461

441-
// checkLatencyMertics sends /metrics request and checks that latency related values are valid
462+
// checkLatencyMetrics sends /metrics request and checks that latency related values are valid
442463
// client the http client to be used for request send
443464
// modelName the model name
444465
// numOfOutputTokens number of tokens in the output of the completion request we want to validate
445466
// ttft time to first token parameter
446467
// prefillTimePerToken prefill time per input tokens
447468
// interTokenLatency processing time per output token
448-
func checkLatencyMertics(client *http.Client, modelName string, numOfInputTokens int, numOfOutputTokens int, ttft int, prefillTimePerToken int, interTokenLatency int) {
469+
func checkLatencyMetrics(client *http.Client, modelName string, numOfInputTokens int, numOfOutputTokens int, ttft int,
470+
prefillTimePerToken int, interTokenLatency int, kvcacheTransferLatency int, kvCacheTransferTimePerToken int, doRemotePrefill bool) {
449471
// wait a little bit and check metrics
450472
time.Sleep(300 * time.Millisecond)
451473
metricsResp, err := client.Get(metricsUrl)
@@ -456,30 +478,38 @@ func checkLatencyMertics(client *http.Client, modelName string, numOfInputTokens
456478
gomega.Expect(err).NotTo(gomega.HaveOccurred())
457479
metrics := string(data)
458480

459-
var expectedPrefillTime float64
460-
// TODO take into consideration remote prefill
461-
if ttft > 0 {
462-
// time-to-first-token overwrites calculation of prefill time based on number of input tokens
463-
expectedPrefillTime = float64(ttft) / 1000
464-
481+
expectedPrefillTimeInSecs := 0.0
482+
if doRemotePrefill {
483+
// when doRemotePrefill is true, this means that this is decode request and prefill was executed on remote vllm
484+
if kvcacheTransferLatency != 0 {
485+
expectedPrefillTimeInSecs = float64(kvcacheTransferLatency) / 1000
486+
} else {
487+
expectedPrefillTimeInSecs = float64(kvCacheTransferTimePerToken*numOfInputTokens) / 1000
488+
}
465489
} else {
466-
expectedPrefillTime = float64(numOfInputTokens*prefillTimePerToken) / 1000
490+
if ttft > 0 {
491+
// time-to-first-token overwrites calculation of prefill time based on number of input tokens
492+
expectedPrefillTimeInSecs = float64(ttft) / 1000
493+
494+
} else {
495+
expectedPrefillTimeInSecs = float64(numOfInputTokens*prefillTimePerToken) / 1000
496+
}
467497
}
468-
expectedDecodeTime := float64(interTokenLatency*(numOfOutputTokens-1)) / 1000
469-
expectedE2ELatency := expectedPrefillTime + expectedDecodeTime
498+
expectedDecodeTimeInSecs := float64(interTokenLatency*(numOfOutputTokens-1)) / 1000
499+
expectedE2ELatency := expectedPrefillTimeInSecs + expectedDecodeTimeInSecs
470500

471501
prevBoundary := math.Inf(-1)
472502

473503
for _, bucketBoudary := range common.RequestLatencyBucketsBoundaries {
474-
checkBucketBoundary(metrics, modelName, prefillTimeMetricName, bucketBoudary, prevBoundary, expectedPrefillTime)
475-
checkBucketBoundary(metrics, modelName, decodeTimeMetricName, bucketBoudary, prevBoundary, expectedDecodeTime)
504+
checkBucketBoundary(metrics, modelName, prefillTimeMetricName, bucketBoudary, prevBoundary, expectedPrefillTimeInSecs)
505+
checkBucketBoundary(metrics, modelName, decodeTimeMetricName, bucketBoudary, prevBoundary, expectedDecodeTimeInSecs)
476506
checkBucketBoundary(metrics, modelName, e2eReqLatencyMetricName, bucketBoudary, prevBoundary, expectedE2ELatency)
477507

478508
prevBoundary = bucketBoudary
479509
}
480510
// check the last bucket
481511
lastBoundary := common.RequestLatencyBucketsBoundaries[len(common.RequestLatencyBucketsBoundaries)-1]
482-
checkBucketBoundary(metrics, modelName, prefillTimeMetricName, math.Inf(1), lastBoundary, expectedPrefillTime)
483-
checkBucketBoundary(metrics, modelName, decodeTimeMetricName, math.Inf(1), lastBoundary, expectedDecodeTime)
512+
checkBucketBoundary(metrics, modelName, prefillTimeMetricName, math.Inf(1), lastBoundary, expectedPrefillTimeInSecs)
513+
checkBucketBoundary(metrics, modelName, decodeTimeMetricName, math.Inf(1), lastBoundary, expectedDecodeTimeInSecs)
484514
checkBucketBoundary(metrics, modelName, e2eReqLatencyMetricName, math.Inf(1), lastBoundary, expectedE2ELatency)
485515
}

0 commit comments

Comments
 (0)