Skip to content

Commit 395619f

Browse files
committed
load or download response dataset
1 parent 8edbb20 commit 395619f

File tree

8 files changed

+314
-1
lines changed

8 files changed

+314
-1
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ vendor
77
.DS_Store
88
*.test
99
manifests/dev-config.yaml
10+
pkg/llm-d-inference-sim/.llm-d
11+
.llm-d/

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ format: ## Format Go source files
7878
test: check-ginkgo download-tokenizer download-zmq ## Run tests
7979
@printf "\033[33;1m==== Running tests ====\033[0m\n"
8080
ifdef GINKGO_FOCUS
81-
CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r --focus="$(GINKGO_FOCUS)"
81+
CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r -- -ginkgo.v -ginkgo.focus="$(GINKGO_FOCUS)"
8282
else
8383
CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r
8484
endif

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ require (
4545
github.com/json-iterator/go v1.1.12 // indirect
4646
github.com/klauspost/compress v1.18.0 // indirect
4747
github.com/mailru/easyjson v0.7.7 // indirect
48+
github.com/mattn/go-sqlite3 v1.14.32 // indirect
4849
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
4950
github.com/modern-go/reflect2 v1.0.2 // indirect
5051
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ github.com/llm-d/llm-d-kv-cache-manager v0.3.0-rc1 h1:SDLiNrcreDcA9m9wfXAumFARDH
7272
github.com/llm-d/llm-d-kv-cache-manager v0.3.0-rc1/go.mod h1:tN80/D0Faf6pE2ocwFgTNoCxKPsqdsa2XnjQUqOaZ8Q=
7373
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
7474
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
75+
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
76+
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
7577
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
7678
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
7779
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

pkg/common/config.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,20 @@ type Configuration struct {
174174

175175
// DPSize is data parallel size - a number of ranks to run, minimum is 1, maximum is 8, default is 1
176176
DPSize int `yaml:"data-parallel-size" json:"data-parallel-size"`
177+
178+
// Dataset configuration for response generation from a dataset. sqlite db file is expected.
179+
Dataset Dataset
180+
}
181+
182+
type Dataset struct {
183+
// Path is the local path to the sqlite db file, default is empty
184+
// when path is empty Url will be checked
185+
Path string `yaml:"path" json:"path"`
186+
// Url is the URL to download the sqlite db file if set, default is empty
187+
Url string `yaml:"url" json:"url"`
188+
// SavePath is the local path to save the downloaded sqlite db file
189+
// if Url is set but SavePath is not, "~/.llmd/dataset.db" will be used
190+
SavePath string `yaml:"save-path" json:"save-path"`
177191
}
178192

179193
type Metrics struct {
8 KB
Binary file not shown.

pkg/llm-d-inference-sim/dataset.go

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
Copyright 2025 The llm-d-inference-sim Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package llmdinferencesim
18+
19+
import (
20+
"context"
21+
"database/sql"
22+
"errors"
23+
"fmt"
24+
"io"
25+
"net/http"
26+
"os"
27+
"os/signal"
28+
"path/filepath"
29+
"syscall"
30+
31+
"github.com/go-logr/logr"
32+
)
33+
34+
type Dataset struct {
35+
db *sql.DB
36+
logger logr.Logger
37+
}
38+
39+
func (d *Dataset) downloadDataset(url string, savePath string) error {
40+
// Set up signal handling for Ctrl+C (SIGINT)
41+
ctx, cancel := context.WithCancel(context.Background())
42+
defer cancel()
43+
sigs := make(chan os.Signal, 1)
44+
signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
45+
defer signal.Stop(sigs)
46+
47+
// Goroutine to listen for signal
48+
go func() {
49+
<-sigs
50+
d.logger.Info("Interrupt signal received, cancelling download...")
51+
cancel()
52+
}()
53+
54+
out, err := os.Create(savePath)
55+
if err != nil {
56+
return err
57+
}
58+
defer func() {
59+
cerr := out.Close()
60+
if cerr != nil {
61+
d.logger.Error(cerr, "failed to close file after download")
62+
}
63+
}()
64+
65+
resp, err := http.Get(url)
66+
if err != nil {
67+
return err
68+
}
69+
defer func() {
70+
cerr := resp.Body.Close()
71+
if cerr != nil {
72+
d.logger.Error(cerr, "failed to close response body after download")
73+
}
74+
}()
75+
76+
if resp.StatusCode != http.StatusOK {
77+
return fmt.Errorf("bad status: %s", resp.Status)
78+
}
79+
80+
// Progress reader with context
81+
pr := &progressReader{
82+
Reader: resp.Body,
83+
total: resp.ContentLength,
84+
logger: d.logger,
85+
ctx: ctx,
86+
}
87+
88+
written, err := io.Copy(out, pr)
89+
if err != nil {
90+
// Remove incomplete file
91+
cerr := os.Remove(savePath)
92+
if cerr != nil {
93+
d.logger.Error(cerr, "failed to remove incomplete file after download")
94+
}
95+
// If context was cancelled, return a specific error
96+
if errors.Is(err, context.Canceled) {
97+
return errors.New("download cancelled by user")
98+
}
99+
return fmt.Errorf("failed to download file: %w", err)
100+
}
101+
// Check if file size is zero or suspiciously small
102+
if written == 0 {
103+
cerr := os.Remove(savePath)
104+
if cerr != nil {
105+
d.logger.Error(cerr, "failed to remove empty file after download")
106+
}
107+
return errors.New("downloaded file is empty")
108+
}
109+
110+
// Ensure file is fully flushed and closed before returning success
111+
if err := out.Sync(); err != nil {
112+
cerr := os.Remove(savePath)
113+
if cerr != nil {
114+
d.logger.Error(cerr, "failed to remove incomplete file after download")
115+
}
116+
return fmt.Errorf("failed to sync file: %w", err)
117+
}
118+
119+
return nil
120+
}
121+
122+
// progressReader wraps an io.Reader and logs download progress.
123+
type progressReader struct {
124+
io.Reader
125+
total int64
126+
downloaded int64
127+
lastPct int
128+
logger logr.Logger
129+
ctx context.Context
130+
}
131+
132+
func (pr *progressReader) Read(p []byte) (int, error) {
133+
select {
134+
case <-pr.ctx.Done():
135+
return 0, pr.ctx.Err()
136+
default:
137+
}
138+
n, err := pr.Reader.Read(p)
139+
pr.downloaded += int64(n)
140+
if pr.total > 0 {
141+
pct := int(float64(pr.downloaded) * 100 / float64(pr.total))
142+
if pct != pr.lastPct && pct%10 == 0 { // log every 10%
143+
pr.logger.Info(fmt.Sprintf("Download progress: %d%%", pct))
144+
pr.lastPct = pct
145+
}
146+
}
147+
return n, err
148+
}
149+
func (d *Dataset) connectToDB(path string) error {
150+
// check if file exists
151+
_, err := os.Stat(path)
152+
if err != nil {
153+
return fmt.Errorf("database file does not exist: %w", err)
154+
}
155+
d.db, err = sql.Open("sqlite3", path)
156+
if err != nil {
157+
return fmt.Errorf("failed to open database: %w", err)
158+
}
159+
return nil
160+
}
161+
162+
func (d *Dataset) Init(path string, url string, savePath string) error {
163+
if path != "" {
164+
return d.connectToDB(path)
165+
}
166+
if url != "" {
167+
if savePath == "" {
168+
savePath = "~/.llmd/dataset.sqlite3"
169+
}
170+
171+
_, err := os.Stat(savePath)
172+
if err != nil {
173+
// file does not exist, download it
174+
folder := filepath.Dir(savePath)
175+
err := os.MkdirAll(folder, 0755)
176+
if err != nil {
177+
return fmt.Errorf("failed to create parent directory: %w", err)
178+
}
179+
d.logger.Info("Downloading dataset from URL", "url", url, "to", savePath)
180+
err = d.downloadDataset(url, savePath)
181+
if err != nil {
182+
return fmt.Errorf("failed to download dataset: %w", err)
183+
}
184+
}
185+
d.logger.Info("Using dataset from", "path", savePath)
186+
187+
return d.connectToDB(savePath)
188+
}
189+
return errors.New("no dataset path or url provided")
190+
}
191+
192+
func (d *Dataset) Close() error {
193+
if d.db != nil {
194+
return d.db.Close()
195+
}
196+
return nil
197+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
Copyright 2025 The llm-d-inference-sim Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package llmdinferencesim
18+
19+
import (
20+
"os"
21+
22+
"github.com/go-logr/logr"
23+
. "github.com/onsi/ginkgo/v2"
24+
. "github.com/onsi/gomega"
25+
26+
_ "github.com/mattn/go-sqlite3"
27+
)
28+
29+
var _ = Describe("Dataset", func() {
30+
var (
31+
dataset *Dataset
32+
file_folder string
33+
savePath string
34+
)
35+
36+
BeforeEach(func() {
37+
dataset = &Dataset{
38+
logger: logr.Discard(),
39+
}
40+
file_folder = "./.llm-d"
41+
savePath = file_folder + "/test.sqlite3"
42+
err := os.MkdirAll(file_folder, os.ModePerm)
43+
Expect(err).NotTo(HaveOccurred())
44+
})
45+
46+
AfterEach(func() {
47+
if dataset.db != nil {
48+
err := dataset.db.Close()
49+
Expect(err).NotTo(HaveOccurred())
50+
}
51+
})
52+
53+
It("should return error for invalid DB path", func() {
54+
err := dataset.connectToDB("/invalid/path/to/db.sqlite")
55+
Expect(err).To(HaveOccurred())
56+
})
57+
58+
It("should download file from url", func() {
59+
url := "https://llm-d.ai"
60+
err := dataset.downloadDataset(url, savePath)
61+
Expect(err).NotTo(HaveOccurred())
62+
_, err = os.Stat(savePath)
63+
Expect(err).NotTo(HaveOccurred())
64+
err = os.Remove(savePath)
65+
Expect(err).NotTo(HaveOccurred())
66+
})
67+
68+
It("should not download file from url", func() {
69+
url := "https://256.256.256.256" // invalid url
70+
err := dataset.downloadDataset(url, savePath)
71+
Expect(err).To(HaveOccurred())
72+
})
73+
74+
It("should successfully init dataset", func() {
75+
validDBPath := file_folder + "/test.valid.sqlite3"
76+
err := dataset.Init(validDBPath, "", "")
77+
Expect(err).NotTo(HaveOccurred())
78+
79+
// read from the db to verify it's valid
80+
row := dataset.db.QueryRow("SELECT * FROM t;")
81+
var value string
82+
err = row.Scan(&value)
83+
Expect(err).NotTo(HaveOccurred())
84+
Expect(value).To(Equal("llm-d"))
85+
})
86+
87+
It("should raise err with invalid DB content", func() {
88+
err := dataset.connectToDB(file_folder)
89+
Expect(err).NotTo(HaveOccurred())
90+
91+
// read from the db to verify it's not valid
92+
row := dataset.db.QueryRow("SELECT * FROM t;")
93+
var value string
94+
err = row.Scan(&value)
95+
Expect(err).To(HaveOccurred())
96+
})
97+
})

0 commit comments

Comments
 (0)