From 97b4230ce53e62e0013e4cac8106882e914c230a Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Sat, 13 Sep 2025 21:53:11 +1000 Subject: [PATCH 01/34] Show config in yaml Signed-off-by: Qifan Deng --- pkg/llm-d-inference-sim/simulator.go | 87 ++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index ab55fea2..bb04bf64 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -29,6 +29,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/valyala/fasthttp" "golang.org/x/sync/errgroup" + "gopkg.in/yaml.v3" "k8s.io/klog/v2" "github.com/llm-d/llm-d-inference-sim/pkg/common" @@ -497,3 +498,89 @@ func (s *VllmSimulator) createModelsResponse() *vllmapi.ModelsResponse { return &modelsResp } +<<<<<<< HEAD +======= + +// HandleHealth http handler for /health +func (s *VllmSimulator) HandleHealth(ctx *fasthttp.RequestCtx) { + s.logger.V(4).Info("health request received") + ctx.Response.Header.SetContentType("application/json") + ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) + ctx.Response.SetBody([]byte("{}")) +} + +// HandleReady http handler for /ready +func (s *VllmSimulator) HandleReady(ctx *fasthttp.RequestCtx) { + s.logger.V(4).Info("readiness request received") + ctx.Response.Header.SetContentType("application/json") + ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) + ctx.Response.SetBody([]byte("{}")) +} + +// getDisplayedModelName returns the model name that must appear in API +// responses. LoRA adapters keep their explicit name, while all base-model +// requests are surfaced as the first alias from --served-model-name. +func (s *VllmSimulator) getDisplayedModelName(reqModel string) string { + if s.isLora(reqModel) { + return reqModel + } + return s.config.ServedModelNames[0] +} + +func (s *VllmSimulator) showConfig(dp bool) error { + cfgYAML, err := yaml.Marshal(s.config) + if err != nil { + return fmt.Errorf("failed to marshal configuration to YAML: %w", err) + } + + var m map[string]interface{} + err = yaml.Unmarshal(cfgYAML, &m) + if err != nil { + return fmt.Errorf("failed to unmarshal YAML to map: %w", err) + } + if dp { + // remove the port + delete(m, "port") + } + // clean LoraModulesString field + m["lora-modules"] = m["LoraModules"] + delete(m, "LoraModules") + delete(m, "LoraModulesString") + + // clean fake-metrics field + if field, ok := m["fake-metrics"].(map[string]interface{}); ok { + delete(field, "LorasString") + } + + // show in YAML + cfgYAML, err = yaml.Marshal(m) + if err != nil { + return fmt.Errorf("failed to marshal configuration to YAML: %w", err) + } + s.logger.Info("Configuration:", "", string(cfgYAML)) + return nil +} + +func (s *VllmSimulator) getCurrFactor() float64 { + if s.config.MaxNumSeqs <= 1 { + return 1.0 + } + return 1 + (s.config.TimeFactorUnderLoad-1)*float64(s.nRunningReqs-1)/float64(s.config.MaxNumSeqs-1) +} + +func (s *VllmSimulator) GetTimeToFirstToken() int { + return int(float64(s.config.TimeToFirstToken) * s.getCurrFactor()) +} + +func (s *VllmSimulator) GetPrefillOverhead() int { + return int(float64(s.config.PrefillOverhead) * s.getCurrFactor()) +} + +func (s *VllmSimulator) GetPrefillTimePerToken() int { + return int(float64(s.config.PrefillTimePerToken) * s.getCurrFactor()) +} + +func (s *VllmSimulator) GetInterTokenLatency() int { + return int(float64(s.config.InterTokenLatency) * s.getCurrFactor()) +} +>>>>>>> 482434e (Show config in yaml) From b001ac68d58d5556afe6267cd47fa56f5e1fe4ca Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Sun, 14 Sep 2025 00:53:52 +1000 Subject: [PATCH 02/34] load or download response dataset Signed-off-by: Qifan Deng --- .gitignore | 2 + Makefile | 2 +- go.mod | 1 + go.sum | 2 + pkg/common/config.go | 13 ++ .../.llm-d/test.valid.sqlite3 | Bin 0 -> 8192 bytes pkg/llm-d-inference-sim/dataset.go | 197 ++++++++++++++++++ pkg/llm-d-inference-sim/dataset_test.go | 97 +++++++++ pkg/llm-d-inference-sim/simulator.go | 87 -------- 9 files changed, 313 insertions(+), 88 deletions(-) create mode 100644 pkg/llm-d-inference-sim/.llm-d/test.valid.sqlite3 create mode 100644 pkg/llm-d-inference-sim/dataset.go create mode 100644 pkg/llm-d-inference-sim/dataset_test.go diff --git a/.gitignore b/.gitignore index 950b0cb4..419af304 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ vendor .DS_Store *.test manifests/dev-config.yaml +pkg/llm-d-inference-sim/.llm-d +.llm-d/ diff --git a/Makefile b/Makefile index 40392e9a..a71b78c6 100644 --- a/Makefile +++ b/Makefile @@ -85,7 +85,7 @@ format: ## Format Go source files test: $(GINKGO) download-tokenizer download-zmq ## Run tests @printf "\033[33;1m==== Running tests ====\033[0m\n" ifdef GINKGO_FOCUS - CGO_ENABLED=1 $(GINKGO) -ldflags="$(GO_LDFLAGS)" -v -r --focus="$(GINKGO_FOCUS)" + CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r -- -ginkgo.v -ginkgo.focus="$(GINKGO_FOCUS)" else CGO_ENABLED=1 $(GINKGO) -ldflags="$(GO_LDFLAGS)" -v -r endif diff --git a/go.mod b/go.mod index b56c7987..e3ca6e85 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect diff --git a/go.sum b/go.sum index 4db4fd05..dd46dc12 100644 --- a/go.sum +++ b/go.sum @@ -72,6 +72,8 @@ github.com/llm-d/llm-d-kv-cache-manager v0.3.0-rc1 h1:SDLiNrcreDcA9m9wfXAumFARDH github.com/llm-d/llm-d-kv-cache-manager v0.3.0-rc1/go.mod h1:tN80/D0Faf6pE2ocwFgTNoCxKPsqdsa2XnjQUqOaZ8Q= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/pkg/common/config.go b/pkg/common/config.go index ca8e5aa5..9878ff4b 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -181,6 +181,19 @@ type Configuration struct { SSLKeyFile string `yaml:"ssl-keyfile" json:"ssl-keyfile"` // SelfSignedCerts enables automatic generation of self-signed certificates for HTTPS SelfSignedCerts bool `yaml:"self-signed-certs" json:"self-signed-certs"` + // Dataset configuration for response generation from a dataset. sqlite db file is expected. + Dataset Dataset +} + +type Dataset struct { + // Path is the local path to the sqlite db file, default is empty + // when path is empty Url will be checked + Path string `yaml:"path" json:"path"` + // Url is the URL to download the sqlite db file if set, default is empty + Url string `yaml:"url" json:"url"` + // SavePath is the local path to save the downloaded sqlite db file + // if Url is set but SavePath is not, "~/.llmd/dataset.db" will be used + SavePath string `yaml:"save-path" json:"save-path"` } type Metrics struct { diff --git a/pkg/llm-d-inference-sim/.llm-d/test.valid.sqlite3 b/pkg/llm-d-inference-sim/.llm-d/test.valid.sqlite3 new file mode 100644 index 0000000000000000000000000000000000000000..dda347c45c66a76d4a699987e488ce81c36b0321 GIT binary patch literal 8192 zcmeI#u?oU46ouiNB05Rjx@Oct7hk|yvUO22=vc9WkRnKQ^eugsuBO`ErGvwNNOBXn zB=Bw7ZHh~%%=vn&%V3r=5v62hjEE%NjO7y**Fm>$OMCv6L>AA(ICd%hk~jzg2q1s} z0tg_000IagfB*sr^d&G;!&GaPhw|vF5R!arGiQy<)`c~}OdqAcSH{`Bn|=T0jp(MY i$GHaq1Q0*~0R#|0009ILKmY**{zPD;lGOWpetH8NRTvil literal 0 HcmV?d00001 diff --git a/pkg/llm-d-inference-sim/dataset.go b/pkg/llm-d-inference-sim/dataset.go new file mode 100644 index 00000000..949e445c --- /dev/null +++ b/pkg/llm-d-inference-sim/dataset.go @@ -0,0 +1,197 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io" + "net/http" + "os" + "os/signal" + "path/filepath" + "syscall" + + "github.com/go-logr/logr" +) + +type Dataset struct { + db *sql.DB + logger logr.Logger +} + +func (d *Dataset) downloadDataset(url string, savePath string) error { + // Set up signal handling for Ctrl+C (SIGINT) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigs) + + // Goroutine to listen for signal + go func() { + <-sigs + d.logger.Info("Interrupt signal received, cancelling download...") + cancel() + }() + + out, err := os.Create(savePath) + if err != nil { + return err + } + defer func() { + cerr := out.Close() + if cerr != nil { + d.logger.Error(cerr, "failed to close file after download") + } + }() + + resp, err := http.Get(url) + if err != nil { + return err + } + defer func() { + cerr := resp.Body.Close() + if cerr != nil { + d.logger.Error(cerr, "failed to close response body after download") + } + }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bad status: %s", resp.Status) + } + + // Progress reader with context + pr := &progressReader{ + Reader: resp.Body, + total: resp.ContentLength, + logger: d.logger, + ctx: ctx, + } + + written, err := io.Copy(out, pr) + if err != nil { + // Remove incomplete file + cerr := os.Remove(savePath) + if cerr != nil { + d.logger.Error(cerr, "failed to remove incomplete file after download") + } + // If context was cancelled, return a specific error + if errors.Is(err, context.Canceled) { + return errors.New("download cancelled by user") + } + return fmt.Errorf("failed to download file: %w", err) + } + // Check if file size is zero or suspiciously small + if written == 0 { + cerr := os.Remove(savePath) + if cerr != nil { + d.logger.Error(cerr, "failed to remove empty file after download") + } + return errors.New("downloaded file is empty") + } + + // Ensure file is fully flushed and closed before returning success + if err := out.Sync(); err != nil { + cerr := os.Remove(savePath) + if cerr != nil { + d.logger.Error(cerr, "failed to remove incomplete file after download") + } + return fmt.Errorf("failed to sync file: %w", err) + } + + return nil +} + +// progressReader wraps an io.Reader and logs download progress. +type progressReader struct { + io.Reader + total int64 + downloaded int64 + lastPct int + logger logr.Logger + ctx context.Context +} + +func (pr *progressReader) Read(p []byte) (int, error) { + select { + case <-pr.ctx.Done(): + return 0, pr.ctx.Err() + default: + } + n, err := pr.Reader.Read(p) + pr.downloaded += int64(n) + if pr.total > 0 { + pct := int(float64(pr.downloaded) * 100 / float64(pr.total)) + if pct != pr.lastPct && pct%10 == 0 { // log every 10% + pr.logger.Info(fmt.Sprintf("Download progress: %d%%", pct)) + pr.lastPct = pct + } + } + return n, err +} +func (d *Dataset) connectToDB(path string) error { + // check if file exists + _, err := os.Stat(path) + if err != nil { + return fmt.Errorf("database file does not exist: %w", err) + } + d.db, err = sql.Open("sqlite3", path) + if err != nil { + return fmt.Errorf("failed to open database: %w", err) + } + return nil +} + +func (d *Dataset) Init(path string, url string, savePath string) error { + if path != "" { + return d.connectToDB(path) + } + if url != "" { + if savePath == "" { + savePath = "~/.llmd/dataset.sqlite3" + } + + _, err := os.Stat(savePath) + if err != nil { + // file does not exist, download it + folder := filepath.Dir(savePath) + err := os.MkdirAll(folder, 0755) + if err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + d.logger.Info("Downloading dataset from URL", "url", url, "to", savePath) + err = d.downloadDataset(url, savePath) + if err != nil { + return fmt.Errorf("failed to download dataset: %w", err) + } + } + d.logger.Info("Using dataset from", "path", savePath) + + return d.connectToDB(savePath) + } + return errors.New("no dataset path or url provided") +} + +func (d *Dataset) Close() error { + if d.db != nil { + return d.db.Close() + } + return nil +} diff --git a/pkg/llm-d-inference-sim/dataset_test.go b/pkg/llm-d-inference-sim/dataset_test.go new file mode 100644 index 00000000..4e226237 --- /dev/null +++ b/pkg/llm-d-inference-sim/dataset_test.go @@ -0,0 +1,97 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "os" + + "github.com/go-logr/logr" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + _ "github.com/mattn/go-sqlite3" +) + +var _ = Describe("Dataset", func() { + var ( + dataset *Dataset + file_folder string + savePath string + ) + + BeforeEach(func() { + dataset = &Dataset{ + logger: logr.Discard(), + } + file_folder = "./.llm-d" + savePath = file_folder + "/test.sqlite3" + err := os.MkdirAll(file_folder, os.ModePerm) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + if dataset.db != nil { + err := dataset.db.Close() + Expect(err).NotTo(HaveOccurred()) + } + }) + + It("should return error for invalid DB path", func() { + err := dataset.connectToDB("/invalid/path/to/db.sqlite") + Expect(err).To(HaveOccurred()) + }) + + It("should download file from url", func() { + url := "https://llm-d.ai" + err := dataset.downloadDataset(url, savePath) + Expect(err).NotTo(HaveOccurred()) + _, err = os.Stat(savePath) + Expect(err).NotTo(HaveOccurred()) + err = os.Remove(savePath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should not download file from url", func() { + url := "https://256.256.256.256" // invalid url + err := dataset.downloadDataset(url, savePath) + Expect(err).To(HaveOccurred()) + }) + + It("should successfully init dataset", func() { + validDBPath := file_folder + "/test.valid.sqlite3" + err := dataset.Init(validDBPath, "", "") + Expect(err).NotTo(HaveOccurred()) + + // read from the db to verify it's valid + row := dataset.db.QueryRow("SELECT * FROM t;") + var value string + err = row.Scan(&value) + Expect(err).NotTo(HaveOccurred()) + Expect(value).To(Equal("llm-d")) + }) + + It("should raise err with invalid DB content", func() { + err := dataset.connectToDB(file_folder) + Expect(err).NotTo(HaveOccurred()) + + // read from the db to verify it's not valid + row := dataset.db.QueryRow("SELECT * FROM t;") + var value string + err = row.Scan(&value) + Expect(err).To(HaveOccurred()) + }) +}) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index bb04bf64..ab55fea2 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -29,7 +29,6 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/valyala/fasthttp" "golang.org/x/sync/errgroup" - "gopkg.in/yaml.v3" "k8s.io/klog/v2" "github.com/llm-d/llm-d-inference-sim/pkg/common" @@ -498,89 +497,3 @@ func (s *VllmSimulator) createModelsResponse() *vllmapi.ModelsResponse { return &modelsResp } -<<<<<<< HEAD -======= - -// HandleHealth http handler for /health -func (s *VllmSimulator) HandleHealth(ctx *fasthttp.RequestCtx) { - s.logger.V(4).Info("health request received") - ctx.Response.Header.SetContentType("application/json") - ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) - ctx.Response.SetBody([]byte("{}")) -} - -// HandleReady http handler for /ready -func (s *VllmSimulator) HandleReady(ctx *fasthttp.RequestCtx) { - s.logger.V(4).Info("readiness request received") - ctx.Response.Header.SetContentType("application/json") - ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) - ctx.Response.SetBody([]byte("{}")) -} - -// getDisplayedModelName returns the model name that must appear in API -// responses. LoRA adapters keep their explicit name, while all base-model -// requests are surfaced as the first alias from --served-model-name. -func (s *VllmSimulator) getDisplayedModelName(reqModel string) string { - if s.isLora(reqModel) { - return reqModel - } - return s.config.ServedModelNames[0] -} - -func (s *VllmSimulator) showConfig(dp bool) error { - cfgYAML, err := yaml.Marshal(s.config) - if err != nil { - return fmt.Errorf("failed to marshal configuration to YAML: %w", err) - } - - var m map[string]interface{} - err = yaml.Unmarshal(cfgYAML, &m) - if err != nil { - return fmt.Errorf("failed to unmarshal YAML to map: %w", err) - } - if dp { - // remove the port - delete(m, "port") - } - // clean LoraModulesString field - m["lora-modules"] = m["LoraModules"] - delete(m, "LoraModules") - delete(m, "LoraModulesString") - - // clean fake-metrics field - if field, ok := m["fake-metrics"].(map[string]interface{}); ok { - delete(field, "LorasString") - } - - // show in YAML - cfgYAML, err = yaml.Marshal(m) - if err != nil { - return fmt.Errorf("failed to marshal configuration to YAML: %w", err) - } - s.logger.Info("Configuration:", "", string(cfgYAML)) - return nil -} - -func (s *VllmSimulator) getCurrFactor() float64 { - if s.config.MaxNumSeqs <= 1 { - return 1.0 - } - return 1 + (s.config.TimeFactorUnderLoad-1)*float64(s.nRunningReqs-1)/float64(s.config.MaxNumSeqs-1) -} - -func (s *VllmSimulator) GetTimeToFirstToken() int { - return int(float64(s.config.TimeToFirstToken) * s.getCurrFactor()) -} - -func (s *VllmSimulator) GetPrefillOverhead() int { - return int(float64(s.config.PrefillOverhead) * s.getCurrFactor()) -} - -func (s *VllmSimulator) GetPrefillTimePerToken() int { - return int(float64(s.config.PrefillTimePerToken) * s.getCurrFactor()) -} - -func (s *VllmSimulator) GetInterTokenLatency() int { - return int(float64(s.config.InterTokenLatency) * s.getCurrFactor()) -} ->>>>>>> 482434e (Show config in yaml) From b4adcccf0bc1dbbec54ad0b6bc316d65f714eb42 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Sun, 14 Sep 2025 15:41:43 +1000 Subject: [PATCH 03/34] Init dataset when sim starts and show downloading speed of url Signed-off-by: Qifan Deng --- pkg/common/config.go | 2 +- pkg/llm-d-inference-sim/dataset.go | 43 +++++++++++++++++++++++----- pkg/llm-d-inference-sim/simulator.go | 11 +++++++ 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/pkg/common/config.go b/pkg/common/config.go index 9878ff4b..73d0cbde 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -192,7 +192,7 @@ type Dataset struct { // Url is the URL to download the sqlite db file if set, default is empty Url string `yaml:"url" json:"url"` // SavePath is the local path to save the downloaded sqlite db file - // if Url is set but SavePath is not, "~/.llmd/dataset.db" will be used + // if Url is set but SavePath is not, "USER_HOME/.llm-d/dataset.db" will be used SavePath string `yaml:"save-path" json:"save-path"` } diff --git a/pkg/llm-d-inference-sim/dataset.go b/pkg/llm-d-inference-sim/dataset.go index 949e445c..a5b65d83 100644 --- a/pkg/llm-d-inference-sim/dataset.go +++ b/pkg/llm-d-inference-sim/dataset.go @@ -27,8 +27,10 @@ import ( "os/signal" "path/filepath" "syscall" + "time" "github.com/go-logr/logr" + _ "github.com/mattn/go-sqlite3" ) type Dataset struct { @@ -79,10 +81,12 @@ func (d *Dataset) downloadDataset(url string, savePath string) error { // Progress reader with context pr := &progressReader{ - Reader: resp.Body, - total: resp.ContentLength, - logger: d.logger, - ctx: ctx, + Reader: resp.Body, + total: resp.ContentLength, + logger: d.logger, + ctx: ctx, + startTime: time.Now(), + hasShownSpeed: false, } written, err := io.Copy(out, pr) @@ -124,9 +128,11 @@ type progressReader struct { io.Reader total int64 downloaded int64 + startTime time.Time lastPct int logger logr.Logger ctx context.Context + hasShownSpeed bool } func (pr *progressReader) Read(p []byte) (int, error) { @@ -139,13 +145,30 @@ func (pr *progressReader) Read(p []byte) (int, error) { pr.downloaded += int64(n) if pr.total > 0 { pct := int(float64(pr.downloaded) * 100 / float64(pr.total)) - if pct != pr.lastPct && pct%10 == 0 { // log every 10% - pr.logger.Info(fmt.Sprintf("Download progress: %d%%", pct)) + if !pr.hasShownSpeed && time.Since(pr.startTime).Seconds() > 2 { + pr.hasShownSpeed = true + pr.logProgress(pct) + pr.lastPct = pct + } + if pct != pr.lastPct && pct%10 == 0 { + pr.logProgress(pct) pr.lastPct = pct } } return n, err } + +func (pr *progressReader) logProgress(pct int) { + elapsedTime := time.Since(pr.startTime).Seconds() + speed := float64(pr.downloaded) / (1024 * 1024 * elapsedTime) + remainingTime := float64(pr.total-pr.downloaded) / (float64(pr.downloaded) / elapsedTime) + if pct != 100 { + pr.logger.Info(fmt.Sprintf("Download progress: %d%%, Speed: %.2f MB/s, Remaining time: %.2fs", pct, speed, remainingTime)) + } else { + pr.logger.Info(fmt.Sprintf("Download completed: 100%%, Average Speed: %.2f MB/s, Total time: %.2fs", speed, elapsedTime)) + } +} + func (d *Dataset) connectToDB(path string) error { // check if file exists _, err := os.Stat(path) @@ -156,6 +179,8 @@ func (d *Dataset) connectToDB(path string) error { if err != nil { return fmt.Errorf("failed to open database: %w", err) } + // Test the connection + return nil } @@ -165,7 +190,11 @@ func (d *Dataset) Init(path string, url string, savePath string) error { } if url != "" { if savePath == "" { - savePath = "~/.llmd/dataset.sqlite3" + user, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed to get user home directory: %w", err) + } + savePath = filepath.Join(user, ".llm-d", "dataset.sqlite3") } _, err := os.Stat(savePath) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index ab55fea2..12ae1e72 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -115,6 +115,8 @@ type VllmSimulator struct { pod string // tokenizer is currently used in kv-cache and in /tokenize tokenizer tokenization.Tokenizer + // dataset is used for managing dataset files + dataset *Dataset } // New creates a new VllmSimulator instance with the given logger @@ -152,6 +154,15 @@ func (s *VllmSimulator) Start(ctx context.Context) error { return err } + dataset := &Dataset{ + logger: s.logger, + } + err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath) + if err != nil { + return err + } + s.dataset = dataset + // For Data Parallel, start data-parallel-size - 1 additional simulators g, ctx := errgroup.WithContext(ctx) if s.config.DPSize > 1 { From 1112b0dd1e800ed8bc6923064b9133e1d33fba2d Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Sun, 14 Sep 2025 16:18:04 +1000 Subject: [PATCH 04/34] Fix tests and init dataset when loading sim Signed-off-by: Qifan Deng --- .../.llm-d/test.valid.sqlite3 | Bin 8192 -> 12288 bytes pkg/llm-d-inference-sim/dataset.go | 36 +++++++++++------- pkg/llm-d-inference-sim/dataset_test.go | 10 ++--- pkg/llm-d-inference-sim/simulator.go | 18 +++++---- 4 files changed, 38 insertions(+), 26 deletions(-) diff --git a/pkg/llm-d-inference-sim/.llm-d/test.valid.sqlite3 b/pkg/llm-d-inference-sim/.llm-d/test.valid.sqlite3 index dda347c45c66a76d4a699987e488ce81c36b0321..2ddcf77e7a2ab1b13cddc50bb545b6e8a9c00c0d 100644 GIT binary patch delta 346 zcmZp0Xh@hKEy%^dz`zW|Fu*!d#~3K6rxU=-|AT>vKaqhyk^lV0!h3x6Kz=n$8I;ap zWEU3|Wo)!ANleN~&B@740bwTRAXmo_SA`HqCm&Y@kcff?mjVzJ6y@g@l*DHw7H23V z<>V&;1u}CJiz*eeQ!90#vI-@s6(u?f>8W|CMTsS;DPTTGx2=MvIuo0?wJ0M)W?o8a zMR8$HW=U#%VrfY}m>&czyQK9z&KIIn4gJ3uP%}oD8$Hrm4W~2W&$SH}=ng%C$4A6JDE1&y-AoYGW<5Z8zhO^DX-4E*0W3o4xEpV*+p$ Date: Sun, 14 Sep 2025 17:08:17 +1000 Subject: [PATCH 05/34] Move dataset init to startSim Signed-off-by: Qifan Deng --- pkg/llm-d-inference-sim/simulator.go | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 90ed4de9..16727223 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -154,19 +154,6 @@ func (s *VllmSimulator) Start(ctx context.Context) error { return err } - if s.config.Dataset.Path == "" && s.config.Dataset.Url == "" && s.config.Dataset.SavePath == "" { - s.dataset = nil - } else { - dataset := &Dataset{ - logger: s.logger, - } - err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath) - if err != nil { - return err - } - s.dataset = dataset - } - // For Data Parallel, start data-parallel-size - 1 additional simulators g, ctx := errgroup.WithContext(ctx) if s.config.DPSize > 1 { @@ -228,6 +215,20 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { go s.kvcacheHelper.Run(ctx) } + if s.config.Dataset.Path == "" && s.config.Dataset.Url == "" && s.config.Dataset.SavePath == "" { + s.dataset = nil + s.logger.Info("No dataset provided, will generate random responses") + } else { + dataset := &Dataset{ + logger: s.logger, + } + err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath) + if err != nil { + return err + } + s.dataset = dataset + } + // run request processing workers for i := 1; i <= s.config.MaxNumSeqs; i++ { go s.reqProcessingWorker(ctx, i) From 1840d68c1a135ff939cab2ac8abd128f2510ec36 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Sun, 14 Sep 2025 21:38:15 +1000 Subject: [PATCH 06/34] Change db structure and add test cases Signed-off-by: Qifan Deng --- .../.llm-d/test.invalid.column.sqlite3 | Bin 0 -> 12288 bytes .../.llm-d/test.invalid.sqlite3 | 1 + .../.llm-d/test.invalid.table.sqlite3 | Bin 0 -> 12288 bytes .../.llm-d/test.invalid.type.sqlite3 | Bin 0 -> 12288 bytes .../.llm-d/test.valid.sqlite3 | Bin 12288 -> 12288 bytes pkg/llm-d-inference-sim/dataset.go | 91 ++++++++++++++++-- pkg/llm-d-inference-sim/dataset_test.go | 66 ++++++++++--- pkg/llm-d-inference-sim/simulator.go | 2 +- 8 files changed, 138 insertions(+), 22 deletions(-) create mode 100644 pkg/llm-d-inference-sim/.llm-d/test.invalid.column.sqlite3 create mode 100644 pkg/llm-d-inference-sim/.llm-d/test.invalid.sqlite3 create mode 100644 pkg/llm-d-inference-sim/.llm-d/test.invalid.table.sqlite3 create mode 100644 pkg/llm-d-inference-sim/.llm-d/test.invalid.type.sqlite3 diff --git a/pkg/llm-d-inference-sim/.llm-d/test.invalid.column.sqlite3 b/pkg/llm-d-inference-sim/.llm-d/test.invalid.column.sqlite3 new file mode 100644 index 0000000000000000000000000000000000000000..b35ad60d695896fe161ccb4a957cca62a20b375a GIT binary patch literal 12288 zcmeI$ze~eF6bJBkZShC3jUc6q!!-_7DS|jSw4kj7t)@RryR=e5rO_7C1e=J8BDlJV zv*6$0B3=CloYcit{1+U(R0>^+tA+35-sRmTkmR$yEA6T+*Wd?sr(JjX3QZGZw9YvZ zg+=;B_G3a6f&M~V*+4z44AWAg7?QrpCsj$RN-x76#6AQd009U<00Izz00bZa0SG`~ zbOc1BdqHBl6rb#{@n|H%uCs1^-!x3K-SE=4p(*LI!prH5s&H?I&j-1PX5DG>j9SR> zqL$62wF=)+Dv92$L&MTt`^c~yzFjKhe>W|CU^|;HE1QZIT!{O|qR9xOmenxMoMTh; z_ImxqwR>ybd!4@Q&0b&c_45=G$)slz0|5aDKmY;|fB*y_009U<00Izzz`qqpv4EOn z-CVqTmkZ2lpHH#()wP+2tJkmF^S4y(W@>uzo*wt)EpY~6%Ly)@+8wju6*0Ny4~qX{ l(g%rwfB*y_009U<00Izz00bZa0SG`~1O?{&>f~SN0zYe4Vio`Z literal 0 HcmV?d00001 diff --git a/pkg/llm-d-inference-sim/.llm-d/test.invalid.sqlite3 b/pkg/llm-d-inference-sim/.llm-d/test.invalid.sqlite3 new file mode 100644 index 00000000..cd087558 --- /dev/null +++ b/pkg/llm-d-inference-sim/.llm-d/test.invalid.sqlite3 @@ -0,0 +1 @@ +Hello world! diff --git a/pkg/llm-d-inference-sim/.llm-d/test.invalid.table.sqlite3 b/pkg/llm-d-inference-sim/.llm-d/test.invalid.table.sqlite3 new file mode 100644 index 0000000000000000000000000000000000000000..b059e36bdfe6951db3910277841886be3e66f10d GIT binary patch literal 12288 zcmeI$zfZzI6bJBkKv0QL6NfGiM;#PHOpJqr7zNWPS_&2xVn|a%6=I-h%wsYoQPtg z2Sp#mq)0-8g_zk;J}QmTaaTdbx$( zeJ(U-d_KqD*EVM#uliqimv5=c?exsj13euoyW%~MKw~~hG;5tr6bfrf9-&4=FG_HEFWldT==)aB9 ztZWT#w7OJj`_|Dxu+uw>IGdJZ`%%Bb|qH&?ZDvKbV0(nE(7{L4kXGd_aL}hAKvO zaZypm#_-8m{6-E=KK@P$0YRR=jzN(M-mZ~4Tna#to|+e5lAoQLSFGR_?C%E=%!@}7 z_Vf#Jb$1QontYDmg5?kc{~@3aJNVgb7+HL*87Du`SCEKS@<`3e$yd@*P%6(a%1Hq+ M6qRC`7byq;0FKBmumAu6 delta 174 zcmZojXh@hKEy%^dz`zW|Fu*!d#~3K6rxU=-|AT>vKaqhyk^lT=L4kXGGC+Z9m^vt( z!^kc!D$3YsJvoNo$S5f%KS`mWC^I*)s8S(2wNeMj$si diff --git a/pkg/llm-d-inference-sim/dataset.go b/pkg/llm-d-inference-sim/dataset.go index a597a96b..b5cc0931 100644 --- a/pkg/llm-d-inference-sim/dataset.go +++ b/pkg/llm-d-inference-sim/dataset.go @@ -38,6 +38,17 @@ type Dataset struct { logger logr.Logger } +// use constants for expected column names and types +const ( + tableName = "llmd" + promptHashCol = "prompt_hash" + genTokensCol = "gen_tokens" + nGenTokensCol = "n_gen_tokens" + promptHashColType = "BLOB" + genTokensColType = "JSON" + nGenTokensColType = "INTEGER" +) + func (d *Dataset) downloadDataset(url string, savePath string) error { // Set up signal handling for Ctrl+C (SIGINT) ctx, cancel := context.WithCancel(context.Background()) @@ -169,7 +180,73 @@ func (pr *progressReader) logProgress(pct int) { } } +func (d *Dataset) verifyDB() error { + rows, err := d.db.Query("PRAGMA table_info(" + tableName + ");") + if err != nil { + return fmt.Errorf("failed to query table info for `%s`: %w", tableName, err) + } + defer func() { + if cerr := rows.Close(); cerr != nil { + d.logger.Error(cerr, "failed to close rows after querying table info") + } + }() + + expectedColumns := map[string]string{ + promptHashCol: promptHashColType, + genTokensCol: genTokensColType, + nGenTokensCol: nGenTokensColType, + } + + columnsFound := make(map[string]bool) + + var ( + columnName string + columnType string + cid int + notnull int + dfltValue interface{} + pk int + ) + + for rows.Next() { + err := rows.Scan(&cid, &columnName, &columnType, ¬null, &dfltValue, &pk) + if err != nil { + return fmt.Errorf("failed to scan table info row: %w", err) + } + if expectedType, exists := expectedColumns[columnName]; exists { + if columnType != expectedType { + return fmt.Errorf("column %s has incorrect type: expected %s, got %s", columnName, expectedType, columnType) + } + columnsFound[columnName] = true + } + } + + for col := range expectedColumns { + if !columnsFound[col] { + return fmt.Errorf("missing expected column in %s table: %s", tableName, col) + } + } + + return nil +} + +func (d *Dataset) getRecordsCount() (int, error) { + var count int + err := d.db.QueryRow("SELECT COUNT(" + promptHashCol + ") FROM " + tableName + ";").Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to query database: %w", err) + } + return count, nil +} + func (d *Dataset) connectToDB(path string) error { + if d.db != nil { + err := d.db.Close() + if err != nil { + d.logger.Error(err, "failed to close existing database connection") + } + d.db = nil + } // check if file exists _, err := os.Stat(path) if err != nil { @@ -180,13 +257,15 @@ func (d *Dataset) connectToDB(path string) error { return fmt.Errorf("failed to open database: %w", err) } - var count int - err = d.db.QueryRow("SELECT COUNT(generated) FROM llmd;").Scan(&count) + err = d.verifyDB() + if err != nil { - err := d.db.Close() - if err != nil { - d.logger.Error(err, "failed to close database after query failure") - } + return fmt.Errorf("failed to verify database: %w", err) + } + + count, err := d.getRecordsCount() + if err != nil { + d.logger.Error(err, "failed to get records count") return fmt.Errorf("failed to query database: %w", err) } d.logger.Info("Database connected successfully", "path", path, "records count", count) diff --git a/pkg/llm-d-inference-sim/dataset_test.go b/pkg/llm-d-inference-sim/dataset_test.go index 24738e5e..1e49bed4 100644 --- a/pkg/llm-d-inference-sim/dataset_test.go +++ b/pkg/llm-d-inference-sim/dataset_test.go @@ -17,6 +17,7 @@ limitations under the License. package llmdinferencesim import ( + "fmt" "os" "github.com/go-logr/logr" @@ -28,19 +29,31 @@ import ( var _ = Describe("Dataset", func() { var ( - dataset *Dataset - file_folder string - savePath string + dataset *Dataset + file_folder string + savePath string + validDBPath string + pathToInvalidDB string + pathNotExist string + pathToInvalidTableDB string + pathToInvalidColumnDB string + pathToInvalidTypeDB string ) BeforeEach(func() { dataset = &Dataset{ logger: logr.Discard(), } - file_folder = "./.llm-d" + file_folder = ".llm-d" savePath = file_folder + "/test.sqlite3" err := os.MkdirAll(file_folder, os.ModePerm) Expect(err).NotTo(HaveOccurred()) + validDBPath = file_folder + "/test.valid.sqlite3" + pathNotExist = file_folder + "/test.notexist.sqlite3" + pathToInvalidDB = file_folder + "/test.invalid.sqlite3" + pathToInvalidTableDB = file_folder + "/test.invalid.table.sqlite3" + pathToInvalidColumnDB = file_folder + "/test.invalid.column.sqlite3" + pathToInvalidTypeDB = file_folder + "/test.invalid.type.sqlite3" }) AfterEach(func() { @@ -72,24 +85,47 @@ var _ = Describe("Dataset", func() { }) It("should successfully init dataset", func() { - validDBPath := file_folder + "/test.valid.sqlite3" err := dataset.Init(validDBPath, "", "") + // debug: get the realpath + wd, _ := os.Getwd() + realpath := fmt.Sprintf("%s/%s", wd, validDBPath) + fmt.Println("Using realpath:", realpath) Expect(err).NotTo(HaveOccurred()) - row := dataset.db.QueryRow("SELECT generated FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2';") - var value string - err = row.Scan(&value) + row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2';") + var n_gen_tokens int + err = row.Scan(&n_gen_tokens) Expect(err).NotTo(HaveOccurred()) - Expect(value).To(Equal("world!")) + Expect(n_gen_tokens).To(Equal(3)) }) - It("should raise err with invalid DB content", func() { - err := dataset.connectToDB(file_folder) + It("should return error for non-existing DB path", func() { + err := dataset.connectToDB(pathNotExist) Expect(err).To(HaveOccurred()) - // read from the db to verify it's not valid - row := dataset.db.QueryRow("SELECT * FROM llmd;") - var value string - err = row.Scan(&value) + Expect(err.Error()).To(ContainSubstring("database file does not exist")) + }) + + It("should return error for invalid DB file", func() { + err := dataset.connectToDB(pathToInvalidDB) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("file is not a database")) + }) + + It("should return error for DB with invalid table", func() { + err := dataset.connectToDB(pathToInvalidTableDB) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to verify database")) + }) + + It("should return error for DB with invalid column", func() { + err := dataset.connectToDB(pathToInvalidColumnDB) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("missing expected column")) + }) + + It("should return error for DB with invalid column type", func() { + err := dataset.connectToDB(pathToInvalidTypeDB) Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("incorrect type")) }) }) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 16727223..0ade7f47 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -220,7 +220,7 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { s.logger.Info("No dataset provided, will generate random responses") } else { dataset := &Dataset{ - logger: s.logger, + logger: s.logger, } err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath) if err != nil { From ce81267e9fd954eca5d4280d9e77d0bc1b98f3cc Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Sun, 14 Sep 2025 21:57:27 +1000 Subject: [PATCH 07/34] fix test Signed-off-by: Qifan Deng --- pkg/llm-d-inference-sim/dataset_test.go | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pkg/llm-d-inference-sim/dataset_test.go b/pkg/llm-d-inference-sim/dataset_test.go index 1e49bed4..aed7baf3 100644 --- a/pkg/llm-d-inference-sim/dataset_test.go +++ b/pkg/llm-d-inference-sim/dataset_test.go @@ -17,7 +17,7 @@ limitations under the License. package llmdinferencesim import ( - "fmt" + "encoding/json" "os" "github.com/go-logr/logr" @@ -86,10 +86,6 @@ var _ = Describe("Dataset", func() { It("should successfully init dataset", func() { err := dataset.Init(validDBPath, "", "") - // debug: get the realpath - wd, _ := os.Getwd() - realpath := fmt.Sprintf("%s/%s", wd, validDBPath) - fmt.Println("Using realpath:", realpath) Expect(err).NotTo(HaveOccurred()) row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2';") @@ -97,6 +93,16 @@ var _ = Describe("Dataset", func() { err = row.Scan(&n_gen_tokens) Expect(err).NotTo(HaveOccurred()) Expect(n_gen_tokens).To(Equal(3)) + + var jsonStr string + row = dataset.db.QueryRow("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2';") + err = row.Scan(&jsonStr) + Expect(err).NotTo(HaveOccurred()) + var tokens []string + err = json.Unmarshal([]byte(jsonStr), &tokens) + Expect(err).NotTo(HaveOccurred()) + Expect(tokens).To(Equal([]string{"Hello", "world", "!"})) + }) It("should return error for non-existing DB path", func() { From 44100d51b895b50eb6fd37225b39fd765f4989a3 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Sun, 14 Sep 2025 22:33:49 +1000 Subject: [PATCH 08/34] remove duplicates in request.go Signed-off-by: Qifan Deng --- pkg/llm-d-inference-sim/dataset_test.go | 6 +++--- pkg/openai-server-api/request.go | 26 +++++++++---------------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/pkg/llm-d-inference-sim/dataset_test.go b/pkg/llm-d-inference-sim/dataset_test.go index aed7baf3..6eb0c88a 100644 --- a/pkg/llm-d-inference-sim/dataset_test.go +++ b/pkg/llm-d-inference-sim/dataset_test.go @@ -110,13 +110,13 @@ var _ = Describe("Dataset", func() { Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("database file does not exist")) }) - + It("should return error for invalid DB file", func() { err := dataset.connectToDB(pathToInvalidDB) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("file is not a database")) }) - + It("should return error for DB with invalid table", func() { err := dataset.connectToDB(pathToInvalidTableDB) Expect(err).To(HaveOccurred()) @@ -128,7 +128,7 @@ var _ = Describe("Dataset", func() { Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("missing expected column")) }) - + It("should return error for DB with invalid column type", func() { err := dataset.connectToDB(pathToInvalidTypeDB) Expect(err).To(HaveOccurred()) diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index e7d5fb3b..81b698c0 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -244,16 +244,21 @@ func (req *ChatCompletionRequest) getLastUserMsg() string { // i.e., an array of generated tokens, the finish reason, and the number of created // tokens func (req ChatCompletionRequest) CreateResponseText(mode string) ([]string, string, int, error) { - maxTokens, err := common.GetMaxTokens(req.MaxCompletionTokens, req.MaxTokens) + return generateResponseText(mode, req.GetMaxCompletionTokens(), req.getLastUserMsg(), req.GetIgnoreEOS()) +} + +// Helper function to generate response text +func generateResponseText(mode string, maxTokens *int64, prompt string, ignoreEOS bool) ([]string, string, int, error) { + maxTokensValue, err := common.GetMaxTokens(nil, maxTokens) if err != nil { return nil, "", 0, err } var text, finishReason string if mode == common.ModeEcho { - text, finishReason = common.GetResponseText(maxTokens, req.getLastUserMsg()) + text, finishReason = common.GetResponseText(maxTokensValue, prompt) } else { - text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS()) + text, finishReason = common.GetRandomResponseText(maxTokensValue, ignoreEOS) } tokens := common.Tokenize(text) @@ -299,18 +304,5 @@ func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 { // i.e., an array of generated tokens, the finish reason, and the number of created // tokens func (req TextCompletionRequest) CreateResponseText(mode string) ([]string, string, int, error) { - maxTokens, err := common.GetMaxTokens(nil, req.MaxTokens) - if err != nil { - return nil, "", 0, err - } - - var text, finishReason string - if mode == common.ModeEcho { - text, finishReason = common.GetResponseText(maxTokens, req.Prompt) - } else { - text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS()) - } - - tokens := common.Tokenize(text) - return tokens, finishReason, len(tokens), nil + return generateResponseText(mode, req.MaxTokens, req.Prompt, req.GetIgnoreEOS()) } From 5a144f158f63beed4ce0ba4c42e808feaa1e13d2 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Sun, 14 Sep 2025 23:05:55 +1000 Subject: [PATCH 09/34] Move token generation to simulator Signed-off-by: Qifan Deng --- .gitignore | 2 +- pkg/llm-d-inference-sim/simulator.go | 2 +- pkg/openai-server-api/request.go | 38 +--------------------------- 3 files changed, 3 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index 419af304..ad03feaf 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,4 @@ vendor *.test manifests/dev-config.yaml pkg/llm-d-inference-sim/.llm-d -.llm-d/ +pkg/llm-d-inference-sim/tests-tmp/ diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 0ade7f47..e0a523a0 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -338,7 +338,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { if toolCalls == nil && err == nil { // Either no tool calls were defined, or we randomly chose not to create tool calls, // so we generate a response text. - responseTokens, finishReason, completionTokens, err = req.CreateResponseText(s.config.Mode) + responseTokens, finishReason, completionTokens, err = s.generateTokens(req) } if err != nil { prefix := "" diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index 81b698c0..a7dcdb63 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -33,10 +33,6 @@ const ( type CompletionRequest interface { // GetRequestID returns the unique request id GetRequestID() string - // CreateResponseText creates and returns response payload based on this request, - // i.e., an array of generated tokens, the finish reason, and the number of created - // tokens - CreateResponseText(mode string) ([]string, string, int, error) // IsStream returns boolean that defines is response should be streamed IsStream() bool // GetModel returns model name as defined in the request @@ -230,7 +226,7 @@ func (c *ChatCompletionRequest) GetMaxCompletionTokens() *int64 { // getLastUserMsg returns last message from this request's messages with user role, // if does not exist - returns an empty string -func (req *ChatCompletionRequest) getLastUserMsg() string { +func (req *ChatCompletionRequest) GetLastUserMsg() string { for i := len(req.Messages) - 1; i >= 0; i-- { if req.Messages[i].Role == RoleUser { return req.Messages[i].Content.PlainText() @@ -240,31 +236,6 @@ func (req *ChatCompletionRequest) getLastUserMsg() string { return "" } -// CreateResponseText creates and returns response payload based on this request, -// i.e., an array of generated tokens, the finish reason, and the number of created -// tokens -func (req ChatCompletionRequest) CreateResponseText(mode string) ([]string, string, int, error) { - return generateResponseText(mode, req.GetMaxCompletionTokens(), req.getLastUserMsg(), req.GetIgnoreEOS()) -} - -// Helper function to generate response text -func generateResponseText(mode string, maxTokens *int64, prompt string, ignoreEOS bool) ([]string, string, int, error) { - maxTokensValue, err := common.GetMaxTokens(nil, maxTokens) - if err != nil { - return nil, "", 0, err - } - - var text, finishReason string - if mode == common.ModeEcho { - text, finishReason = common.GetResponseText(maxTokensValue, prompt) - } else { - text, finishReason = common.GetRandomResponseText(maxTokensValue, ignoreEOS) - } - - tokens := common.Tokenize(text) - return tokens, finishReason, len(tokens), nil -} - // v1/completion // TextCompletionRequest defines structure of /completion request type TextCompletionRequest struct { @@ -299,10 +270,3 @@ func (c *TextCompletionRequest) GetToolChoice() string { func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 { return c.MaxTokens } - -// CreateResponseText creates and returns response payload based on this request, -// i.e., an array of generated tokens, the finish reason, and the number of created -// tokens -func (req TextCompletionRequest) CreateResponseText(mode string) ([]string, string, int, error) { - return generateResponseText(mode, req.MaxTokens, req.Prompt, req.GetIgnoreEOS()) -} From a4cd9a8fea044e690283a5ef20b771b6e7b8d833 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Sun, 14 Sep 2025 23:34:46 +1000 Subject: [PATCH 10/34] Generate tokens instead of strings Signed-off-by: Qifan Deng --- pkg/common/utils.go | 43 +++++++++------------------------------- pkg/common/utils_test.go | 34 +++++++++++++++++-------------- 2 files changed, 28 insertions(+), 49 deletions(-) diff --git a/pkg/common/utils.go b/pkg/common/utils.go index fa853e26..5d5f5848 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -17,7 +17,6 @@ limitations under the License. package common import ( - "fmt" "math" "math/rand" "regexp" @@ -73,26 +72,6 @@ func init() { } } -// returns the max tokens or error if incorrect -func GetMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) { - var typeToken string - var tokens *int64 - // if both arguments are passed, - // use maxCompletionTokens - // as in the real vllm - if maxCompletionTokens != nil { - tokens = maxCompletionTokens - typeToken = "max_completion_tokens" - } else if maxTokens != nil { - tokens = maxTokens - typeToken = "max_tokens" - } - if tokens != nil && *tokens < 1 { - return nil, fmt.Errorf("%s must be at least 1, got %d", typeToken, *tokens) - } - return tokens, nil -} - // ValidateContextWindow checks if the request fits within the model's context window // Returns validation result, actual completion tokens, and total tokens func ValidateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) (bool, int64, int64) { @@ -157,7 +136,7 @@ func GetRandomText(numOfTokens int) string { return strings.Join(allTokens, "") } -// GetRandomResponseText generates text to be returned in a response, and the finish reason (stop or length) +// GetRandomTokens generates tokens to be returned in a response, and the finish reason (stop or length) // if maxCompletionTokens is defined // - currently, the generated number of words in the text will be equal to it value // - in future - need to find statistics about generated tokens distribution and return less tokens in part os requests @@ -167,7 +146,7 @@ func GetRandomText(numOfTokens int) string { // - finish reason is stop // if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens // - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined -func GetRandomResponseText(maxCompletionTokens *int64, ignore_eos bool) (string, string) { +func GetRandomTokens(maxCompletionTokens *int64, ignore_eos bool) ([]string, string) { numOfTokens := 0 finishReason := StopFinishReason @@ -189,8 +168,7 @@ func GetRandomResponseText(maxCompletionTokens *int64, ignore_eos bool) (string, } } - text := GetRandomText(numOfTokens) - return text, finishReason + return Tokenize(GetRandomText(numOfTokens)), finishReason } // getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets. @@ -282,23 +260,20 @@ func calcBucketBoundaries(maxTokens int, bucketIndex int) (start int, end int) { return start, end } -// GetResponseText returns response text, from a given text +// GetResponseTokens returns needed tokens, from a given text // considering max completion tokens if it is not nil, and a finish reason (stop or length) -func GetResponseText(maxCompletionTokens *int64, text string) (string, string) { +func GetResponseTokens(maxCompletionTokens *int64, text string) ([]string, string) { + tokens := Tokenize(text) // no max completion tokens, return entire text if maxCompletionTokens == nil { - return text, StopFinishReason + return tokens, StopFinishReason } - // create tokens from text, splitting by spaces - tokens := Tokenize(text) - - // return entire text if *maxCompletionTokens >= int64(len(tokens)) { - return text, StopFinishReason + return tokens, StopFinishReason } // return truncated text - return strings.Join(tokens[0:*maxCompletionTokens], " "), LengthFinishReason + return tokens[0:*maxCompletionTokens], LengthFinishReason } func RandomNumericString(length int) string { diff --git a/pkg/common/utils_test.go b/pkg/common/utils_test.go index d847df35..21a69e6b 100644 --- a/pkg/common/utils_test.go +++ b/pkg/common/utils_test.go @@ -18,6 +18,7 @@ package common import ( "fmt" + "strings" "time" . "github.com/onsi/ginkgo/v2" @@ -29,16 +30,17 @@ var _ = Describe("Utils", Ordered, func() { InitRandom(time.Now().UnixNano()) }) - Context("GetRandomResponseText", func() { + Context("GetRandomTokens", func() { It("should return complete text", func() { - text, finishReason := GetRandomResponseText(nil, false) + tokens, finishReason := GetRandomTokens(nil, false) + text := strings.Join(tokens, "") Expect(IsValidText(text)).To(BeTrue()) Expect(finishReason).Should(Equal(StopFinishReason)) }) It("should return short text", func() { maxCompletionTokens := int64(2) - text, finishReason := GetRandomResponseText(&maxCompletionTokens, false) - tokensCnt := int64(len(Tokenize(text))) + tokens, finishReason := GetRandomTokens(&maxCompletionTokens, false) + tokensCnt := int64(len(tokens)) Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) if tokensCnt == maxCompletionTokens { Expect(finishReason).To(Equal(LengthFinishReason)) @@ -50,9 +52,10 @@ var _ = Describe("Utils", Ordered, func() { It("should return long text", func() { // return required number of tokens although it is higher than ResponseLenMax maxCompletionTokens := int64(ResponseLenMax * 5) - text, finishReason := GetRandomResponseText(&maxCompletionTokens, false) - tokensCnt := int64(len(Tokenize(text))) + tokens, finishReason := GetRandomTokens(&maxCompletionTokens, false) + tokensCnt := int64(len(tokens)) Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) + text := strings.Join(tokens, "") Expect(IsValidText(text)).To(BeTrue()) if tokensCnt == maxCompletionTokens { Expect(finishReason).To(Equal(LengthFinishReason)) @@ -65,8 +68,8 @@ var _ = Describe("Utils", Ordered, func() { DescribeTable("should return exact num of tokens", func(maxCompletionTokens int) { n := int64(maxCompletionTokens) - text, finishReason := GetRandomResponseText(&n, true) - nGenTokens := int64(len(Tokenize(text))) + tokens, finishReason := GetRandomTokens(&n, true) + nGenTokens := int64(len(tokens)) Expect(nGenTokens).Should(Equal(n)) Expect(finishReason).To(Equal(LengthFinishReason)) }, @@ -80,24 +83,25 @@ var _ = Describe("Utils", Ordered, func() { ) }) - Context("GetResponseText", func() { + Context("GetResponseTokens", func() { theText := "Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime" + theTokens := Tokenize(theText) It("should return the same text since max tokens is not defined", func() { - text, finishReason := GetResponseText(nil, theText) - Expect(text).Should(Equal(theText)) + tokens, finishReason := GetResponseTokens(nil, theText) + Expect(tokens).Should(Equal(theTokens)) Expect(finishReason).Should(Equal(StopFinishReason)) }) It("should return the same text since max tokens is higher than the text length", func() { maxCompletionTokens := int64(1000) - text, finishReason := GetResponseText(&maxCompletionTokens, theText) - Expect(text).Should(Equal(theText)) + tokens, finishReason := GetResponseTokens(&maxCompletionTokens, theText) + Expect(tokens).Should(Equal(theTokens)) Expect(finishReason).Should(Equal(StopFinishReason)) }) It("should return partial text", func() { maxCompletionTokens := int64(2) - text, finishReason := GetResponseText(&maxCompletionTokens, theText) - Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens)) + tokens, finishReason := GetResponseTokens(&maxCompletionTokens, theText) + Expect(int64(len(tokens))).Should(Equal(maxCompletionTokens)) Expect(finishReason).Should(Equal(LengthFinishReason)) }) }) From ac5a575932f6df130182d999b284fa232c2634e7 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Mon, 15 Sep 2025 00:06:38 +1000 Subject: [PATCH 11/34] Move dataset.go to common Signed-off-by: Qifan Deng --- .gitignore | 2 +- .../.llm-d/test.invalid.column.sqlite3 | Bin .../.llm-d/test.invalid.sqlite3 | 0 .../.llm-d/test.invalid.table.sqlite3 | Bin .../.llm-d/test.invalid.type.sqlite3 | Bin .../.llm-d/test.valid.sqlite3 | Bin pkg/common/config.go | 4 +-- .../dataset.go | 30 +++++++++--------- .../dataset_test.go | 4 +-- pkg/common/utils.go | 6 ++-- pkg/common/utils_test.go | 14 ++++---- pkg/llm-d-inference-sim/simulator.go | 6 ++-- 12 files changed, 33 insertions(+), 33 deletions(-) rename pkg/{llm-d-inference-sim => common}/.llm-d/test.invalid.column.sqlite3 (100%) rename pkg/{llm-d-inference-sim => common}/.llm-d/test.invalid.sqlite3 (100%) rename pkg/{llm-d-inference-sim => common}/.llm-d/test.invalid.table.sqlite3 (100%) rename pkg/{llm-d-inference-sim => common}/.llm-d/test.invalid.type.sqlite3 (100%) rename pkg/{llm-d-inference-sim => common}/.llm-d/test.valid.sqlite3 (100%) rename pkg/{llm-d-inference-sim => common}/dataset.go (89%) rename pkg/{llm-d-inference-sim => common}/dataset_test.go (98%) diff --git a/.gitignore b/.gitignore index ad03feaf..9c4df263 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,5 @@ vendor .DS_Store *.test manifests/dev-config.yaml -pkg/llm-d-inference-sim/.llm-d +pkg/common/.llm-d pkg/llm-d-inference-sim/tests-tmp/ diff --git a/pkg/llm-d-inference-sim/.llm-d/test.invalid.column.sqlite3 b/pkg/common/.llm-d/test.invalid.column.sqlite3 similarity index 100% rename from pkg/llm-d-inference-sim/.llm-d/test.invalid.column.sqlite3 rename to pkg/common/.llm-d/test.invalid.column.sqlite3 diff --git a/pkg/llm-d-inference-sim/.llm-d/test.invalid.sqlite3 b/pkg/common/.llm-d/test.invalid.sqlite3 similarity index 100% rename from pkg/llm-d-inference-sim/.llm-d/test.invalid.sqlite3 rename to pkg/common/.llm-d/test.invalid.sqlite3 diff --git a/pkg/llm-d-inference-sim/.llm-d/test.invalid.table.sqlite3 b/pkg/common/.llm-d/test.invalid.table.sqlite3 similarity index 100% rename from pkg/llm-d-inference-sim/.llm-d/test.invalid.table.sqlite3 rename to pkg/common/.llm-d/test.invalid.table.sqlite3 diff --git a/pkg/llm-d-inference-sim/.llm-d/test.invalid.type.sqlite3 b/pkg/common/.llm-d/test.invalid.type.sqlite3 similarity index 100% rename from pkg/llm-d-inference-sim/.llm-d/test.invalid.type.sqlite3 rename to pkg/common/.llm-d/test.invalid.type.sqlite3 diff --git a/pkg/llm-d-inference-sim/.llm-d/test.valid.sqlite3 b/pkg/common/.llm-d/test.valid.sqlite3 similarity index 100% rename from pkg/llm-d-inference-sim/.llm-d/test.valid.sqlite3 rename to pkg/common/.llm-d/test.valid.sqlite3 diff --git a/pkg/common/config.go b/pkg/common/config.go index 73d0cbde..7d8f4ca7 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -182,10 +182,10 @@ type Configuration struct { // SelfSignedCerts enables automatic generation of self-signed certificates for HTTPS SelfSignedCerts bool `yaml:"self-signed-certs" json:"self-signed-certs"` // Dataset configuration for response generation from a dataset. sqlite db file is expected. - Dataset Dataset + Dataset DatasetConf } -type Dataset struct { +type DatasetConf struct { // Path is the local path to the sqlite db file, default is empty // when path is empty Url will be checked Path string `yaml:"path" json:"path"` diff --git a/pkg/llm-d-inference-sim/dataset.go b/pkg/common/dataset.go similarity index 89% rename from pkg/llm-d-inference-sim/dataset.go rename to pkg/common/dataset.go index b5cc0931..ebebc9bf 100644 --- a/pkg/llm-d-inference-sim/dataset.go +++ b/pkg/common/dataset.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package llmdinferencesim +package common import ( "context" @@ -35,7 +35,7 @@ import ( type Dataset struct { db *sql.DB - logger logr.Logger + Logger logr.Logger } // use constants for expected column names and types @@ -60,7 +60,7 @@ func (d *Dataset) downloadDataset(url string, savePath string) error { // Goroutine to listen for signal go func() { <-sigs - d.logger.Info("Interrupt signal received, cancelling download...") + d.Logger.Info("Interrupt signal received, cancelling download...") cancel() }() @@ -71,7 +71,7 @@ func (d *Dataset) downloadDataset(url string, savePath string) error { defer func() { cerr := out.Close() if cerr != nil { - d.logger.Error(cerr, "failed to close file after download") + d.Logger.Error(cerr, "failed to close file after download") } }() @@ -82,7 +82,7 @@ func (d *Dataset) downloadDataset(url string, savePath string) error { defer func() { cerr := resp.Body.Close() if cerr != nil { - d.logger.Error(cerr, "failed to close response body after download") + d.Logger.Error(cerr, "failed to close response body after download") } }() @@ -94,7 +94,7 @@ func (d *Dataset) downloadDataset(url string, savePath string) error { pr := &progressReader{ Reader: resp.Body, total: resp.ContentLength, - logger: d.logger, + logger: d.Logger, ctx: ctx, startTime: time.Now(), hasShownSpeed: false, @@ -105,7 +105,7 @@ func (d *Dataset) downloadDataset(url string, savePath string) error { // Remove incomplete file cerr := os.Remove(savePath) if cerr != nil { - d.logger.Error(cerr, "failed to remove incomplete file after download") + d.Logger.Error(cerr, "failed to remove incomplete file after download") } // If context was cancelled, return a specific error if errors.Is(err, context.Canceled) { @@ -117,7 +117,7 @@ func (d *Dataset) downloadDataset(url string, savePath string) error { if written == 0 { cerr := os.Remove(savePath) if cerr != nil { - d.logger.Error(cerr, "failed to remove empty file after download") + d.Logger.Error(cerr, "failed to remove empty file after download") } return errors.New("downloaded file is empty") } @@ -126,7 +126,7 @@ func (d *Dataset) downloadDataset(url string, savePath string) error { if err := out.Sync(); err != nil { cerr := os.Remove(savePath) if cerr != nil { - d.logger.Error(cerr, "failed to remove incomplete file after download") + d.Logger.Error(cerr, "failed to remove incomplete file after download") } return fmt.Errorf("failed to sync file: %w", err) } @@ -187,7 +187,7 @@ func (d *Dataset) verifyDB() error { } defer func() { if cerr := rows.Close(); cerr != nil { - d.logger.Error(cerr, "failed to close rows after querying table info") + d.Logger.Error(cerr, "failed to close rows after querying table info") } }() @@ -243,7 +243,7 @@ func (d *Dataset) connectToDB(path string) error { if d.db != nil { err := d.db.Close() if err != nil { - d.logger.Error(err, "failed to close existing database connection") + d.Logger.Error(err, "failed to close existing database connection") } d.db = nil } @@ -265,10 +265,10 @@ func (d *Dataset) connectToDB(path string) error { count, err := d.getRecordsCount() if err != nil { - d.logger.Error(err, "failed to get records count") + d.Logger.Error(err, "failed to get records count") return fmt.Errorf("failed to query database: %w", err) } - d.logger.Info("Database connected successfully", "path", path, "records count", count) + d.Logger.Info("Database connected successfully", "path", path, "records count", count) return nil } @@ -294,13 +294,13 @@ func (d *Dataset) Init(path string, url string, savePath string) error { if err != nil { return fmt.Errorf("failed to create parent directory: %w", err) } - d.logger.Info("Downloading dataset from URL", "url", url, "to", savePath) + d.Logger.Info("Downloading dataset from URL", "url", url, "to", savePath) err = d.downloadDataset(url, savePath) if err != nil { return fmt.Errorf("failed to download dataset: %w", err) } } - d.logger.Info("Using dataset from", "path", savePath) + d.Logger.Info("Using dataset from", "path", savePath) return d.connectToDB(savePath) } diff --git a/pkg/llm-d-inference-sim/dataset_test.go b/pkg/common/dataset_test.go similarity index 98% rename from pkg/llm-d-inference-sim/dataset_test.go rename to pkg/common/dataset_test.go index 6eb0c88a..43d7e6a2 100644 --- a/pkg/llm-d-inference-sim/dataset_test.go +++ b/pkg/common/dataset_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package llmdinferencesim +package common import ( "encoding/json" @@ -42,7 +42,7 @@ var _ = Describe("Dataset", func() { BeforeEach(func() { dataset = &Dataset{ - logger: logr.Discard(), + Logger: logr.Discard(), } file_folder = ".llm-d" savePath = file_folder + "/test.sqlite3" diff --git a/pkg/common/utils.go b/pkg/common/utils.go index 5d5f5848..8b5fcc5b 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -146,7 +146,7 @@ func GetRandomText(numOfTokens int) string { // - finish reason is stop // if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens // - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined -func GetRandomTokens(maxCompletionTokens *int64, ignore_eos bool) ([]string, string) { +func GetRandomTokens(maxCompletionTokens *int64, ignore_eos bool, dataset *Dataset) ([]string, string) { numOfTokens := 0 finishReason := StopFinishReason @@ -260,9 +260,9 @@ func calcBucketBoundaries(maxTokens int, bucketIndex int) (start int, end int) { return start, end } -// GetResponseTokens returns needed tokens, from a given text +// EchoResponseTokens returns needed tokens, from a given text // considering max completion tokens if it is not nil, and a finish reason (stop or length) -func GetResponseTokens(maxCompletionTokens *int64, text string) ([]string, string) { +func EchoResponseTokens(maxCompletionTokens *int64, text string) ([]string, string) { tokens := Tokenize(text) // no max completion tokens, return entire text if maxCompletionTokens == nil { diff --git a/pkg/common/utils_test.go b/pkg/common/utils_test.go index 21a69e6b..c76f42bd 100644 --- a/pkg/common/utils_test.go +++ b/pkg/common/utils_test.go @@ -32,14 +32,14 @@ var _ = Describe("Utils", Ordered, func() { Context("GetRandomTokens", func() { It("should return complete text", func() { - tokens, finishReason := GetRandomTokens(nil, false) + tokens, finishReason := GetRandomTokens(nil, false, nil) text := strings.Join(tokens, "") Expect(IsValidText(text)).To(BeTrue()) Expect(finishReason).Should(Equal(StopFinishReason)) }) It("should return short text", func() { maxCompletionTokens := int64(2) - tokens, finishReason := GetRandomTokens(&maxCompletionTokens, false) + tokens, finishReason := GetRandomTokens(&maxCompletionTokens, false, nil) tokensCnt := int64(len(tokens)) Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) if tokensCnt == maxCompletionTokens { @@ -52,7 +52,7 @@ var _ = Describe("Utils", Ordered, func() { It("should return long text", func() { // return required number of tokens although it is higher than ResponseLenMax maxCompletionTokens := int64(ResponseLenMax * 5) - tokens, finishReason := GetRandomTokens(&maxCompletionTokens, false) + tokens, finishReason := GetRandomTokens(&maxCompletionTokens, false, nil) tokensCnt := int64(len(tokens)) Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) text := strings.Join(tokens, "") @@ -68,7 +68,7 @@ var _ = Describe("Utils", Ordered, func() { DescribeTable("should return exact num of tokens", func(maxCompletionTokens int) { n := int64(maxCompletionTokens) - tokens, finishReason := GetRandomTokens(&n, true) + tokens, finishReason := GetRandomTokens(&n, true, nil) nGenTokens := int64(len(tokens)) Expect(nGenTokens).Should(Equal(n)) Expect(finishReason).To(Equal(LengthFinishReason)) @@ -88,19 +88,19 @@ var _ = Describe("Utils", Ordered, func() { theTokens := Tokenize(theText) It("should return the same text since max tokens is not defined", func() { - tokens, finishReason := GetResponseTokens(nil, theText) + tokens, finishReason := EchoResponseTokens(nil, theText) Expect(tokens).Should(Equal(theTokens)) Expect(finishReason).Should(Equal(StopFinishReason)) }) It("should return the same text since max tokens is higher than the text length", func() { maxCompletionTokens := int64(1000) - tokens, finishReason := GetResponseTokens(&maxCompletionTokens, theText) + tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText) Expect(tokens).Should(Equal(theTokens)) Expect(finishReason).Should(Equal(StopFinishReason)) }) It("should return partial text", func() { maxCompletionTokens := int64(2) - tokens, finishReason := GetResponseTokens(&maxCompletionTokens, theText) + tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText) Expect(int64(len(tokens))).Should(Equal(maxCompletionTokens)) Expect(finishReason).Should(Equal(LengthFinishReason)) }) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index e0a523a0..4207cdcd 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -116,7 +116,7 @@ type VllmSimulator struct { // tokenizer is currently used in kv-cache and in /tokenize tokenizer tokenization.Tokenizer // dataset is used for managing dataset files - dataset *Dataset + dataset *common.Dataset } // New creates a new VllmSimulator instance with the given logger @@ -219,8 +219,8 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { s.dataset = nil s.logger.Info("No dataset provided, will generate random responses") } else { - dataset := &Dataset{ - logger: s.logger, + dataset := &common.Dataset{ + Logger: s.logger, } err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath) if err != nil { From c56728a9771b36d2da01db41688742f842801d31 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Mon, 15 Sep 2025 13:13:59 +1000 Subject: [PATCH 12/34] Refactor: abstract dataset and move response generation from common to dataset Signed-off-by: Qifan Deng --- .gitignore | 2 +- pkg/common/utils.go | 238 ------------- pkg/common/utils_test.go | 143 -------- .../.llm-d/test.invalid.column.sqlite3 | Bin .../.llm-d/test.invalid.sqlite3 | 0 .../.llm-d/test.invalid.table.sqlite3 | Bin .../.llm-d/test.invalid.type.sqlite3 | Bin .../.llm-d/test.valid.sqlite3 | Bin .../dataset.go => dataset/custom_dataset.go} | 76 ++++- .../custom_dataset_test.go} | 10 +- pkg/dataset/dataset.go | 319 ++++++++++++++++++ pkg/dataset/dataset_test.go | 203 +++++++++++ pkg/{common => dataset}/test_helpers.go | 2 +- pkg/llm-d-inference-sim/simulator.go | 13 +- pkg/llm-d-inference-sim/simulator_test.go | 9 +- pkg/llm-d-inference-sim/streaming.go | 5 +- pkg/llm-d-inference-sim/tools_test.go | 3 +- pkg/openai-server-api/request.go | 26 +- pkg/openai-server-api/tools_utils.go | 10 +- 19 files changed, 632 insertions(+), 427 deletions(-) rename pkg/{common => dataset}/.llm-d/test.invalid.column.sqlite3 (100%) rename pkg/{common => dataset}/.llm-d/test.invalid.sqlite3 (100%) rename pkg/{common => dataset}/.llm-d/test.invalid.table.sqlite3 (100%) rename pkg/{common => dataset}/.llm-d/test.invalid.type.sqlite3 (100%) rename pkg/{common => dataset}/.llm-d/test.valid.sqlite3 (100%) rename pkg/{common/dataset.go => dataset/custom_dataset.go} (78%) rename pkg/{common/dataset_test.go => dataset/custom_dataset_test.go} (96%) create mode 100644 pkg/dataset/dataset.go create mode 100644 pkg/dataset/dataset_test.go rename pkg/{common => dataset}/test_helpers.go (98%) diff --git a/.gitignore b/.gitignore index 9c4df263..1ee731ff 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,5 @@ vendor .DS_Store *.test manifests/dev-config.yaml -pkg/common/.llm-d +pkg/dataset/.llm-d pkg/llm-d-inference-sim/tests-tmp/ diff --git a/pkg/common/utils.go b/pkg/common/utils.go index 8b5fcc5b..20f0cca8 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -17,61 +17,13 @@ limitations under the License. package common import ( - "math" "math/rand" "regexp" - "strings" "sync" "github.com/google/uuid" ) -const ( - ResponseLenMax = 128 - responseLenMean = 40 - responseLenStddev = 20 - stopFinishReasonProbability = 0.8 - - StopFinishReason = "stop" - LengthFinishReason = "length" - ToolsFinishReason = "tool_calls" - RemoteDecodeFinishReason = "remote_decode" -) - -// this array defines the probabilities for the buckets to be used for the generation of number of tokens in response -var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15} -var cumulativeBucketsProbabilities []float64 - -const ( - flexBucketIndex = 3 - maxFixedBucketSize = 20 -) - -// list of responses to use in random mode for comepltion requests -var chatCompletionFakeResponses = []string{ - `Testing@, #testing 1$ ,2%,3^, [4&*5], 6~, 7-_ + (8 : 9) / \ < > .`, - `Testing, testing 1,2,3.`, - `I am fine, how are you today?`, - `I am your AI assistant, how can I help you today?`, - `Today is a nice sunny day.`, - `The temperature here is twenty-five degrees centigrade.`, - `Today it is partially cloudy and raining.`, - `To be or not to be that is the question.`, - `Alas, poor Yorick! I knew him, Horatio: A fellow of infinite jest`, - `The rest is silence. `, - `Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`, -} - -func init() { - cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities)) - sum := 0.0 - - for i, val := range respLenBucketsProbabilities { - sum += val - cumulativeBucketsProbabilities[i] = sum - } -} - // ValidateContextWindow checks if the request fits within the model's context window // Returns validation result, actual completion tokens, and total tokens func ValidateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) (bool, int64, int64) { @@ -86,196 +38,6 @@ func ValidateContextWindow(promptTokens int, maxCompletionTokens *int64, maxMode return isValid, completionTokens, totalTokens } -// GetRandomResponseLen returns int in range [1, responseLenMax] -// numbers are chosen according a gaussian distribution with mean responseLenMean, and standard deviation responseLenStddev -func GetRandomResponseLen() int { - for { - val := rand.NormFloat64()*responseLenStddev + responseLenMean - if val >= 1 && val <= ResponseLenMax { - return int(math.Round(val)) - } - // else reject and resample - } -} - -// GetRandomFinishReason returns finish reason with the probability for 'stop' as defined by stopFinishReasonProbability -func GetRandomFinishReason() string { - if rand.Float64() < stopFinishReasonProbability { - return StopFinishReason - } - return LengthFinishReason -} - -// GetRandomText generates random text for the required number of tokens, -// select randomly a sentence from chatCompletionFakeResponses, -// if number of tokens is lower than required - select another sentence, -// continue until the required number of tokens is achieved -func GetRandomText(numOfTokens int) string { - allTokens := make([]string, 0) - - for len(allTokens) < numOfTokens { - index := RandomInt(0, len(chatCompletionFakeResponses)-1) - // create tokens from text, splitting by spaces and special characters - tokens := Tokenize(chatCompletionFakeResponses[index]) - remaining := numOfTokens - len(allTokens) - - if len(tokens) > remaining { - // there is too many tokens, append only the relevant part - tokens = tokens[:remaining] - } - - if len(allTokens) > 0 { - // for not first sentences add space to the first token to separate between sentences without adding an additional token - tokens[0] = " " + tokens[0] - } - - allTokens = append(allTokens, tokens...) - } - - // return all tokens as text - return strings.Join(allTokens, "") -} - -// GetRandomTokens generates tokens to be returned in a response, and the finish reason (stop or length) -// if maxCompletionTokens is defined -// - currently, the generated number of words in the text will be equal to it value -// - in future - need to find statistics about generated tokens distribution and return less tokens in part os requests -// - finish reason will be chosen randomly from the collection (stop, length) with 80% for stop and 20% for length -// if maxCompletionTokens is nil -// - the response text's length is randomly chosen from the range [1, responseLenMax] according additional parameters -// - finish reason is stop -// if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens -// - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined -func GetRandomTokens(maxCompletionTokens *int64, ignore_eos bool, dataset *Dataset) ([]string, string) { - numOfTokens := 0 - finishReason := StopFinishReason - - // no max completion tokens, return text with random length - if maxCompletionTokens == nil { - numOfTokens = GetRandomResponseLen() - } else { - maxTokens := int(*maxCompletionTokens) - if ignore_eos { - numOfTokens = maxTokens - finishReason = LengthFinishReason - } else { - // max tokens is defined - generate real length of the response based on it - numOfTokens = getResponseLengthByHistogram(maxTokens) - if numOfTokens == maxTokens { - // if response should be create with maximum number of tokens - finish reason will be 'length' - finishReason = LengthFinishReason - } - } - } - - return Tokenize(GetRandomText(numOfTokens)), finishReason -} - -// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets. -// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities. -// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value. -// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens. -// Other values define probabilities for the equally sized buckets. -// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens] -func getResponseLengthByHistogram(maxTokens int) int { - if maxTokens <= 1 { - return maxTokens - } - // maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens] - if maxTokens <= len(cumulativeBucketsProbabilities) { - res := RandomInt(1, maxTokens) - return res - } - - r := RandomFloat(0, 1) - - // check if r is in the last bucket, then maxTokens should be returned - if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] { - return maxTokens - } - - // determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use - // initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1 - bucketIndex := len(cumulativeBucketsProbabilities) - 1 - for i, c := range cumulativeBucketsProbabilities { - if r <= c { - bucketIndex = i - break - } - } - - // calculate the size of all of the buckets (except the special last bucket) - start, end := calcBucketBoundaries(maxTokens, bucketIndex) - - // pick uniformly within the bucket’s range - return RandomInt(start, end) -} - -// calcBucketBoundaries calculates boundaries of a bucket with the given index. -// Maximum size for equally sized buckets is defined by maxFixedBucketSize. -// [maxFixedBucketSize*(number-of-buckets-1)+1] is the value of maxTokens for which -// division to equally size buckets will give buckets with size maxFixedBucketSize. -// If maxTokens is [maxFixedBucketSize*(number-of-buckets-1)+1] or less, -// all buckets will be of equal size, except the last bucket, which contains only one value. -// If maxTokens is higher than [maxFixedBucketSize*(number-of-buckets-1)+1], -// and flexBucketIndex is valid (between 0 and number of buckets - 1) the buckets sizes will not be equal. -// In this case, all buckets except the one at flexBucketIndex index will have size 20 (and the last is with size 1), -// and the bucket at flexBucketIndex index will 'stretch' to cover the remaining range. -func calcBucketBoundaries(maxTokens int, bucketIndex int) (start int, end int) { - maxEquallyBucketsSz := maxFixedBucketSize*(len(cumulativeBucketsProbabilities)-1) + 1 - - if maxTokens <= maxEquallyBucketsSz || flexBucketIndex < 0 || flexBucketIndex >= len(cumulativeBucketsProbabilities)-1 { - // create equally size buckets - // calculate the size of all of the buckets (except the special last bucket) - bucketSize := float64(maxTokens-1) / float64(len(cumulativeBucketsProbabilities)-1) - start = int(bucketSize*float64(bucketIndex)) + 1 - end = int(bucketSize * float64(bucketIndex+1)) - } else { - // create non-equally sized buckets and find boundaries of the required bucket - if bucketIndex < flexBucketIndex { - // the relevant bucket is before the flex bucket, all buckets are of the same size (maxFixedBucketSize) - // start is the minimum number in the required bucket - start = maxFixedBucketSize*bucketIndex + 1 - end = maxFixedBucketSize * (bucketIndex + 1) - } else { - flexBucketSize := maxTokens - (maxFixedBucketSize * (len(cumulativeBucketsProbabilities) - 2)) - - if bucketIndex == flexBucketIndex { - // the relevant bucket is the flex bucket - start = int(maxFixedBucketSize*float64(bucketIndex)) + 1 - end = maxFixedBucketSize*bucketIndex + flexBucketSize - } else { - // the relevant bucket is one of buckets after the flex bucket - start = int(maxFixedBucketSize*float64(bucketIndex-1)) + flexBucketSize + 1 - end = maxFixedBucketSize*bucketIndex + flexBucketSize - } - } - } - - // sometimes end could be maxTokens because of rounding, change the value to maxToken-1 - if end >= maxTokens { - end = maxTokens - 1 - } - - return start, end -} - -// EchoResponseTokens returns needed tokens, from a given text -// considering max completion tokens if it is not nil, and a finish reason (stop or length) -func EchoResponseTokens(maxCompletionTokens *int64, text string) ([]string, string) { - tokens := Tokenize(text) - // no max completion tokens, return entire text - if maxCompletionTokens == nil { - return tokens, StopFinishReason - } - - if *maxCompletionTokens >= int64(len(tokens)) { - return tokens, StopFinishReason - } - // return truncated text - return tokens[0:*maxCompletionTokens], LengthFinishReason -} - func RandomNumericString(length int) string { digits := "0123456789" result := make([]byte, length) diff --git a/pkg/common/utils_test.go b/pkg/common/utils_test.go index c76f42bd..9a0af043 100644 --- a/pkg/common/utils_test.go +++ b/pkg/common/utils_test.go @@ -17,8 +17,6 @@ limitations under the License. package common import ( - "fmt" - "strings" "time" . "github.com/onsi/ginkgo/v2" @@ -30,82 +28,6 @@ var _ = Describe("Utils", Ordered, func() { InitRandom(time.Now().UnixNano()) }) - Context("GetRandomTokens", func() { - It("should return complete text", func() { - tokens, finishReason := GetRandomTokens(nil, false, nil) - text := strings.Join(tokens, "") - Expect(IsValidText(text)).To(BeTrue()) - Expect(finishReason).Should(Equal(StopFinishReason)) - }) - It("should return short text", func() { - maxCompletionTokens := int64(2) - tokens, finishReason := GetRandomTokens(&maxCompletionTokens, false, nil) - tokensCnt := int64(len(tokens)) - Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) - if tokensCnt == maxCompletionTokens { - Expect(finishReason).To(Equal(LengthFinishReason)) - } else { - Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) - Expect(finishReason).To(Equal(StopFinishReason)) - } - }) - It("should return long text", func() { - // return required number of tokens although it is higher than ResponseLenMax - maxCompletionTokens := int64(ResponseLenMax * 5) - tokens, finishReason := GetRandomTokens(&maxCompletionTokens, false, nil) - tokensCnt := int64(len(tokens)) - Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) - text := strings.Join(tokens, "") - Expect(IsValidText(text)).To(BeTrue()) - if tokensCnt == maxCompletionTokens { - Expect(finishReason).To(Equal(LengthFinishReason)) - } else { - Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) - Expect(finishReason).To(Equal(StopFinishReason)) - } - }) - - DescribeTable("should return exact num of tokens", - func(maxCompletionTokens int) { - n := int64(maxCompletionTokens) - tokens, finishReason := GetRandomTokens(&n, true, nil) - nGenTokens := int64(len(tokens)) - Expect(nGenTokens).Should(Equal(n)) - Expect(finishReason).To(Equal(LengthFinishReason)) - }, - func(maxCompletionTokens int) string { - return fmt.Sprintf("maxCompletionTokens: %d", maxCompletionTokens) - }, - Entry("1", 1), - Entry("42", 42), - Entry("99", 99), - Entry("10000", 10000), - ) - }) - - Context("GetResponseTokens", func() { - theText := "Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime" - theTokens := Tokenize(theText) - - It("should return the same text since max tokens is not defined", func() { - tokens, finishReason := EchoResponseTokens(nil, theText) - Expect(tokens).Should(Equal(theTokens)) - Expect(finishReason).Should(Equal(StopFinishReason)) - }) - It("should return the same text since max tokens is higher than the text length", func() { - maxCompletionTokens := int64(1000) - tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText) - Expect(tokens).Should(Equal(theTokens)) - Expect(finishReason).Should(Equal(StopFinishReason)) - }) - It("should return partial text", func() { - maxCompletionTokens := int64(2) - tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText) - Expect(int64(len(tokens))).Should(Equal(maxCompletionTokens)) - Expect(finishReason).Should(Equal(LengthFinishReason)) - }) - }) - Context("validateContextWindow", func() { It("should pass when total tokens are within limit", func() { promptTokens := 100 @@ -150,69 +72,4 @@ var _ = Describe("Utils", Ordered, func() { }) }) - Context("GetRandomText", func() { - lenArr := []int{5, 20, 50, 150} - - for _, len := range lenArr { - name := fmt.Sprintf("should return text with %d tokens", len) - It(name, func() { - text := GetRandomText(len) - Expect(Tokenize(text)).Should(HaveLen(len)) - }) - } - }) - - Context("IsValidText", func() { - validTxts := make([]string, 0) - invalidTxts := make([]string, 0) - - validTxts = append(validTxts, chatCompletionFakeResponses[0][:4]) - validTxts = append(validTxts, chatCompletionFakeResponses[1]) - validTxts = append(validTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2]) - - invalidTxts = append(invalidTxts, (chatCompletionFakeResponses[1] + " " + chatCompletionFakeResponses[2])[3:4]) - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[0][4:]) - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+"-"+chatCompletionFakeResponses[2]) - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" ") - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2]) - - for _, txt := range validTxts { - It("text should be valid", func() { - Expect(IsValidText(txt)).To(BeTrue()) - }) - } - - for _, txt := range invalidTxts { - It("text should be invalid", func() { - Expect(IsValidText(txt)).To(BeFalse()) - }) - } - }) - - Context("validateBucketsBoundaries", func() { - type bucketBoundaries struct { - start int - end int - } - type bucketTest struct { - maxTokens int - expectedBuckets []bucketBoundaries - } - - tests := []bucketTest{{500, []bucketBoundaries{{1, 20}, {21, 40}, {41, 60}, {61, 480}, {481, 499}}}, - {47, []bucketBoundaries{{1, 9}, {10, 18}, {19, 27}, {28, 36}, {37, 46}}}, - {50, []bucketBoundaries{{1, 9}, {10, 19}, {20, 29}, {30, 39}, {40, 49}}}} - - for _, test := range tests { - Expect(test.expectedBuckets).To(HaveLen(len(cumulativeBucketsProbabilities) - 1)) - - It(fmt.Sprintf("should return bucket boundaries for maxTokens %d", test.maxTokens), func() { - for i := range len(cumulativeBucketsProbabilities) - 1 { - start, end := calcBucketBoundaries(test.maxTokens, i) - Expect(start).To(Equal(test.expectedBuckets[i].start)) - Expect(end).To(Equal(test.expectedBuckets[i].end)) - } - }) - } - }) }) diff --git a/pkg/common/.llm-d/test.invalid.column.sqlite3 b/pkg/dataset/.llm-d/test.invalid.column.sqlite3 similarity index 100% rename from pkg/common/.llm-d/test.invalid.column.sqlite3 rename to pkg/dataset/.llm-d/test.invalid.column.sqlite3 diff --git a/pkg/common/.llm-d/test.invalid.sqlite3 b/pkg/dataset/.llm-d/test.invalid.sqlite3 similarity index 100% rename from pkg/common/.llm-d/test.invalid.sqlite3 rename to pkg/dataset/.llm-d/test.invalid.sqlite3 diff --git a/pkg/common/.llm-d/test.invalid.table.sqlite3 b/pkg/dataset/.llm-d/test.invalid.table.sqlite3 similarity index 100% rename from pkg/common/.llm-d/test.invalid.table.sqlite3 rename to pkg/dataset/.llm-d/test.invalid.table.sqlite3 diff --git a/pkg/common/.llm-d/test.invalid.type.sqlite3 b/pkg/dataset/.llm-d/test.invalid.type.sqlite3 similarity index 100% rename from pkg/common/.llm-d/test.invalid.type.sqlite3 rename to pkg/dataset/.llm-d/test.invalid.type.sqlite3 diff --git a/pkg/common/.llm-d/test.valid.sqlite3 b/pkg/dataset/.llm-d/test.valid.sqlite3 similarity index 100% rename from pkg/common/.llm-d/test.valid.sqlite3 rename to pkg/dataset/.llm-d/test.valid.sqlite3 diff --git a/pkg/common/dataset.go b/pkg/dataset/custom_dataset.go similarity index 78% rename from pkg/common/dataset.go rename to pkg/dataset/custom_dataset.go index ebebc9bf..1ec89615 100644 --- a/pkg/common/dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -14,14 +14,16 @@ See the License for the specific language governing permissions and limitations under the License. */ -package common +package dataset import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "io" + "math/rand" "net/http" "os" "os/signal" @@ -30,12 +32,14 @@ import ( "time" "github.com/go-logr/logr" + "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" ) -type Dataset struct { - db *sql.DB - Logger logr.Logger +type CustomDataset struct { + Dataset + db *sql.DB + hasWarned bool } // use constants for expected column names and types @@ -49,7 +53,7 @@ const ( nGenTokensColType = "INTEGER" ) -func (d *Dataset) downloadDataset(url string, savePath string) error { +func (d CustomDataset) downloadDataset(url string, savePath string) error { // Set up signal handling for Ctrl+C (SIGINT) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -180,7 +184,7 @@ func (pr *progressReader) logProgress(pct int) { } } -func (d *Dataset) verifyDB() error { +func (d CustomDataset) verifyDB() error { rows, err := d.db.Query("PRAGMA table_info(" + tableName + ");") if err != nil { return fmt.Errorf("failed to query table info for `%s`: %w", tableName, err) @@ -230,7 +234,7 @@ func (d *Dataset) verifyDB() error { return nil } -func (d *Dataset) getRecordsCount() (int, error) { +func (d CustomDataset) getRecordsCount() (int, error) { var count int err := d.db.QueryRow("SELECT COUNT(" + promptHashCol + ") FROM " + tableName + ";").Scan(&count) if err != nil { @@ -239,7 +243,7 @@ func (d *Dataset) getRecordsCount() (int, error) { return count, nil } -func (d *Dataset) connectToDB(path string) error { +func (d CustomDataset) connectToDB(path string) error { if d.db != nil { err := d.db.Close() if err != nil { @@ -273,7 +277,8 @@ func (d *Dataset) connectToDB(path string) error { return nil } -func (d *Dataset) Init(path string, url string, savePath string) error { +func (d CustomDataset) Init(path string, url string, savePath string) error { + d.hasWarned = false if path != "" { return d.connectToDB(path) } @@ -307,9 +312,60 @@ func (d *Dataset) Init(path string, url string, savePath string) error { return errors.New("no dataset path or url provided") } -func (d *Dataset) Close() error { +func (d CustomDataset) Close() error { if d.db != nil { return d.db.Close() } return nil } + +func unmarshalAllRecords(rows *sql.Rows) ([][]string, error) { + var tokensList [][]string + for rows.Next() { + var tokensJSON string + if err := rows.Scan(&tokensJSON); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + + var tokens []string + if err := json.Unmarshal([]byte(tokensJSON), &tokens); err != nil { + return nil, fmt.Errorf("failed to unmarshal tokens JSON: %w", err) + } + tokensList = append(tokensList, tokens) + } + return tokensList, nil +} + +func (d CustomDataset) getRandomTokens(n_gen_tokens int) []string { + return nil +} + +func (d *CustomDataset) GetTokens(prompt string, n_gen_tokens int) []string { + promptHash := uuid.NewSHA1(uuid.NameSpaceOID, []byte(prompt)).NodeID() + rows, err := d.db.Query("SELECT "+genTokensCol+" FROM "+tableName+" WHERE "+promptHashCol+" = ?;", promptHash) + if err != nil { + if !d.hasWarned { + d.Logger.Error(err, "failed to query database. Ensure the prompt hash exists in the dataset. Will generate random tokens instead.") + d.hasWarned = true + } + return d.getRandomTokens(n_gen_tokens) + } + defer func() { + if cerr := rows.Close(); cerr != nil { + d.Logger.Error(cerr, "failed to close rows after query") + } + }() + + tokensList, err := unmarshalAllRecords(rows) + if err != nil { + d.Logger.Error(err, "failed to unmarshal records from database") + return d.getRandomTokens(n_gen_tokens) + } + + if len(tokensList) == 0 { + return d.getRandomTokens(n_gen_tokens) + } + d.hasWarned = false + randIndex := rand.Intn(len(tokensList)) + return tokensList[randIndex] +} diff --git a/pkg/common/dataset_test.go b/pkg/dataset/custom_dataset_test.go similarity index 96% rename from pkg/common/dataset_test.go rename to pkg/dataset/custom_dataset_test.go index 43d7e6a2..a43656de 100644 --- a/pkg/common/dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package common +package dataset import ( "encoding/json" @@ -29,7 +29,7 @@ import ( var _ = Describe("Dataset", func() { var ( - dataset *Dataset + dataset *CustomDataset file_folder string savePath string validDBPath string @@ -41,8 +41,10 @@ var _ = Describe("Dataset", func() { ) BeforeEach(func() { - dataset = &Dataset{ - Logger: logr.Discard(), + dataset = &CustomDataset{ + Dataset: Dataset{ + Logger: logr.Discard(), + }, } file_folder = ".llm-d" savePath = file_folder + "/test.sqlite3" diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go new file mode 100644 index 00000000..7f293f74 --- /dev/null +++ b/pkg/dataset/dataset.go @@ -0,0 +1,319 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dataset + +import ( + "errors" + "math" + "math/rand" + + "github.com/go-logr/logr" + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" + _ "github.com/mattn/go-sqlite3" +) + +const ( + RoleAssistant = "assistant" + RoleUser = "user" +) + +const ( + ResponseLenMax = 128 + responseLenMean = 40 + responseLenStddev = 20 + stopFinishReasonProbability = 0.8 + + StopFinishReason = "stop" + LengthFinishReason = "length" + ToolsFinishReason = "tool_calls" + RemoteDecodeFinishReason = "remote_decode" +) + +// this array defines the probabilities for the buckets to be used for the generation of number of tokens in response +var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15} +var cumulativeBucketsProbabilities []float64 + +const ( + flexBucketIndex = 3 + maxFixedBucketSize = 20 +) + +// list of responses to use in random mode for completion requests +var chatCompletionFakeResponses = []string{ + `Testing@, #testing 1$ ,2%,3^, [4&*5], 6~, 7-_ + (8 : 9) / \ < > .`, + `Testing, testing 1,2,3.`, + `I am fine, how are you today?`, + `I am your AI assistant, how can I help you today?`, + `Today is a nice sunny day.`, + `The temperature here is twenty-five degrees centigrade.`, + `Today it is partially cloudy and raining.`, + `To be or not to be that is the question.`, + `Alas, poor Yorick! I knew him, Horatio: A fellow of infinite jest`, + `The rest is silence. `, + `Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`, +} + +func init() { + cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities)) + sum := 0.0 + + for i, val := range respLenBucketsProbabilities { + sum += val + cumulativeBucketsProbabilities[i] = sum + } +} + +// GetRandomResponseLen returns int in range [1, responseLenMax] +// numbers are chosen according a gaussian distribution with mean responseLenMean, and standard deviation responseLenStddev +func GetRandomResponseLen() int { + for { + val := rand.NormFloat64()*responseLenStddev + responseLenMean + if val >= 1 && val <= ResponseLenMax { + return int(math.Round(val)) + } + // else reject and resample + } +} + +// GetRandomFinishReason returns finish reason with the probability for 'stop' as defined by stopFinishReasonProbability +func GetRandomFinishReason() string { + if rand.Float64() < stopFinishReasonProbability { + return StopFinishReason + } + return LengthFinishReason +} + +// GenPresetRandomTokens generates random tokens for the required number of tokens, +// select randomly a sentence from chatCompletionFakeResponses, +// if number of tokens is lower than required - select another sentence, +// continue until the required number of tokens is achieved +func GenPresetRandomTokens(numOfTokens int) []string { + allTokens := make([]string, 0) + + for len(allTokens) < numOfTokens { + index := common.RandomInt(0, len(chatCompletionFakeResponses)-1) + // create tokens from text, splitting by spaces and special characters + tokens := common.Tokenize(chatCompletionFakeResponses[index]) + remaining := numOfTokens - len(allTokens) + + if len(tokens) > remaining { + // there is too many tokens, append only the relevant part + tokens = tokens[:remaining] + } + + if len(allTokens) > 0 { + // for not first sentences add space to the first token to separate between sentences without adding an additional token + tokens[0] = " " + tokens[0] + } + + allTokens = append(allTokens, tokens...) + } + + return allTokens +} + +// howManyTokensToGen generates the number of tokens to be returned in a response, and the finish reason (see constants) +// if maxCompletionTokens is defined +// - currently, the generated number of words in the text will be equal to it value +// - in future - need to find statistics about generated tokens distribution and return less tokens in part os requests +// - finish reason will be chosen randomly from the collection (stop, length) with 80% for stop and 20% for length +// if maxCompletionTokens is nil +// - the response text's length is randomly chosen from the range [1, responseLenMax] according additional parameters +// - finish reason is stop +// if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens +// - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined +func howManyTokensToGen(maxCompletionTokens *int64, ignore_eos bool) (int, string) { + numOfTokens := 0 + finishReason := StopFinishReason + + // no max completion tokens, return text with random length + if maxCompletionTokens == nil { + numOfTokens = GetRandomResponseLen() + } else { + maxTokens := int(*maxCompletionTokens) + if ignore_eos { + numOfTokens = maxTokens + finishReason = LengthFinishReason + } else { + // max tokens is defined - generate real length of the response based on it + numOfTokens = getResponseLengthByHistogram(maxTokens) + if numOfTokens == maxTokens { + // if response should be create with maximum number of tokens - finish reason will be 'length' + finishReason = LengthFinishReason + } + } + } + + return numOfTokens, finishReason +} + +// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets. +// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities. +// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value. +// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens. +// Other values define probabilities for the equally sized buckets. +// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens] +func getResponseLengthByHistogram(maxTokens int) int { + if maxTokens <= 1 { + return maxTokens + } + // maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens] + if maxTokens <= len(cumulativeBucketsProbabilities) { + res := common.RandomInt(1, maxTokens) + return res + } + + r := common.RandomFloat(0, 1) + + // check if r is in the last bucket, then maxTokens should be returned + if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] { + return maxTokens + } + + // determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use + // initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1 + bucketIndex := len(cumulativeBucketsProbabilities) - 1 + for i, c := range cumulativeBucketsProbabilities { + if r <= c { + bucketIndex = i + break + } + } + + // calculate the size of all of the buckets (except the special last bucket) + start, end := calcBucketBoundaries(maxTokens, bucketIndex) + + // pick uniformly within the bucket’s range + return common.RandomInt(start, end) +} + +// calcBucketBoundaries calculates boundaries of a bucket with the given index. +// Maximum size for equally sized buckets is defined by maxFixedBucketSize. +// [maxFixedBucketSize*(number-of-buckets-1)+1] is the value of maxTokens for which +// division to equally size buckets will give buckets with size maxFixedBucketSize. +// If maxTokens is [maxFixedBucketSize*(number-of-buckets-1)+1] or less, +// all buckets will be of equal size, except the last bucket, which contains only one value. +// If maxTokens is higher than [maxFixedBucketSize*(number-of-buckets-1)+1], +// and flexBucketIndex is valid (between 0 and number of buckets - 1) the buckets sizes will not be equal. +// In this case, all buckets except the one at flexBucketIndex index will have size 20 (and the last is with size 1), +// and the bucket at flexBucketIndex index will 'stretch' to cover the remaining range. +func calcBucketBoundaries(maxTokens int, bucketIndex int) (start int, end int) { + maxEquallyBucketsSz := maxFixedBucketSize*(len(cumulativeBucketsProbabilities)-1) + 1 + + if maxTokens <= maxEquallyBucketsSz || flexBucketIndex < 0 || flexBucketIndex >= len(cumulativeBucketsProbabilities)-1 { + // create equally size buckets + // calculate the size of all of the buckets (except the special last bucket) + bucketSize := float64(maxTokens-1) / float64(len(cumulativeBucketsProbabilities)-1) + start = int(bucketSize*float64(bucketIndex)) + 1 + end = int(bucketSize * float64(bucketIndex+1)) + } else { + // create non-equally sized buckets and find boundaries of the required bucket + if bucketIndex < flexBucketIndex { + // the relevant bucket is before the flex bucket, all buckets are of the same size (maxFixedBucketSize) + // start is the minimum number in the required bucket + start = maxFixedBucketSize*bucketIndex + 1 + end = maxFixedBucketSize * (bucketIndex + 1) + } else { + flexBucketSize := maxTokens - (maxFixedBucketSize * (len(cumulativeBucketsProbabilities) - 2)) + + if bucketIndex == flexBucketIndex { + // the relevant bucket is the flex bucket + start = int(maxFixedBucketSize*float64(bucketIndex)) + 1 + end = maxFixedBucketSize*bucketIndex + flexBucketSize + } else { + // the relevant bucket is one of buckets after the flex bucket + start = int(maxFixedBucketSize*float64(bucketIndex-1)) + flexBucketSize + 1 + end = maxFixedBucketSize*bucketIndex + flexBucketSize + } + } + } + + // sometimes end could be maxTokens because of rounding, change the value to maxToken-1 + if end >= maxTokens { + end = maxTokens - 1 + } + + return start, end +} + +// EchoResponseTokens returns needed tokens, from a given text +// considering max completion tokens if it is not nil, and a finish reason (stop or length) +func EchoResponseTokens(maxCompletionTokens *int64, text string) ([]string, string) { + tokens := common.Tokenize(text) + // no max completion tokens, return entire text + if maxCompletionTokens == nil { + return tokens, StopFinishReason + } + + if *maxCompletionTokens >= int64(len(tokens)) { + return tokens, StopFinishReason + } + // return truncated text + return tokens[0:*maxCompletionTokens], LengthFinishReason +} + +type Dataset struct { + Logger logr.Logger +} + +func (d *Dataset) Init(path string, url string, savePath string) error { + return nil +} + +func (d *Dataset) Close() error { + return nil +} + +func (d *Dataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) { + nMaxTokens := d.extractMaxTokens(req) + if mode == common.ModeEcho { + prompt, err := d.extractPrompt(req) + if err != nil { + return nil, "", err + } + tokens, finishReason := EchoResponseTokens(nMaxTokens, prompt) + return tokens, finishReason, nil + } + + nTokensToGen, finishReason := howManyTokensToGen(nMaxTokens, req.GetIgnoreEOS()) + tokens, err := d.GenerateTokens(req, nTokensToGen) + return tokens, finishReason, err +} + +func (d *Dataset) extractMaxTokens(req openaiserverapi.CompletionRequest) *int64 { + if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok { + return chatReq.GetMaxCompletionTokens() + } else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok { + return textReq.MaxTokens + } + return nil +} + +func (d *Dataset) extractPrompt(req openaiserverapi.CompletionRequest) (string, error) { + if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok { + return chatReq.GetLastUserMsg(), nil + } else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok { + return textReq.GetPrompt(), nil + } + return "", errors.New("unknown request type") +} + +func (d *Dataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) { + tokens := GenPresetRandomTokens(nTokens) + return tokens, nil +} diff --git a/pkg/dataset/dataset_test.go b/pkg/dataset/dataset_test.go new file mode 100644 index 00000000..37a8abba --- /dev/null +++ b/pkg/dataset/dataset_test.go @@ -0,0 +1,203 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dataset + +import ( + "fmt" + "strings" + + "github.com/go-logr/logr" + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Utils", Ordered, func() { + var ( + dataset *Dataset + ) + + BeforeEach(func() { + dataset = &Dataset{ + Logger: logr.Discard(), + } + }) + + Context("GetRandomTokens", func() { + It("should return complete text", func() { + var n int64 + req := &openaiserverapi.ChatCompletionRequest{ + MaxTokens: &n, + MaxCompletionTokens: &n, + } + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).ShouldNot(HaveOccurred()) + text := strings.Join(tokens, "") + Expect(IsValidText(text)).To(BeTrue()) + Expect(finishReason).Should(Equal(StopFinishReason)) + }) + It("should return short text", func() { + maxCompletionTokens := int64(2) + req := &openaiserverapi.ChatCompletionRequest{ + MaxCompletionTokens: &maxCompletionTokens, + } + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).ShouldNot(HaveOccurred()) + tokensCnt := int64(len(tokens)) + Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) + if tokensCnt == maxCompletionTokens { + Expect(finishReason).To(Equal(LengthFinishReason)) + } else { + Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) + Expect(finishReason).To(Equal(StopFinishReason)) + } + }) + It("should return long text", func() { + // return required number of tokens although it is higher than ResponseLenMax + maxCompletionTokens := int64(ResponseLenMax * 5) + req := &openaiserverapi.ChatCompletionRequest{ + MaxTokens: &maxCompletionTokens, + } + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).ShouldNot(HaveOccurred()) + tokensCnt := int64(len(tokens)) + Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) + text := strings.Join(tokens, "") + Expect(IsValidText(text)).To(BeTrue()) + if tokensCnt == maxCompletionTokens { + Expect(finishReason).To(Equal(LengthFinishReason)) + } else { + Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) + Expect(finishReason).To(Equal(StopFinishReason)) + } + }) + + DescribeTable("should return exact num of tokens", + func(maxCompletionTokens int) { + n := int64(maxCompletionTokens) + req := &openaiserverapi.ChatCompletionRequest{ + BaseCompletionRequest: openaiserverapi.BaseCompletionRequest{ + IgnoreEOS: true, + }, + MaxTokens: &n, + } + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).ShouldNot(HaveOccurred()) + nGenTokens := int64(len(tokens)) + Expect(nGenTokens).Should(Equal(n)) + Expect(finishReason).To(Equal(LengthFinishReason)) + }, + func(maxCompletionTokens int) string { + return fmt.Sprintf("maxCompletionTokens: %d", maxCompletionTokens) + }, + Entry("1", 1), + Entry("42", 42), + Entry("99", 99), + Entry("10000", 10000), + ) + }) + + Context("GetResponseTokens", func() { + theText := "Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime" + theTokens := common.Tokenize(theText) + + It("should return the same text since max tokens is not defined", func() { + tokens, finishReason := EchoResponseTokens(nil, theText) + Expect(tokens).Should(Equal(theTokens)) + Expect(finishReason).Should(Equal(StopFinishReason)) + }) + It("should return the same text since max tokens is higher than the text length", func() { + maxCompletionTokens := int64(1000) + tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText) + Expect(tokens).Should(Equal(theTokens)) + Expect(finishReason).Should(Equal(StopFinishReason)) + }) + It("should return partial text", func() { + maxCompletionTokens := int64(2) + tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText) + Expect(int64(len(tokens))).Should(Equal(maxCompletionTokens)) + Expect(finishReason).Should(Equal(LengthFinishReason)) + }) + }) + + Context("GetRandomTokens", func() { + lenArr := []int{5, 20, 50, 150} + + for _, len := range lenArr { + name := fmt.Sprintf("should return text with %d tokens", len) + It(name, func() { + tokens := GenPresetRandomTokens(len) + Expect(tokens).Should(HaveLen(len)) + }) + } + }) + + Context("IsValidText", func() { + validTxts := make([]string, 0) + invalidTxts := make([]string, 0) + + validTxts = append(validTxts, chatCompletionFakeResponses[0][:4]) + validTxts = append(validTxts, chatCompletionFakeResponses[1]) + validTxts = append(validTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2]) + + invalidTxts = append(invalidTxts, (chatCompletionFakeResponses[1] + " " + chatCompletionFakeResponses[2])[3:4]) + invalidTxts = append(invalidTxts, chatCompletionFakeResponses[0][4:]) + invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+"-"+chatCompletionFakeResponses[2]) + invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" ") + invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2]) + + for _, txt := range validTxts { + It("text should be valid", func() { + Expect(IsValidText(txt)).To(BeTrue()) + }) + } + + for _, txt := range invalidTxts { + It("text should be invalid", func() { + Expect(IsValidText(txt)).To(BeFalse()) + }) + } + }) + + Context("validateBucketsBoundaries", func() { + type bucketBoundaries struct { + start int + end int + } + type bucketTest struct { + maxTokens int + expectedBuckets []bucketBoundaries + } + + tests := []bucketTest{{500, []bucketBoundaries{{1, 20}, {21, 40}, {41, 60}, {61, 480}, {481, 499}}}, + {47, []bucketBoundaries{{1, 9}, {10, 18}, {19, 27}, {28, 36}, {37, 46}}}, + {50, []bucketBoundaries{{1, 9}, {10, 19}, {20, 29}, {30, 39}, {40, 49}}}} + + for _, test := range tests { + Expect(test.expectedBuckets).To(HaveLen(len(cumulativeBucketsProbabilities) - 1)) + + It(fmt.Sprintf("should return bucket boundaries for maxTokens %d", test.maxTokens), func() { + for i := range len(cumulativeBucketsProbabilities) - 1 { + start, end := calcBucketBoundaries(test.maxTokens, i) + Expect(start).To(Equal(test.expectedBuckets[i].start)) + Expect(end).To(Equal(test.expectedBuckets[i].end)) + } + }) + } + }) +}) diff --git a/pkg/common/test_helpers.go b/pkg/dataset/test_helpers.go similarity index 98% rename from pkg/common/test_helpers.go rename to pkg/dataset/test_helpers.go index 31ff4bd5..ed6b6f1b 100644 --- a/pkg/common/test_helpers.go +++ b/pkg/dataset/test_helpers.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package common +package dataset import "strings" diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 4207cdcd..4b098cc4 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -32,6 +32,7 @@ import ( "k8s.io/klog/v2" "github.com/llm-d/llm-d-inference-sim/pkg/common" + "github.com/llm-d/llm-d-inference-sim/pkg/dataset" kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api" @@ -116,7 +117,7 @@ type VllmSimulator struct { // tokenizer is currently used in kv-cache and in /tokenize tokenizer tokenization.Tokenizer // dataset is used for managing dataset files - dataset *common.Dataset + dataset *dataset.Dataset } // New creates a new VllmSimulator instance with the given logger @@ -219,7 +220,7 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { s.dataset = nil s.logger.Info("No dataset provided, will generate random responses") } else { - dataset := &common.Dataset{ + dataset := &dataset.Dataset{ Logger: s.logger, } err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath) @@ -332,13 +333,15 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { if reqCtx.IsChatCompletion && req.GetToolChoice() != openaiserverapi.ToolChoiceNone && req.GetTools() != nil { - toolCalls, finishReason, completionTokens, err = + toolCalls, completionTokens, err = openaiserverapi.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config) + finishReason = dataset.ToolsFinishReason } if toolCalls == nil && err == nil { // Either no tool calls were defined, or we randomly chose not to create tool calls, // so we generate a response text. - responseTokens, finishReason, completionTokens, err = s.generateTokens(req) + responseTokens, finishReason, err = s.dataset.GetTokens(req, s.config.Mode) + completionTokens += len(responseTokens) } if err != nil { prefix := "" @@ -374,7 +377,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { } else { if req.IsDoRemoteDecode() { // in case this is prefill pod processing, return special finish reason - finishReason = common.RemoteDecodeFinishReason + finishReason = dataset.RemoteDecodeFinishReason } s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData) diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index e504c5d5..e2df226f 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -28,6 +28,7 @@ import ( "strings" "github.com/llm-d/llm-d-inference-sim/pkg/common" + "github.com/llm-d/llm-d-inference-sim/pkg/dataset" kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" "github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization" . "github.com/onsi/ginkgo/v2" @@ -190,7 +191,7 @@ var _ = Describe("Simulator", func() { msg := strings.Join(tokens, "") if mode == common.ModeRandom { // in case of random mode ensure that the returned message could be output of the random text generator - Expect(common.IsValidText(msg)).To(BeTrue()) + Expect(dataset.IsValidText(msg)).To(BeTrue()) } else { // in case of echo mode check that the text is returned as-is Expect(msg).Should(Equal(userMessage)) @@ -239,7 +240,7 @@ var _ = Describe("Simulator", func() { text := strings.Join(tokens, "") if mode == common.ModeRandom { // in case of random mode ensure that the returned message could be output of the random text generator - Expect(common.IsValidText(text)).To(BeTrue()) + Expect(dataset.IsValidText(text)).To(BeTrue()) } else { // in case of echo mode check that the text is returned as-is Expect(text).Should(Equal(userMessage)) @@ -300,7 +301,7 @@ var _ = Describe("Simulator", func() { } else { if mode == common.ModeRandom { // in case of random mode ensure that the returned message could be output of the random text generator - Expect(common.IsValidText(msg)).To(BeTrue()) + Expect(dataset.IsValidText(msg)).To(BeTrue()) } else { // in case of echo mode check that the text is returned as-is Expect(msg).Should(Equal(userMessage)) @@ -371,7 +372,7 @@ var _ = Describe("Simulator", func() { } else { if mode == common.ModeRandom { // in case of random mode ensure that the returned message could be output of the random text generator - Expect(common.IsValidText(text)).To(BeTrue()) + Expect(dataset.IsValidText(text)).To(BeTrue()) } else { // in case of echo mode check that the text is returned as-is Expect(text).Should(Equal(userMessage)) diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index 2508298d..c64affc8 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -23,6 +23,7 @@ import ( "time" "github.com/llm-d/llm-d-inference-sim/pkg/common" + "github.com/llm-d/llm-d-inference-sim/pkg/dataset" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" "github.com/valyala/fasthttp" ) @@ -124,7 +125,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ var chunk openaiserverapi.CompletionRespChunk var finishReasonToSend *string - if i == len(genTokens)-1 && (finishReason == common.LengthFinishReason || finishReason == common.ToolsFinishReason) { + if i == len(genTokens)-1 && (finishReason == dataset.LengthFinishReason || finishReason == dataset.ToolsFinishReason) { finishReasonToSend = &finishReason } if context.isChatCompletion { @@ -141,7 +142,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ // send the last chunk if finish reason is stop var chunk openaiserverapi.CompletionRespChunk - if finishReason == common.StopFinishReason { + if finishReason == dataset.StopFinishReason { if context.isChatCompletion { chunk = s.createChatCompletionChunk(context, "", nil, "", &finishReason) } else { diff --git a/pkg/llm-d-inference-sim/tools_test.go b/pkg/llm-d-inference-sim/tools_test.go index ae22a7f6..bffb7eea 100644 --- a/pkg/llm-d-inference-sim/tools_test.go +++ b/pkg/llm-d-inference-sim/tools_test.go @@ -24,6 +24,7 @@ import ( "strings" "github.com/llm-d/llm-d-inference-sim/pkg/common" + "github.com/llm-d/llm-d-inference-sim/pkg/dataset" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/openai/openai-go" @@ -365,7 +366,7 @@ var _ = Describe("Simulator for request with tools", func() { for _, choice := range chunk.Choices { if choice.Delta.Role != "" { role = choice.Delta.Role - } else if choice.FinishReason == "" || choice.FinishReason == common.ToolsFinishReason { + } else if choice.FinishReason == "" || choice.FinishReason == dataset.ToolsFinishReason { toolCalls := choice.Delta.ToolCalls Expect(toolCalls).To(HaveLen(1)) tc := toolCalls[0] diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index a7dcdb63..6ed0814a 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -67,8 +67,8 @@ type CompletionRequest interface { IsDoRemotePrefill() bool } -// baseCompletionRequest contains base completion request related information -type baseCompletionRequest struct { +// BaseCompletionRequest contains base completion request related information +type BaseCompletionRequest struct { // RequestID is the unique id of this request RequestID string // Stream is a boolean value, defines whether response should be sent as a Stream @@ -101,44 +101,44 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage"` } -func (b *baseCompletionRequest) GetRequestID() string { +func (b *BaseCompletionRequest) GetRequestID() string { return b.RequestID } -func (b *baseCompletionRequest) IsStream() bool { +func (b *BaseCompletionRequest) IsStream() bool { return b.Stream } -func (b *baseCompletionRequest) GetModel() string { +func (b *BaseCompletionRequest) GetModel() string { return b.Model } -func (b *baseCompletionRequest) IncludeUsage() bool { +func (b *BaseCompletionRequest) IncludeUsage() bool { return !b.Stream || b.StreamOptions.IncludeUsage } -func (b *baseCompletionRequest) IsDoRemoteDecode() bool { +func (b *BaseCompletionRequest) IsDoRemoteDecode() bool { return b.DoRemoteDecode } -func (b *baseCompletionRequest) IsDoRemotePrefill() bool { +func (b *BaseCompletionRequest) IsDoRemotePrefill() bool { return b.DoRemotePrefill } // GetNumberOfCachedPromptTokens returns the number of tokens in the prompt that are // in the local KV Cache -func (b *baseCompletionRequest) GetNumberOfCachedPromptTokens() int { +func (b *BaseCompletionRequest) GetNumberOfCachedPromptTokens() int { return b.cachedPromptTokens } // GetIgnoreEOS returns the value of IgnoreEOS -func (b *baseCompletionRequest) GetIgnoreEOS() bool { +func (b *BaseCompletionRequest) GetIgnoreEOS() bool { return b.IgnoreEOS } // SetNumberOfCachedPromptTokens sets the number of tokens in the prompt that are // in the local KV Cache -func (b *baseCompletionRequest) SetNumberOfCachedPromptTokens(cachedPromptTokens int) { +func (b *BaseCompletionRequest) SetNumberOfCachedPromptTokens(cachedPromptTokens int) { b.cachedPromptTokens = cachedPromptTokens } @@ -153,7 +153,7 @@ type CompletionReqCtx struct { // ChatCompletionRequest defines structure of /chat/completion request type ChatCompletionRequest struct { - baseCompletionRequest + BaseCompletionRequest // Messages list of request's Messages Messages []Message `json:"messages"` @@ -239,7 +239,7 @@ func (req *ChatCompletionRequest) GetLastUserMsg() string { // v1/completion // TextCompletionRequest defines structure of /completion request type TextCompletionRequest struct { - baseCompletionRequest + BaseCompletionRequest // Prompt defines request's content Prompt string `json:"prompt"` diff --git a/pkg/openai-server-api/tools_utils.go b/pkg/openai-server-api/tools_utils.go index 3546aa9d..58f3a0df 100644 --- a/pkg/openai-server-api/tools_utils.go +++ b/pkg/openai-server-api/tools_utils.go @@ -55,7 +55,7 @@ var fakeStringArguments = []string{ // CreateToolCalls creates and returns response payload based on this request // (tool calls or nothing in case we randomly choose not to generate calls), // and the number of generated completion token sand the finish reason -func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configuration) ([]ToolCall, string, int, error) { +func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configuration) ([]ToolCall, int, error) { // This function is called if tool choice is either 'required' or 'auto'. // In case of 'required' at least one tool call has to be created, and we randomly choose // the number of calls starting from one. Otherwise, we start from 0, and in case we randomly @@ -66,7 +66,7 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati } numberOfCalls := common.RandomInt(min, len(tools)) if numberOfCalls == 0 { - return nil, "", 0, nil + return nil, 0, nil } calls := make([]ToolCall, 0) @@ -75,11 +75,11 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati index := common.RandomInt(0, len(tools)-1) args, err := GenerateToolArguments(tools[index], config) if err != nil { - return nil, "", 0, err + return nil, 0, err } argsJson, err := json.Marshal(args) if err != nil { - return nil, "", 0, err + return nil, 0, err } call := ToolCall{ @@ -95,7 +95,7 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati calls = append(calls, call) } - return calls, common.ToolsFinishReason, CountTokensForToolCalls(calls), nil + return calls, CountTokensForToolCalls(calls), nil } func GetRequiredAsMap(property map[string]any) map[string]struct{} { From 39a9d24216dd387f4e30f05131b1ee493e39f13c Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Mon, 15 Sep 2025 15:48:13 +1000 Subject: [PATCH 13/34] fix dataset tests Signed-off-by: Qifan Deng --- pkg/dataset/custom_dataset.go | 27 ++++-- pkg/dataset/custom_dataset_test.go | 4 +- pkg/dataset/dataset.go | 32 +++++-- pkg/dataset/dataset_suite_test.go | 13 +++ pkg/dataset/dataset_test.go | 19 ++-- pkg/llm-d-inference-sim/simulator.go | 106 +++++++++++++++++++--- pkg/llm-d-inference-sim/simulator_test.go | 5 + 7 files changed, 164 insertions(+), 42 deletions(-) create mode 100644 pkg/dataset/dataset_suite_test.go diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index 1ec89615..36119450 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -28,16 +28,18 @@ import ( "os" "os/signal" "path/filepath" + "strconv" "syscall" "time" "github.com/go-logr/logr" "github.com/google/uuid" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" _ "github.com/mattn/go-sqlite3" ) type CustomDataset struct { - Dataset + BaseDataset db *sql.DB hasWarned bool } @@ -53,7 +55,7 @@ const ( nGenTokensColType = "INTEGER" ) -func (d CustomDataset) downloadDataset(url string, savePath string) error { +func (d *CustomDataset) downloadDataset(url string, savePath string) error { // Set up signal handling for Ctrl+C (SIGINT) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -184,7 +186,7 @@ func (pr *progressReader) logProgress(pct int) { } } -func (d CustomDataset) verifyDB() error { +func (d *CustomDataset) verifyDB() error { rows, err := d.db.Query("PRAGMA table_info(" + tableName + ");") if err != nil { return fmt.Errorf("failed to query table info for `%s`: %w", tableName, err) @@ -234,7 +236,7 @@ func (d CustomDataset) verifyDB() error { return nil } -func (d CustomDataset) getRecordsCount() (int, error) { +func (d *CustomDataset) getRecordsCount() (int, error) { var count int err := d.db.QueryRow("SELECT COUNT(" + promptHashCol + ") FROM " + tableName + ";").Scan(&count) if err != nil { @@ -243,7 +245,7 @@ func (d CustomDataset) getRecordsCount() (int, error) { return count, nil } -func (d CustomDataset) connectToDB(path string) error { +func (d *CustomDataset) connectToDB(path string) error { if d.db != nil { err := d.db.Close() if err != nil { @@ -277,7 +279,7 @@ func (d CustomDataset) connectToDB(path string) error { return nil } -func (d CustomDataset) Init(path string, url string, savePath string) error { +func (d *CustomDataset) Init(path string, url string, savePath string) error { d.hasWarned = false if path != "" { return d.connectToDB(path) @@ -312,7 +314,7 @@ func (d CustomDataset) Init(path string, url string, savePath string) error { return errors.New("no dataset path or url provided") } -func (d CustomDataset) Close() error { +func (d *CustomDataset) Close() error { if d.db != nil { return d.db.Close() } @@ -336,11 +338,11 @@ func unmarshalAllRecords(rows *sql.Rows) ([][]string, error) { return tokensList, nil } -func (d CustomDataset) getRandomTokens(n_gen_tokens int) []string { - return nil +func (d *CustomDataset) getRandomTokens(n_gen_tokens int) []string { + return []string{"<|random_tokens|>", strconv.Itoa(n_gen_tokens)} } -func (d *CustomDataset) GetTokens(prompt string, n_gen_tokens int) []string { +func (d *CustomDataset) readTokensFromDB(prompt string, n_gen_tokens int) []string { promptHash := uuid.NewSHA1(uuid.NameSpaceOID, []byte(prompt)).NodeID() rows, err := d.db.Query("SELECT "+genTokensCol+" FROM "+tableName+" WHERE "+promptHashCol+" = ?;", promptHash) if err != nil { @@ -369,3 +371,8 @@ func (d *CustomDataset) GetTokens(prompt string, n_gen_tokens int) []string { randIndex := rand.Intn(len(tokensList)) return tokensList[randIndex] } + +func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) { + tokens := d.readTokensFromDB("", nTokens) + return tokens, nil +} diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go index a43656de..9d61747e 100644 --- a/pkg/dataset/custom_dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -27,7 +27,7 @@ import ( _ "github.com/mattn/go-sqlite3" ) -var _ = Describe("Dataset", func() { +var _ = Describe("CustomDataset", func() { var ( dataset *CustomDataset file_folder string @@ -42,7 +42,7 @@ var _ = Describe("Dataset", func() { BeforeEach(func() { dataset = &CustomDataset{ - Dataset: Dataset{ + BaseDataset: BaseDataset{ Logger: logr.Discard(), }, } diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go index 7f293f74..a221320a 100644 --- a/pkg/dataset/dataset.go +++ b/pkg/dataset/dataset.go @@ -68,6 +68,15 @@ var chatCompletionFakeResponses = []string{ `Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`, } +type Dataset interface { + // Init initializes the dataset using configs + Init(path string, url string, savePath string) error + // Close closes the dataset + Close() error + // GetTokens returns tokens for the given request and mode (echo or random) + GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) +} + func init() { cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities)) sum := 0.0 @@ -267,19 +276,20 @@ func EchoResponseTokens(maxCompletionTokens *int64, text string) ([]string, stri return tokens[0:*maxCompletionTokens], LengthFinishReason } -type Dataset struct { +type BaseDataset struct { Logger logr.Logger } -func (d *Dataset) Init(path string, url string, savePath string) error { +func (d *BaseDataset) Init(path string, url string, savePath string) error { return nil } -func (d *Dataset) Close() error { +func (d *BaseDataset) Close() error { return nil } -func (d *Dataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) { +// GetTokens returns tokens and finishReason for the given request and mode (echo or random) +func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) { nMaxTokens := d.extractMaxTokens(req) if mode == common.ModeEcho { prompt, err := d.extractPrompt(req) @@ -295,7 +305,10 @@ func (d *Dataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) return tokens, finishReason, err } -func (d *Dataset) extractMaxTokens(req openaiserverapi.CompletionRequest) *int64 { +// extractMaxTokens extracts the max tokens from the request +// for chat completion - max_completion_tokens field is used +// for text completion - max_tokens field is used +func (d *BaseDataset) extractMaxTokens(req openaiserverapi.CompletionRequest) *int64 { if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok { return chatReq.GetMaxCompletionTokens() } else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok { @@ -304,7 +317,10 @@ func (d *Dataset) extractMaxTokens(req openaiserverapi.CompletionRequest) *int64 return nil } -func (d *Dataset) extractPrompt(req openaiserverapi.CompletionRequest) (string, error) { +// extractPrompt extracts the prompt from the request +// for chat completion - the last user message is used as the prompt +// for text completion - the prompt field is used +func (d *BaseDataset) extractPrompt(req openaiserverapi.CompletionRequest) (string, error) { if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok { return chatReq.GetLastUserMsg(), nil } else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok { @@ -313,7 +329,9 @@ func (d *Dataset) extractPrompt(req openaiserverapi.CompletionRequest) (string, return "", errors.New("unknown request type") } -func (d *Dataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) { +// GenerateTokens generates random tokens for the required number of tokens +// other dataset types should override this function +func (d *BaseDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) { tokens := GenPresetRandomTokens(nTokens) return tokens, nil } diff --git a/pkg/dataset/dataset_suite_test.go b/pkg/dataset/dataset_suite_test.go new file mode 100644 index 00000000..c9dea52c --- /dev/null +++ b/pkg/dataset/dataset_suite_test.go @@ -0,0 +1,13 @@ +package dataset_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestDataset(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Dataset Suite") +} diff --git a/pkg/dataset/dataset_test.go b/pkg/dataset/dataset_test.go index 37a8abba..f890ff9f 100644 --- a/pkg/dataset/dataset_test.go +++ b/pkg/dataset/dataset_test.go @@ -19,6 +19,7 @@ package dataset import ( "fmt" "strings" + "time" "github.com/go-logr/logr" "github.com/llm-d/llm-d-inference-sim/pkg/common" @@ -27,30 +28,31 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Utils", Ordered, func() { +var _ = Describe("Dataset", Ordered, func() { var ( - dataset *Dataset + dataset *BaseDataset ) + BeforeAll(func() { + common.InitRandom(time.Now().UnixNano()) + }) BeforeEach(func() { - dataset = &Dataset{ + dataset = &BaseDataset{ Logger: logr.Discard(), } }) Context("GetRandomTokens", func() { + It("should return complete text", func() { - var n int64 - req := &openaiserverapi.ChatCompletionRequest{ - MaxTokens: &n, - MaxCompletionTokens: &n, - } + req := &openaiserverapi.ChatCompletionRequest{} tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) Expect(err).ShouldNot(HaveOccurred()) text := strings.Join(tokens, "") Expect(IsValidText(text)).To(BeTrue()) Expect(finishReason).Should(Equal(StopFinishReason)) }) + It("should return short text", func() { maxCompletionTokens := int64(2) req := &openaiserverapi.ChatCompletionRequest{ @@ -67,6 +69,7 @@ var _ = Describe("Utils", Ordered, func() { Expect(finishReason).To(Equal(StopFinishReason)) } }) + It("should return long text", func() { // return required number of tokens although it is higher than ResponseLenMax maxCompletionTokens := int64(ResponseLenMax * 5) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 4b098cc4..91da1b42 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -116,8 +116,8 @@ type VllmSimulator struct { pod string // tokenizer is currently used in kv-cache and in /tokenize tokenizer tokenization.Tokenizer - // dataset is used for managing dataset files - dataset *dataset.Dataset + // dataset is used for token generation in responses + dataset dataset.Dataset } // New creates a new VllmSimulator instance with the given logger @@ -216,18 +216,9 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { go s.kvcacheHelper.Run(ctx) } - if s.config.Dataset.Path == "" && s.config.Dataset.Url == "" && s.config.Dataset.SavePath == "" { - s.dataset = nil - s.logger.Info("No dataset provided, will generate random responses") - } else { - dataset := &dataset.Dataset{ - Logger: s.logger, - } - err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath) - if err != nil { - return err - } - s.dataset = dataset + err = s.initDataset() + if err != nil { + return fmt.Errorf("dataset initialization error: %w", err) } // run request processing workers @@ -239,13 +230,98 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { listener, err := s.newListener() if err != nil { - return err + s.logger.Error(err, "Failed to create listener") + return fmt.Errorf("listener creation error: %w", err) } // start the http server with context support return s.startServer(ctx, listener) } +func (s *VllmSimulator) initDataset() error { + randDataset := &dataset.BaseDataset{ + Logger: s.logger, + } + + if s.config.Dataset.Path == "" && s.config.Dataset.Url == "" && s.config.Dataset.SavePath == "" { + s.logger.Info("No dataset provided, will generate random responses") + s.dataset = randDataset + } else { + s.logger.Info("Custom dataset configuration detected") + s.dataset = &dataset.CustomDataset{ + BaseDataset: *randDataset, + } + } + + if err := s.dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath); err != nil { + return fmt.Errorf("dataset initialization error: %w", err) + } + return nil +} + +func (s *VllmSimulator) newListener() (net.Listener, error) { + s.logger.Info("Server starting", "port", s.config.Port) + listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", s.config.Port)) + if err != nil { + return nil, err + } + return listener, nil +} + +// startServer starts http server on port defined in command line +func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) error { + r := fasthttprouter.New() + + // support completion APIs + r.POST("/v1/chat/completions", s.HandleChatCompletions) + r.POST("/v1/completions", s.HandleTextCompletions) + // supports /models API + r.GET("/v1/models", s.HandleModels) + // support load/unload of lora adapter + r.POST("/v1/load_lora_adapter", s.HandleLoadLora) + r.POST("/v1/unload_lora_adapter", s.HandleUnloadLora) + // supports /metrics prometheus API + r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.HandlerFor(s.registry, promhttp.HandlerOpts{}))) + // supports standard Kubernetes health and readiness checks + r.GET("/health", s.HandleHealth) + r.GET("/ready", s.HandleReady) + r.POST("/tokenize", s.HandleTokenize) + + server := fasthttp.Server{ + ErrorHandler: s.HandleError, + Handler: r.Handler, + Logger: s, + } + + // Start server in a goroutine + serverErr := make(chan error, 1) + go func() { + s.logger.Info("HTTP server starting") + serverErr <- server.Serve(listener) + }() + + // Wait for either context cancellation or server error + select { + case <-ctx.Done(): + s.logger.Info("Shutdown signal received, shutting down HTTP server gracefully") + + // Gracefully shutdown the server + if err := server.Shutdown(); err != nil { + s.logger.Error(err, "Error during server shutdown") + return err + } + + s.logger.Info("HTTP server stopped") + return nil + + case err := <-serverErr: + if err != nil { + s.logger.Error(err, "HTTP server failed") + } + return err + } +} + // Print prints to a log, implementation of fasthttp.Logger func (s *VllmSimulator) Printf(format string, args ...interface{}) { s.logger.Info("Server error", "msg", fmt.Sprintf(format, args...)) diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index e2df226f..9326f8a9 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -118,6 +118,11 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m go s.kvcacheHelper.Run(ctx) } + err = s.initDataset() + if err != nil { + return nil, fmt.Errorf("dataset initialization error: %w", err) + } + // calculate number of tokens for user message, // must be activated after parseCommandParamsAndLoadConfig since it initializes the random engine userMsgTokens = int64(len(common.Tokenize(userMessage))) From 6cf6effdfbe7ad553d9c49d52be862ae233f16e3 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Mon, 15 Sep 2025 19:21:03 +1000 Subject: [PATCH 14/34] add tests for custom dataset Signed-off-by: Qifan Deng --- pkg/dataset/.llm-d/test.valid.sqlite3 | Bin 12288 -> 12288 bytes pkg/dataset/custom_dataset.go | 44 ++++++++++++++-------- pkg/dataset/custom_dataset_test.go | 51 ++++++++++++++++++++++++-- pkg/dataset/dataset.go | 31 +++++++--------- pkg/dataset/dataset_test.go | 2 +- pkg/llm-d-inference-sim/simulator.go | 2 +- pkg/openai-server-api/request.go | 21 +++++++++++ 7 files changed, 111 insertions(+), 40 deletions(-) diff --git a/pkg/dataset/.llm-d/test.valid.sqlite3 b/pkg/dataset/.llm-d/test.valid.sqlite3 index 3cac3f9a0fdfe6d5e8db04cace5cce0cddd0463f..847a6257842080335691a3de3943f0152a9382e1 100644 GIT binary patch delta 172 zcmZojXh@hK&B!-V#+i|CW5N=C4krE@2L40*HJb$m3iy3om{@$g8B6wy9GJu0tay6y zbn#mc?-Xk&uk+g5CmI=PR;Ic5Am@r`C6CmcoO~r6C54=vT-_7}AgesTC?^F-D=NjZ r*f6sASTjz3psz6bg}e;VitP;i+xhPRtys;ks>19;&;q8(cl1>N-Zwg4 delta 66 zcmZojXh@hK&B!}Z#+i|KW5N=CHb(wK4E%>S3o0z*pV%P8X2ZzhW6e1Efxg1z7xFSd R**gsUcc8L5lke!O004|J6kPxS diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index 36119450..22da7a36 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -18,7 +18,9 @@ package dataset import ( "context" + "crypto/sha256" "database/sql" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -28,12 +30,11 @@ import ( "os" "os/signal" "path/filepath" - "strconv" "syscall" "time" "github.com/go-logr/logr" - "github.com/google/uuid" + "github.com/llm-d/llm-d-inference-sim/pkg/common" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" _ "github.com/mattn/go-sqlite3" ) @@ -338,19 +339,35 @@ func unmarshalAllRecords(rows *sql.Rows) ([][]string, error) { return tokensList, nil } -func (d *CustomDataset) getRandomTokens(n_gen_tokens int) []string { - return []string{"<|random_tokens|>", strconv.Itoa(n_gen_tokens)} +func (d *CustomDataset) GetPromptHash(req openaiserverapi.CompletionRequest) []byte { + hashArray := sha256.Sum256([]byte(req.GetFullPrompt())) + return hashArray[:] } -func (d *CustomDataset) readTokensFromDB(prompt string, n_gen_tokens int) []string { - promptHash := uuid.NewSHA1(uuid.NameSpaceOID, []byte(prompt)).NodeID() - rows, err := d.db.Query("SELECT "+genTokensCol+" FROM "+tableName+" WHERE "+promptHashCol+" = ?;", promptHash) +func (d *CustomDataset) GetPromptHashHex(hashBytes []byte) string { + return hex.EncodeToString(hashBytes) +} + +// GetTokens returns tokens and finishReason for the given request and mode (echo or random) +func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) { + if mode == common.ModeEcho { + return d.echo(req) + } + nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS()) + tokens, err := d.GenerateTokens(req, nTokensToGen) + return tokens, finishReason, err +} + +func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) { + promptHash := d.GetPromptHash(req) + promptHashHex := d.GetPromptHashHex(promptHash) + rows, err := d.db.Query("SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';") if err != nil { if !d.hasWarned { d.Logger.Error(err, "failed to query database. Ensure the prompt hash exists in the dataset. Will generate random tokens instead.") d.hasWarned = true } - return d.getRandomTokens(n_gen_tokens) + return GenPresetRandomTokens(nTokens), nil } defer func() { if cerr := rows.Close(); cerr != nil { @@ -361,18 +378,13 @@ func (d *CustomDataset) readTokensFromDB(prompt string, n_gen_tokens int) []stri tokensList, err := unmarshalAllRecords(rows) if err != nil { d.Logger.Error(err, "failed to unmarshal records from database") - return d.getRandomTokens(n_gen_tokens) + return GenPresetRandomTokens(nTokens), nil } if len(tokensList) == 0 { - return d.getRandomTokens(n_gen_tokens) + return GenPresetRandomTokens(nTokens), nil } d.hasWarned = false randIndex := rand.Intn(len(tokensList)) - return tokensList[randIndex] -} - -func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) { - tokens := d.readTokensFromDB("", nTokens) - return tokens, nil + return tokensList[randIndex], nil } diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go index 9d61747e..c0a6daeb 100644 --- a/pkg/dataset/custom_dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -21,12 +21,18 @@ import ( "os" "github.com/go-logr/logr" + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" _ "github.com/mattn/go-sqlite3" ) +const ( + testPrompt = "Hello world!" +) + var _ = Describe("CustomDataset", func() { var ( dataset *CustomDataset @@ -90,20 +96,20 @@ var _ = Describe("CustomDataset", func() { err := dataset.Init(validDBPath, "", "") Expect(err).NotTo(HaveOccurred()) - row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2';") + row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") var n_gen_tokens int err = row.Scan(&n_gen_tokens) Expect(err).NotTo(HaveOccurred()) - Expect(n_gen_tokens).To(Equal(3)) + Expect(n_gen_tokens).To(Equal(4)) var jsonStr string - row = dataset.db.QueryRow("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2';") + row = dataset.db.QueryRow("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") err = row.Scan(&jsonStr) Expect(err).NotTo(HaveOccurred()) var tokens []string err = json.Unmarshal([]byte(jsonStr), &tokens) Expect(err).NotTo(HaveOccurred()) - Expect(tokens).To(Equal([]string{"Hello", "world", "!"})) + Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"})) }) @@ -136,4 +142,41 @@ var _ = Describe("CustomDataset", func() { Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("incorrect type")) }) + + It("should return correct prompt hash in bytes", func() { + // b't\xbf\x14\xc0\x9c\x03\x83!\xcb\xa3\x97\x17\xda\xe1\xdcs(#\xaeJ\xbd\x8e\x15YY6v)\xa3\xc1\t\xa8' + expectedHashBytes := []byte{0x74, 0xbf, 0x14, 0xc0, 0x9c, 0x03, 0x83, 0x21, 0xcb, 0xa3, 0x97, 0x17, 0xda, 0xe1, 0xdc, 0x73, 0x28, 0x23, 0xae, 0x4a, 0xbd, 0x8e, 0x15, 0x59, 0x59, 0x36, 0x76, 0x29, 0xa3, 0xc1, 0x09, 0xa8} + + req := &openaiserverapi.TextCompletionRequest{ + Prompt: testPrompt, + } + + hashBytes := dataset.GetPromptHash(req) + Expect(hashBytes).To(Equal(expectedHashBytes)) + }) + + It("should return correct prompt hash in hex", func() { + expectedHashHex := "74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8" + + req := &openaiserverapi.TextCompletionRequest{ + Prompt: testPrompt, + } + + hashBytes := dataset.GetPromptHash(req) + hashHex := dataset.GetPromptHashHex(hashBytes) + Expect(hashHex).To(Equal(expectedHashHex)) + }) + + It("should return tokens for existing prompt", func() { + err := dataset.Init(validDBPath, "", "") + Expect(err).NotTo(HaveOccurred()) + + req := &openaiserverapi.TextCompletionRequest{ + Prompt: testPrompt, + } + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) + Expect(finishReason).To(Equal(StopFinishReason)) + Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"})) + }) }) diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go index a221320a..6d633171 100644 --- a/pkg/dataset/dataset.go +++ b/pkg/dataset/dataset.go @@ -288,21 +288,23 @@ func (d *BaseDataset) Close() error { return nil } +func (d *BaseDataset) echo(req openaiserverapi.CompletionRequest) ([]string, string, error) { + nMaxTokens := d.extractMaxTokens(req) + prompt, err := d.extractPrompt(req) + if err != nil { + return nil, "", err + } + tokens, finishReason := EchoResponseTokens(nMaxTokens, prompt) + return tokens, finishReason, nil +} + // GetTokens returns tokens and finishReason for the given request and mode (echo or random) func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) { - nMaxTokens := d.extractMaxTokens(req) if mode == common.ModeEcho { - prompt, err := d.extractPrompt(req) - if err != nil { - return nil, "", err - } - tokens, finishReason := EchoResponseTokens(nMaxTokens, prompt) - return tokens, finishReason, nil + return d.echo(req) } - - nTokensToGen, finishReason := howManyTokensToGen(nMaxTokens, req.GetIgnoreEOS()) - tokens, err := d.GenerateTokens(req, nTokensToGen) - return tokens, finishReason, err + nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS()) + return GenPresetRandomTokens(nTokensToGen), finishReason, nil } // extractMaxTokens extracts the max tokens from the request @@ -328,10 +330,3 @@ func (d *BaseDataset) extractPrompt(req openaiserverapi.CompletionRequest) (stri } return "", errors.New("unknown request type") } - -// GenerateTokens generates random tokens for the required number of tokens -// other dataset types should override this function -func (d *BaseDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) { - tokens := GenPresetRandomTokens(nTokens) - return tokens, nil -} diff --git a/pkg/dataset/dataset_test.go b/pkg/dataset/dataset_test.go index f890ff9f..c6304941 100644 --- a/pkg/dataset/dataset_test.go +++ b/pkg/dataset/dataset_test.go @@ -69,7 +69,7 @@ var _ = Describe("Dataset", Ordered, func() { Expect(finishReason).To(Equal(StopFinishReason)) } }) - + It("should return long text", func() { // return required number of tokens although it is higher than ResponseLenMax maxCompletionTokens := int64(ResponseLenMax * 5) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 91da1b42..39fa8409 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -244,7 +244,7 @@ func (s *VllmSimulator) initDataset() error { } if s.config.Dataset.Path == "" && s.config.Dataset.Url == "" && s.config.Dataset.SavePath == "" { - s.logger.Info("No dataset provided, will generate random responses") + s.logger.Info("No dataset provided, will generate random responses from preset text") s.dataset = randDataset } else { s.logger.Info("Custom dataset configuration detected") diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index 6ed0814a..34db0ee6 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -65,6 +65,8 @@ type CompletionRequest interface { // when the field is true, the prefill phase should be done on remote pod, // whereas decode phase is done on local pod, thus this is a decode request IsDoRemotePrefill() bool + // GetFullPrompt returns the full prompt including system and user prompts + GetFullPrompt() string } // BaseCompletionRequest contains base completion request related information @@ -236,6 +238,21 @@ func (req *ChatCompletionRequest) GetLastUserMsg() string { return "" } +func (req *ChatCompletionRequest) GetFullPrompt() string { + prompt := "" + for _, msg := range req.Messages { + switch msg.Role { + case RoleUser: + prompt += "### user:\n" + msg.Content.Raw + "\n" + case RoleAssistant: + prompt += "### assistant:\n" + msg.Content.Raw + "\n" + default: + prompt += "### unknown:\n" + msg.Content.Raw + "\n" + } + } + return prompt +} + // v1/completion // TextCompletionRequest defines structure of /completion request type TextCompletionRequest struct { @@ -270,3 +287,7 @@ func (c *TextCompletionRequest) GetToolChoice() string { func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 { return c.MaxTokens } + +func (t *TextCompletionRequest) GetFullPrompt() string { + return "### user:\n" + t.Prompt + "\n" +} From 46076d68a6f011271608750f1691ec7d938ccb24 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Mon, 15 Sep 2025 21:07:46 +1000 Subject: [PATCH 15/34] fix custom dataset test case Signed-off-by: Qifan Deng --- pkg/dataset/custom_dataset.go | 3 +-- pkg/dataset/custom_dataset_test.go | 7 ++++++- pkg/dataset/dataset_test.go | 1 + 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index 22da7a36..eff93d52 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -25,7 +25,6 @@ import ( "errors" "fmt" "io" - "math/rand" "net/http" "os" "os/signal" @@ -385,6 +384,6 @@ func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nT return GenPresetRandomTokens(nTokens), nil } d.hasWarned = false - randIndex := rand.Intn(len(tokensList)) + randIndex := common.RandomInt(0, len(tokensList)-1) return tokensList[randIndex], nil } diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go index c0a6daeb..fbfea9dd 100644 --- a/pkg/dataset/custom_dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -19,6 +19,7 @@ package dataset import ( "encoding/json" "os" + "time" "github.com/go-logr/logr" "github.com/llm-d/llm-d-inference-sim/pkg/common" @@ -33,7 +34,7 @@ const ( testPrompt = "Hello world!" ) -var _ = Describe("CustomDataset", func() { +var _ = Describe("CustomDataset", Ordered, func() { var ( dataset *CustomDataset file_folder string @@ -46,6 +47,10 @@ var _ = Describe("CustomDataset", func() { pathToInvalidTypeDB string ) + BeforeAll(func() { + common.InitRandom(time.Now().UnixNano()) + }) + BeforeEach(func() { dataset = &CustomDataset{ BaseDataset: BaseDataset{ diff --git a/pkg/dataset/dataset_test.go b/pkg/dataset/dataset_test.go index c6304941..2e01463a 100644 --- a/pkg/dataset/dataset_test.go +++ b/pkg/dataset/dataset_test.go @@ -36,6 +36,7 @@ var _ = Describe("Dataset", Ordered, func() { BeforeAll(func() { common.InitRandom(time.Now().UnixNano()) }) + BeforeEach(func() { dataset = &BaseDataset{ Logger: logr.Discard(), From e1a555a1769c54fd00fe402bef8b1b46fe4ad726 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Mon, 15 Sep 2025 21:52:35 +1000 Subject: [PATCH 16/34] Remove unnecessary config Signed-off-by: Qifan Deng --- pkg/common/config.go | 5 ++-- pkg/dataset/custom_dataset.go | 34 +++++++++++++++------------- pkg/dataset/custom_dataset_test.go | 16 ++++++------- pkg/dataset/dataset.go | 4 ++-- pkg/llm-d-inference-sim/simulator.go | 4 ++-- 5 files changed, 32 insertions(+), 31 deletions(-) diff --git a/pkg/common/config.go b/pkg/common/config.go index 7d8f4ca7..79046c70 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -190,10 +190,9 @@ type DatasetConf struct { // when path is empty Url will be checked Path string `yaml:"path" json:"path"` // Url is the URL to download the sqlite db file if set, default is empty + // if Path is not provided and Url is provided, the file will be downloaded + // to "USER_HOME/.llm-d/dataset.db" Url string `yaml:"url" json:"url"` - // SavePath is the local path to save the downloaded sqlite db file - // if Url is set but SavePath is not, "USER_HOME/.llm-d/dataset.db" will be used - SavePath string `yaml:"save-path" json:"save-path"` } type Metrics struct { diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index eff93d52..90e4889e 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -55,7 +55,7 @@ const ( nGenTokensColType = "INTEGER" ) -func (d *CustomDataset) downloadDataset(url string, savePath string) error { +func (d *CustomDataset) downloadDataset(url string, path string) error { // Set up signal handling for Ctrl+C (SIGINT) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -70,7 +70,7 @@ func (d *CustomDataset) downloadDataset(url string, savePath string) error { cancel() }() - out, err := os.Create(savePath) + out, err := os.Create(path) if err != nil { return err } @@ -109,7 +109,7 @@ func (d *CustomDataset) downloadDataset(url string, savePath string) error { written, err := io.Copy(out, pr) if err != nil { // Remove incomplete file - cerr := os.Remove(savePath) + cerr := os.Remove(path) if cerr != nil { d.Logger.Error(cerr, "failed to remove incomplete file after download") } @@ -121,7 +121,7 @@ func (d *CustomDataset) downloadDataset(url string, savePath string) error { } // Check if file size is zero or suspiciously small if written == 0 { - cerr := os.Remove(savePath) + cerr := os.Remove(path) if cerr != nil { d.Logger.Error(cerr, "failed to remove empty file after download") } @@ -130,7 +130,7 @@ func (d *CustomDataset) downloadDataset(url string, savePath string) error { // Ensure file is fully flushed and closed before returning success if err := out.Sync(); err != nil { - cerr := os.Remove(savePath) + cerr := os.Remove(path) if cerr != nil { d.Logger.Error(cerr, "failed to remove incomplete file after download") } @@ -279,37 +279,39 @@ func (d *CustomDataset) connectToDB(path string) error { return nil } -func (d *CustomDataset) Init(path string, url string, savePath string) error { +func (d *CustomDataset) Init(path string, url string) error { d.hasWarned = false - if path != "" { + if path != "" && url == "" { + d.Logger.Info("Using dataset from", "path", path) return d.connectToDB(path) } if url != "" { - if savePath == "" { + d.Logger.Info("Url detected", "url", url) + if path == "" { user, err := os.UserHomeDir() if err != nil { return fmt.Errorf("failed to get user home directory: %w", err) } - savePath = filepath.Join(user, ".llm-d", "dataset.sqlite3") + path = filepath.Join(user, ".llm-d", "dataset.sqlite3") + d.Logger.Info("Using default for dataset", "path", path) + } else { + d.Logger.Info("Using provided path for dataset", "path", path) } - _, err := os.Stat(savePath) + _, err := os.Stat(path) if err != nil { // file does not exist, download it - folder := filepath.Dir(savePath) + folder := filepath.Dir(path) err := os.MkdirAll(folder, 0755) if err != nil { return fmt.Errorf("failed to create parent directory: %w", err) } - d.Logger.Info("Downloading dataset from URL", "url", url, "to", savePath) - err = d.downloadDataset(url, savePath) + err = d.downloadDataset(url, path) if err != nil { return fmt.Errorf("failed to download dataset: %w", err) } } - d.Logger.Info("Using dataset from", "path", savePath) - - return d.connectToDB(savePath) + return d.connectToDB(path) } return errors.New("no dataset path or url provided") } diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go index fbfea9dd..a1392016 100644 --- a/pkg/dataset/custom_dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -38,7 +38,7 @@ var _ = Describe("CustomDataset", Ordered, func() { var ( dataset *CustomDataset file_folder string - savePath string + path string validDBPath string pathToInvalidDB string pathNotExist string @@ -58,7 +58,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }, } file_folder = ".llm-d" - savePath = file_folder + "/test.sqlite3" + path = file_folder + "/test.sqlite3" err := os.MkdirAll(file_folder, os.ModePerm) Expect(err).NotTo(HaveOccurred()) validDBPath = file_folder + "/test.valid.sqlite3" @@ -83,22 +83,22 @@ var _ = Describe("CustomDataset", Ordered, func() { It("should download file from url", func() { url := "https://llm-d.ai" - err := dataset.downloadDataset(url, savePath) + err := dataset.downloadDataset(url, path) Expect(err).NotTo(HaveOccurred()) - _, err = os.Stat(savePath) + _, err = os.Stat(path) Expect(err).NotTo(HaveOccurred()) - err = os.Remove(savePath) + err = os.Remove(path) Expect(err).NotTo(HaveOccurred()) }) It("should not download file from url", func() { url := "https://256.256.256.256" // invalid url - err := dataset.downloadDataset(url, savePath) + err := dataset.downloadDataset(url, path) Expect(err).To(HaveOccurred()) }) It("should successfully init dataset", func() { - err := dataset.Init(validDBPath, "", "") + err := dataset.Init(validDBPath, "") Expect(err).NotTo(HaveOccurred()) row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") @@ -173,7 +173,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }) It("should return tokens for existing prompt", func() { - err := dataset.Init(validDBPath, "", "") + err := dataset.Init(validDBPath, "") Expect(err).NotTo(HaveOccurred()) req := &openaiserverapi.TextCompletionRequest{ diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go index 6d633171..cf90df49 100644 --- a/pkg/dataset/dataset.go +++ b/pkg/dataset/dataset.go @@ -70,7 +70,7 @@ var chatCompletionFakeResponses = []string{ type Dataset interface { // Init initializes the dataset using configs - Init(path string, url string, savePath string) error + Init(path string, url string) error // Close closes the dataset Close() error // GetTokens returns tokens for the given request and mode (echo or random) @@ -280,7 +280,7 @@ type BaseDataset struct { Logger logr.Logger } -func (d *BaseDataset) Init(path string, url string, savePath string) error { +func (d *BaseDataset) Init(path string, url string) error { return nil } diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 39fa8409..86dbd06b 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -243,7 +243,7 @@ func (s *VllmSimulator) initDataset() error { Logger: s.logger, } - if s.config.Dataset.Path == "" && s.config.Dataset.Url == "" && s.config.Dataset.SavePath == "" { + if s.config.Dataset.Path == "" && s.config.Dataset.Url == "" { s.logger.Info("No dataset provided, will generate random responses from preset text") s.dataset = randDataset } else { @@ -253,7 +253,7 @@ func (s *VllmSimulator) initDataset() error { } } - if err := s.dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath); err != nil { + if err := s.dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url); err != nil { return fmt.Errorf("dataset initialization error: %w", err) } return nil From 2356d8098e65ce254738899d8c03b367ce3ea740 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Mon, 15 Sep 2025 22:16:49 +1000 Subject: [PATCH 17/34] Add cli arg of dataset path and url, also update readme Signed-off-by: Qifan Deng --- README.md | 5 ++++- pkg/common/config.go | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8e63b793..8c4da70e 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,10 @@ For more details see the /.llm-d/dataset.sqlite3` will be used. If neither dataset-path nor dataset-url are set, response is randomly generated. See [llm-d converted ShareGPT](https://huggingface.co/datasets/llm-d/inference-sim-datasets/tree/980e326f222e3e7390eef9df02a4f5e77d2a6da0/huggingface/ShareGPT_Vicuna_unfiltered) for more details on the expected format of the sqlite db file. +- `dataset-url`: URL to download the sqlite db file for response generation from a dataset, optional, if set, the sqlite db file will be downloaded to the path specified by `dataset-path`. If the file already exists at that path, it will not be downloaded again. Example url: `https://huggingface.co/datasets/llm-d/inference-sim-datasets/resolve/980e326f222e3e7390eef9df02a4f5e77d2a6da0/huggingface/ShareGPT_Vicuna_unfiltered/conversations.sqlite3?download=true` +--- In addition, as we are using klog, the following parameters are available: - `add_dir_header`: if true, adds the file directory to the header of the log messages - `alsologtostderr`: log to standard error as well as files (no effect when -logtostderr=true) diff --git a/pkg/common/config.go b/pkg/common/config.go index 79046c70..b82a7d72 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -576,6 +576,9 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") f.IntVar(&config.DPSize, "data-parallel-size", config.DPSize, "Number of ranks to run") + f.StringVar(&config.Dataset.Path, "dataset-path", config.Dataset.Path, "Local path to the sqlite db file for response generation from a dataset") + f.StringVar(&config.Dataset.Url, "dataset-url", config.Dataset.Url, "URL to download the sqlite db file for response generation from a dataset") + f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures") failureTypes := getParamValueFromArgs("failure-types") var dummyFailureTypes multiString From 51c6f49283d3934225a3583cf53abcada5d39e9a Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Tue, 16 Sep 2025 11:45:54 +1000 Subject: [PATCH 18/34] Return random from dataset if prmopt hash does not hit Signed-off-by: Qifan Deng --- pkg/dataset/custom_dataset.go | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index 90e4889e..f0e9638a 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -29,6 +29,7 @@ import ( "os" "os/signal" "path/filepath" + "strconv" "syscall" "time" @@ -359,33 +360,43 @@ func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode st return tokens, finishReason, err } -func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) { - promptHash := d.GetPromptHash(req) - promptHashHex := d.GetPromptHashHex(promptHash) - rows, err := d.db.Query("SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';") +func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) { + rows, err := d.db.Query(query) if err != nil { if !d.hasWarned { - d.Logger.Error(err, "failed to query database. Ensure the prompt hash exists in the dataset. Will generate random tokens instead.") + d.Logger.Error(err, "Failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.") d.hasWarned = true } - return GenPresetRandomTokens(nTokens), nil + return [][]string{GenPresetRandomTokens(nTokens)}, nil } defer func() { if cerr := rows.Close(); cerr != nil { d.Logger.Error(cerr, "failed to close rows after query") } }() + return unmarshalAllRecords(rows) +} - tokensList, err := unmarshalAllRecords(rows) - if err != nil { - d.Logger.Error(err, "failed to unmarshal records from database") - return GenPresetRandomTokens(nTokens), nil +func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) { + // query by prompt hash first + promptHash := d.GetPromptHash(req) + promptHashHex := d.GetPromptHashHex(promptHash) + query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';" + tokensList, err := d.query(query, nTokens) + + if err != nil || len(tokensList) == 0 { + // if query by prompt hash fails, fallback to query by number of tokens + query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";" + tokensList, err = d.query(query, nTokens) } - if len(tokensList) == 0 { + if err != nil || len(tokensList) == 0 { + // if both queries fail or return no results, generate random tokens return GenPresetRandomTokens(nTokens), nil } - d.hasWarned = false + if d.hasWarned { + d.hasWarned = false + } randIndex := common.RandomInt(0, len(tokensList)-1) return tokensList[randIndex], nil } From 789a4f27337fff0a864375473ca34fbbb38902ed Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Tue, 16 Sep 2025 14:48:47 +1000 Subject: [PATCH 19/34] Respect maxTokens Signed-off-by: Qifan Deng --- pkg/dataset/custom_dataset.go | 31 ++++++++++++++++++++++++------ pkg/dataset/custom_dataset_test.go | 14 ++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index f0e9638a..1576460c 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -356,7 +356,7 @@ func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode st return d.echo(req) } nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS()) - tokens, err := d.GenerateTokens(req, nTokensToGen) + tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason) return tokens, finishReason, err } @@ -377,17 +377,36 @@ func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) { return unmarshalAllRecords(rows) } -func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) { +func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int, finishReason string) ([]string, error) { // query by prompt hash first promptHash := d.GetPromptHash(req) promptHashHex := d.GetPromptHashHex(promptHash) query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';" tokensList, err := d.query(query, nTokens) - if err != nil || len(tokensList) == 0 { - // if query by prompt hash fails, fallback to query by number of tokens - query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";" - tokensList, err = d.query(query, nTokens) + // filter out results according to finish reason + var filteredTokensList [][]string + if finishReason != LengthFinishReason && finishReason != StopFinishReason { + d.Logger.Error(errors.New("unknown finish reason"), "Unexpected finish reason", "reason", finishReason) + } + for _, tokens := range tokensList { + if finishReason == StopFinishReason && len(tokens) <= nTokens { + filteredTokensList = append(filteredTokensList, tokens) + } else if finishReason == LengthFinishReason && len(tokens) == nTokens { + filteredTokensList = append(filteredTokensList, tokens) + } + } + tokensList = filteredTokensList + + if err != nil || len(filteredTokensList) == 0 { + switch finishReason { + case LengthFinishReason: + query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";" + tokensList, err = d.query(query, nTokens) + case StopFinishReason: + query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "<=" + strconv.Itoa(nTokens) + ";" + tokensList, err = d.query(query, nTokens) + } } if err != nil || len(tokensList) == 0 { diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go index a1392016..120ad292 100644 --- a/pkg/dataset/custom_dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -184,4 +184,18 @@ var _ = Describe("CustomDataset", Ordered, func() { Expect(finishReason).To(Equal(StopFinishReason)) Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"})) }) + + It("should return at most 2 tokens for existing prompt", func() { + err := dataset.Init(validDBPath, "") + Expect(err).NotTo(HaveOccurred()) + n := int64(2) + req := &openaiserverapi.TextCompletionRequest{ + Prompt: testPrompt, + MaxTokens: &n, + } + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) + Expect(finishReason).To(Equal(LengthFinishReason)) + Expect(len(tokens)).To(BeNumerically("<=", 2)) + }) }) From 1baeb9ee5f05d8b67ebe875a98b187039b4586e5 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Tue, 16 Sep 2025 21:21:44 +1000 Subject: [PATCH 20/34] Resolve conflicts and fix test case Signed-off-by: Qifan Deng --- pkg/dataset/custom_dataset_test.go | 3 +- pkg/llm-d-inference-sim/simulator.go | 63 ---------------------------- 2 files changed, 1 insertion(+), 65 deletions(-) diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go index 120ad292..7a021dd8 100644 --- a/pkg/dataset/custom_dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -193,9 +193,8 @@ var _ = Describe("CustomDataset", Ordered, func() { Prompt: testPrompt, MaxTokens: &n, } - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + tokens, _, err := dataset.GetTokens(req, common.ModeRandom) Expect(err).NotTo(HaveOccurred()) - Expect(finishReason).To(Equal(LengthFinishReason)) Expect(len(tokens)).To(BeNumerically("<=", 2)) }) }) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 86dbd06b..3eb67db3 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -259,69 +259,6 @@ func (s *VllmSimulator) initDataset() error { return nil } -func (s *VllmSimulator) newListener() (net.Listener, error) { - s.logger.Info("Server starting", "port", s.config.Port) - listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", s.config.Port)) - if err != nil { - return nil, err - } - return listener, nil -} - -// startServer starts http server on port defined in command line -func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) error { - r := fasthttprouter.New() - - // support completion APIs - r.POST("/v1/chat/completions", s.HandleChatCompletions) - r.POST("/v1/completions", s.HandleTextCompletions) - // supports /models API - r.GET("/v1/models", s.HandleModels) - // support load/unload of lora adapter - r.POST("/v1/load_lora_adapter", s.HandleLoadLora) - r.POST("/v1/unload_lora_adapter", s.HandleUnloadLora) - // supports /metrics prometheus API - r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.HandlerFor(s.registry, promhttp.HandlerOpts{}))) - // supports standard Kubernetes health and readiness checks - r.GET("/health", s.HandleHealth) - r.GET("/ready", s.HandleReady) - r.POST("/tokenize", s.HandleTokenize) - - server := fasthttp.Server{ - ErrorHandler: s.HandleError, - Handler: r.Handler, - Logger: s, - } - - // Start server in a goroutine - serverErr := make(chan error, 1) - go func() { - s.logger.Info("HTTP server starting") - serverErr <- server.Serve(listener) - }() - - // Wait for either context cancellation or server error - select { - case <-ctx.Done(): - s.logger.Info("Shutdown signal received, shutting down HTTP server gracefully") - - // Gracefully shutdown the server - if err := server.Shutdown(); err != nil { - s.logger.Error(err, "Error during server shutdown") - return err - } - - s.logger.Info("HTTP server stopped") - return nil - - case err := <-serverErr: - if err != nil { - s.logger.Error(err, "HTTP server failed") - } - return err - } -} - // Print prints to a log, implementation of fasthttp.Logger func (s *VllmSimulator) Printf(format string, args ...interface{}) { s.logger.Info("Server error", "msg", fmt.Sprintf(format, args...)) From d6429758692a25df4c794ad46ae2b14843a7b8b5 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Wed, 17 Sep 2025 14:59:40 +1000 Subject: [PATCH 21/34] Update readme Signed-off-by: Qifan Deng --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8c4da70e..764cf790 100644 --- a/README.md +++ b/README.md @@ -150,8 +150,8 @@ For more details see the Date: Tue, 23 Sep 2025 14:25:27 +1000 Subject: [PATCH 23/34] Ignore test temp folder Signed-off-by: Qifan Deng --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 1ee731ff..d24a1264 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ vendor manifests/dev-config.yaml pkg/dataset/.llm-d pkg/llm-d-inference-sim/tests-tmp/ +pkg/llm-d-inference-sim/.llm-d/ From e13352b948551f708ec992732821c5d71c3ab654 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Thu, 25 Sep 2025 15:43:57 +1000 Subject: [PATCH 24/34] Update README Signed-off-by: Qifan Deng --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 764cf790..525701f3 100644 --- a/README.md +++ b/README.md @@ -150,8 +150,8 @@ For more details see the Date: Fri, 26 Sep 2025 16:24:53 +1000 Subject: [PATCH 26/34] Use ctx in main Signed-off-by: Qifan Deng --- pkg/dataset/custom_dataset.go | 21 +++------------------ pkg/dataset/custom_dataset_test.go | 11 ++++++----- pkg/dataset/dataset.go | 5 +++-- pkg/llm-d-inference-sim/simulator.go | 6 +++--- pkg/llm-d-inference-sim/simulator_test.go | 2 +- 5 files changed, 16 insertions(+), 29 deletions(-) diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index 3af1547b..05c41379 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -27,10 +27,8 @@ import ( "io" "net/http" "os" - "os/signal" "path/filepath" "strconv" - "syscall" "time" "github.com/go-logr/logr" @@ -56,20 +54,7 @@ const ( nGenTokensColType = "INTEGER" ) -func (d *CustomDataset) downloadDataset(url string, path string) error { - // Set up signal handling for Ctrl+C (SIGINT) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, os.Interrupt, syscall.SIGTERM) - defer signal.Stop(sigs) - - // Goroutine to listen for signal - go func() { - <-sigs - d.Logger.Info("Interrupt signal received, cancelling download...") - cancel() - }() +func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path string) error { out, err := os.Create(path) if err != nil { @@ -280,7 +265,7 @@ func (d *CustomDataset) connectToDB(path string) error { return nil } -func (d *CustomDataset) Init(path string, url string) error { +func (d *CustomDataset) Init(ctx context.Context, path string, url string) error { d.hasWarned = false if path != "" && url == "" { d.Logger.Info("Using dataset from", "path", path) @@ -304,7 +289,7 @@ func (d *CustomDataset) Init(path string, url string) error { if err != nil { return fmt.Errorf("failed to create parent directory: %w", err) } - err = d.downloadDataset(url, path) + err = d.downloadDataset(ctx, url, path) if err != nil { return fmt.Errorf("failed to download dataset: %w", err) } diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go index 7a021dd8..7872e378 100644 --- a/pkg/dataset/custom_dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -17,6 +17,7 @@ limitations under the License. package dataset import ( + "context" "encoding/json" "os" "time" @@ -83,7 +84,7 @@ var _ = Describe("CustomDataset", Ordered, func() { It("should download file from url", func() { url := "https://llm-d.ai" - err := dataset.downloadDataset(url, path) + err := dataset.downloadDataset(context.Background(), url, path) Expect(err).NotTo(HaveOccurred()) _, err = os.Stat(path) Expect(err).NotTo(HaveOccurred()) @@ -93,12 +94,12 @@ var _ = Describe("CustomDataset", Ordered, func() { It("should not download file from url", func() { url := "https://256.256.256.256" // invalid url - err := dataset.downloadDataset(url, path) + err := dataset.downloadDataset(context.Background(), url, path) Expect(err).To(HaveOccurred()) }) It("should successfully init dataset", func() { - err := dataset.Init(validDBPath, "") + err := dataset.Init(context.Background(), validDBPath, "") Expect(err).NotTo(HaveOccurred()) row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") @@ -173,7 +174,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }) It("should return tokens for existing prompt", func() { - err := dataset.Init(validDBPath, "") + err := dataset.Init(context.Background(), validDBPath, "") Expect(err).NotTo(HaveOccurred()) req := &openaiserverapi.TextCompletionRequest{ @@ -186,7 +187,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }) It("should return at most 2 tokens for existing prompt", func() { - err := dataset.Init(validDBPath, "") + err := dataset.Init(context.Background(), validDBPath, "") Expect(err).NotTo(HaveOccurred()) n := int64(2) req := &openaiserverapi.TextCompletionRequest{ diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go index cf90df49..14138cc5 100644 --- a/pkg/dataset/dataset.go +++ b/pkg/dataset/dataset.go @@ -17,6 +17,7 @@ limitations under the License. package dataset import ( + "context" "errors" "math" "math/rand" @@ -70,7 +71,7 @@ var chatCompletionFakeResponses = []string{ type Dataset interface { // Init initializes the dataset using configs - Init(path string, url string) error + Init(ctx context.Context, path string, url string) error // Close closes the dataset Close() error // GetTokens returns tokens for the given request and mode (echo or random) @@ -280,7 +281,7 @@ type BaseDataset struct { Logger logr.Logger } -func (d *BaseDataset) Init(path string, url string) error { +func (d *BaseDataset) Init(ctx context.Context, path string, url string) error { return nil } diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index f54915e3..118717ba 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -216,7 +216,7 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { go s.kvcacheHelper.Run(ctx) } - err = s.initDataset() + err = s.initDataset(ctx) if err != nil { return fmt.Errorf("dataset initialization error: %w", err) } @@ -238,7 +238,7 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { return s.startServer(ctx, listener) } -func (s *VllmSimulator) initDataset() error { +func (s *VllmSimulator) initDataset(ctx context.Context) error { randDataset := &dataset.BaseDataset{ Logger: s.logger, } @@ -253,7 +253,7 @@ func (s *VllmSimulator) initDataset() error { } } - if err := s.dataset.Init(s.config.DatasetPath, s.config.DatasetURL); err != nil { + if err := s.dataset.Init(ctx, s.config.DatasetPath, s.config.DatasetURL); err != nil { return fmt.Errorf("dataset initialization error: %w", err) } return nil diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 9326f8a9..e8e57e41 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -118,7 +118,7 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m go s.kvcacheHelper.Run(ctx) } - err = s.initDataset() + err = s.initDataset(ctx) if err != nil { return nil, fmt.Errorf("dataset initialization error: %w", err) } From d93174f25b01c0c70ecac2acd89395e15a90a667 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Fri, 26 Sep 2025 17:51:52 +1000 Subject: [PATCH 27/34] Update readme and dataset downloading logic Signed-off-by: Qifan Deng --- README.md | 13 ++++- pkg/common/config.go | 20 ++++++-- pkg/dataset/custom_dataset.go | 74 +++++++++++++++++++--------- pkg/dataset/custom_dataset_test.go | 1 - pkg/llm-d-inference-sim/simulator.go | 27 ++++++---- 5 files changed, 94 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 525701f3..84de13a2 100644 --- a/README.md +++ b/README.md @@ -150,8 +150,17 @@ For more details see the Date: Fri, 26 Sep 2025 18:11:53 +1000 Subject: [PATCH 28/34] Pass logger when init dataset Signed-off-by: Qifan Deng --- .gitignore | 1 + pkg/dataset/custom_dataset.go | 43 ++++++++++++++-------------- pkg/dataset/custom_dataset_test.go | 21 ++++++++------ pkg/dataset/dataset.go | 7 +++-- pkg/dataset/dataset_test.go | 5 +--- pkg/llm-d-inference-sim/simulator.go | 12 ++++---- 6 files changed, 46 insertions(+), 43 deletions(-) diff --git a/.gitignore b/.gitignore index d24a1264..d684889b 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ manifests/dev-config.yaml pkg/dataset/.llm-d pkg/llm-d-inference-sim/tests-tmp/ pkg/llm-d-inference-sim/.llm-d/ +.llm-d/ diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index 7614c0d0..522a9591 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -63,7 +63,7 @@ func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path st if _, err := os.Stat(path); err == nil { // file already exists - return errors.New("Dataset file already exists, should not download: " + path) + return errors.New("Dataset file already exists, should not download: " + path) } out, err := os.Create(path) @@ -73,11 +73,11 @@ func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path st defer func() { cerr := out.Close() if cerr != nil { - d.Logger.Error(cerr, "failed to close file after download") + d.logger.Error(cerr, "failed to close file after download") } }() - d.Logger.Info("Using dataset-url", "dataset-url", url) + d.logger.Info("Using dataset-url", "dataset-url", url) resp, err := http.Get(url) if err != nil { return err @@ -85,7 +85,7 @@ func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path st defer func() { cerr := resp.Body.Close() if cerr != nil { - d.Logger.Error(cerr, "failed to close response body after download") + d.logger.Error(cerr, "failed to close response body after download") } }() @@ -97,7 +97,7 @@ func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path st pr := &progressReader{ Reader: resp.Body, total: resp.ContentLength, - logger: d.Logger, + logger: d.logger, ctx: ctx, startTime: time.Now(), hasShownSpeed: false, @@ -108,7 +108,7 @@ func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path st // Remove incomplete file cerr := os.Remove(path) if cerr != nil { - d.Logger.Error(cerr, "failed to remove incomplete file after download") + d.logger.Error(cerr, "failed to remove incomplete file after download") } // If context was cancelled, return a specific error if errors.Is(err, context.Canceled) { @@ -120,7 +120,7 @@ func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path st if written == 0 { cerr := os.Remove(path) if cerr != nil { - d.Logger.Error(cerr, "failed to remove empty file after download") + d.logger.Error(cerr, "failed to remove empty file after download") } return errors.New("downloaded file is empty") } @@ -129,7 +129,7 @@ func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path st if err := out.Sync(); err != nil { cerr := os.Remove(path) if cerr != nil { - d.Logger.Error(cerr, "failed to remove incomplete file after download") + d.logger.Error(cerr, "failed to remove incomplete file after download") } return fmt.Errorf("failed to sync file: %w", err) } @@ -190,7 +190,7 @@ func (d *CustomDataset) verifyDB() error { } defer func() { if cerr := rows.Close(); cerr != nil { - d.Logger.Error(cerr, "failed to close rows after querying table info") + d.logger.Error(cerr, "failed to close rows after querying table info") } }() @@ -246,7 +246,7 @@ func (d *CustomDataset) connectToDB(path string) error { if d.db != nil { err := d.db.Close() if err != nil { - d.Logger.Error(err, "failed to close existing database connection") + d.logger.Error(err, "failed to close existing database connection") } d.db = nil } @@ -265,7 +265,7 @@ func (d *CustomDataset) connectToDB(path string) error { if err != nil { err := d.db.Close() if err != nil { - d.Logger.Error(err, "failed to close database after failing to acquire exclusive lock") + d.logger.Error(err, "failed to close database after failing to acquire exclusive lock") } d.db = nil return fmt.Errorf("database is locked or has other active connections: %w", err) @@ -279,20 +279,21 @@ func (d *CustomDataset) connectToDB(path string) error { count, err := d.getRecordsCount() if err != nil { - d.Logger.Error(err, "failed to get records count") + d.logger.Error(err, "failed to get records count") return fmt.Errorf("failed to query database: %w", err) } - d.Logger.Info("Database connected successfully", "path", path, "records count", count) + d.logger.Info("Database connected successfully", "path", path, "records count", count) return nil } -func (d *CustomDataset) Init(ctx context.Context, path string, url string) error { +func (d *CustomDataset) Init(ctx context.Context, logger logr.Logger, path string, url string) error { + d.logger = logger if path == "" { return errors.New("no dataset path provided") } d.hasWarned = false if url == "" { - d.Logger.Info("Using dataset from", "path", path) + d.logger.Info("Using dataset from", "path", path) return d.connectToDB(path) } _, err := os.Stat(path) @@ -304,13 +305,13 @@ func (d *CustomDataset) Init(ctx context.Context, path string, url string) error if _, statErr := os.Stat(path); statErr == nil { cerr := os.Remove(path) if cerr != nil { - d.Logger.Error(cerr, "failed to remove incomplete file after download") + d.logger.Error(cerr, "failed to remove incomplete file after download") } } return fmt.Errorf("failed to download dataset: %w", err) } } - d.Logger.Info("Using dataset path", "dataset-path", path) + d.logger.Info("Using dataset path", "dataset-path", path) return d.connectToDB(path) } @@ -320,7 +321,7 @@ func (d *CustomDataset) Close() error { _, err := d.db.Exec("ROLLBACK;") if err != nil { if cerr := d.db.Close(); cerr != nil { - d.Logger.Error(cerr, "failed to close database after failing to acquire exclusive lock") + d.logger.Error(cerr, "failed to close database after failing to acquire exclusive lock") } d.db = nil return fmt.Errorf("failed to release exclusive lock: %w", err) @@ -372,14 +373,14 @@ func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) { rows, err := d.db.Query(query) if err != nil { if !d.hasWarned { - d.Logger.Error(err, "Failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.") + d.logger.Error(err, "Failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.") d.hasWarned = true } return [][]string{GenPresetRandomTokens(nTokens)}, nil } defer func() { if cerr := rows.Close(); cerr != nil { - d.Logger.Error(cerr, "failed to close rows after query") + d.logger.Error(cerr, "failed to close rows after query") } }() return unmarshalAllRecords(rows) @@ -395,7 +396,7 @@ func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nT // filter out results according to finish reason var filteredTokensList [][]string if finishReason != LengthFinishReason && finishReason != StopFinishReason { - d.Logger.Error(errors.New("unknown finish reason"), "Unexpected finish reason", "reason", finishReason) + d.logger.Error(errors.New("unknown finish reason"), "Unexpected finish reason", "reason", finishReason) } for _, tokens := range tokensList { if finishReason == StopFinishReason && len(tokens) <= nTokens { diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go index 1c0053d4..758aa678 100644 --- a/pkg/dataset/custom_dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -53,11 +53,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }) BeforeEach(func() { - dataset = &CustomDataset{ - BaseDataset: BaseDataset{ - Logger: logr.Discard(), - }, - } + dataset = &CustomDataset{} file_folder = ".llm-d" path = file_folder + "/test.sqlite3" err := os.MkdirAll(file_folder, os.ModePerm) @@ -83,8 +79,15 @@ var _ = Describe("CustomDataset", Ordered, func() { }) It("should download file from url", func() { + // remove file if it exists + _, err := os.Stat(path) + if err == nil { + err = os.Remove(path) + Expect(err).NotTo(HaveOccurred()) + } + url := "https://llm-d.ai" - err := dataset.downloadDataset(context.Background(), url, path) + err = dataset.downloadDataset(context.Background(), url, path) Expect(err).NotTo(HaveOccurred()) _, err = os.Stat(path) Expect(err).NotTo(HaveOccurred()) @@ -99,7 +102,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }) It("should successfully init dataset", func() { - err := dataset.Init(context.Background(), validDBPath, "") + err := dataset.Init(context.Background(), logr.Discard(), validDBPath, "") Expect(err).NotTo(HaveOccurred()) row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") @@ -173,7 +176,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }) It("should return tokens for existing prompt", func() { - err := dataset.Init(context.Background(), validDBPath, "") + err := dataset.Init(context.Background(), logr.Discard(), validDBPath, "") Expect(err).NotTo(HaveOccurred()) req := &openaiserverapi.TextCompletionRequest{ @@ -186,7 +189,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }) It("should return at most 2 tokens for existing prompt", func() { - err := dataset.Init(context.Background(), validDBPath, "") + err := dataset.Init(context.Background(), logr.Discard(), validDBPath, "") Expect(err).NotTo(HaveOccurred()) n := int64(2) req := &openaiserverapi.TextCompletionRequest{ diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go index 14138cc5..589c4f2b 100644 --- a/pkg/dataset/dataset.go +++ b/pkg/dataset/dataset.go @@ -71,7 +71,7 @@ var chatCompletionFakeResponses = []string{ type Dataset interface { // Init initializes the dataset using configs - Init(ctx context.Context, path string, url string) error + Init(ctx context.Context, logger logr.Logger, path string, url string) error // Close closes the dataset Close() error // GetTokens returns tokens for the given request and mode (echo or random) @@ -278,10 +278,11 @@ func EchoResponseTokens(maxCompletionTokens *int64, text string) ([]string, stri } type BaseDataset struct { - Logger logr.Logger + logger logr.Logger } -func (d *BaseDataset) Init(ctx context.Context, path string, url string) error { +func (d *BaseDataset) Init(ctx context.Context, logger logr.Logger, path string, url string) error { + d.logger = logger return nil } diff --git a/pkg/dataset/dataset_test.go b/pkg/dataset/dataset_test.go index 2e01463a..83a2953b 100644 --- a/pkg/dataset/dataset_test.go +++ b/pkg/dataset/dataset_test.go @@ -21,7 +21,6 @@ import ( "strings" "time" - "github.com/go-logr/logr" "github.com/llm-d/llm-d-inference-sim/pkg/common" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" . "github.com/onsi/ginkgo/v2" @@ -38,9 +37,7 @@ var _ = Describe("Dataset", Ordered, func() { }) BeforeEach(func() { - dataset = &BaseDataset{ - Logger: logr.Discard(), - } + dataset = &BaseDataset{} }) Context("GetRandomTokens", func() { diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 0302e71a..0a7b301f 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -239,8 +239,10 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { } func (s *VllmSimulator) initDataset(ctx context.Context) error { - randDataset := &dataset.BaseDataset{ - Logger: s.logger, + randDataset := &dataset.BaseDataset{} + err := randDataset.Init(ctx, s.logger, "", "") + if err != nil { + return fmt.Errorf("failed to initialize random dataset: %w", err) } if s.config.DatasetPath == "" && s.config.DatasetURL == "" { @@ -249,10 +251,8 @@ func (s *VllmSimulator) initDataset(ctx context.Context) error { return nil } - custDataset := &dataset.CustomDataset{ - BaseDataset: *randDataset, - } - err := custDataset.Init(ctx, s.config.DatasetPath, s.config.DatasetURL) + custDataset := &dataset.CustomDataset{} + err = custDataset.Init(ctx, s.logger, s.config.DatasetPath, s.config.DatasetURL) if err == nil { s.dataset = custDataset From 63c184ed456a2fc98210b44bda859db72a87464c Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Fri, 26 Sep 2025 18:27:57 +1000 Subject: [PATCH 29/34] Improve progress logging, show it every 5 seconds or 10% Signed-off-by: Qifan Deng --- go.mod | 2 +- pkg/dataset/custom_dataset.go | 57 ++++++++++++++++++----------------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/go.mod b/go.mod index e3ca6e85..32fe86a8 100644 --- a/go.mod +++ b/go.mod @@ -45,7 +45,7 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect - github.com/mattn/go-sqlite3 v1.14.32 // indirect + github.com/mattn/go-sqlite3 v1.14.32 // direct github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index 522a9591..396c84a1 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -45,13 +45,15 @@ type CustomDataset struct { // use constants for expected column names and types const ( - tableName = "llmd" - promptHashCol = "prompt_hash" - genTokensCol = "gen_tokens" - nGenTokensCol = "n_gen_tokens" - promptHashColType = "BLOB" - genTokensColType = "JSON" - nGenTokensColType = "INTEGER" + tableName = "llmd" + promptHashCol = "prompt_hash" + genTokensCol = "gen_tokens" + nGenTokensCol = "n_gen_tokens" + promptHashColType = "BLOB" + genTokensColType = "JSON" + nGenTokensColType = "INTEGER" + progressLogTimeInterval = 5 * time.Second + progressLogPercentInterval = 10 ) func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path string) error { @@ -95,12 +97,11 @@ func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path st // Progress reader with context pr := &progressReader{ - Reader: resp.Body, - total: resp.ContentLength, - logger: d.logger, - ctx: ctx, - startTime: time.Now(), - hasShownSpeed: false, + Reader: resp.Body, + total: resp.ContentLength, + logger: d.logger, + ctx: ctx, + startTime: time.Now(), } written, err := io.Copy(out, pr) @@ -116,7 +117,7 @@ func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path st } return fmt.Errorf("failed to download file: %w", err) } - // Check if file size is zero or suspiciously small + // Check if file size is zero if written == 0 { cerr := os.Remove(path) if cerr != nil { @@ -140,13 +141,13 @@ func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path st // progressReader wraps an io.Reader and logs download progress. type progressReader struct { io.Reader - total int64 - downloaded int64 - startTime time.Time - lastPct int - logger logr.Logger - ctx context.Context - hasShownSpeed bool + total int64 + downloaded int64 + startTime time.Time + lastPct int + lastLogTime time.Time + logger logr.Logger + ctx context.Context } func (pr *progressReader) Read(p []byte) (int, error) { @@ -159,14 +160,16 @@ func (pr *progressReader) Read(p []byte) (int, error) { pr.downloaded += int64(n) if pr.total > 0 { pct := int(float64(pr.downloaded) * 100 / float64(pr.total)) - if !pr.hasShownSpeed && time.Since(pr.startTime).Seconds() > 2 { - pr.hasShownSpeed = true - pr.logProgress(pct) - pr.lastPct = pct - } - if pct != pr.lastPct && pct%10 == 0 { + now := time.Now() + + timeSinceLastLog := now.Sub(pr.lastLogTime).Seconds() + pctDiff := pct - pr.lastPct + + if timeSinceLastLog >= progressLogTimeInterval.Seconds() || (pctDiff >= progressLogPercentInterval && pct != pr.lastPct) { + // progress will be shown every interval seconds or every interval percent of progress pr.logProgress(pct) pr.lastPct = pct + pr.lastLogTime = now } } return n, err From e46ae874668498f7d599636c030ea9346cc464df Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Fri, 26 Sep 2025 19:25:21 +1000 Subject: [PATCH 30/34] Use in memory database Signed-off-by: Qifan Deng --- README.md | 1 + pkg/common/config.go | 3 + pkg/dataset/custom_dataset.go | 128 +++++++++++++++++++++++---- pkg/dataset/custom_dataset_test.go | 31 +++++-- pkg/dataset/dataset.go | 4 +- pkg/llm-d-inference-sim/simulator.go | 4 +- 6 files changed, 140 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 84de13a2..85975c6e 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,7 @@ For more details see the Date: Fri, 26 Sep 2025 19:57:29 +1000 Subject: [PATCH 31/34] Use backup api to load dataset from disk to memory Signed-off-by: Qifan Deng --- pkg/dataset/.llm-d/test.valid.sqlite3 | Bin 12288 -> 16384 bytes pkg/dataset/custom_dataset.go | 68 +++++++++----------------- 2 files changed, 24 insertions(+), 44 deletions(-) diff --git a/pkg/dataset/.llm-d/test.valid.sqlite3 b/pkg/dataset/.llm-d/test.valid.sqlite3 index 847a6257842080335691a3de3943f0152a9382e1..f6e6601e1f4b636883a3fc205eb066a2d0635860 100644 GIT binary patch delta 251 zcmZojXlP)ZAT21wz`(!)#LPg<1jIZOb&Q3D81!@kc!4}d{$d9HV!mB`yEYbj@i8&F zO)Oj|I2|Yf1dUvb?Bb%LjE&`!WBAo{GE)>h{X$&bU4s+?f;@d4gCZ5YT_bh46o8Vi;wft%_1x5L}1tswriNzTTPCoul3IRc$zK%hW3f``flh5-T+Nv|LiCc>@ zGGykZq*fFc=46(n#wV7Rb_V|K{C9vxuI6V}VfFzUsL0Bs% Date: Mon, 29 Sep 2025 20:52:58 +1000 Subject: [PATCH 32/34] Remove duplicated log of Server starting Signed-off-by: Qifan Deng --- pkg/llm-d-inference-sim/server.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/llm-d-inference-sim/server.go b/pkg/llm-d-inference-sim/server.go index 1c6284a1..6384f28d 100644 --- a/pkg/llm-d-inference-sim/server.go +++ b/pkg/llm-d-inference-sim/server.go @@ -34,7 +34,6 @@ import ( ) func (s *VllmSimulator) newListener() (net.Listener, error) { - s.logger.Info("Server starting", "port", s.config.Port) listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", s.config.Port)) if err != nil { return nil, err From eb591efbb373862e571419cf89632d5873b16e8b Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Mon, 29 Sep 2025 20:57:10 +1000 Subject: [PATCH 33/34] use klog Signed-off-by: Qifan Deng --- pkg/dataset/custom_dataset_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go index 30dc6815..afd734a2 100644 --- a/pkg/dataset/custom_dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -22,11 +22,11 @@ import ( "os" "time" - "github.com/go-logr/logr" "github.com/llm-d/llm-d-inference-sim/pkg/common" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "k8s.io/klog/v2" _ "github.com/mattn/go-sqlite3" ) @@ -102,7 +102,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }) It("should successfully init dataset", func() { - err := dataset.Init(context.Background(), logr.Discard(), validDBPath, "", false) + err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false) Expect(err).NotTo(HaveOccurred()) row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") @@ -176,7 +176,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }) It("should return tokens for existing prompt", func() { - err := dataset.Init(context.Background(), logr.Discard(), validDBPath, "", false) + err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false) Expect(err).NotTo(HaveOccurred()) req := &openaiserverapi.TextCompletionRequest{ @@ -189,7 +189,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }) It("should return at most 2 tokens for existing prompt", func() { - err := dataset.Init(context.Background(), logr.Discard(), validDBPath, "", false) + err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false) Expect(err).NotTo(HaveOccurred()) n := int64(2) req := &openaiserverapi.TextCompletionRequest{ @@ -202,7 +202,7 @@ var _ = Describe("CustomDataset", Ordered, func() { }) It("should successfully init dataset with in-memory option", func() { - err := dataset.Init(context.Background(), logr.Discard(), validDBPath, "", true) + err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", true) Expect(err).NotTo(HaveOccurred()) req := &openaiserverapi.TextCompletionRequest{ From b3ac9f51b4ffb3a42987c1700a2e9124efa9fa65 Mon Sep 17 00:00:00 2001 From: Qifan Deng Date: Mon, 29 Sep 2025 21:04:26 +1000 Subject: [PATCH 34/34] update readme Signed-off-by: Qifan Deng --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 85975c6e..fa4dfde2 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,7 @@ For more details see the