Skip to content

Commit 1bea160

Browse files
committed
add tests for custom dataset
Signed-off-by: Qifan Deng <[email protected]>
1 parent 60592b1 commit 1bea160

File tree

7 files changed

+111
-40
lines changed

7 files changed

+111
-40
lines changed
0 Bytes
Binary file not shown.

pkg/dataset/custom_dataset.go

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ package dataset
1818

1919
import (
2020
"context"
21+
"crypto/sha256"
2122
"database/sql"
23+
"encoding/hex"
2224
"encoding/json"
2325
"errors"
2426
"fmt"
@@ -28,12 +30,11 @@ import (
2830
"os"
2931
"os/signal"
3032
"path/filepath"
31-
"strconv"
3233
"syscall"
3334
"time"
3435

3536
"github.com/go-logr/logr"
36-
"github.com/google/uuid"
37+
"github.com/llm-d/llm-d-inference-sim/pkg/common"
3738
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
3839
_ "github.com/mattn/go-sqlite3"
3940
)
@@ -338,19 +339,35 @@ func unmarshalAllRecords(rows *sql.Rows) ([][]string, error) {
338339
return tokensList, nil
339340
}
340341

341-
func (d *CustomDataset) getRandomTokens(n_gen_tokens int) []string {
342-
return []string{"<|random_tokens|>", strconv.Itoa(n_gen_tokens)}
342+
func (d *CustomDataset) GetPromptHash(req openaiserverapi.CompletionRequest) []byte {
343+
hashArray := sha256.Sum256([]byte(req.GetFullPrompt()))
344+
return hashArray[:]
343345
}
344346

345-
func (d *CustomDataset) readTokensFromDB(prompt string, n_gen_tokens int) []string {
346-
promptHash := uuid.NewSHA1(uuid.NameSpaceOID, []byte(prompt)).NodeID()
347-
rows, err := d.db.Query("SELECT "+genTokensCol+" FROM "+tableName+" WHERE "+promptHashCol+" = ?;", promptHash)
347+
func (d *CustomDataset) GetPromptHashHex(hashBytes []byte) string {
348+
return hex.EncodeToString(hashBytes)
349+
}
350+
351+
// GetTokens returns tokens and finishReason for the given request and mode (echo or random)
352+
func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) {
353+
if mode == common.ModeEcho {
354+
return d.echo(req)
355+
}
356+
nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS())
357+
tokens, err := d.GenerateTokens(req, nTokensToGen)
358+
return tokens, finishReason, err
359+
}
360+
361+
func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) {
362+
promptHash := d.GetPromptHash(req)
363+
promptHashHex := d.GetPromptHashHex(promptHash)
364+
rows, err := d.db.Query("SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';")
348365
if err != nil {
349366
if !d.hasWarned {
350367
d.Logger.Error(err, "failed to query database. Ensure the prompt hash exists in the dataset. Will generate random tokens instead.")
351368
d.hasWarned = true
352369
}
353-
return d.getRandomTokens(n_gen_tokens)
370+
return GenPresetRandomTokens(nTokens), nil
354371
}
355372
defer func() {
356373
if cerr := rows.Close(); cerr != nil {
@@ -361,18 +378,13 @@ func (d *CustomDataset) readTokensFromDB(prompt string, n_gen_tokens int) []stri
361378
tokensList, err := unmarshalAllRecords(rows)
362379
if err != nil {
363380
d.Logger.Error(err, "failed to unmarshal records from database")
364-
return d.getRandomTokens(n_gen_tokens)
381+
return GenPresetRandomTokens(nTokens), nil
365382
}
366383

367384
if len(tokensList) == 0 {
368-
return d.getRandomTokens(n_gen_tokens)
385+
return GenPresetRandomTokens(nTokens), nil
369386
}
370387
d.hasWarned = false
371388
randIndex := rand.Intn(len(tokensList))
372-
return tokensList[randIndex]
373-
}
374-
375-
func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) {
376-
tokens := d.readTokensFromDB("", nTokens)
377-
return tokens, nil
389+
return tokensList[randIndex], nil
378390
}

pkg/dataset/custom_dataset_test.go

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,18 @@ import (
2121
"os"
2222

2323
"github.com/go-logr/logr"
24+
"github.com/llm-d/llm-d-inference-sim/pkg/common"
25+
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
2426
. "github.com/onsi/ginkgo/v2"
2527
. "github.com/onsi/gomega"
2628

2729
_ "github.com/mattn/go-sqlite3"
2830
)
2931

32+
const (
33+
testPrompt = "Hello world!"
34+
)
35+
3036
var _ = Describe("CustomDataset", func() {
3137
var (
3238
dataset *CustomDataset
@@ -90,20 +96,20 @@ var _ = Describe("CustomDataset", func() {
9096
err := dataset.Init(validDBPath, "", "")
9197
Expect(err).NotTo(HaveOccurred())
9298

93-
row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2';")
99+
row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';")
94100
var n_gen_tokens int
95101
err = row.Scan(&n_gen_tokens)
96102
Expect(err).NotTo(HaveOccurred())
97-
Expect(n_gen_tokens).To(Equal(3))
103+
Expect(n_gen_tokens).To(Equal(4))
98104

99105
var jsonStr string
100-
row = dataset.db.QueryRow("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2';")
106+
row = dataset.db.QueryRow("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';")
101107
err = row.Scan(&jsonStr)
102108
Expect(err).NotTo(HaveOccurred())
103109
var tokens []string
104110
err = json.Unmarshal([]byte(jsonStr), &tokens)
105111
Expect(err).NotTo(HaveOccurred())
106-
Expect(tokens).To(Equal([]string{"Hello", "world", "!"}))
112+
Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"}))
107113

108114
})
109115

@@ -136,4 +142,41 @@ var _ = Describe("CustomDataset", func() {
136142
Expect(err).To(HaveOccurred())
137143
Expect(err.Error()).To(ContainSubstring("incorrect type"))
138144
})
145+
146+
It("should return correct prompt hash in bytes", func() {
147+
// b't\xbf\x14\xc0\x9c\x03\x83!\xcb\xa3\x97\x17\xda\xe1\xdcs(#\xaeJ\xbd\x8e\x15YY6v)\xa3\xc1\t\xa8'
148+
expectedHashBytes := []byte{0x74, 0xbf, 0x14, 0xc0, 0x9c, 0x03, 0x83, 0x21, 0xcb, 0xa3, 0x97, 0x17, 0xda, 0xe1, 0xdc, 0x73, 0x28, 0x23, 0xae, 0x4a, 0xbd, 0x8e, 0x15, 0x59, 0x59, 0x36, 0x76, 0x29, 0xa3, 0xc1, 0x09, 0xa8}
149+
150+
req := &openaiserverapi.TextCompletionRequest{
151+
Prompt: testPrompt,
152+
}
153+
154+
hashBytes := dataset.GetPromptHash(req)
155+
Expect(hashBytes).To(Equal(expectedHashBytes))
156+
})
157+
158+
It("should return correct prompt hash in hex", func() {
159+
expectedHashHex := "74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8"
160+
161+
req := &openaiserverapi.TextCompletionRequest{
162+
Prompt: testPrompt,
163+
}
164+
165+
hashBytes := dataset.GetPromptHash(req)
166+
hashHex := dataset.GetPromptHashHex(hashBytes)
167+
Expect(hashHex).To(Equal(expectedHashHex))
168+
})
169+
170+
It("should return tokens for existing prompt", func() {
171+
err := dataset.Init(validDBPath, "", "")
172+
Expect(err).NotTo(HaveOccurred())
173+
174+
req := &openaiserverapi.TextCompletionRequest{
175+
Prompt: testPrompt,
176+
}
177+
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
178+
Expect(err).NotTo(HaveOccurred())
179+
Expect(finishReason).To(Equal(StopFinishReason))
180+
Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"}))
181+
})
139182
})

pkg/dataset/dataset.go

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -288,21 +288,23 @@ func (d *BaseDataset) Close() error {
288288
return nil
289289
}
290290

291+
func (d *BaseDataset) echo(req openaiserverapi.CompletionRequest) ([]string, string, error) {
292+
nMaxTokens := d.extractMaxTokens(req)
293+
prompt, err := d.extractPrompt(req)
294+
if err != nil {
295+
return nil, "", err
296+
}
297+
tokens, finishReason := EchoResponseTokens(nMaxTokens, prompt)
298+
return tokens, finishReason, nil
299+
}
300+
291301
// GetTokens returns tokens and finishReason for the given request and mode (echo or random)
292302
func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) {
293-
nMaxTokens := d.extractMaxTokens(req)
294303
if mode == common.ModeEcho {
295-
prompt, err := d.extractPrompt(req)
296-
if err != nil {
297-
return nil, "", err
298-
}
299-
tokens, finishReason := EchoResponseTokens(nMaxTokens, prompt)
300-
return tokens, finishReason, nil
304+
return d.echo(req)
301305
}
302-
303-
nTokensToGen, finishReason := howManyTokensToGen(nMaxTokens, req.GetIgnoreEOS())
304-
tokens, err := d.GenerateTokens(req, nTokensToGen)
305-
return tokens, finishReason, err
306+
nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS())
307+
return GenPresetRandomTokens(nTokensToGen), finishReason, nil
306308
}
307309

308310
// extractMaxTokens extracts the max tokens from the request
@@ -328,10 +330,3 @@ func (d *BaseDataset) extractPrompt(req openaiserverapi.CompletionRequest) (stri
328330
}
329331
return "", errors.New("unknown request type")
330332
}
331-
332-
// GenerateTokens generates random tokens for the required number of tokens
333-
// other dataset types should override this function
334-
func (d *BaseDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) {
335-
tokens := GenPresetRandomTokens(nTokens)
336-
return tokens, nil
337-
}

pkg/dataset/dataset_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ var _ = Describe("Dataset", Ordered, func() {
6969
Expect(finishReason).To(Equal(StopFinishReason))
7070
}
7171
})
72-
72+
7373
It("should return long text", func() {
7474
// return required number of tokens although it is higher than ResponseLenMax
7575
maxCompletionTokens := int64(ResponseLenMax * 5)

pkg/llm-d-inference-sim/simulator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ func (s *VllmSimulator) initDataset() error {
245245
}
246246

247247
if s.config.Dataset.Path == "" && s.config.Dataset.Url == "" && s.config.Dataset.SavePath == "" {
248-
s.logger.Info("No dataset provided, will generate random responses")
248+
s.logger.Info("No dataset provided, will generate random responses from preset text")
249249
s.dataset = randDataset
250250
} else {
251251
s.logger.Info("Custom dataset configuration detected")

pkg/openai-server-api/request.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ type CompletionRequest interface {
6565
// when the field is true, the prefill phase should be done on remote pod,
6666
// whereas decode phase is done on local pod, thus this is a decode request
6767
IsDoRemotePrefill() bool
68+
// GetFullPrompt returns the full prompt including system and user prompts
69+
GetFullPrompt() string
6870
}
6971

7072
// BaseCompletionRequest contains base completion request related information
@@ -236,6 +238,21 @@ func (req *ChatCompletionRequest) GetLastUserMsg() string {
236238
return ""
237239
}
238240

241+
func (req *ChatCompletionRequest) GetFullPrompt() string {
242+
prompt := ""
243+
for _, msg := range req.Messages {
244+
switch msg.Role {
245+
case RoleUser:
246+
prompt += "### user:\n" + msg.Content.Raw + "\n"
247+
case RoleAssistant:
248+
prompt += "### assistant:\n" + msg.Content.Raw + "\n"
249+
default:
250+
prompt += "### unknown:\n" + msg.Content.Raw + "\n"
251+
}
252+
}
253+
return prompt
254+
}
255+
239256
// v1/completion
240257
// TextCompletionRequest defines structure of /completion request
241258
type TextCompletionRequest struct {
@@ -270,3 +287,7 @@ func (c *TextCompletionRequest) GetToolChoice() string {
270287
func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 {
271288
return c.MaxTokens
272289
}
290+
291+
func (t *TextCompletionRequest) GetFullPrompt() string {
292+
return "### user:\n" + t.Prompt + "\n"
293+
}

0 commit comments

Comments
 (0)