@@ -27,8 +27,10 @@ import (
2727 "os/signal"
2828 "path/filepath"
2929 "syscall"
30+ "time"
3031
3132 "github.com/go-logr/logr"
33+ _ "github.com/mattn/go-sqlite3"
3234)
3335
3436type Dataset struct {
@@ -79,10 +81,12 @@ func (d *Dataset) downloadDataset(url string, savePath string) error {
7981
8082 // Progress reader with context
8183 pr := & progressReader {
82- Reader : resp .Body ,
83- total : resp .ContentLength ,
84- logger : d .logger ,
85- ctx : ctx ,
84+ Reader : resp .Body ,
85+ total : resp .ContentLength ,
86+ logger : d .logger ,
87+ ctx : ctx ,
88+ startTime : time .Now (),
89+ hasShownSpeed : false ,
8690 }
8791
8892 written , err := io .Copy (out , pr )
@@ -124,9 +128,11 @@ type progressReader struct {
124128 io.Reader
125129 total int64
126130 downloaded int64
131+ startTime time.Time
127132 lastPct int
128133 logger logr.Logger
129134 ctx context.Context
135+ hasShownSpeed bool
130136}
131137
132138func (pr * progressReader ) Read (p []byte ) (int , error ) {
@@ -139,13 +145,30 @@ func (pr *progressReader) Read(p []byte) (int, error) {
139145 pr .downloaded += int64 (n )
140146 if pr .total > 0 {
141147 pct := int (float64 (pr .downloaded ) * 100 / float64 (pr .total ))
142- if pct != pr .lastPct && pct % 10 == 0 { // log every 10%
143- pr .logger .Info (fmt .Sprintf ("Download progress: %d%%" , pct ))
148+ if ! pr .hasShownSpeed && time .Since (pr .startTime ).Seconds () > 2 {
149+ pr .hasShownSpeed = true
150+ pr .logProgress (pct )
151+ pr .lastPct = pct
152+ }
153+ if pct != pr .lastPct && pct % 10 == 0 {
154+ pr .logProgress (pct )
144155 pr .lastPct = pct
145156 }
146157 }
147158 return n , err
148159}
160+
161+ func (pr * progressReader ) logProgress (pct int ) {
162+ elapsedTime := time .Since (pr .startTime ).Seconds ()
163+ speed := float64 (pr .downloaded ) / (1024 * 1024 * elapsedTime )
164+ remainingTime := float64 (pr .total - pr .downloaded ) / (float64 (pr .downloaded ) / elapsedTime )
165+ if pct != 100 {
166+ pr .logger .Info (fmt .Sprintf ("Download progress: %d%%, Speed: %.2f MB/s, Remaining time: %.2fs" , pct , speed , remainingTime ))
167+ } else {
168+ pr .logger .Info (fmt .Sprintf ("Download completed: 100%%, Average Speed: %.2f MB/s, Total time: %.2fs" , speed , elapsedTime ))
169+ }
170+ }
171+
149172func (d * Dataset ) connectToDB (path string ) error {
150173 // check if file exists
151174 _ , err := os .Stat (path )
@@ -156,6 +179,8 @@ func (d *Dataset) connectToDB(path string) error {
156179 if err != nil {
157180 return fmt .Errorf ("failed to open database: %w" , err )
158181 }
182+ // Test the connection
183+
159184 return nil
160185}
161186
@@ -165,7 +190,11 @@ func (d *Dataset) Init(path string, url string, savePath string) error {
165190 }
166191 if url != "" {
167192 if savePath == "" {
168- savePath = "~/.llmd/dataset.sqlite3"
193+ user , err := os .UserHomeDir ()
194+ if err != nil {
195+ return fmt .Errorf ("failed to get user home directory: %w" , err )
196+ }
197+ savePath = filepath .Join (user , ".llm-d" , "dataset.sqlite3" )
169198 }
170199
171200 _ , err := os .Stat (savePath )
0 commit comments