Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 125 additions & 15 deletions runner/internal/model_hub/model_hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"reflect"
"sync"
"sync/atomic"
"time"

"github.com/NexaAI/nexa-sdk/runner/internal/types"
"github.com/bytedance/sonic"
Expand Down Expand Up @@ -143,16 +144,38 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo
defer close(errCh)
defer close(resCh)

var downloaded int64
var totalSize int64
for _, f := range files {
totalSize += f.Size
}

// Load existing download progress if available
progress, err := types.LoadDownloadProgress(outputPath)
if err != nil {
slog.Warn("Failed to load download progress, starting fresh", "error", err)
progress = nil
}

// Initialize progress if not exists
if progress == nil {
progress = types.NewDownloadProgress(modelName, totalSize)
for _, f := range files {
progress.AddFile(f.Name, f.Size)
}
slog.Info("Starting fresh download", "total_size", totalSize)
} else {
slog.Info("Resuming download", "already_downloaded", progress.Downloaded, "total_size", totalSize)
}

var downloaded int64 = progress.Downloaded

// create tasks
tasks := make(chan downloadTask, maxConcurrency)
nctx, cancel := context.WithCancel(ctx)

// Mutex for progress updates
var progressMutex sync.Mutex

var wg1 sync.WaitGroup
wg1.Add(1)
go func() {
Expand All @@ -165,26 +188,67 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo
return
}

// create task
task := downloadTask{
OutputPath: outputPath,
ModelName: modelName,
FileName: f.Name,
// Get file progress
fileProgress := progress.Files[f.Name]
if fileProgress == nil {
fileProgress = &types.FileProgress{
FileName: f.Name,
TotalSize: f.Size,
CompletedRange: make([]types.CompletedRange, 0),
}
progress.Files[f.Name] = fileProgress
}

// Skip if file already complete
if fileProgress.IsComplete() {
slog.Info("File already downloaded, skipping", "file", f.Name)
continue
}

// enqueue tasks
// Get missing ranges
chunkSize := max(minChunkSize, f.Size/128)
slog.Info("Download file", "name", f.Name, "size", f.Size, "chunkSize", chunkSize)
for task.Offset = 0; task.Offset < f.Size; task.Offset += chunkSize {
task.Limit = min(chunkSize, f.Size-task.Offset)
missingRanges := fileProgress.GetMissingRanges(chunkSize)

slog.Info("Download file", "name", f.Name, "size", f.Size, "chunkSize", chunkSize, "missingRanges", len(missingRanges))

// Create tasks for missing ranges only
for _, missingRange := range missingRanges {
for offset := missingRange.Start; offset < missingRange.End; offset += chunkSize {
task := downloadTask{
OutputPath: outputPath,
ModelName: modelName,
FileName: f.Name,
Offset: offset,
Limit: min(chunkSize, missingRange.End-offset),
}

// send chunk
select {
case tasks <- task:
case <-nctx.Done():
slog.Warn("download canceled", "error", nctx.Err())
return
}
}
}
}
}()

// send chunk
// Periodic progress saver
progressTicker := time.NewTicker(5 * time.Second)
defer progressTicker.Stop()

saveChan := make(chan struct{}, 1)
go func() {
for {
select {
case <-progressTicker.C:
select {
case tasks <- task:
case <-nctx.Done():
slog.Warn("download canceled", "error", nctx.Err())
return
case saveChan <- struct{}{}:
default:
}
case <-nctx.Done():
return
}
}
}()
Expand All @@ -205,23 +269,69 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo
return
}

// Update progress
progressMutex.Lock()
fileProgress := progress.Files[task.FileName]
fileProgress.MarkRangeComplete(task.Offset, task.Offset+task.Limit)
progress.Downloaded = calculateTotalDownloaded(progress)
progressMutex.Unlock()

resCh <- types.DownloadInfo{
TotalDownloaded: atomic.AddInt64(&downloaded, task.Limit),
TotalSize: totalSize,
}

// Try to save progress periodically
select {
case <-saveChan:
progressMutex.Lock()
if err := progress.Save(outputPath); err != nil {
slog.Warn("Failed to save progress", "error", err)
} else {
slog.Debug("Progress saved", "downloaded", progress.Downloaded)
}
progressMutex.Unlock()
default:
}
}
}()
}

wg1.Wait()
close(tasks)
wg2.Wait()

// Final progress save
progressMutex.Lock()
if err := progress.Save(outputPath); err != nil {
slog.Warn("Failed to save final progress", "error", err)
}
progressMutex.Unlock()

// Cleanup progress file on successful completion
if progress.Downloaded >= totalSize {
if err := types.CleanupProgress(outputPath); err != nil {
slog.Warn("Failed to cleanup progress file", "error", err)
} else {
slog.Info("Download completed, progress file cleaned up")
}
}

cancel()
}()

return resCh, errCh
}

// calculateTotalDownloaded sums up downloaded bytes across all files
func calculateTotalDownloaded(progress *types.DownloadProgress) int64 {
var total int64
for _, fp := range progress.Files {
total += fp.Downloaded
}
return total
}

func getHub(ctx context.Context, modelName string) (ModelHub, error) {
for _, h := range hubs {
if err := h.CheckAvailable(ctx, modelName); err != nil {
Expand Down
Loading