diff --git a/README.md b/README.md
index b162c7b..6d1a80b 100644
--- a/README.md
+++ b/README.md
@@ -362,3 +362,53 @@ curl -X POST http://localhost:8000/v1/chat/completions \
]
}'
```
+
+## Response generation
+
+The `/v1/completions` and `/v1/chat/completions` endpoints produce responses based on simulator configurations and the specific request parameters.
+
+### Echo mode
+In `echo` mode, responses always mirror the request content. Parameters `max_tokens`, `max_completions_tokens` or `ignore_eos` are ignored in this mode.
+
+### Random mode
+In `random` mode, the fields `max_tokens`, `max_completions_tokens` and `ignore_eos` from the request are used during response generation.
+
+#### Use predefined texts for response generation
+The simulator can generate responses from a predefined list of sentences.
+If `max_tokens` or`max_completions_tokens` is specified, the response length is caclulated using a histogramwith six buckets and the following probabilities: 20%, 30%, 20%, 5%, 10%, 15%.
+For a maximum length ≤ 120, bucket sizes are equal.
+For a maximum length > 120, all buckets except forth are of size 20;
+the forth bucket covers the remaining range.
+After buckets are set, response length is sampled according to these probabilities.
+
+
+Exmaples:
+max-len = 120: buckets are 1-20, 21-40, 41-60, 61-80, 81-100, 101-120.
+max-len = 200: buckets are 1-20, 21-40, 41-60, 61-160, 161-180, 181-200.
+
+If the maximum response length is not specified, it defaults to `-`.
+In this case, response length is sampled from a Gaussian distribution with mean 40 and standard deviation 20.
+
+
+After determining the response length:
+
+A random sentence from the predefined list is chosen and trimmed if it exceeds the required length.
+If the sentence is shorter, additional random sentences are concatenated until the required token count is met.
+
+If `ignore_eos` is true, the response always reaches the maximum allowed length.
+
+The finish_reason is set to LENGTH if the response length equals the maximum; otherwise, it is set to STOP.
+
+
+#### Use responses dataset for response generation
+If `dataset-url` is set in command line, the dataset is downloaded to the location specified by `dataset-path`.
+
+If a valid dataset exists in the `dataset-path`, it is used for response selection:
+The request prompt is hashed, and this value is matched against dataset entries.
+If all matches are longer, a random match is selected and then trimmed.
+
+If `ignore_eos` is true is true and no match meets the required length, the response is completed with random tokens from the predefined list.
+
+If the prompt hash is not present in the dataset, a random response of length ≤ maximum is selected;
+if all responses are longer, a random response is chosen and trimmed.
+
diff --git a/pkg/common/config.go b/pkg/common/config.go
index 7929b2b..bc51be9 100644
--- a/pkg/common/config.go
+++ b/pkg/common/config.go
@@ -660,6 +660,10 @@ func (c *Configuration) validate() error {
return errors.New("dataset-path is required when dataset-url is set")
}
+ if c.Mode == ModeEcho && (c.DatasetPath != "" || c.DatasetURL != "") {
+ return errors.New("dataset cannot be defined in echo mode")
+ }
+
return nil
}
diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go
index ac5a42e..37cb180 100644
--- a/pkg/common/config_test.go
+++ b/pkg/common/config_test.go
@@ -532,6 +532,12 @@ var _ = Describe("Simulator configuration", func() {
"--config", "../../manifests/config.yaml"},
expectedError: "fake metrics request-max-generation-tokens cannot contain negative values",
},
+ {
+ name: "invalid echo mode with dataset",
+ args: []string{"random", "--model", "test", "--dataset-path", "my/path",
+ "--mode", "echo"},
+ expectedError: "dataset cannot be defined in echo mode",
+ },
}
for _, test := range invalidTests {
diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go
index 557321d..dc6a552 100644
--- a/pkg/dataset/custom_dataset.go
+++ b/pkg/dataset/custom_dataset.go
@@ -19,17 +19,8 @@ package dataset
import (
"context"
"crypto/sha256"
- "database/sql"
"encoding/hex"
- "encoding/json"
"errors"
- "fmt"
- "io"
- "net/http"
- "os"
- "path/filepath"
- "strconv"
- "time"
"github.com/go-logr/logr"
"github.com/llm-d/llm-d-inference-sim/pkg/common"
@@ -39,469 +30,145 @@ import (
type CustomDataset struct {
BaseDataset
- db *sql.DB
- hasWarned bool
+ sqliteHelper *sqliteHelper
}
-// use constants for expected column names and types
-const (
- tableName = "llmd"
- idCol = "id"
- promptHashCol = "prompt_hash"
- genTokensCol = "gen_tokens"
- nGenTokensCol = "n_gen_tokens"
- idColType = "INTEGER"
- promptHashColType = "BLOB"
- genTokensColType = "JSON"
- nGenTokensColType = "INTEGER"
- progressLogTimeInterval = 5 * time.Second
- progressLogPercentInterval = 10
-)
-
-func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path string) error {
- folder := filepath.Dir(path)
- err := os.MkdirAll(folder, 0755)
- if err != nil {
- return fmt.Errorf("failed to create parent directory: %w", err)
- }
-
- if _, err := os.Stat(path); err == nil {
- // file already exists
- return errors.New("Dataset file already exists, should not download: " + path)
- }
-
- out, err := os.Create(path)
- if err != nil {
+func (d *CustomDataset) Init(ctx context.Context, logger logr.Logger, random *common.Random,
+ path string, useInMemory bool, maxModelLen int) error {
+ if err := d.BaseDataset.Init(ctx, logger, random, maxModelLen); err != nil {
return err
}
- defer func() {
- cerr := out.Close()
- if cerr != nil {
- d.logger.Error(cerr, "failed to close file after download")
- }
- }()
-
- d.logger.V(logging.INFO).Info("Using dataset-url", "dataset-url", url)
- resp, err := http.Get(url)
- if err != nil {
- return err
- }
- defer func() {
- cerr := resp.Body.Close()
- if cerr != nil {
- d.logger.Error(cerr, "failed to close response body after download")
- }
- }()
-
- if resp.StatusCode != http.StatusOK {
- return fmt.Errorf("bad status: %s", resp.Status)
- }
-
- // Progress reader with context
- pr := &progressReader{
- Reader: resp.Body,
- total: resp.ContentLength,
- logger: d.logger,
- ctx: ctx,
- startTime: time.Now(),
- }
-
- written, err := io.Copy(out, pr)
- if err != nil {
- // Remove incomplete file
- cerr := os.Remove(path)
- if cerr != nil {
- d.logger.Error(cerr, "failed to remove incomplete file after download")
- }
- // If context was cancelled, return a specific error
- if errors.Is(err, context.Canceled) {
- return errors.New("download cancelled by user")
- }
- return fmt.Errorf("failed to download file: %w", err)
- }
- // Check if file size is zero
- if written == 0 {
- cerr := os.Remove(path)
- if cerr != nil {
- d.logger.Error(cerr, "failed to remove empty file after download")
- }
- return errors.New("downloaded file is empty")
- }
-
- // Ensure file is fully flushed and closed before returning success
- if err := out.Sync(); err != nil {
- cerr := os.Remove(path)
- if cerr != nil {
- d.logger.Error(cerr, "failed to remove incomplete file after download")
- }
- return fmt.Errorf("failed to sync file: %w", err)
+ if path == "" {
+ return errors.New("no dataset path provided")
}
- return nil
+ d.sqliteHelper = newSqliteHelper(logger)
+ d.logger.V(logging.INFO).Info("Using dataset from", "path", path)
+ return d.sqliteHelper.connectToDB(path, useInMemory)
}
-// progressReader wraps an io.Reader and logs download progress.
-type progressReader struct {
- io.Reader
- total int64
- downloaded int64
- startTime time.Time
- lastPct int
- lastLogTime time.Time
- logger logr.Logger
- ctx context.Context
-}
-
-func (pr *progressReader) Read(p []byte) (int, error) {
- select {
- case <-pr.ctx.Done():
- return 0, pr.ctx.Err()
- default:
- }
- n, err := pr.Reader.Read(p)
- pr.downloaded += int64(n)
- if pr.total > 0 {
- pct := int(float64(pr.downloaded) * 100 / float64(pr.total))
- now := time.Now()
-
- timeSinceLastLog := now.Sub(pr.lastLogTime).Seconds()
- pctDiff := pct - pr.lastPct
-
- if timeSinceLastLog >= progressLogTimeInterval.Seconds() || (pctDiff >= progressLogPercentInterval && pct != pr.lastPct) {
- // progress will be shown every interval seconds or every interval percent of progress
- pr.logProgress(pct)
- pr.lastPct = pct
- pr.lastLogTime = now
- }
- }
- return n, err
+func (d *CustomDataset) getPromptHash(req openaiserverapi.CompletionRequest) []byte {
+ hashArray := sha256.Sum256([]byte(req.GetFullPrompt()))
+ return hashArray[:]
}
-func (pr *progressReader) logProgress(pct int) {
- elapsedTime := time.Since(pr.startTime).Seconds()
- speed := float64(pr.downloaded) / (1024 * 1024 * elapsedTime)
- remainingTime := float64(pr.total-pr.downloaded) / (float64(pr.downloaded) / elapsedTime)
- if pct != 100 {
- pr.logger.V(logging.INFO).Info("Dataset download progress", "%", pct, "speed (MB/s)", speed, "remaining time (s)", remainingTime)
- } else {
- pr.logger.V(logging.INFO).Info("Download completed", "average speed (MB/s)", speed, "total time (s)", elapsedTime)
- }
+func (d *CustomDataset) getPromptHashHex(hashBytes []byte) string {
+ return hex.EncodeToString(hashBytes)
}
-func (d *CustomDataset) verifyDB() error {
- rows, err := d.db.Query("PRAGMA table_info(" + tableName + ");")
- if err != nil {
- return fmt.Errorf("failed to query table info for `%s`: %w", tableName, err)
- }
- defer func() {
- if cerr := rows.Close(); cerr != nil {
- d.logger.Error(cerr, "failed to close rows after querying table info")
+// categorizeResponses receives list of responses tokens and maximum response length
+// categorize responses to three collections:
+// - shorter or equal length to maxLen
+// - exact maxLen length
+// - longer than maxLen
+func (d *CustomDataset) categorizeResponses(responses [][]string, maxLen int) (shorterOrEqLen [][]string, equalLen [][]string, longerLen [][]string) {
+ for _, respTokens := range responses {
+ switch {
+ case len(respTokens) == maxLen:
+ shorterOrEqLen = append(shorterOrEqLen, respTokens)
+ equalLen = append(equalLen, respTokens)
+ case len(respTokens) < maxLen:
+ shorterOrEqLen = append(shorterOrEqLen, respTokens)
+ default:
+ longerLen = append(longerLen, respTokens)
}
- }()
-
- expectedColumns := map[string]string{
- idCol: idColType,
- promptHashCol: promptHashColType,
- genTokensCol: genTokensColType,
- nGenTokensCol: nGenTokensColType,
}
-
- columnsFound := make(map[string]bool)
-
- var (
- columnName string
- columnType string
- cid int
- notnull int
- dfltValue interface{}
- pk int
- )
-
- for rows.Next() {
- err := rows.Scan(&cid, &columnName, &columnType, ¬null, &dfltValue, &pk)
- if err != nil {
- return fmt.Errorf("failed to scan table info row: %w", err)
- }
- if expectedType, exists := expectedColumns[columnName]; exists {
- if columnType != expectedType {
- return fmt.Errorf("column %s has incorrect type: expected %s, got %s", columnName, expectedType, columnType)
- }
- columnsFound[columnName] = true
- }
- }
-
- for col := range expectedColumns {
- if !columnsFound[col] {
- return fmt.Errorf("missing expected column in %s table: %s", tableName, col)
- }
- }
-
- return nil
+ return
}
-func (d *CustomDataset) getRecordsCount() (int, error) {
- var count int
- err := d.db.QueryRow("SELECT COUNT(" + promptHashCol + ") FROM " + tableName + ";").Scan(&count)
- if err != nil {
- return 0, fmt.Errorf("failed to query database: %w", err)
- }
- return count, nil
+// getRandomResponse returns a randomly selected element from the given array, array is not empty
+func (d *CustomDataset) getRandomResponse(responses [][]string) []string {
+ return responses[d.random.RandomInt(0, len(responses)-1)]
}
-func (d *CustomDataset) loadDatabaseInMemory(path string) error {
- d.logger.V(logging.INFO).Info("Loading database into memory...")
- start := time.Now()
-
- // Create in-memory database
- var err error
- d.db, err = sql.Open("sqlite3", ":memory:")
- if err != nil {
- return fmt.Errorf("failed to create in-memory database: %w", err)
- }
-
- // Use ATTACH to copy the database
- attachSQL := fmt.Sprintf("ATTACH DATABASE '%s' AS source", path)
- _, err = d.db.Exec(attachSQL)
- if err != nil {
- if closeErr := d.db.Close(); closeErr != nil {
- d.logger.Error(closeErr, "failed to close in-memory database after attach failure")
- }
- d.db = nil
- return fmt.Errorf("failed to attach source database: %w", err)
- }
-
- // Copy the table structure first
- _, err = d.db.Exec(`CREATE TABLE llmd (
- id INTEGER PRIMARY KEY,
- prompt_hash BLOB,
- gen_tokens JSON,
- n_gen_tokens INTEGER
- )`)
- if err != nil {
- if closeErr := d.db.Close(); closeErr != nil {
- d.logger.Error(closeErr, "failed to close in-memory database after create table failure")
- }
- d.db = nil
- return fmt.Errorf("failed to create table: %w", err)
- }
-
- // Copy the data
- _, err = d.db.Exec("INSERT INTO llmd SELECT * FROM source.llmd")
- if err != nil {
- if closeErr := d.db.Close(); closeErr != nil {
- d.logger.Error(closeErr, "failed to close in-memory database after copy failure")
- }
- d.db = nil
- return fmt.Errorf("failed to copy data: %w", err)
- }
-
- // Detach the source database
- _, err = d.db.Exec("DETACH DATABASE source")
- if err != nil {
- d.logger.Error(err, "failed to detach source database")
- }
-
- loadTime := time.Since(start)
- d.logger.V(logging.INFO).Info("Database loaded into memory", "load_time", loadTime.String())
- return nil
-}
-
-func (d *CustomDataset) connectToDB(path string, useInMemory bool) error {
- if d.db != nil {
- err := d.db.Close()
- if err != nil {
- d.logger.Error(err, "failed to close existing database connection")
- }
- d.db = nil
- }
- // check if file exists
- _, err := os.Stat(path)
- if err != nil {
- return fmt.Errorf("database file does not exist: %w", err)
- }
-
- if useInMemory {
- err = d.loadDatabaseInMemory(path)
- if err != nil {
- return err
+// GetTokens returns tokens and finishReason for the given request and mode (echo or random)
+// In echo mode the prompt is returned.
+// In random mode follow this steps:
+// Calculate maximum length of response (basedon max-tokens or max-completions-tokens or model-len)
+// If dataset contains responses for the given prompt, and there are responses with length <=
+// max response length - use random one from the list,
+// otherwise select random one from the longer responses and trim it as required
+// If no responses were found in the dataset for the given prompt,
+// get random record fromn the dataset with response length equal or lower than max response length,
+// if there is no records shorter/equal to max length - get random response from the dataset
+// and trim it to the required length
+// if ignore_eos=true the response always will have the max response len tokens, missing tokens
+// are randomly selected from the hard-coded collection
+func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) {
+ if mode == common.ModeEcho {
+ return d.getTokensInEchoMode(req)
+ }
+
+ maxResponseLen, _ := d.calculateResponseMaxLen(req)
+ responseTokens := []string{}
+
+ // get all records for the hashes prompt
+ promptHash := d.getPromptHash(req)
+ promptHashHex := d.getPromptHashHex(promptHash)
+ responses, err := d.sqliteHelper.getResponsesForPrompt(promptHashHex)
+ if err != nil {
+ return responseTokens, "", err
+ }
+
+ if len(responses) > 0 {
+ // has responses for the given request
+ d.logger.V(logging.TRACE).Info("Reponses were found in the dataset for the request's prompt")
+ shorterOrEqLenResponses, equalLenResponses, longerLenResponses := d.categorizeResponses(responses, maxResponseLen)
+
+ if req.GetIgnoreEOS() {
+ // must return response with exactly calculated max length
+ switch {
+ case len(equalLenResponses) > 0:
+ // has responses with required length - return randomly selected response
+ responseTokens = d.getRandomResponse(equalLenResponses)
+ case len(longerLenResponses) > 0:
+ // has responses longer than required - return randomly selected trimmed response
+ responseTokens = d.getRandomResponse(longerLenResponses)[:maxResponseLen]
+ default:
+ // all responses are shorter than required, select randomly and pad with random tokens
+ responseTokens = d.getRandomResponse(shorterOrEqLenResponses)
+ responseTokens = append(responseTokens, d.generatePresetRandomTokens(maxResponseLen-len(responseTokens))...)
+ }
+ } else {
+ // has responses for the request, return response shorter or equal to the maxReponsesLen
+ // finishReason = common.LengthFinishReason
+ if len(shorterOrEqLenResponses) > 0 {
+ // has responses shorter or equal length than required - return randomly selected response
+ responseTokens = d.getRandomResponse(shorterOrEqLenResponses)
+ } else {
+ // all responses are longer than required, use randomly sleected trimmed response
+ responseTokens = d.getRandomResponse(longerLenResponses)[:maxResponseLen]
+ }
}
} else {
- // Use file-based database (original behavior)
- d.db, err = sql.Open("sqlite3", path)
+ // no resopnses for the given request
+ d.logger.V(logging.TRACE).Info("No reponses in the dataset for the request's prompt")
+ // try to find a random response with number of tokens <= tokens limit
+ randomResponses, err := d.sqliteHelper.getResponsesForLen(maxResponseLen)
if err != nil {
- return fmt.Errorf("failed to open database: %w", err)
+ return responseTokens, "", err
}
-
- // Check if there are other connections to the database
- _, err = d.db.Exec("BEGIN EXCLUSIVE;")
- if err != nil {
- if closeErr := d.db.Close(); closeErr != nil {
- d.logger.Error(closeErr, "failed to close database after failing to acquire exclusive lock")
+ if len(randomResponses) == 0 {
+ // failed to get response with number of tokens <= tokensLimit, get response with any number of tokens
+ randomResponses, err = d.sqliteHelper.getRandomResponse()
+ if err != nil {
+ return responseTokens, "", err
}
- d.db = nil
- return fmt.Errorf("database is locked or has other active connections: %w", err)
- }
- }
-
- err = d.verifyDB()
- if err != nil {
- return fmt.Errorf("failed to verify database: %w", err)
- }
-
- count, err := d.getRecordsCount()
- if err != nil {
- d.logger.Error(err, "failed to get records count")
- return fmt.Errorf("failed to query database: %w", err)
- }
-
- if useInMemory {
- d.logger.V(logging.INFO).Info("In-memory database connected successfully", "path", path, "records count", count)
- } else {
- d.logger.V(logging.INFO).Info("Database connected successfully", "path", path, "records count", count)
- }
- return nil
-}
-
-func (d *CustomDataset) Init(ctx context.Context, logger logr.Logger, path string, url string, useInMemory bool) error {
- d.logger = logger
- if path == "" {
- return errors.New("no dataset path provided")
- }
- d.hasWarned = false
- if url == "" {
- d.logger.V(logging.INFO).Info("Using dataset from", "path", path)
- return d.connectToDB(path, useInMemory)
- }
- _, err := os.Stat(path)
- if err != nil {
- // file does not exist, download it
- err = d.downloadDataset(ctx, url, path)
- if err != nil {
- // if the file is created but incomplete, remove it
- if _, statErr := os.Stat(path); statErr == nil {
- cerr := os.Remove(path)
- if cerr != nil {
- d.logger.Error(cerr, "failed to remove incomplete file after download")
- }
+ if len(randomResponses) == 0 {
+ // shouldn't happen
+ return responseTokens, "", errors.New("Dataset is empty")
}
- return fmt.Errorf("failed to download dataset: %w", err)
}
- }
- d.logger.V(logging.INFO).Info("Using dataset path", "dataset-path", path)
-
- return d.connectToDB(path, useInMemory)
-}
-
-func (d *CustomDataset) Close() error {
- // Release db lock (only for file-based databases)
- _, err := d.db.Exec("ROLLBACK;")
- if err != nil {
- if cerr := d.db.Close(); cerr != nil {
- d.logger.Error(cerr, "failed to close database after failing to acquire exclusive lock")
+ // if response has too much tokens, trim it
+ if len(randomResponses[0]) > maxResponseLen {
+ responseTokens = randomResponses[0][:maxResponseLen]
}
- d.db = nil
- return fmt.Errorf("failed to release exclusive lock: %w", err)
}
- if d.db != nil {
- return d.db.Close()
- }
- return nil
-}
-
-func unmarshalAllRecords(rows *sql.Rows) ([][]string, error) {
- var tokensList [][]string
- for rows.Next() {
- var tokensJSON string
- if err := rows.Scan(&tokensJSON); err != nil {
- return nil, fmt.Errorf("failed to scan row: %w", err)
- }
-
- var tokens []string
- if err := json.Unmarshal([]byte(tokensJSON), &tokens); err != nil {
- return nil, fmt.Errorf("failed to unmarshal tokens JSON: %w", err)
- }
- tokensList = append(tokensList, tokens)
+ finishReason := common.StopFinishReason
+ if len(responseTokens) == maxResponseLen {
+ finishReason = common.LengthFinishReason
}
- return tokensList, nil
-}
-func (d *CustomDataset) GetPromptHash(req openaiserverapi.CompletionRequest) []byte {
- hashArray := sha256.Sum256([]byte(req.GetFullPrompt()))
- return hashArray[:]
-}
-
-func (d *CustomDataset) GetPromptHashHex(hashBytes []byte) string {
- return hex.EncodeToString(hashBytes)
-}
-
-// GetTokens returns tokens and finishReason for the given request and mode (echo or random)
-func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string, random *common.Random) ([]string, string, error) {
- if mode == common.ModeEcho {
- return d.echo(req)
- }
- nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS(), random)
- tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason, random)
- return tokens, finishReason, err
-}
-
-func (d *CustomDataset) query(query string, nTokens int, random *common.Random) ([][]string, error) {
- rows, err := d.db.Query(query)
- if err != nil {
- if !d.hasWarned {
- d.logger.Error(err, "failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.")
- d.hasWarned = true
- }
- return [][]string{GenPresetRandomTokens(random, nTokens)}, nil
- }
- defer func() {
- if cerr := rows.Close(); cerr != nil {
- d.logger.Error(cerr, "failed to close rows after query")
- }
- }()
- return unmarshalAllRecords(rows)
-}
-
-func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int, finishReason string,
- random *common.Random) ([]string, error) {
- // query by prompt hash first
- promptHash := d.GetPromptHash(req)
- promptHashHex := d.GetPromptHashHex(promptHash)
- query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';"
- tokensList, err := d.query(query, nTokens, random)
-
- // filter out results according to finish reason
- var filteredTokensList [][]string
- if finishReason != LengthFinishReason && finishReason != StopFinishReason {
- d.logger.Error(errors.New("unknown finish reason"), "unexpected finish reason", "reason", finishReason)
- }
- for _, tokens := range tokensList {
- if finishReason == StopFinishReason && len(tokens) <= nTokens {
- filteredTokensList = append(filteredTokensList, tokens)
- } else if finishReason == LengthFinishReason && len(tokens) == nTokens {
- filteredTokensList = append(filteredTokensList, tokens)
- }
- }
- tokensList = filteredTokensList
-
- if err != nil || len(filteredTokensList) == 0 {
- switch finishReason {
- case LengthFinishReason:
- query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";"
- tokensList, err = d.query(query, nTokens, random)
- case StopFinishReason:
- query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "<=" + strconv.Itoa(nTokens) + ";"
- tokensList, err = d.query(query, nTokens, random)
- }
- }
-
- if err != nil || len(tokensList) == 0 {
- // if both queries fail or return no results, generate random tokens
- return GenPresetRandomTokens(random, nTokens), nil
- }
- if d.hasWarned {
- d.hasWarned = false
- }
- randIndex := random.RandomInt(0, len(tokensList)-1)
- return tokensList[randIndex], nil
+ return responseTokens, finishReason, nil
}
diff --git a/pkg/dataset/custom_dataset_downloader.go b/pkg/dataset/custom_dataset_downloader.go
new file mode 100644
index 0000000..4a1d766
--- /dev/null
+++ b/pkg/dataset/custom_dataset_downloader.go
@@ -0,0 +1,181 @@
+/*
+Copyright 2025 The llm-d-inference-sim Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package dataset
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/go-logr/logr"
+ "github.com/llm-d/llm-d-inference-sim/pkg/common/logging"
+)
+
+type CustomDatasetDownloader struct {
+ logger logr.Logger
+}
+
+const (
+ progressLogTimeInterval = 5 * time.Second
+ progressLogPercentInterval = 10
+)
+
+// progressReader wraps an io.Reader and logs download progress.
+type progressReader struct {
+ io.Reader
+ total int64
+ downloaded int64
+ startTime time.Time
+ lastPct int
+ lastLogTime time.Time
+ logger logr.Logger
+ ctx context.Context
+}
+
+func NewDsDownloader(logger logr.Logger) *CustomDatasetDownloader {
+ return &CustomDatasetDownloader{logger: logger}
+}
+
+// DownloadDataset downloads dataset from the given url and stores it to the given path
+func (d *CustomDatasetDownloader) DownloadDataset(ctx context.Context, url string, path string) error {
+ folder := filepath.Dir(path)
+ err := os.MkdirAll(folder, 0755)
+ if err != nil {
+ return fmt.Errorf("failed to create parent directory: %w", err)
+ }
+
+ if _, err := os.Stat(path); err == nil {
+ // file already exists
+ d.logger.V(logging.INFO).Info("Dataset file already exists, should not download: " + path)
+ return nil
+ }
+
+ out, err := os.Create(path)
+ if err != nil {
+ return err
+ }
+ defer func() {
+ cerr := out.Close()
+ if cerr != nil {
+ d.logger.Error(cerr, "failed to close file after download")
+ err = errors.Join(err, cerr)
+ }
+ }()
+
+ d.logger.V(logging.INFO).Info("Using dataset-url", "dataset-url", url)
+ resp, err := http.Get(url)
+ if err != nil {
+ return err
+ }
+ defer func() {
+ cerr := resp.Body.Close()
+ if cerr != nil {
+ d.logger.Error(cerr, "failed to close response body after download")
+ err = errors.Join(err, cerr)
+ }
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("dataset download bad status: %s", resp.Status)
+ }
+
+ // Progress reader with context
+ pr := &progressReader{
+ Reader: resp.Body,
+ total: resp.ContentLength,
+ logger: d.logger,
+ ctx: ctx,
+ startTime: time.Now(),
+ }
+
+ written, err := io.Copy(out, pr)
+ if err != nil {
+ // Remove incomplete file
+ cerr := os.Remove(path)
+ if cerr != nil {
+ d.logger.Error(cerr, "failed to remove incomplete file after download")
+ }
+ // If context was cancelled, return a specific error
+ if errors.Is(err, context.Canceled) {
+ return errors.New("download cancelled by user")
+ }
+ return fmt.Errorf("failed to download file: %w", err)
+ }
+ // Check if file size is zero
+ if written == 0 {
+ cerr := os.Remove(path)
+ if cerr != nil {
+ d.logger.Error(cerr, "failed to remove empty file after download")
+ }
+ return errors.New("downloaded file is empty")
+ }
+
+ // Ensure file is fully flushed and closed before returning success
+ if err := out.Sync(); err != nil {
+ cerr := os.Remove(path)
+ if cerr != nil {
+ d.logger.Error(cerr, "failed to remove incomplete file after download")
+ }
+ return fmt.Errorf("failed to sync file: %w", err)
+ }
+
+ d.logger.V(logging.INFO).Info("Downloaded dataset from %s, stored in %s\n", url, path)
+ return nil
+}
+
+func (pr *progressReader) Read(p []byte) (int, error) {
+ select {
+ case <-pr.ctx.Done():
+ return 0, pr.ctx.Err()
+ default:
+ }
+ n, err := pr.Reader.Read(p)
+ pr.downloaded += int64(n)
+ if pr.total > 0 {
+ pct := int(float64(pr.downloaded) * 100 / float64(pr.total))
+ now := time.Now()
+
+ timeSinceLastLog := now.Sub(pr.lastLogTime).Seconds()
+ pctDiff := pct - pr.lastPct
+
+ if timeSinceLastLog >= progressLogTimeInterval.Seconds() || (pctDiff >= progressLogPercentInterval && pct != pr.lastPct) {
+ // progress will be shown every interval seconds or every interval percent of progress
+ pr.logProgress(pct)
+ pr.lastPct = pct
+ pr.lastLogTime = now
+ }
+ }
+ return n, err
+}
+
+func (pr *progressReader) logProgress(pct int) {
+ elapsedTime := time.Since(pr.startTime).Seconds()
+ speed := float64(pr.downloaded) / (1024 * 1024 * elapsedTime)
+ remainingTime := float64(pr.total-pr.downloaded) / (float64(pr.downloaded) / elapsedTime)
+ if pct != 100 {
+ pr.logger.V(logging.INFO).Info("Dataset download progress", "%", pct, "speed (MB/s)", speed, "remaining time (s)", remainingTime)
+ fmt.Println("Dataset download progress", "%", pct, "speed (MB/s)", speed, "remaining time (s)", remainingTime)
+ } else {
+ pr.logger.V(logging.INFO).Info("Download completed", "average speed (MB/s)", speed, "total time (s)", elapsedTime)
+ fmt.Println("Download completed", "average speed (MB/s)", speed, "total time (s)", elapsedTime)
+ }
+}
diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go
index 907a076..ef236ff 100644
--- a/pkg/dataset/custom_dataset_test.go
+++ b/pkg/dataset/custom_dataset_test.go
@@ -22,11 +22,13 @@ import (
"os"
"time"
+ "k8s.io/klog/v2"
+
"github.com/llm-d/llm-d-inference-sim/pkg/common"
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
. "github.com/onsi/ginkgo/v2"
+
. "github.com/onsi/gomega"
- "k8s.io/klog/v2"
_ "github.com/mattn/go-sqlite3"
)
@@ -37,7 +39,8 @@ const (
var _ = Describe("CustomDataset", Ordered, func() {
var (
- dataset *CustomDataset
+ sqliteHelper *sqliteHelper
+ dsDownloader *CustomDatasetDownloader
file_folder string
path string
validDBPath string
@@ -54,7 +57,8 @@ var _ = Describe("CustomDataset", Ordered, func() {
})
BeforeEach(func() {
- dataset = &CustomDataset{}
+ sqliteHelper = newSqliteHelper(klog.Background())
+ dsDownloader = NewDsDownloader(klog.Background())
file_folder = ".llm-d"
path = file_folder + "/test.sqlite3"
err := os.MkdirAll(file_folder, os.ModePerm)
@@ -67,15 +71,8 @@ var _ = Describe("CustomDataset", Ordered, func() {
pathToInvalidTypeDB = file_folder + "/test.invalid.type.sqlite3"
})
- AfterEach(func() {
- if dataset.db != nil {
- err := dataset.db.Close()
- Expect(err).NotTo(HaveOccurred())
- }
- })
-
It("should return error for invalid DB path", func() {
- err := dataset.connectToDB("/invalid/path/to/db.sqlite", false)
+ err := sqliteHelper.connectToDB("/invalid/path/to/db.sqlite", false)
Expect(err).To(HaveOccurred())
})
@@ -86,9 +83,8 @@ var _ = Describe("CustomDataset", Ordered, func() {
err = os.Remove(path)
Expect(err).NotTo(HaveOccurred())
}
-
url := "https://llm-d.ai"
- err = dataset.downloadDataset(context.Background(), url, path)
+ err = dsDownloader.DownloadDataset(context.Background(), url, path)
Expect(err).NotTo(HaveOccurred())
_, err = os.Stat(path)
Expect(err).NotTo(HaveOccurred())
@@ -98,22 +94,23 @@ var _ = Describe("CustomDataset", Ordered, func() {
It("should not download file from url", func() {
url := "https://256.256.256.256" // invalid url
- err := dataset.downloadDataset(context.Background(), url, path)
+ err := dsDownloader.DownloadDataset(context.Background(), url, path)
Expect(err).To(HaveOccurred())
})
It("should successfully init dataset", func() {
- err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false)
+ dataset := &CustomDataset{}
+ err := dataset.Init(context.Background(), klog.Background(), random, validDBPath, false, 1024)
Expect(err).NotTo(HaveOccurred())
- row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';")
+ row := dataset.sqliteHelper.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';")
var n_gen_tokens int
err = row.Scan(&n_gen_tokens)
Expect(err).NotTo(HaveOccurred())
Expect(n_gen_tokens).To(Equal(4))
var jsonStr string
- row = dataset.db.QueryRow("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';")
+ row = dataset.sqliteHelper.db.QueryRow("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';")
err = row.Scan(&jsonStr)
Expect(err).NotTo(HaveOccurred())
var tokens []string
@@ -121,33 +118,35 @@ var _ = Describe("CustomDataset", Ordered, func() {
Expect(err).NotTo(HaveOccurred())
Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"}))
+ err = dataset.sqliteHelper.db.Close()
+ Expect(err).NotTo(HaveOccurred())
})
It("should return error for non-existing DB path", func() {
- err := dataset.connectToDB(pathNotExist, false)
+ err := sqliteHelper.connectToDB(pathNotExist, false)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("database file does not exist"))
})
It("should return error for invalid DB file", func() {
- err := dataset.connectToDB(pathToInvalidDB, false)
+ err := sqliteHelper.connectToDB(pathToInvalidDB, false)
Expect(err).To(HaveOccurred())
})
It("should return error for DB with invalid table", func() {
- err := dataset.connectToDB(pathToInvalidTableDB, false)
+ err := sqliteHelper.connectToDB(pathToInvalidTableDB, false)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("failed to verify database"))
})
It("should return error for DB with invalid column", func() {
- err := dataset.connectToDB(pathToInvalidColumnDB, false)
+ err := sqliteHelper.connectToDB(pathToInvalidColumnDB, false)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("missing expected column"))
})
It("should return error for DB with invalid column type", func() {
- err := dataset.connectToDB(pathToInvalidTypeDB, false)
+ err := sqliteHelper.connectToDB(pathToInvalidTypeDB, false)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("incorrect type"))
})
@@ -159,8 +158,8 @@ var _ = Describe("CustomDataset", Ordered, func() {
req := &openaiserverapi.TextCompletionRequest{
Prompt: testPrompt,
}
-
- hashBytes := dataset.GetPromptHash(req)
+ dataset := &CustomDataset{}
+ hashBytes := dataset.getPromptHash(req)
Expect(hashBytes).To(Equal(expectedHashBytes))
})
@@ -170,52 +169,98 @@ var _ = Describe("CustomDataset", Ordered, func() {
req := &openaiserverapi.TextCompletionRequest{
Prompt: testPrompt,
}
-
- hashBytes := dataset.GetPromptHash(req)
- hashHex := dataset.GetPromptHashHex(hashBytes)
+ dataset := &CustomDataset{}
+ hashBytes := dataset.getPromptHash(req)
+ hashHex := dataset.getPromptHashHex(hashBytes)
Expect(hashHex).To(Equal(expectedHashHex))
})
It("should return tokens for existing prompt", func() {
- err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false)
+ dataset := &CustomDataset{}
+ err := dataset.Init(context.Background(), klog.Background(), random, validDBPath, false, 1024)
Expect(err).NotTo(HaveOccurred())
req := &openaiserverapi.TextCompletionRequest{
Prompt: testPrompt,
}
- tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random)
+ tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
Expect(err).NotTo(HaveOccurred())
- Expect(finishReason).To(Equal(StopFinishReason))
+ Expect(finishReason).To(Equal(common.StopFinishReason))
if len(tokens) >= 4 {
// The number of tokens to generate is random, and if it's less than 4
// we will not get these tokens
Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"}))
}
+ err = dataset.sqliteHelper.db.Close()
+ Expect(err).NotTo(HaveOccurred())
})
It("should return at most 2 tokens for existing prompt", func() {
- err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false)
+ dataset := &CustomDataset{}
+ err := dataset.Init(context.Background(), klog.Background(), random, validDBPath, false, 1024)
Expect(err).NotTo(HaveOccurred())
n := int64(2)
req := &openaiserverapi.TextCompletionRequest{
Prompt: testPrompt,
MaxTokens: &n,
}
- tokens, _, err := dataset.GetTokens(req, common.ModeRandom, random)
+ tokens, _, err := dataset.GetTokens(req, common.ModeRandom)
Expect(err).NotTo(HaveOccurred())
Expect(len(tokens)).To(BeNumerically("<=", 2))
+ err = dataset.sqliteHelper.db.Close()
+ Expect(err).NotTo(HaveOccurred())
})
It("should successfully init dataset with in-memory option", func() {
- err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", true)
+ dataset := &CustomDataset{}
+ err := dataset.Init(context.Background(), klog.Background(), random, validDBPath, true, 1024)
Expect(err).NotTo(HaveOccurred())
req := &openaiserverapi.TextCompletionRequest{
Prompt: testPrompt,
}
- tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random)
+ tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
Expect(err).NotTo(HaveOccurred())
- Expect(finishReason).To(Equal(StopFinishReason))
+ Expect(finishReason).To(Equal(common.StopFinishReason))
Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"}))
+ err = dataset.sqliteHelper.db.Close()
+ Expect(err).NotTo(HaveOccurred())
+ })
+})
+
+var _ = Describe("custom dataset for multiple simulators", Ordered, func() {
+ It("should not fail on custom datasets initialization", func() {
+ file_folder := ".llm-d"
+ validDBPath := file_folder + "/test.valid.sqlite3"
+
+ random1 := common.NewRandom(time.Now().UnixNano(), 8081)
+ dataset1 := &CustomDataset{}
+ err := dataset1.Init(context.Background(), klog.Background(), random1, validDBPath, false, 1024)
+ Expect(err).NotTo(HaveOccurred())
+
+ random2 := common.NewRandom(time.Now().UnixNano(), 8082)
+ dataset2 := &CustomDataset{}
+ err = dataset2.Init(context.Background(), klog.Background(), random2, validDBPath, false, 1024)
+ Expect(err).NotTo(HaveOccurred())
+ })
+})
+
+var _ = Describe("download custom dataset from HF", Ordered, func() {
+ // currently there is only one dataset which is too large
+ // one we will create a small sample dataset - restore this test
+ XIt("should download and save ds", func() {
+ url := "https://huggingface.co/datasets/hf07397/inference-sim-datasets/resolve/91ffa7aafdfd6b3b1af228a517edc1e8f22cd274/huggingface/ShareGPT_Vicuna_unfiltered/conversations.sqlite3"
+ downloader := NewDsDownloader(klog.Background())
+ tempFile := "./ds1.sqlite3"
+
+ if _, err := os.Stat(tempFile); err == nil {
+ err := os.Remove(tempFile)
+ Expect(err).NotTo(HaveOccurred())
+ }
+ err := downloader.DownloadDataset(context.Background(), url, tempFile)
+ Expect(err).NotTo(HaveOccurred())
+
+ err = os.Remove(tempFile)
+ Expect(err).NotTo(HaveOccurred())
})
})
diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go
index 15c737a..ec7bd6f 100644
--- a/pkg/dataset/dataset.go
+++ b/pkg/dataset/dataset.go
@@ -26,34 +26,8 @@ import (
_ "github.com/mattn/go-sqlite3"
)
-const (
- RoleAssistant = "assistant"
- RoleUser = "user"
-)
-
-const (
- ResponseLenMax = 128
- responseLenMean = 40
- responseLenStddev = 20
- stopFinishReasonProbability = 0.8
-
- StopFinishReason = "stop"
- LengthFinishReason = "length"
- ToolsFinishReason = "tool_calls"
- RemoteDecodeFinishReason = "remote_decode"
-)
-
-// this array defines the probabilities for the buckets to be used for the generation of number of tokens in response
-var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15}
-var cumulativeBucketsProbabilities []float64
-
-const (
- flexBucketIndex = 3
- maxFixedBucketSize = 20
-)
-
// list of responses to use in random mode for completion requests
-var chatCompletionFakeResponses = []string{
+var completionFakeResponses = []string{
`Testing@, #testing 1$ ,2%,3^, [4&*5], 6~, 7-_ + (8 : 9) / \ < > .`,
`Testing, testing 1,2,3.`,
`I am fine, how are you today?`,
@@ -68,55 +42,113 @@ var chatCompletionFakeResponses = []string{
}
type Dataset interface {
- // Init initializes the dataset using configs
- Init(ctx context.Context, logger logr.Logger, path string, url string, useInMemory bool) error
// Close closes the dataset
Close() error
// GetTokens returns tokens for the given request and mode (echo or random)
- GetTokens(req openaiserverapi.CompletionRequest, mode string, random *common.Random) ([]string, string, error)
+ GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error)
+}
+
+type BaseDataset struct {
+ logger logr.Logger
+ maxModelLen int
+ random *common.Random
+ histogramHelper *histogramHelper
+}
+
+func (d *BaseDataset) Init(ctx context.Context, logger logr.Logger, random *common.Random, maxModelLen int) error {
+ d.logger = logger
+ d.maxModelLen = maxModelLen
+ d.random = random
+ d.histogramHelper = newHistogramHelper(d.random)
+
+ return nil
+}
+
+func (d *BaseDataset) Close() error {
+ return nil
+}
+
+// GetTokens returns tokens and finishReason for the given request and mode (echo or random)
+func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) {
+ if mode == common.ModeEcho {
+ return d.getTokensInEchoMode(req)
+ }
+
+ maxRespTokens, isMaxTokensInReq := d.calculateResponseMaxLen(req)
+
+ numOfRespTokens := 0
+ finishReason := common.StopFinishReason
+
+ switch {
+ case req.GetIgnoreEOS():
+ // ignore_eos is true - response must have the maximum number of tokens
+ numOfRespTokens = maxRespTokens
+ case isMaxTokensInReq:
+ // max tokens is defined in the request - generate number of tokens in the response based on the histogram
+ numOfRespTokens = d.histogramHelper.getResponseLengthByHistogram(maxRespTokens)
+ default:
+ // no tokens limitation in the request - use gaussian with the mean (currently hard-coded)
+ numOfRespTokens = d.getRandomResponseLenByGaussian(maxRespTokens)
+ }
+
+ if numOfRespTokens == maxRespTokens {
+ // if response should be create with maximum number of tokens - finish reason will be 'length'
+ finishReason = common.LengthFinishReason
+ }
+
+ return d.generatePresetRandomTokens(numOfRespTokens), finishReason, nil
}
-func init() {
- cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities))
- sum := 0.0
+func (d *BaseDataset) getTokensInEchoMode(req openaiserverapi.CompletionRequest) ([]string, string, error) {
+ tokens := common.Tokenize(req.ExtractPrompt())
+ maxTokens := req.GetMaxCompletionTokens()
+ finishReason := common.StopFinishReason
- for i, val := range respLenBucketsProbabilities {
- sum += val
- cumulativeBucketsProbabilities[i] = sum
+ if maxTokens != nil && len(tokens) >= int(*maxTokens) {
+ finishReason = common.LengthFinishReason
}
+
+ return tokens, finishReason, nil
}
-// GetRandomResponseLen returns int in range [1, responseLenMax]
+// calculateResponseMaxLen - calculates maximum length of a response to be randomly chosen from the dataset
+// for the given request and the simulator configuration.
+// If max-tokens/max-completion-tokens is defined - use it,
+// otherwise use -
+// boolean returned value defines whether max tokens number was passed in the request
+func (d *BaseDataset) calculateResponseMaxLen(req openaiserverapi.CompletionRequest) (int, bool) {
+ maxTokens := req.GetMaxCompletionTokens()
+
+ if maxTokens != nil {
+ return int(*maxTokens), true
+ }
+
+ return d.maxModelLen - len(common.Tokenize(req.GetFullPrompt())), false
+}
+
+// getRandomResponseLenByDistribution returns int in range [1, responseLenMax]
// numbers are chosen according a gaussian distribution with mean responseLenMean, and standard deviation responseLenStddev
-func GetRandomResponseLen(random *common.Random) int {
+func (d *BaseDataset) getRandomResponseLenByGaussian(maxLen int) int {
for {
- val := random.RandomNorm(responseLenMean, responseLenStddev)
- if val >= 1 && val <= ResponseLenMax {
+ val := d.random.RandomNorm(responseLenMean, responseLenStddev)
+ if val >= 1 && val <= float64(maxLen) {
return int(math.Round(val))
}
// else reject and resample
}
}
-// GetRandomFinishReason returns finish reason with the probability for 'stop' as defined by stopFinishReasonProbability
-func GetRandomFinishReason(random *common.Random) string {
- if random.RandomFloat(0, 1) < stopFinishReasonProbability {
- return StopFinishReason
- }
- return LengthFinishReason
-}
-
-// GenPresetRandomTokens generates random tokens for the required number of tokens,
+// generatePresetRandomTokens generates random tokens for the required number of tokens,
// select randomly a sentence from chatCompletionFakeResponses,
// if number of tokens is lower than required - select another sentence,
-// continue until the required number of tokens is achieved
-func GenPresetRandomTokens(random *common.Random, numOfTokens int) []string {
+// continue until the required number of tokens is achieved,
+// returned exactly tokens
+func (d BaseDataset) generatePresetRandomTokens(numOfTokens int) []string {
allTokens := make([]string, 0)
for len(allTokens) < numOfTokens {
- index := random.RandomInt(0, len(chatCompletionFakeResponses)-1)
- // create tokens from text, splitting by spaces and special characters
- tokens := common.Tokenize(chatCompletionFakeResponses[index])
+ index := d.random.RandomInt(0, len(completionFakeResponses)-1)
+ tokens := common.Tokenize(completionFakeResponses[index])
remaining := numOfTokens - len(allTokens)
if len(tokens) > remaining {
@@ -134,171 +166,3 @@ func GenPresetRandomTokens(random *common.Random, numOfTokens int) []string {
return allTokens
}
-
-// howManyTokensToGen generates the number of tokens to be returned in a response, and the finish reason (see constants)
-// if maxCompletionTokens is defined
-// - currently, the generated number of words in the text will be equal to it value
-// - in future - need to find statistics about generated tokens distribution and return less tokens in part os requests
-// - finish reason will be chosen randomly from the collection (stop, length) with 80% for stop and 20% for length
-// if maxCompletionTokens is nil
-// - the response text's length is randomly chosen from the range [1, responseLenMax] according additional parameters
-// - finish reason is stop
-// if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens
-// - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined
-func howManyTokensToGen(maxCompletionTokens *int64, ignore_eos bool, random *common.Random) (int, string) {
- numOfTokens := 0
- finishReason := StopFinishReason
-
- // no max completion tokens, return text with random length
- if maxCompletionTokens == nil {
- numOfTokens = GetRandomResponseLen(random)
- } else {
- maxTokens := int(*maxCompletionTokens)
- if ignore_eos {
- numOfTokens = maxTokens
- finishReason = LengthFinishReason
- } else {
- // max tokens is defined - generate real length of the response based on it
- numOfTokens = getResponseLengthByHistogram(random, maxTokens)
- if numOfTokens == maxTokens {
- // if response should be create with maximum number of tokens - finish reason will be 'length'
- finishReason = LengthFinishReason
- }
- }
- }
-
- return numOfTokens, finishReason
-}
-
-// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets.
-// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities.
-// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value.
-// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens.
-// Other values define probabilities for the equally sized buckets.
-// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens]
-func getResponseLengthByHistogram(random *common.Random, maxTokens int) int {
- if maxTokens <= 1 {
- return maxTokens
- }
- // maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens]
- if maxTokens <= len(cumulativeBucketsProbabilities) {
- res := random.RandomInt(1, maxTokens)
- return res
- }
-
- r := random.RandomFloat(0, 1)
-
- // check if r is in the last bucket, then maxTokens should be returned
- if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] {
- return maxTokens
- }
-
- // determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use
- // initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1
- bucketIndex := len(cumulativeBucketsProbabilities) - 1
- for i, c := range cumulativeBucketsProbabilities {
- if r <= c {
- bucketIndex = i
- break
- }
- }
-
- // calculate the size of all of the buckets (except the special last bucket)
- start, end := calcBucketBoundaries(maxTokens, bucketIndex)
-
- // pick uniformly within the bucket’s range
- return random.RandomInt(start, end)
-}
-
-// calcBucketBoundaries calculates boundaries of a bucket with the given index.
-// Maximum size for equally sized buckets is defined by maxFixedBucketSize.
-// [maxFixedBucketSize*(number-of-buckets-1)+1] is the value of maxTokens for which
-// division to equally size buckets will give buckets with size maxFixedBucketSize.
-// If maxTokens is [maxFixedBucketSize*(number-of-buckets-1)+1] or less,
-// all buckets will be of equal size, except the last bucket, which contains only one value.
-// If maxTokens is higher than [maxFixedBucketSize*(number-of-buckets-1)+1],
-// and flexBucketIndex is valid (between 0 and number of buckets - 1) the buckets sizes will not be equal.
-// In this case, all buckets except the one at flexBucketIndex index will have size 20 (and the last is with size 1),
-// and the bucket at flexBucketIndex index will 'stretch' to cover the remaining range.
-func calcBucketBoundaries(maxTokens int, bucketIndex int) (start int, end int) {
- maxEquallyBucketsSz := maxFixedBucketSize*(len(cumulativeBucketsProbabilities)-1) + 1
-
- if maxTokens <= maxEquallyBucketsSz || flexBucketIndex < 0 || flexBucketIndex >= len(cumulativeBucketsProbabilities)-1 {
- // create equally size buckets
- // calculate the size of all of the buckets (except the special last bucket)
- bucketSize := float64(maxTokens-1) / float64(len(cumulativeBucketsProbabilities)-1)
- start = int(bucketSize*float64(bucketIndex)) + 1
- end = int(bucketSize * float64(bucketIndex+1))
- } else {
- // create non-equally sized buckets and find boundaries of the required bucket
- if bucketIndex < flexBucketIndex {
- // the relevant bucket is before the flex bucket, all buckets are of the same size (maxFixedBucketSize)
- // start is the minimum number in the required bucket
- start = maxFixedBucketSize*bucketIndex + 1
- end = maxFixedBucketSize * (bucketIndex + 1)
- } else {
- flexBucketSize := maxTokens - (maxFixedBucketSize * (len(cumulativeBucketsProbabilities) - 2))
-
- if bucketIndex == flexBucketIndex {
- // the relevant bucket is the flex bucket
- start = int(maxFixedBucketSize*float64(bucketIndex)) + 1
- end = maxFixedBucketSize*bucketIndex + flexBucketSize
- } else {
- // the relevant bucket is one of buckets after the flex bucket
- start = int(maxFixedBucketSize*float64(bucketIndex-1)) + flexBucketSize + 1
- end = maxFixedBucketSize*bucketIndex + flexBucketSize
- }
- }
- }
-
- // sometimes end could be maxTokens because of rounding, change the value to maxToken-1
- if end >= maxTokens {
- end = maxTokens - 1
- }
-
- return start, end
-}
-
-// EchoResponseTokens returns needed tokens, from a given text
-// considering max completion tokens if it is not nil, and a finish reason (stop or length)
-func EchoResponseTokens(maxCompletionTokens *int64, text string) ([]string, string) {
- tokens := common.Tokenize(text)
- // no max completion tokens, return entire text
- if maxCompletionTokens == nil {
- return tokens, StopFinishReason
- }
-
- if *maxCompletionTokens >= int64(len(tokens)) {
- return tokens, StopFinishReason
- }
- // return truncated text
- return tokens[0:*maxCompletionTokens], LengthFinishReason
-}
-
-type BaseDataset struct {
- logger logr.Logger
-}
-
-func (d *BaseDataset) Init(ctx context.Context, logger logr.Logger, path string, url string, useInMemory bool) error {
- d.logger = logger
- return nil
-}
-
-func (d *BaseDataset) Close() error {
- return nil
-}
-
-func (d *BaseDataset) echo(req openaiserverapi.CompletionRequest) ([]string, string, error) {
- tokens, finishReason := EchoResponseTokens(req.ExtractMaxTokens(), req.ExtractPrompt())
- return tokens, finishReason, nil
-}
-
-// GetTokens returns tokens and finishReason for the given request and mode (echo or random)
-func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string,
- random *common.Random) ([]string, string, error) {
- if mode == common.ModeEcho {
- return d.echo(req)
- }
- nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS(), random)
- return GenPresetRandomTokens(random, nTokensToGen), finishReason, nil
-}
diff --git a/pkg/dataset/dataset_test.go b/pkg/dataset/dataset_test.go
index 3e7974a..8efa170 100644
--- a/pkg/dataset/dataset_test.go
+++ b/pkg/dataset/dataset_test.go
@@ -17,39 +17,53 @@ limitations under the License.
package dataset
import (
+ "context"
"fmt"
"strings"
+
"time"
+ "github.com/go-logr/logr"
+ "github.com/go-logr/logr/funcr"
"github.com/llm-d/llm-d-inference-sim/pkg/common"
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
+func NewStdoutLogger() logr.Logger {
+ return funcr.New(func(prefix, args string) {
+ if prefix != "" {
+ fmt.Printf("%s: %s\n", prefix, args)
+ } else {
+ fmt.Println(args)
+ }
+ }, funcr.Options{})
+}
+
var _ = Describe("Dataset", Ordered, func() {
var (
dataset *BaseDataset
- random *common.Random
)
- BeforeAll(func() {
- random = common.NewRandom(time.Now().UnixNano(), 8080)
- })
+ createDataset := func() {
+ dataset = &BaseDataset{}
+ err := dataset.Init(context.Background(), NewStdoutLogger(), common.NewRandom(time.Now().UnixNano(), 8080), 1024)
+ Expect(err).ShouldNot(HaveOccurred())
+ }
BeforeEach(func() {
- dataset = &BaseDataset{}
+ createDataset()
})
Context("GetRandomTokens", func() {
-
It("should return complete text", func() {
req := &openaiserverapi.ChatCompletionRequest{}
- tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random)
+ tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
Expect(err).ShouldNot(HaveOccurred())
text := strings.Join(tokens, "")
Expect(IsValidText(text)).To(BeTrue())
- Expect(finishReason).Should(Equal(StopFinishReason))
+ Expect(finishReason).Should(Equal(common.StopFinishReason))
})
It("should return short text", func() {
@@ -57,35 +71,34 @@ var _ = Describe("Dataset", Ordered, func() {
req := &openaiserverapi.ChatCompletionRequest{
MaxCompletionTokens: &maxCompletionTokens,
}
- tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random)
+ tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
Expect(err).ShouldNot(HaveOccurred())
tokensCnt := int64(len(tokens))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
if tokensCnt == maxCompletionTokens {
- Expect(finishReason).To(Equal(LengthFinishReason))
+ Expect(finishReason).To(Equal(common.LengthFinishReason))
} else {
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
- Expect(finishReason).To(Equal(StopFinishReason))
+ Expect(finishReason).To(Equal(common.StopFinishReason))
}
})
It("should return long text", func() {
- // return required number of tokens although it is higher than ResponseLenMax
- maxCompletionTokens := int64(ResponseLenMax * 5)
+ maxCompletionTokens := int64(1000)
req := &openaiserverapi.ChatCompletionRequest{
MaxTokens: &maxCompletionTokens,
}
- tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random)
+ tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
Expect(err).ShouldNot(HaveOccurred())
tokensCnt := int64(len(tokens))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
text := strings.Join(tokens, "")
Expect(IsValidText(text)).To(BeTrue())
if tokensCnt == maxCompletionTokens {
- Expect(finishReason).To(Equal(LengthFinishReason))
+ Expect(finishReason).To(Equal(common.LengthFinishReason))
} else {
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
- Expect(finishReason).To(Equal(StopFinishReason))
+ Expect(finishReason).To(Equal(common.StopFinishReason))
}
})
@@ -96,11 +109,11 @@ var _ = Describe("Dataset", Ordered, func() {
MaxTokens: &n,
}
req.SetIgnoreEOS(true)
- tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random)
+ tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
Expect(err).ShouldNot(HaveOccurred())
nGenTokens := int64(len(tokens))
Expect(nGenTokens).Should(Equal(n))
- Expect(finishReason).To(Equal(LengthFinishReason))
+ Expect(finishReason).To(Equal(common.LengthFinishReason))
},
func(maxCompletionTokens int) string {
return fmt.Sprintf("maxCompletionTokens: %d", maxCompletionTokens)
@@ -116,22 +129,36 @@ var _ = Describe("Dataset", Ordered, func() {
theText := "Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime"
theTokens := common.Tokenize(theText)
- It("should return the same text since max tokens is not defined", func() {
- tokens, finishReason := EchoResponseTokens(nil, theText)
+ It("should return the same text, max tokens is not defined", func() {
+ req := &openaiserverapi.TextCompletionRequest{
+ Prompt: theText,
+ }
+ tokens, finishReason, err := dataset.getTokensInEchoMode(req)
+ Expect(err).ShouldNot(HaveOccurred())
Expect(tokens).Should(Equal(theTokens))
- Expect(finishReason).Should(Equal(StopFinishReason))
+ Expect(finishReason).Should(Equal(common.StopFinishReason))
})
- It("should return the same text since max tokens is higher than the text length", func() {
- maxCompletionTokens := int64(1000)
- tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText)
+ It("should return the same text, max tokens is higher than the text length", func() {
+ maxTokens := int64(1000)
+ req := &openaiserverapi.TextCompletionRequest{
+ Prompt: theText,
+ MaxTokens: &maxTokens,
+ }
+ tokens, finishReason, err := dataset.getTokensInEchoMode(req)
+ Expect(err).ShouldNot(HaveOccurred())
Expect(tokens).Should(Equal(theTokens))
- Expect(finishReason).Should(Equal(StopFinishReason))
+ Expect(finishReason).Should(Equal(common.StopFinishReason))
})
- It("should return partial text", func() {
- maxCompletionTokens := int64(2)
- tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText)
- Expect(int64(len(tokens))).Should(Equal(maxCompletionTokens))
- Expect(finishReason).Should(Equal(LengthFinishReason))
+ It("should return the same text, finish reason is stop", func() {
+ maxTokens := int64(2)
+ req := &openaiserverapi.TextCompletionRequest{
+ Prompt: theText,
+ MaxTokens: &maxTokens,
+ }
+ tokens, finishReason, err := dataset.getTokensInEchoMode(req)
+ Expect(err).ShouldNot(HaveOccurred())
+ Expect(tokens).Should(Equal(theTokens))
+ Expect(finishReason).Should(Equal(common.LengthFinishReason))
})
})
@@ -141,7 +168,7 @@ var _ = Describe("Dataset", Ordered, func() {
for _, len := range lenArr {
name := fmt.Sprintf("should return text with %d tokens", len)
It(name, func() {
- tokens := GenPresetRandomTokens(random, len)
+ tokens := dataset.generatePresetRandomTokens(len)
Expect(tokens).Should(HaveLen(len))
})
}
@@ -151,15 +178,15 @@ var _ = Describe("Dataset", Ordered, func() {
validTxts := make([]string, 0)
invalidTxts := make([]string, 0)
- validTxts = append(validTxts, chatCompletionFakeResponses[0][:4])
- validTxts = append(validTxts, chatCompletionFakeResponses[1])
- validTxts = append(validTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2])
+ validTxts = append(validTxts, completionFakeResponses[0][:4])
+ validTxts = append(validTxts, completionFakeResponses[1])
+ validTxts = append(validTxts, completionFakeResponses[1]+" "+completionFakeResponses[2])
- invalidTxts = append(invalidTxts, (chatCompletionFakeResponses[1] + " " + chatCompletionFakeResponses[2])[3:4])
- invalidTxts = append(invalidTxts, chatCompletionFakeResponses[0][4:])
- invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+"-"+chatCompletionFakeResponses[2])
- invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" ")
- invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2])
+ invalidTxts = append(invalidTxts, (completionFakeResponses[1] + " " + completionFakeResponses[2])[3:4])
+ invalidTxts = append(invalidTxts, completionFakeResponses[0][4:])
+ invalidTxts = append(invalidTxts, completionFakeResponses[1]+"-"+completionFakeResponses[2])
+ invalidTxts = append(invalidTxts, completionFakeResponses[1]+" ")
+ invalidTxts = append(invalidTxts, completionFakeResponses[1]+" "+completionFakeResponses[2])
for _, txt := range validTxts {
It("text should be valid", func() {
@@ -175,6 +202,9 @@ var _ = Describe("Dataset", Ordered, func() {
})
Context("validateBucketsBoundaries", func() {
+ // create dataset earlier than BeforeEach since it's helper is used before It execution
+ createDataset()
+
type bucketBoundaries struct {
start int
end int
@@ -189,11 +219,11 @@ var _ = Describe("Dataset", Ordered, func() {
{50, []bucketBoundaries{{1, 9}, {10, 19}, {20, 29}, {30, 39}, {40, 49}}}}
for _, test := range tests {
- Expect(test.expectedBuckets).To(HaveLen(len(cumulativeBucketsProbabilities) - 1))
+ Expect(test.expectedBuckets).To(HaveLen(len(dataset.histogramHelper.cumulativeBucketsProbabilities) - 1))
It(fmt.Sprintf("should return bucket boundaries for maxTokens %d", test.maxTokens), func() {
- for i := range len(cumulativeBucketsProbabilities) - 1 {
- start, end := calcBucketBoundaries(test.maxTokens, i)
+ for i := range len(dataset.histogramHelper.cumulativeBucketsProbabilities) - 1 {
+ start, end := dataset.histogramHelper.calcBucketBoundaries(test.maxTokens, i)
Expect(start).To(Equal(test.expectedBuckets[i].start))
Expect(end).To(Equal(test.expectedBuckets[i].end))
}
diff --git a/pkg/dataset/histogram_helper.go b/pkg/dataset/histogram_helper.go
new file mode 100644
index 0000000..8107fad
--- /dev/null
+++ b/pkg/dataset/histogram_helper.go
@@ -0,0 +1,139 @@
+/*
+Copyright 2025 The llm-d-inference-sim Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package dataset
+
+import "github.com/llm-d/llm-d-inference-sim/pkg/common"
+
+const (
+ responseLenMean = 40
+ responseLenStddev = 20
+)
+
+// this array defines the probabilities for the buckets to be used for the generation of number of tokens in response
+var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15}
+
+const (
+ flexBucketIndex = 3
+ maxFixedBucketSize = 20
+)
+
+type histogramHelper struct {
+ cumulativeBucketsProbabilities []float64
+ random *common.Random
+}
+
+func newHistogramHelper(random *common.Random) *histogramHelper {
+ h := histogramHelper{random: random}
+
+ h.cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities))
+ sum := 0.0
+
+ for i, val := range respLenBucketsProbabilities {
+ sum += val
+ h.cumulativeBucketsProbabilities[i] = sum
+ }
+ return &h
+}
+
+// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets.
+// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities.
+// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value.
+// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens.
+// Other values define probabilities for the equally sized buckets.
+// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens]
+func (hh *histogramHelper) getResponseLengthByHistogram(maxTokens int) int {
+ if maxTokens <= 1 {
+ return maxTokens
+ }
+ // maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens]
+ if maxTokens <= len(hh.cumulativeBucketsProbabilities) {
+ res := hh.random.RandomInt(1, maxTokens)
+ return res
+ }
+
+ r := hh.random.RandomFloat(0, 1)
+
+ // check if r is in the last bucket, then maxTokens should be returned
+ if r > hh.cumulativeBucketsProbabilities[len(hh.cumulativeBucketsProbabilities)-2] {
+ return maxTokens
+ }
+
+ // determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use
+ // initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1
+ bucketIndex := len(hh.cumulativeBucketsProbabilities) - 1
+ for i, c := range hh.cumulativeBucketsProbabilities {
+ if r <= c {
+ bucketIndex = i
+ break
+ }
+ }
+
+ // calculate the size of all of the buckets (except the special last bucket)
+ start, end := hh.calcBucketBoundaries(maxTokens, bucketIndex)
+
+ // pick uniformly within the bucket’s range
+ return hh.random.RandomInt(start, end)
+}
+
+// calcBucketBoundaries calculates boundaries of a bucket with the given index.
+// Maximum size for equally sized buckets is defined by maxFixedBucketSize.
+// [maxFixedBucketSize*(number-of-buckets-1)+1] is the value of maxTokens for which
+// division to equally size buckets will give buckets with size maxFixedBucketSize.
+// If maxTokens is [maxFixedBucketSize*(number-of-buckets-1)+1] or less,
+// all buckets will be of equal size, except the last bucket, which contains only one value.
+// If maxTokens is higher than [maxFixedBucketSize*(number-of-buckets-1)+1],
+// and flexBucketIndex is valid (between 0 and number of buckets - 1) the buckets sizes will not be equal.
+// In this case, all buckets except the one at flexBucketIndex index will have size 20 (and the last is with size 1),
+// and the bucket at flexBucketIndex index will 'stretch' to cover the remaining range.
+func (hh *histogramHelper) calcBucketBoundaries(maxTokens int, bucketIndex int) (start int, end int) {
+ maxEquallyBucketsSz := maxFixedBucketSize*(len(hh.cumulativeBucketsProbabilities)-1) + 1
+
+ if maxTokens <= maxEquallyBucketsSz || flexBucketIndex < 0 || flexBucketIndex >= len(hh.cumulativeBucketsProbabilities)-1 {
+ // create equally size buckets
+ // calculate the size of all of the buckets (except the special last bucket)
+ bucketSize := float64(maxTokens-1) / float64(len(hh.cumulativeBucketsProbabilities)-1)
+ start = int(bucketSize*float64(bucketIndex)) + 1
+ end = int(bucketSize * float64(bucketIndex+1))
+ } else {
+ // create non-equally sized buckets and find boundaries of the required bucket
+ if bucketIndex < flexBucketIndex {
+ // the relevant bucket is before the flex bucket, all buckets are of the same size (maxFixedBucketSize)
+ // start is the minimum number in the required bucket
+ start = maxFixedBucketSize*bucketIndex + 1
+ end = maxFixedBucketSize * (bucketIndex + 1)
+ } else {
+ flexBucketSize := maxTokens - (maxFixedBucketSize * (len(hh.cumulativeBucketsProbabilities) - 2))
+
+ if bucketIndex == flexBucketIndex {
+ // the relevant bucket is the flex bucket
+ start = int(maxFixedBucketSize*float64(bucketIndex)) + 1
+ end = maxFixedBucketSize*bucketIndex + flexBucketSize
+ } else {
+ // the relevant bucket is one of buckets after the flex bucket
+ start = int(maxFixedBucketSize*float64(bucketIndex-1)) + flexBucketSize + 1
+ end = maxFixedBucketSize*bucketIndex + flexBucketSize
+ }
+ }
+ }
+
+ // sometimes end could be maxTokens because of rounding, change the value to maxToken-1
+ if end >= maxTokens {
+ end = maxTokens - 1
+ }
+
+ return start, end
+}
diff --git a/pkg/dataset/sqlite_helper.go b/pkg/dataset/sqlite_helper.go
new file mode 100644
index 0000000..ca45dcd
--- /dev/null
+++ b/pkg/dataset/sqlite_helper.go
@@ -0,0 +1,300 @@
+/*
+Copyright 2025 The llm-d-inference-sim Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package dataset
+
+import (
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "strconv"
+ "time"
+
+ "github.com/go-logr/logr"
+ "github.com/llm-d/llm-d-inference-sim/pkg/common/logging"
+)
+
+// use constants for expected column names and types
+const (
+ tableName = "llmd"
+ idCol = "id"
+ promptHashCol = "prompt_hash"
+ genTokensCol = "gen_tokens"
+ nGenTokensCol = "n_gen_tokens"
+ idColType = "INTEGER"
+ promptHashColType = "BLOB"
+ genTokensColType = "JSON"
+ nGenTokensColType = "INTEGER"
+)
+
+type sqliteHelper struct {
+ logger logr.Logger
+ db *sql.DB
+}
+
+func newSqliteHelper(logger logr.Logger) *sqliteHelper {
+ return &sqliteHelper{logger: logger}
+}
+
+func (s *sqliteHelper) connectToDB(path string, useInMemory bool) error {
+ if s.db != nil {
+ err := s.db.Close()
+ if err != nil {
+ s.logger.Error(err, "failed to close existing database connection")
+ }
+ s.db = nil
+ }
+ // check if file exists
+ _, err := os.Stat(path)
+ if err != nil {
+ return fmt.Errorf("database file does not exist: %w", err)
+ }
+
+ if useInMemory {
+ err = s.loadDatabaseInMemory(path)
+ if err != nil {
+ return err
+ }
+ } else {
+ // Use file-based database (original behavior)
+ s.db, err = sql.Open("sqlite3", "file:"+path+"?mode=ro")
+ if err != nil {
+ return fmt.Errorf("failed to open database: %w", err)
+ }
+
+ // Check if there are other connections to the database
+ _, err = s.db.Exec("BEGIN EXCLUSIVE;")
+ if err != nil {
+ if closeErr := s.db.Close(); closeErr != nil {
+ s.logger.Error(closeErr, "failed to close database after failing to acquire exclusive lock")
+ }
+ s.db = nil
+ return fmt.Errorf("database is locked or has other active connections: %w", err)
+ }
+ }
+
+ err = s.verifyDB()
+ if err != nil {
+ return fmt.Errorf("failed to verify database: %w", err)
+ }
+
+ count, err := s.getRecordsCount()
+ if err != nil {
+ s.logger.Error(err, "failed to get records count")
+ return fmt.Errorf("failed to query database: %w", err)
+ }
+
+ if useInMemory {
+ s.logger.V(logging.INFO).Info("In-memory database connected successfully", "path", path, "records count", count)
+ } else {
+ s.logger.V(logging.INFO).Info("Database connected successfully", "path", path, "records count", count)
+ }
+ return nil
+}
+
+func (s *sqliteHelper) loadDatabaseInMemory(path string) error {
+ s.logger.V(logging.INFO).Info("Loading database into memory...")
+ start := time.Now()
+
+ // Create in-memory database
+ var err error
+ s.db, err = sql.Open("sqlite3", ":memory:")
+ if err != nil {
+ return fmt.Errorf("failed to create in-memory database: %w", err)
+ }
+
+ // Use ATTACH to copy the database
+ attachSQL := fmt.Sprintf("ATTACH DATABASE '%s' AS source", path)
+ _, err = s.db.Exec(attachSQL)
+ if err != nil {
+ if closeErr := s.db.Close(); closeErr != nil {
+ s.logger.Error(closeErr, "failed to close in-memory database after attach failure")
+ }
+ s.db = nil
+ return fmt.Errorf("failed to attach source database: %w", err)
+ }
+
+ // Copy the table structure first
+ createTableStmt := fmt.Sprintf(`CREATE TABLE %s (
+ id INTEGER PRIMARY KEY,
+ prompt_hash BLOB,
+ gen_tokens JSON,
+ n_gen_tokens INTEGER
+ )`, tableName)
+ _, err = s.db.Exec(createTableStmt)
+ if err != nil {
+ if closeErr := s.db.Close(); closeErr != nil {
+ s.logger.Error(closeErr, "failed to close in-memory database after create table failure")
+ }
+ s.db = nil
+ return fmt.Errorf("failed to create table: %w", err)
+ }
+
+ // Copy the data
+ _, err = s.db.Exec("INSERT INTO " + tableName + " SELECT * FROM source." + tableName)
+ if err != nil {
+ if closeErr := s.db.Close(); closeErr != nil {
+ s.logger.Error(closeErr, "failed to close in-memory database after copy failure")
+ }
+ s.db = nil
+ return fmt.Errorf("failed to copy data: %w", err)
+ }
+
+ // Detach the source database
+ _, err = s.db.Exec("DETACH DATABASE source")
+ if err != nil {
+ s.logger.Error(err, "failed to detach source database")
+ }
+
+ loadTime := time.Since(start)
+ s.logger.V(logging.INFO).Info("Database loaded into memory", "load_time", loadTime.String())
+ return nil
+}
+
+func (s *sqliteHelper) verifyDB() error {
+ rows, err := s.db.Query("PRAGMA table_info(" + tableName + ");")
+ if err != nil {
+ return fmt.Errorf("failed to query table info for `%s`: %w", tableName, err)
+ }
+ defer func() {
+ if cerr := rows.Close(); cerr != nil {
+ s.logger.Error(cerr, "failed to close rows after querying table info")
+ }
+ }()
+
+ expectedColumns := map[string]string{
+ idCol: idColType,
+ promptHashCol: promptHashColType,
+ genTokensCol: genTokensColType,
+ nGenTokensCol: nGenTokensColType,
+ }
+
+ columnsFound := make(map[string]bool)
+
+ var (
+ columnName string
+ columnType string
+ cid int
+ notnull int
+ dfltValue interface{}
+ pk int
+ )
+
+ for rows.Next() {
+ err := rows.Scan(&cid, &columnName, &columnType, ¬null, &dfltValue, &pk)
+ if err != nil {
+ return fmt.Errorf("failed to scan table info row: %w", err)
+ }
+ if expectedType, exists := expectedColumns[columnName]; exists {
+ if columnType != expectedType {
+ return fmt.Errorf("column %s has incorrect type: expected %s, got %s", columnName, expectedType, columnType)
+ }
+ columnsFound[columnName] = true
+ }
+ }
+
+ for col := range expectedColumns {
+ if !columnsFound[col] {
+ return fmt.Errorf("missing expected column in %s table: %s", tableName, col)
+ }
+ }
+
+ return nil
+}
+
+func (s *sqliteHelper) getRecordsCount() (int, error) {
+ var count int
+ err := s.db.QueryRow("SELECT COUNT(" + promptHashCol + ") FROM " + tableName + ";").Scan(&count)
+ if err != nil {
+ return 0, fmt.Errorf("failed to query database: %w", err)
+ }
+ return count, nil
+}
+
+// query runs a SQL query which retrieves response tokens as an array of strings
+// returns multuple responses
+func (s *sqliteHelper) query(query string) ([][]string, error) {
+ rows, err := s.db.Query(query)
+ if err != nil {
+ s.logger.Error(err, "failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.")
+ return nil, err
+ }
+ defer func() {
+ if cerr := rows.Close(); cerr != nil {
+ s.logger.Error(cerr, "failed to close rows after query")
+ if err == nil {
+ err = cerr
+ } else {
+ err = errors.Join(err, cerr)
+ }
+ }
+ }()
+ return unmarshalAllRecords(rows)
+}
+
+func unmarshalAllRecords(rows *sql.Rows) ([][]string, error) {
+ var responses [][]string
+
+ for rows.Next() {
+ var responseJSON string
+ if err := rows.Scan(&responseJSON); err != nil {
+ return nil, fmt.Errorf("failed to scan row: %w", err)
+ }
+
+ var tokens []string
+ if err := json.Unmarshal([]byte(responseJSON), &tokens); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal tokens JSON: %w", err)
+ }
+ responses = append(responses, tokens)
+ }
+ return responses, nil
+}
+
+func (s *sqliteHelper) buildQuery(where string, isRand bool, isLimitOne bool) string {
+ query := "SELECT " + genTokensCol + " FROM " + tableName
+
+ if where != "" {
+ query += " WHERE " + where
+ }
+
+ if isRand {
+ query += " ORDER BY RANDOM()"
+ }
+
+ if isLimitOne {
+ query += " LIMIT 1"
+ }
+ query += ";"
+
+ return query
+}
+
+func (s *sqliteHelper) getResponsesForPrompt(promptHashHex string) ([][]string, error) {
+ query := s.buildQuery(promptHashCol+"=X'"+promptHashHex+"'", false, false)
+ return s.query(query)
+}
+
+func (s *sqliteHelper) getResponsesForLen(maxLen int) ([][]string, error) {
+ query := s.buildQuery(nGenTokensCol+"<="+strconv.Itoa(maxLen), true, true)
+ return s.query(query)
+}
+
+func (s *sqliteHelper) getRandomResponse() ([][]string, error) {
+ query := s.buildQuery("", true, true)
+ return s.query(query)
+}
diff --git a/pkg/dataset/test_helpers.go b/pkg/dataset/test_helpers.go
index ed6b6f1..7cde861 100644
--- a/pkg/dataset/test_helpers.go
+++ b/pkg/dataset/test_helpers.go
@@ -27,7 +27,7 @@ func IsValidText(text string) bool {
textToCheck := text[charsTested:]
found := false
- for _, fakeSentence := range chatCompletionFakeResponses {
+ for _, fakeSentence := range completionFakeResponses {
if len(textToCheck) <= len(fakeSentence) {
if strings.HasPrefix(fakeSentence, textToCheck) {
found = true
diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go
index cb1ab54..c0ca219 100644
--- a/pkg/llm-d-inference-sim/simulator.go
+++ b/pkg/llm-d-inference-sim/simulator.go
@@ -247,6 +247,14 @@ func (s *VllmSimulator) Start(ctx context.Context) error {
return err
}
+ if s.config.DatasetURL != "" {
+ // if should use remote responses dataset, download it first (it can take time)
+ downloader := dataset.NewDsDownloader(s.logger)
+ if err := downloader.DownloadDataset(ctx, s.config.DatasetURL, s.config.DatasetPath); err != nil {
+ return err
+ }
+ }
+
// For Data Parallel, start data-parallel-size - 1 additional simulators
g, ctx := errgroup.WithContext(ctx)
if s.config.DPSize > 1 {
@@ -371,32 +379,27 @@ func (s *VllmSimulator) initializeSim(ctx context.Context) error {
}
func (s *VllmSimulator) initDataset(ctx context.Context) error {
- randDataset := &dataset.BaseDataset{}
- err := randDataset.Init(ctx, s.logger, "", "", false)
- if err != nil {
- return fmt.Errorf("failed to initialize random dataset: %w", err)
- }
-
if s.config.DatasetPath == "" && s.config.DatasetURL == "" {
+ // use predefined sentences as responses
+ randDataset := &dataset.BaseDataset{}
+ err := randDataset.Init(ctx, s.logger, s.random, s.config.MaxModelLen)
+ if err != nil {
+ return fmt.Errorf("failed to initialize random dataset: %w", err)
+ }
s.logger.V(logging.INFO).Info("No dataset path or URL provided, using random text for responses")
s.dataset = randDataset
return nil
}
+ // use dataset containing responses
custDataset := &dataset.CustomDataset{}
- err = custDataset.Init(ctx, s.logger, s.config.DatasetPath, s.config.DatasetURL, s.config.DatasetInMemory)
+ err := custDataset.Init(ctx, s.logger, s.random, s.config.DatasetPath, s.config.DatasetInMemory, s.config.MaxModelLen)
if err == nil {
s.dataset = custDataset
return nil
}
- if strings.HasPrefix(err.Error(), "database is locked") {
- s.logger.V(logging.WARN).Info("Database is locked by another process, will use preset text for responses instead")
- s.dataset = randDataset
- return nil
- }
-
return err
}
diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go
index 38e0ad9..a08d946 100644
--- a/pkg/llm-d-inference-sim/simulator_test.go
+++ b/pkg/llm-d-inference-sim/simulator_test.go
@@ -180,16 +180,16 @@ var _ = Describe("Simulator", func() {
msg := resp.Choices[0].Message.Content
Expect(msg).ShouldNot(BeEmpty())
- if numTokens > 0 {
- tokens := common.Tokenize(msg)
- Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens))
+ if mode == common.ModeEcho {
+ // in case of echo mode check that the text is returned as-is
+ Expect(msg).Should(Equal(testUserMessage))
} else {
- if mode == common.ModeRandom {
+ if numTokens > 0 {
+ tokens := common.Tokenize(msg)
+ Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens))
+ } else {
// in case of random mode ensure that the returned message could be output of the random text generator
Expect(dataset.IsValidText(msg)).To(BeTrue())
- } else {
- // in case of echo mode check that the text is returned as-is
- Expect(msg).Should(Equal(testUserMessage))
}
}
},
@@ -251,16 +251,16 @@ var _ = Describe("Simulator", func() {
text := resp.Choices[0].Text
Expect(text).ShouldNot(BeEmpty())
- if numTokens != 0 {
- tokens := common.Tokenize(text)
- Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens))
+ if mode == common.ModeEcho {
+ // in case of echo mode check that the text is returned as-is
+ Expect(text).Should(Equal(testUserMessage))
} else {
- if mode == common.ModeRandom {
+ if numTokens != 0 {
+ tokens := common.Tokenize(text)
+ Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens))
+ } else {
// in case of random mode ensure that the returned message could be output of the random text generator
Expect(dataset.IsValidText(text)).To(BeTrue())
- } else {
- // in case of echo mode check that the text is returned as-is
- Expect(text).Should(Equal(testUserMessage))
}
}
},
diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go
index 918408a..7c96fa9 100644
--- a/pkg/llm-d-inference-sim/streaming.go
+++ b/pkg/llm-d-inference-sim/streaming.go
@@ -26,7 +26,6 @@ import (
"github.com/llm-d/llm-d-inference-sim/pkg/common"
"github.com/llm-d/llm-d-inference-sim/pkg/common/logging"
- "github.com/llm-d/llm-d-inference-sim/pkg/dataset"
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
"github.com/valyala/fasthttp"
)
@@ -142,7 +141,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ
var chunk openaiserverapi.CompletionRespChunk
var finishReasonToSend *string
- if i == len(genTokens)-1 && (finishReason == dataset.LengthFinishReason || finishReason == dataset.ToolsFinishReason) {
+ if i == len(genTokens)-1 && (finishReason == common.LengthFinishReason || finishReason == common.ToolsFinishReason) {
finishReasonToSend = &finishReason
}
if context.isChatCompletion {
@@ -161,7 +160,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ
// send the last chunk if finish reason is stop
var chunk openaiserverapi.CompletionRespChunk
- if finishReason == dataset.StopFinishReason {
+ if finishReason == common.StopFinishReason {
if context.isChatCompletion {
chunk = s.createChatCompletionChunk(context, "", nil, "", &finishReason)
} else {
diff --git a/pkg/llm-d-inference-sim/tools_test.go b/pkg/llm-d-inference-sim/tools_test.go
index aa6f54c..9edab45 100644
--- a/pkg/llm-d-inference-sim/tools_test.go
+++ b/pkg/llm-d-inference-sim/tools_test.go
@@ -24,7 +24,6 @@ import (
"strings"
"github.com/llm-d/llm-d-inference-sim/pkg/common"
- "github.com/llm-d/llm-d-inference-sim/pkg/dataset"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/openai/openai-go/v3"
@@ -397,7 +396,7 @@ var _ = Describe("Simulator for request with tools", func() {
for _, choice := range chunk.Choices {
if choice.Delta.Role != "" {
role = choice.Delta.Role
- } else if choice.FinishReason == "" || choice.FinishReason == dataset.ToolsFinishReason {
+ } else if choice.FinishReason == "" || choice.FinishReason == common.ToolsFinishReason {
toolCalls := choice.Delta.ToolCalls
Expect(toolCalls).To(HaveLen(1))
tc := toolCalls[0]
diff --git a/pkg/llm-d-inference-sim/worker.go b/pkg/llm-d-inference-sim/worker.go
index a256d19..041b3b4 100644
--- a/pkg/llm-d-inference-sim/worker.go
+++ b/pkg/llm-d-inference-sim/worker.go
@@ -25,7 +25,6 @@ import (
"github.com/go-logr/logr"
"github.com/llm-d/llm-d-inference-sim/pkg/common"
"github.com/llm-d/llm-d-inference-sim/pkg/common/logging"
- "github.com/llm-d/llm-d-inference-sim/pkg/dataset"
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
"github.com/valyala/fasthttp"
)
@@ -109,12 +108,12 @@ func (s *VllmSimulator) processRequestAsync(reqCtx *openaiserverapi.CompletionRe
req.GetTools() != nil {
toolCalls, completionTokens, err =
common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config, s.random)
- finishReason = dataset.ToolsFinishReason
+ finishReason = common.ToolsFinishReason
}
if toolCalls == nil && err == nil {
// Either no tool calls were defined, or we randomly chose not to create tool calls,
// so we generate a response text.
- responseTokens, finishReason, err = s.dataset.GetTokens(req, s.config.Mode, s.random)
+ responseTokens, finishReason, err = s.dataset.GetTokens(req, s.config.Mode)
completionTokens += len(responseTokens)
}
if err != nil {
@@ -154,7 +153,7 @@ func (s *VllmSimulator) processRequestAsync(reqCtx *openaiserverapi.CompletionRe
} else {
if req.IsDoRemoteDecode() {
// in case this is prefill pod processing, return special finish reason
- finishReason = dataset.RemoteDecodeFinishReason
+ finishReason = common.RemoteDecodeFinishReason
}
s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData)
wg.Done()