Skip to content

Commit 39a9d24

Browse files
committed
fix dataset tests
Signed-off-by: Qifan Deng <[email protected]>
1 parent c56728a commit 39a9d24

File tree

7 files changed

+164
-42
lines changed

7 files changed

+164
-42
lines changed

pkg/dataset/custom_dataset.go

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,18 @@ import (
2828
"os"
2929
"os/signal"
3030
"path/filepath"
31+
"strconv"
3132
"syscall"
3233
"time"
3334

3435
"github.com/go-logr/logr"
3536
"github.com/google/uuid"
37+
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
3638
_ "github.com/mattn/go-sqlite3"
3739
)
3840

3941
type CustomDataset struct {
40-
Dataset
42+
BaseDataset
4143
db *sql.DB
4244
hasWarned bool
4345
}
@@ -53,7 +55,7 @@ const (
5355
nGenTokensColType = "INTEGER"
5456
)
5557

56-
func (d CustomDataset) downloadDataset(url string, savePath string) error {
58+
func (d *CustomDataset) downloadDataset(url string, savePath string) error {
5759
// Set up signal handling for Ctrl+C (SIGINT)
5860
ctx, cancel := context.WithCancel(context.Background())
5961
defer cancel()
@@ -184,7 +186,7 @@ func (pr *progressReader) logProgress(pct int) {
184186
}
185187
}
186188

187-
func (d CustomDataset) verifyDB() error {
189+
func (d *CustomDataset) verifyDB() error {
188190
rows, err := d.db.Query("PRAGMA table_info(" + tableName + ");")
189191
if err != nil {
190192
return fmt.Errorf("failed to query table info for `%s`: %w", tableName, err)
@@ -234,7 +236,7 @@ func (d CustomDataset) verifyDB() error {
234236
return nil
235237
}
236238

237-
func (d CustomDataset) getRecordsCount() (int, error) {
239+
func (d *CustomDataset) getRecordsCount() (int, error) {
238240
var count int
239241
err := d.db.QueryRow("SELECT COUNT(" + promptHashCol + ") FROM " + tableName + ";").Scan(&count)
240242
if err != nil {
@@ -243,7 +245,7 @@ func (d CustomDataset) getRecordsCount() (int, error) {
243245
return count, nil
244246
}
245247

246-
func (d CustomDataset) connectToDB(path string) error {
248+
func (d *CustomDataset) connectToDB(path string) error {
247249
if d.db != nil {
248250
err := d.db.Close()
249251
if err != nil {
@@ -277,7 +279,7 @@ func (d CustomDataset) connectToDB(path string) error {
277279
return nil
278280
}
279281

280-
func (d CustomDataset) Init(path string, url string, savePath string) error {
282+
func (d *CustomDataset) Init(path string, url string, savePath string) error {
281283
d.hasWarned = false
282284
if path != "" {
283285
return d.connectToDB(path)
@@ -312,7 +314,7 @@ func (d CustomDataset) Init(path string, url string, savePath string) error {
312314
return errors.New("no dataset path or url provided")
313315
}
314316

315-
func (d CustomDataset) Close() error {
317+
func (d *CustomDataset) Close() error {
316318
if d.db != nil {
317319
return d.db.Close()
318320
}
@@ -336,11 +338,11 @@ func unmarshalAllRecords(rows *sql.Rows) ([][]string, error) {
336338
return tokensList, nil
337339
}
338340

339-
func (d CustomDataset) getRandomTokens(n_gen_tokens int) []string {
340-
return nil
341+
func (d *CustomDataset) getRandomTokens(n_gen_tokens int) []string {
342+
return []string{"<|random_tokens|>", strconv.Itoa(n_gen_tokens)}
341343
}
342344

343-
func (d *CustomDataset) GetTokens(prompt string, n_gen_tokens int) []string {
345+
func (d *CustomDataset) readTokensFromDB(prompt string, n_gen_tokens int) []string {
344346
promptHash := uuid.NewSHA1(uuid.NameSpaceOID, []byte(prompt)).NodeID()
345347
rows, err := d.db.Query("SELECT "+genTokensCol+" FROM "+tableName+" WHERE "+promptHashCol+" = ?;", promptHash)
346348
if err != nil {
@@ -369,3 +371,8 @@ func (d *CustomDataset) GetTokens(prompt string, n_gen_tokens int) []string {
369371
randIndex := rand.Intn(len(tokensList))
370372
return tokensList[randIndex]
371373
}
374+
375+
func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) {
376+
tokens := d.readTokensFromDB("", nTokens)
377+
return tokens, nil
378+
}

pkg/dataset/custom_dataset_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import (
2727
_ "github.com/mattn/go-sqlite3"
2828
)
2929

30-
var _ = Describe("Dataset", func() {
30+
var _ = Describe("CustomDataset", func() {
3131
var (
3232
dataset *CustomDataset
3333
file_folder string
@@ -42,7 +42,7 @@ var _ = Describe("Dataset", func() {
4242

4343
BeforeEach(func() {
4444
dataset = &CustomDataset{
45-
Dataset: Dataset{
45+
BaseDataset: BaseDataset{
4646
Logger: logr.Discard(),
4747
},
4848
}

pkg/dataset/dataset.go

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ var chatCompletionFakeResponses = []string{
6868
`Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`,
6969
}
7070

71+
type Dataset interface {
72+
// Init initializes the dataset using configs
73+
Init(path string, url string, savePath string) error
74+
// Close closes the dataset
75+
Close() error
76+
// GetTokens returns tokens for the given request and mode (echo or random)
77+
GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error)
78+
}
79+
7180
func init() {
7281
cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities))
7382
sum := 0.0
@@ -267,19 +276,20 @@ func EchoResponseTokens(maxCompletionTokens *int64, text string) ([]string, stri
267276
return tokens[0:*maxCompletionTokens], LengthFinishReason
268277
}
269278

270-
type Dataset struct {
279+
type BaseDataset struct {
271280
Logger logr.Logger
272281
}
273282

274-
func (d *Dataset) Init(path string, url string, savePath string) error {
283+
func (d *BaseDataset) Init(path string, url string, savePath string) error {
275284
return nil
276285
}
277286

278-
func (d *Dataset) Close() error {
287+
func (d *BaseDataset) Close() error {
279288
return nil
280289
}
281290

282-
func (d *Dataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) {
291+
// GetTokens returns tokens and finishReason for the given request and mode (echo or random)
292+
func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) {
283293
nMaxTokens := d.extractMaxTokens(req)
284294
if mode == common.ModeEcho {
285295
prompt, err := d.extractPrompt(req)
@@ -295,7 +305,10 @@ func (d *Dataset) GetTokens(req openaiserverapi.CompletionRequest, mode string)
295305
return tokens, finishReason, err
296306
}
297307

298-
func (d *Dataset) extractMaxTokens(req openaiserverapi.CompletionRequest) *int64 {
308+
// extractMaxTokens extracts the max tokens from the request
309+
// for chat completion - max_completion_tokens field is used
310+
// for text completion - max_tokens field is used
311+
func (d *BaseDataset) extractMaxTokens(req openaiserverapi.CompletionRequest) *int64 {
299312
if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok {
300313
return chatReq.GetMaxCompletionTokens()
301314
} else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok {
@@ -304,7 +317,10 @@ func (d *Dataset) extractMaxTokens(req openaiserverapi.CompletionRequest) *int64
304317
return nil
305318
}
306319

307-
func (d *Dataset) extractPrompt(req openaiserverapi.CompletionRequest) (string, error) {
320+
// extractPrompt extracts the prompt from the request
321+
// for chat completion - the last user message is used as the prompt
322+
// for text completion - the prompt field is used
323+
func (d *BaseDataset) extractPrompt(req openaiserverapi.CompletionRequest) (string, error) {
308324
if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok {
309325
return chatReq.GetLastUserMsg(), nil
310326
} else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok {
@@ -313,7 +329,9 @@ func (d *Dataset) extractPrompt(req openaiserverapi.CompletionRequest) (string,
313329
return "", errors.New("unknown request type")
314330
}
315331

316-
func (d *Dataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) {
332+
// GenerateTokens generates random tokens for the required number of tokens
333+
// other dataset types should override this function
334+
func (d *BaseDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) {
317335
tokens := GenPresetRandomTokens(nTokens)
318336
return tokens, nil
319337
}

pkg/dataset/dataset_suite_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package dataset_test
2+
3+
import (
4+
"testing"
5+
6+
. "github.com/onsi/ginkgo/v2"
7+
. "github.com/onsi/gomega"
8+
)
9+
10+
func TestDataset(t *testing.T) {
11+
RegisterFailHandler(Fail)
12+
RunSpecs(t, "Dataset Suite")
13+
}

pkg/dataset/dataset_test.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package dataset
1919
import (
2020
"fmt"
2121
"strings"
22+
"time"
2223

2324
"github.com/go-logr/logr"
2425
"github.com/llm-d/llm-d-inference-sim/pkg/common"
@@ -27,30 +28,31 @@ import (
2728
. "github.com/onsi/gomega"
2829
)
2930

30-
var _ = Describe("Utils", Ordered, func() {
31+
var _ = Describe("Dataset", Ordered, func() {
3132
var (
32-
dataset *Dataset
33+
dataset *BaseDataset
3334
)
3435

36+
BeforeAll(func() {
37+
common.InitRandom(time.Now().UnixNano())
38+
})
3539
BeforeEach(func() {
36-
dataset = &Dataset{
40+
dataset = &BaseDataset{
3741
Logger: logr.Discard(),
3842
}
3943
})
4044

4145
Context("GetRandomTokens", func() {
46+
4247
It("should return complete text", func() {
43-
var n int64
44-
req := &openaiserverapi.ChatCompletionRequest{
45-
MaxTokens: &n,
46-
MaxCompletionTokens: &n,
47-
}
48+
req := &openaiserverapi.ChatCompletionRequest{}
4849
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
4950
Expect(err).ShouldNot(HaveOccurred())
5051
text := strings.Join(tokens, "")
5152
Expect(IsValidText(text)).To(BeTrue())
5253
Expect(finishReason).Should(Equal(StopFinishReason))
5354
})
55+
5456
It("should return short text", func() {
5557
maxCompletionTokens := int64(2)
5658
req := &openaiserverapi.ChatCompletionRequest{
@@ -67,6 +69,7 @@ var _ = Describe("Utils", Ordered, func() {
6769
Expect(finishReason).To(Equal(StopFinishReason))
6870
}
6971
})
72+
7073
It("should return long text", func() {
7174
// return required number of tokens although it is higher than ResponseLenMax
7275
maxCompletionTokens := int64(ResponseLenMax * 5)

pkg/llm-d-inference-sim/simulator.go

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ type VllmSimulator struct {
116116
pod string
117117
// tokenizer is currently used in kv-cache and in /tokenize
118118
tokenizer tokenization.Tokenizer
119-
// dataset is used for managing dataset files
120-
dataset *dataset.Dataset
119+
// dataset is used for token generation in responses
120+
dataset dataset.Dataset
121121
}
122122

123123
// New creates a new VllmSimulator instance with the given logger
@@ -216,18 +216,9 @@ func (s *VllmSimulator) startSim(ctx context.Context) error {
216216
go s.kvcacheHelper.Run(ctx)
217217
}
218218

219-
if s.config.Dataset.Path == "" && s.config.Dataset.Url == "" && s.config.Dataset.SavePath == "" {
220-
s.dataset = nil
221-
s.logger.Info("No dataset provided, will generate random responses")
222-
} else {
223-
dataset := &dataset.Dataset{
224-
Logger: s.logger,
225-
}
226-
err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath)
227-
if err != nil {
228-
return err
229-
}
230-
s.dataset = dataset
219+
err = s.initDataset()
220+
if err != nil {
221+
return fmt.Errorf("dataset initialization error: %w", err)
231222
}
232223

233224
// run request processing workers
@@ -239,13 +230,98 @@ func (s *VllmSimulator) startSim(ctx context.Context) error {
239230

240231
listener, err := s.newListener()
241232
if err != nil {
242-
return err
233+
s.logger.Error(err, "Failed to create listener")
234+
return fmt.Errorf("listener creation error: %w", err)
243235
}
244236

245237
// start the http server with context support
246238
return s.startServer(ctx, listener)
247239
}
248240

241+
func (s *VllmSimulator) initDataset() error {
242+
randDataset := &dataset.BaseDataset{
243+
Logger: s.logger,
244+
}
245+
246+
if s.config.Dataset.Path == "" && s.config.Dataset.Url == "" && s.config.Dataset.SavePath == "" {
247+
s.logger.Info("No dataset provided, will generate random responses")
248+
s.dataset = randDataset
249+
} else {
250+
s.logger.Info("Custom dataset configuration detected")
251+
s.dataset = &dataset.CustomDataset{
252+
BaseDataset: *randDataset,
253+
}
254+
}
255+
256+
if err := s.dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath); err != nil {
257+
return fmt.Errorf("dataset initialization error: %w", err)
258+
}
259+
return nil
260+
}
261+
262+
func (s *VllmSimulator) newListener() (net.Listener, error) {
263+
s.logger.Info("Server starting", "port", s.config.Port)
264+
listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", s.config.Port))
265+
if err != nil {
266+
return nil, err
267+
}
268+
return listener, nil
269+
}
270+
271+
// startServer starts http server on port defined in command line
272+
func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) error {
273+
r := fasthttprouter.New()
274+
275+
// support completion APIs
276+
r.POST("/v1/chat/completions", s.HandleChatCompletions)
277+
r.POST("/v1/completions", s.HandleTextCompletions)
278+
// supports /models API
279+
r.GET("/v1/models", s.HandleModels)
280+
// support load/unload of lora adapter
281+
r.POST("/v1/load_lora_adapter", s.HandleLoadLora)
282+
r.POST("/v1/unload_lora_adapter", s.HandleUnloadLora)
283+
// supports /metrics prometheus API
284+
r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.HandlerFor(s.registry, promhttp.HandlerOpts{})))
285+
// supports standard Kubernetes health and readiness checks
286+
r.GET("/health", s.HandleHealth)
287+
r.GET("/ready", s.HandleReady)
288+
r.POST("/tokenize", s.HandleTokenize)
289+
290+
server := fasthttp.Server{
291+
ErrorHandler: s.HandleError,
292+
Handler: r.Handler,
293+
Logger: s,
294+
}
295+
296+
// Start server in a goroutine
297+
serverErr := make(chan error, 1)
298+
go func() {
299+
s.logger.Info("HTTP server starting")
300+
serverErr <- server.Serve(listener)
301+
}()
302+
303+
// Wait for either context cancellation or server error
304+
select {
305+
case <-ctx.Done():
306+
s.logger.Info("Shutdown signal received, shutting down HTTP server gracefully")
307+
308+
// Gracefully shutdown the server
309+
if err := server.Shutdown(); err != nil {
310+
s.logger.Error(err, "Error during server shutdown")
311+
return err
312+
}
313+
314+
s.logger.Info("HTTP server stopped")
315+
return nil
316+
317+
case err := <-serverErr:
318+
if err != nil {
319+
s.logger.Error(err, "HTTP server failed")
320+
}
321+
return err
322+
}
323+
}
324+
249325
// Print prints to a log, implementation of fasthttp.Logger
250326
func (s *VllmSimulator) Printf(format string, args ...interface{}) {
251327
s.logger.Info("Server error", "msg", fmt.Sprintf(format, args...))

0 commit comments

Comments
 (0)