Skip to content

Commit e4bf3b3

Browse files
committed
Init dataset when sim starts and show downloading speed of url
Signed-off-by: Qifan Deng <[email protected]>
1 parent 321f405 commit e4bf3b3

File tree

3 files changed

+48
-8
lines changed

3 files changed

+48
-8
lines changed

pkg/common/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ type Dataset struct {
186186
// Url is the URL to download the sqlite db file if set, default is empty
187187
Url string `yaml:"url" json:"url"`
188188
// SavePath is the local path to save the downloaded sqlite db file
189-
// if Url is set but SavePath is not, "~/.llmd/dataset.db" will be used
189+
// if Url is set but SavePath is not, "USER_HOME/.llm-d/dataset.db" will be used
190190
SavePath string `yaml:"save-path" json:"save-path"`
191191
}
192192

pkg/llm-d-inference-sim/dataset.go

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ import (
2727
"os/signal"
2828
"path/filepath"
2929
"syscall"
30+
"time"
3031

3132
"github.com/go-logr/logr"
33+
_ "github.com/mattn/go-sqlite3"
3234
)
3335

3436
type Dataset struct {
@@ -79,10 +81,12 @@ func (d *Dataset) downloadDataset(url string, savePath string) error {
7981

8082
// Progress reader with context
8183
pr := &progressReader{
82-
Reader: resp.Body,
83-
total: resp.ContentLength,
84-
logger: d.logger,
85-
ctx: ctx,
84+
Reader: resp.Body,
85+
total: resp.ContentLength,
86+
logger: d.logger,
87+
ctx: ctx,
88+
startTime: time.Now(),
89+
hasShownSpeed: false,
8690
}
8791

8892
written, err := io.Copy(out, pr)
@@ -124,9 +128,11 @@ type progressReader struct {
124128
io.Reader
125129
total int64
126130
downloaded int64
131+
startTime time.Time
127132
lastPct int
128133
logger logr.Logger
129134
ctx context.Context
135+
hasShownSpeed bool
130136
}
131137

132138
func (pr *progressReader) Read(p []byte) (int, error) {
@@ -139,13 +145,30 @@ func (pr *progressReader) Read(p []byte) (int, error) {
139145
pr.downloaded += int64(n)
140146
if pr.total > 0 {
141147
pct := int(float64(pr.downloaded) * 100 / float64(pr.total))
142-
if pct != pr.lastPct && pct%10 == 0 { // log every 10%
143-
pr.logger.Info(fmt.Sprintf("Download progress: %d%%", pct))
148+
if !pr.hasShownSpeed && time.Since(pr.startTime).Seconds() > 2 {
149+
pr.hasShownSpeed = true
150+
pr.logProgress(pct)
151+
pr.lastPct = pct
152+
}
153+
if pct != pr.lastPct && pct%10 == 0 {
154+
pr.logProgress(pct)
144155
pr.lastPct = pct
145156
}
146157
}
147158
return n, err
148159
}
160+
161+
func (pr *progressReader) logProgress(pct int) {
162+
elapsedTime := time.Since(pr.startTime).Seconds()
163+
speed := float64(pr.downloaded) / (1024 * 1024 * elapsedTime)
164+
remainingTime := float64(pr.total-pr.downloaded) / (float64(pr.downloaded) / elapsedTime)
165+
if pct != 100 {
166+
pr.logger.Info(fmt.Sprintf("Download progress: %d%%, Speed: %.2f MB/s, Remaining time: %.2fs", pct, speed, remainingTime))
167+
} else {
168+
pr.logger.Info(fmt.Sprintf("Download completed: 100%%, Average Speed: %.2f MB/s, Total time: %.2fs", speed, elapsedTime))
169+
}
170+
}
171+
149172
func (d *Dataset) connectToDB(path string) error {
150173
// check if file exists
151174
_, err := os.Stat(path)
@@ -156,6 +179,8 @@ func (d *Dataset) connectToDB(path string) error {
156179
if err != nil {
157180
return fmt.Errorf("failed to open database: %w", err)
158181
}
182+
// Test the connection
183+
159184
return nil
160185
}
161186

@@ -165,7 +190,11 @@ func (d *Dataset) Init(path string, url string, savePath string) error {
165190
}
166191
if url != "" {
167192
if savePath == "" {
168-
savePath = "~/.llmd/dataset.sqlite3"
193+
user, err := os.UserHomeDir()
194+
if err != nil {
195+
return fmt.Errorf("failed to get user home directory: %w", err)
196+
}
197+
savePath = filepath.Join(user, ".llm-d", "dataset.sqlite3")
169198
}
170199

171200
_, err := os.Stat(savePath)

pkg/llm-d-inference-sim/simulator.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ type VllmSimulator struct {
116116
pod string
117117
// tokenizer is currently used in kv-cache and in /tokenize
118118
tokenizer tokenization.Tokenizer
119+
// dataset is used for managing dataset files
120+
dataset *Dataset
119121
}
120122

121123
// New creates a new VllmSimulator instance with the given logger
@@ -153,6 +155,15 @@ func (s *VllmSimulator) Start(ctx context.Context) error {
153155
return err
154156
}
155157

158+
dataset := &Dataset{
159+
logger: s.logger,
160+
}
161+
err = dataset.Init(s.config.Dataset.Path, s.config.Dataset.Url, s.config.Dataset.SavePath)
162+
if err != nil {
163+
return err
164+
}
165+
s.dataset = dataset
166+
156167
// For Data Parallel, start data-parallel-size - 1 additional simulators
157168
g, ctx := errgroup.WithContext(ctx)
158169
if s.config.DPSize > 1 {

0 commit comments

Comments
 (0)