Skip to content

Commit 76e9d86

Browse files
committed
Fix tests and init dataset when loading sim
Signed-off-by: Qifan Deng <[email protected]>
1 parent e4bf3b3 commit 76e9d86

File tree

4 files changed

+38
-26
lines changed

4 files changed

+38
-26
lines changed
4 KB
Binary file not shown.

pkg/llm-d-inference-sim/dataset.go

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ func (d *Dataset) downloadDataset(url string, savePath string) error {
8181

8282
// Progress reader with context
8383
pr := &progressReader{
84-
Reader: resp.Body,
85-
total: resp.ContentLength,
86-
logger: d.logger,
87-
ctx: ctx,
88-
startTime: time.Now(),
84+
Reader: resp.Body,
85+
total: resp.ContentLength,
86+
logger: d.logger,
87+
ctx: ctx,
88+
startTime: time.Now(),
8989
hasShownSpeed: false,
9090
}
9191

@@ -126,12 +126,12 @@ func (d *Dataset) downloadDataset(url string, savePath string) error {
126126
// progressReader wraps an io.Reader and logs download progress.
127127
type progressReader struct {
128128
io.Reader
129-
total int64
130-
downloaded int64
131-
startTime time.Time
132-
lastPct int
133-
logger logr.Logger
134-
ctx context.Context
129+
total int64
130+
downloaded int64
131+
startTime time.Time
132+
lastPct int
133+
logger logr.Logger
134+
ctx context.Context
135135
hasShownSpeed bool
136136
}
137137

@@ -161,7 +161,7 @@ func (pr *progressReader) Read(p []byte) (int, error) {
161161
func (pr *progressReader) logProgress(pct int) {
162162
elapsedTime := time.Since(pr.startTime).Seconds()
163163
speed := float64(pr.downloaded) / (1024 * 1024 * elapsedTime)
164-
remainingTime := float64(pr.total-pr.downloaded) / (float64(pr.downloaded) / elapsedTime)
164+
remainingTime := float64(pr.total-pr.downloaded) / (float64(pr.downloaded) / elapsedTime)
165165
if pct != 100 {
166166
pr.logger.Info(fmt.Sprintf("Download progress: %d%%, Speed: %.2f MB/s, Remaining time: %.2fs", pct, speed, remainingTime))
167167
} else {
@@ -179,7 +179,17 @@ func (d *Dataset) connectToDB(path string) error {
179179
if err != nil {
180180
return fmt.Errorf("failed to open database: %w", err)
181181
}
182-
// Test the connection
182+
183+
var count int
184+
err = d.db.QueryRow("SELECT COUNT(generated) FROM llmd;").Scan(&count)
185+
if err != nil {
186+
err := d.db.Close()
187+
if err != nil {
188+
d.logger.Error(err, "failed to close database after query failure")
189+
}
190+
return fmt.Errorf("failed to query database: %w", err)
191+
}
192+
d.logger.Info("Database connected successfully", "path", path, "records count", count)
183193

184194
return nil
185195
}

pkg/llm-d-inference-sim/dataset_test.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,18 @@ var _ = Describe("Dataset", func() {
7676
err := dataset.Init(validDBPath, "", "")
7777
Expect(err).NotTo(HaveOccurred())
7878

79-
// read from the db to verify it's valid
80-
row := dataset.db.QueryRow("SELECT * FROM t;")
79+
row := dataset.db.QueryRow("SELECT generated FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2';")
8180
var value string
8281
err = row.Scan(&value)
8382
Expect(err).NotTo(HaveOccurred())
84-
Expect(value).To(Equal("llm-d"))
83+
Expect(value).To(Equal("world!"))
8584
})
8685

8786
It("should raise err with invalid DB content", func() {
8887
err := dataset.connectToDB(file_folder)
89-
Expect(err).NotTo(HaveOccurred())
90-
88+
Expect(err).To(HaveOccurred())
9189
// read from the db to verify it's not valid
92-
row := dataset.db.QueryRow("SELECT * FROM t;")
90+
row := dataset.db.QueryRow("SELECT * FROM llmd;")
9391
var value string
9492
err = row.Scan(&value)
9593
Expect(err).To(HaveOccurred())

pkg/llm-d-inference-sim/simulator.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,18 @@ func (s *VllmSimulator) Start(ctx context.Context) error {
155155
return err
156156
}
157157

158-
dataset := &Dataset{
159-
logger: s.logger,
160-
}
161-
err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath)
162-
if err != nil {
163-
return err
158+
if s.config.Dataset.Path == "" && s.config.Dataset.Url == "" && s.config.Dataset.SavePath == "" {
159+
s.dataset = nil
160+
} else {
161+
dataset := &Dataset{
162+
logger: s.logger,
163+
}
164+
err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath)
165+
if err != nil {
166+
return err
167+
}
168+
s.dataset = dataset
164169
}
165-
s.dataset = dataset
166170

167171
// For Data Parallel, start data-parallel-size - 1 additional simulators
168172
g, ctx := errgroup.WithContext(ctx)

0 commit comments

Comments
 (0)