Skip to content

Commit c5bfb72

Browse files
authored
update PD support to be compatible with nixlv2 (#258)
* update PD support to be compatible with nixlv2 Signed-off-by: Maya Barnea <[email protected]> * fix test + add comment Signed-off-by: Maya Barnea <[email protected]> --------- Signed-off-by: Maya Barnea <[email protected]>
1 parent 719a895 commit c5bfb72

File tree

4 files changed

+36
-34
lines changed

4 files changed

+36
-34
lines changed

pkg/llm-d-inference-sim/simulator.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -576,14 +576,16 @@ func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion
576576
time.Now().Unix(), modelName, usageData)
577577

578578
if doRemoteDecode {
579+
baseResp.KVParams = &openaiserverapi.KVTransferParams{}
579580
// add special fields related to the prefill pod special behavior
580-
baseResp.DoRemoteDecode = true
581-
baseResp.DoRemotePrefill = false
581+
baseResp.KVParams.DoRemoteDecode = false
582+
baseResp.KVParams.DoRemotePrefill = true
582583
// currently remote prefill information is hard-coded
583-
baseResp.RemoteBlockIds = []string{"DUMMY_ID"}
584-
baseResp.RemoteEngineId = "DUMMY_ID"
585-
baseResp.RemoteHost = "DUMMY"
586-
baseResp.RemotePort = 1234
584+
baseResp.KVParams.RemoteBlockIds = []string{"DUMMY_ID"}
585+
baseResp.KVParams.RemoteEngineId = "DUMMY_ID"
586+
baseResp.KVParams.RemoteHost = "DUMMY"
587+
baseResp.KVParams.RemotePort = 1234
588+
baseResp.KVParams.TPSize = 1
587589
}
588590

589591
baseChoice := openaiserverapi.CreateBaseResponseChoice(0, finishReason)

pkg/llm-d-inference-sim/test_utils.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"bufio"
2020
"context"
2121
"crypto/tls"
22+
"encoding/json"
2223
"errors"
2324
"fmt"
2425
"io"
@@ -33,6 +34,7 @@ import (
3334
"time"
3435

3536
"github.com/llm-d/llm-d-inference-sim/pkg/common"
37+
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
3638
"github.com/openai/openai-go/v3"
3739
"github.com/openai/openai-go/v3/option"
3840
"github.com/openai/openai-go/v3/packages/param"
@@ -169,14 +171,15 @@ func singleRequestLatencyTest(ttft int, prefillTimePerToken int, interTokenLaten
169171
func sendCompletionRequestForLatencyTest(client *http.Client, modelName string, prompt string, isStreaming bool, doRemotePrefill bool) {
170172
// send completions request using http post because disagregated PD fields should be included
171173
// Test with raw HTTP to verify the error response format
172-
reqBody := fmt.Sprintf(`{
173-
"prompt": "%s",
174-
"model": "%s",
175-
"stream": %t,
176-
"do_remote_prefill": %t
177-
}`, prompt, modelName, isStreaming, doRemotePrefill)
178-
179-
resp, err := client.Post("http://localhost/v1/completions", "application/json", strings.NewReader(reqBody))
174+
req := &openaiserverapi.TextCompletionRequest{Prompt: prompt}
175+
req.KVParams = &openaiserverapi.KVTransferParams{DoRemotePrefill: doRemotePrefill}
176+
req.Model = modelName
177+
req.Stream = isStreaming
178+
179+
body, err := json.Marshal(req)
180+
gomega.Expect(err).NotTo(gomega.HaveOccurred())
181+
182+
resp, err := client.Post("http://localhost/v1/completions", "application/json", strings.NewReader(string(body)))
180183
gomega.Expect(err).NotTo(gomega.HaveOccurred())
181184
defer func() {
182185
err := resp.Body.Close()

pkg/openai-server-api/request.go

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,29 @@ type baseCompletionRequest struct {
9191
StreamOptions StreamOptions `json:"stream_options"`
9292
// Model defines Model name to use for "inference", could be base Model name or one of available LoRA adapters
9393
Model string `json:"model"`
94+
// KVParams kv transfer related fields
95+
KVParams *KVTransferParams `json:"kv_transfer_params"`
96+
// The number of tokens in the prompt that are in the local KV Cache
97+
cachedPromptTokens int
98+
// IgnoreEOS is a boolean value, true when the model should ignore end-of-sequence tokens
99+
IgnoreEOS bool `json:"ignore_eos"`
100+
}
101+
102+
type KVTransferParams struct {
94103
// DoRemoteDecode boolean value, true when request's decode will be done on remote pod
95104
DoRemoteDecode bool `json:"do_remote_decode"`
96105
// DoRemotePrefill boolean value, true when request's prefill was done on remote pod
97106
DoRemotePrefill bool `json:"do_remote_prefill"`
98-
// RemoteBlockIds is a list of block identifiers to process remotely for distributed decoding
99-
RemoteBlockIds []string `json:"remote_block_ids"`
100107
// RemoteEngineId is an identifier of the remote inference engine or backend to use for processing requests
101108
RemoteEngineId string `json:"remote_engine_id"`
109+
// RemoteBlockIds is a list of block identifiers to process remotely for distributed decoding
110+
RemoteBlockIds []string `json:"remote_block_ids"`
102111
// RemoteHost is a hostname or IP address of the remote server handling prefill
103112
RemoteHost string `json:"remote_host"`
104113
// RemotePort is a port of the remote server handling prefill
105114
RemotePort int `json:"remote_port"`
106-
// The number of tokens in the prompt that are in the local KV Cache
107-
cachedPromptTokens int
108-
// IgnoreEOS is a boolean value, true when the model should ignore end-of-sequence tokens
109-
IgnoreEOS bool `json:"ignore_eos"`
115+
// TPSize is the tensor parallelism size for KV cache transfer
116+
TPSize int `json:"tp_size" default:"1"`
110117
}
111118

112119
// StreamOptions defines streaming options for streaming requests
@@ -132,11 +139,11 @@ func (b *baseCompletionRequest) IncludeUsage() bool {
132139
}
133140

134141
func (b *baseCompletionRequest) IsDoRemoteDecode() bool {
135-
return b.DoRemoteDecode
142+
return b.KVParams != nil && b.KVParams.DoRemoteDecode
136143
}
137144

138145
func (b *baseCompletionRequest) IsDoRemotePrefill() bool {
139-
return b.DoRemotePrefill
146+
return b.KVParams != nil && b.KVParams.DoRemotePrefill
140147
}
141148

142149
// GetNumberOfCachedPromptTokens returns the number of tokens in the prompt that are

pkg/openai-server-api/response.go

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,8 @@ type baseCompletionResponse struct {
4040
Usage *Usage `json:"usage"`
4141
// Object is the Object type, "text_completion", "chat.completion", or "chat.completion.chunk"
4242
Object string `json:"object"`
43-
// DoRemoteDecode boolean value, true when request's decode will be done on remote pod
44-
DoRemoteDecode bool `json:"do_remote_decode"`
45-
// DoRemotePrefill boolean value, true when request's prefill was done on remote pod
46-
DoRemotePrefill bool `json:"do_remote_prefill"`
47-
// RemoteBlockIds is a list of block identifiers to process remotely for distributed decoding
48-
RemoteBlockIds []string `json:"remote_block_ids"`
49-
// RemoteEngineId is an identifier of the remote inference engine or backend to use for processing requests
50-
RemoteEngineId string `json:"remote_engine_id"`
51-
// RemoteHost is a hostname or IP address of the remote server handling prefill
52-
RemoteHost string `json:"remote_host"`
53-
// RemotePort is a port of the remote server handling prefill
54-
RemotePort int `json:"remote_port"`
43+
// KVParams kv transfer related fields
44+
KVParams *KVTransferParams `json:"kv_transfer_params"`
5545
}
5646

5747
// Usage contains token Usage statistics

0 commit comments

Comments
 (0)