Skip to content

Commit 711721e

Browse files
committed
Change db structure and add test cases
Signed-off-by: Qifan Deng <[email protected]>
1 parent 1b56969 commit 711721e

File tree

8 files changed

+138
-22
lines changed

8 files changed

+138
-22
lines changed
12 KB
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Hello world!
12 KB
Binary file not shown.
12 KB
Binary file not shown.
0 Bytes
Binary file not shown.

pkg/llm-d-inference-sim/dataset.go

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ type Dataset struct {
3838
logger logr.Logger
3939
}
4040

41+
// use constants for expected column names and types
42+
const (
43+
tableName = "llmd"
44+
promptHashCol = "prompt_hash"
45+
genTokensCol = "gen_tokens"
46+
nGenTokensCol = "n_gen_tokens"
47+
promptHashColType = "BLOB"
48+
genTokensColType = "JSON"
49+
nGenTokensColType = "INTEGER"
50+
)
51+
4152
func (d *Dataset) downloadDataset(url string, savePath string) error {
4253
// Set up signal handling for Ctrl+C (SIGINT)
4354
ctx, cancel := context.WithCancel(context.Background())
@@ -169,7 +180,73 @@ func (pr *progressReader) logProgress(pct int) {
169180
}
170181
}
171182

183+
func (d *Dataset) verifyDB() error {
184+
rows, err := d.db.Query("PRAGMA table_info(" + tableName + ");")
185+
if err != nil {
186+
return fmt.Errorf("failed to query table info for `%s`: %w", tableName, err)
187+
}
188+
defer func() {
189+
if cerr := rows.Close(); cerr != nil {
190+
d.logger.Error(cerr, "failed to close rows after querying table info")
191+
}
192+
}()
193+
194+
expectedColumns := map[string]string{
195+
promptHashCol: promptHashColType,
196+
genTokensCol: genTokensColType,
197+
nGenTokensCol: nGenTokensColType,
198+
}
199+
200+
columnsFound := make(map[string]bool)
201+
202+
var (
203+
columnName string
204+
columnType string
205+
cid int
206+
notnull int
207+
dfltValue interface{}
208+
pk int
209+
)
210+
211+
for rows.Next() {
212+
err := rows.Scan(&cid, &columnName, &columnType, &notnull, &dfltValue, &pk)
213+
if err != nil {
214+
return fmt.Errorf("failed to scan table info row: %w", err)
215+
}
216+
if expectedType, exists := expectedColumns[columnName]; exists {
217+
if columnType != expectedType {
218+
return fmt.Errorf("column %s has incorrect type: expected %s, got %s", columnName, expectedType, columnType)
219+
}
220+
columnsFound[columnName] = true
221+
}
222+
}
223+
224+
for col := range expectedColumns {
225+
if !columnsFound[col] {
226+
return fmt.Errorf("missing expected column in %s table: %s", tableName, col)
227+
}
228+
}
229+
230+
return nil
231+
}
232+
233+
func (d *Dataset) getRecordsCount() (int, error) {
234+
var count int
235+
err := d.db.QueryRow("SELECT COUNT(" + promptHashCol + ") FROM " + tableName + ";").Scan(&count)
236+
if err != nil {
237+
return 0, fmt.Errorf("failed to query database: %w", err)
238+
}
239+
return count, nil
240+
}
241+
172242
func (d *Dataset) connectToDB(path string) error {
243+
if d.db != nil {
244+
err := d.db.Close()
245+
if err != nil {
246+
d.logger.Error(err, "failed to close existing database connection")
247+
}
248+
d.db = nil
249+
}
173250
// check if file exists
174251
_, err := os.Stat(path)
175252
if err != nil {
@@ -180,13 +257,15 @@ func (d *Dataset) connectToDB(path string) error {
180257
return fmt.Errorf("failed to open database: %w", err)
181258
}
182259

183-
var count int
184-
err = d.db.QueryRow("SELECT COUNT(generated) FROM llmd;").Scan(&count)
260+
err = d.verifyDB()
261+
185262
if err != nil {
186-
err := d.db.Close()
187-
if err != nil {
188-
d.logger.Error(err, "failed to close database after query failure")
189-
}
263+
return fmt.Errorf("failed to verify database: %w", err)
264+
}
265+
266+
count, err := d.getRecordsCount()
267+
if err != nil {
268+
d.logger.Error(err, "failed to get records count")
190269
return fmt.Errorf("failed to query database: %w", err)
191270
}
192271
d.logger.Info("Database connected successfully", "path", path, "records count", count)

pkg/llm-d-inference-sim/dataset_test.go

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package llmdinferencesim
1818

1919
import (
20+
"fmt"
2021
"os"
2122

2223
"github.com/go-logr/logr"
@@ -28,19 +29,31 @@ import (
2829

2930
var _ = Describe("Dataset", func() {
3031
var (
31-
dataset *Dataset
32-
file_folder string
33-
savePath string
32+
dataset *Dataset
33+
file_folder string
34+
savePath string
35+
validDBPath string
36+
pathToInvalidDB string
37+
pathNotExist string
38+
pathToInvalidTableDB string
39+
pathToInvalidColumnDB string
40+
pathToInvalidTypeDB string
3441
)
3542

3643
BeforeEach(func() {
3744
dataset = &Dataset{
3845
logger: logr.Discard(),
3946
}
40-
file_folder = "./.llm-d"
47+
file_folder = ".llm-d"
4148
savePath = file_folder + "/test.sqlite3"
4249
err := os.MkdirAll(file_folder, os.ModePerm)
4350
Expect(err).NotTo(HaveOccurred())
51+
validDBPath = file_folder + "/test.valid.sqlite3"
52+
pathNotExist = file_folder + "/test.notexist.sqlite3"
53+
pathToInvalidDB = file_folder + "/test.invalid.sqlite3"
54+
pathToInvalidTableDB = file_folder + "/test.invalid.table.sqlite3"
55+
pathToInvalidColumnDB = file_folder + "/test.invalid.column.sqlite3"
56+
pathToInvalidTypeDB = file_folder + "/test.invalid.type.sqlite3"
4457
})
4558

4659
AfterEach(func() {
@@ -72,24 +85,47 @@ var _ = Describe("Dataset", func() {
7285
})
7386

7487
It("should successfully init dataset", func() {
75-
validDBPath := file_folder + "/test.valid.sqlite3"
7688
err := dataset.Init(validDBPath, "", "")
89+
// debug: get the realpath
90+
wd, _ := os.Getwd()
91+
realpath := fmt.Sprintf("%s/%s", wd, validDBPath)
92+
fmt.Println("Using realpath:", realpath)
7793
Expect(err).NotTo(HaveOccurred())
7894

79-
row := dataset.db.QueryRow("SELECT generated FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2';")
80-
var value string
81-
err = row.Scan(&value)
95+
row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'b94d27b9934d041c52e5b721d7373f13a07ed5e79179d63c5d8a0c102a9d00b2';")
96+
var n_gen_tokens int
97+
err = row.Scan(&n_gen_tokens)
8298
Expect(err).NotTo(HaveOccurred())
83-
Expect(value).To(Equal("world!"))
99+
Expect(n_gen_tokens).To(Equal(3))
84100
})
85101

86-
It("should raise err with invalid DB content", func() {
87-
err := dataset.connectToDB(file_folder)
102+
It("should return error for non-existing DB path", func() {
103+
err := dataset.connectToDB(pathNotExist)
88104
Expect(err).To(HaveOccurred())
89-
// read from the db to verify it's not valid
90-
row := dataset.db.QueryRow("SELECT * FROM llmd;")
91-
var value string
92-
err = row.Scan(&value)
105+
Expect(err.Error()).To(ContainSubstring("database file does not exist"))
106+
})
107+
108+
It("should return error for invalid DB file", func() {
109+
err := dataset.connectToDB(pathToInvalidDB)
110+
Expect(err).To(HaveOccurred())
111+
Expect(err.Error()).To(ContainSubstring("file is not a database"))
112+
})
113+
114+
It("should return error for DB with invalid table", func() {
115+
err := dataset.connectToDB(pathToInvalidTableDB)
116+
Expect(err).To(HaveOccurred())
117+
Expect(err.Error()).To(ContainSubstring("failed to verify database"))
118+
})
119+
120+
It("should return error for DB with invalid column", func() {
121+
err := dataset.connectToDB(pathToInvalidColumnDB)
122+
Expect(err).To(HaveOccurred())
123+
Expect(err.Error()).To(ContainSubstring("missing expected column"))
124+
})
125+
126+
It("should return error for DB with invalid column type", func() {
127+
err := dataset.connectToDB(pathToInvalidTypeDB)
93128
Expect(err).To(HaveOccurred())
129+
Expect(err.Error()).To(ContainSubstring("incorrect type"))
94130
})
95131
})

pkg/llm-d-inference-sim/simulator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ func (s *VllmSimulator) startSim(ctx context.Context) error {
221221
s.logger.Info("No dataset provided, will generate random responses")
222222
} else {
223223
dataset := &Dataset{
224-
logger: s.logger,
224+
logger: s.logger,
225225
}
226226
err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath)
227227
if err != nil {

0 commit comments

Comments
 (0)