From be1de2f11546dcb38d60cf6f91bb74f67a6527dd Mon Sep 17 00:00:00 2001 From: Asmista <98347402+artik0din@users.noreply.github.com> Date: Tue, 11 Nov 2025 16:37:15 +0400 Subject: [PATCH] feat: add resumable download support for nexa pull Implements resumable download feature that allows interrupted model downloads to be resumed from where they stopped, saving bandwidth and time for users downloading large models. Features: - Automatic progress tracking with JSON metadata file - Resume from exact byte position after interruption - Smart range merging to handle overlapping downloads - Periodic progress saves (every 5 seconds) - Automatic cleanup on successful completion - Support for multiple concurrent file downloads Implementation: - Added DownloadProgress types for tracking download state - Modified StartDownload to check for existing progress - Only downloads missing byte ranges when resuming - Thread-safe progress updates with mutex protection - Comprehensive unit tests for all scenarios Technical details: - Progress stored in .download_progress.json - Uses HTTP Range headers for partial downloads - Merges adjacent/overlapping byte ranges automatically - Validates completion before cleanup Testing: - Unit tests for progress tracking - Range merging validation - Multi-file download scenarios - Load/save/cleanup operations Fixes #803 --- runner/internal/model_hub/model_hub.go | 140 +++++++++++++-- runner/internal/types/download.go | 204 +++++++++++++++++++++ runner/internal/types/download_test.go | 239 +++++++++++++++++++++++++ 3 files changed, 568 insertions(+), 15 deletions(-) create mode 100644 runner/internal/types/download.go create mode 100644 runner/internal/types/download_test.go diff --git a/runner/internal/model_hub/model_hub.go b/runner/internal/model_hub/model_hub.go index 855d1c85..cabab647 100644 --- a/runner/internal/model_hub/model_hub.go +++ b/runner/internal/model_hub/model_hub.go @@ -11,6 +11,7 @@ import ( "reflect" "sync" "sync/atomic" + "time" "github.com/NexaAI/nexa-sdk/runner/internal/types" "github.com/bytedance/sonic" @@ -143,16 +144,38 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo defer close(errCh) defer close(resCh) - var downloaded int64 var totalSize int64 for _, f := range files { totalSize += f.Size } + // Load existing download progress if available + progress, err := types.LoadDownloadProgress(outputPath) + if err != nil { + slog.Warn("Failed to load download progress, starting fresh", "error", err) + progress = nil + } + + // Initialize progress if not exists + if progress == nil { + progress = types.NewDownloadProgress(modelName, totalSize) + for _, f := range files { + progress.AddFile(f.Name, f.Size) + } + slog.Info("Starting fresh download", "total_size", totalSize) + } else { + slog.Info("Resuming download", "already_downloaded", progress.Downloaded, "total_size", totalSize) + } + + var downloaded int64 = progress.Downloaded + // create tasks tasks := make(chan downloadTask, maxConcurrency) nctx, cancel := context.WithCancel(ctx) + // Mutex for progress updates + var progressMutex sync.Mutex + var wg1 sync.WaitGroup wg1.Add(1) go func() { @@ -165,26 +188,67 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo return } - // create task - task := downloadTask{ - OutputPath: outputPath, - ModelName: modelName, - FileName: f.Name, + // Get file progress + fileProgress := progress.Files[f.Name] + if fileProgress == nil { + fileProgress = &types.FileProgress{ + FileName: f.Name, + TotalSize: f.Size, + CompletedRange: make([]types.CompletedRange, 0), + } + progress.Files[f.Name] = fileProgress + } + + // Skip if file already complete + if fileProgress.IsComplete() { + slog.Info("File already downloaded, skipping", "file", f.Name) + continue } - // enqueue tasks + // Get missing ranges chunkSize := max(minChunkSize, f.Size/128) - slog.Info("Download file", "name", f.Name, "size", f.Size, "chunkSize", chunkSize) - for task.Offset = 0; task.Offset < f.Size; task.Offset += chunkSize { - task.Limit = min(chunkSize, f.Size-task.Offset) + missingRanges := fileProgress.GetMissingRanges(chunkSize) + + slog.Info("Download file", "name", f.Name, "size", f.Size, "chunkSize", chunkSize, "missingRanges", len(missingRanges)) + + // Create tasks for missing ranges only + for _, missingRange := range missingRanges { + for offset := missingRange.Start; offset < missingRange.End; offset += chunkSize { + task := downloadTask{ + OutputPath: outputPath, + ModelName: modelName, + FileName: f.Name, + Offset: offset, + Limit: min(chunkSize, missingRange.End-offset), + } + + // send chunk + select { + case tasks <- task: + case <-nctx.Done(): + slog.Warn("download canceled", "error", nctx.Err()) + return + } + } + } + } + }() - // send chunk + // Periodic progress saver + progressTicker := time.NewTicker(5 * time.Second) + defer progressTicker.Stop() + + saveChan := make(chan struct{}, 1) + go func() { + for { + select { + case <-progressTicker.C: select { - case tasks <- task: - case <-nctx.Done(): - slog.Warn("download canceled", "error", nctx.Err()) - return + case saveChan <- struct{}{}: + default: } + case <-nctx.Done(): + return } } }() @@ -205,10 +269,30 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo return } + // Update progress + progressMutex.Lock() + fileProgress := progress.Files[task.FileName] + fileProgress.MarkRangeComplete(task.Offset, task.Offset+task.Limit) + progress.Downloaded = calculateTotalDownloaded(progress) + progressMutex.Unlock() + resCh <- types.DownloadInfo{ TotalDownloaded: atomic.AddInt64(&downloaded, task.Limit), TotalSize: totalSize, } + + // Try to save progress periodically + select { + case <-saveChan: + progressMutex.Lock() + if err := progress.Save(outputPath); err != nil { + slog.Warn("Failed to save progress", "error", err) + } else { + slog.Debug("Progress saved", "downloaded", progress.Downloaded) + } + progressMutex.Unlock() + default: + } } }() } @@ -216,12 +300,38 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo wg1.Wait() close(tasks) wg2.Wait() + + // Final progress save + progressMutex.Lock() + if err := progress.Save(outputPath); err != nil { + slog.Warn("Failed to save final progress", "error", err) + } + progressMutex.Unlock() + + // Cleanup progress file on successful completion + if progress.Downloaded >= totalSize { + if err := types.CleanupProgress(outputPath); err != nil { + slog.Warn("Failed to cleanup progress file", "error", err) + } else { + slog.Info("Download completed, progress file cleaned up") + } + } + cancel() }() return resCh, errCh } +// calculateTotalDownloaded sums up downloaded bytes across all files +func calculateTotalDownloaded(progress *types.DownloadProgress) int64 { + var total int64 + for _, fp := range progress.Files { + total += fp.Downloaded + } + return total +} + func getHub(ctx context.Context, modelName string) (ModelHub, error) { for _, h := range hubs { if err := h.CheckAvailable(ctx, modelName); err != nil { diff --git a/runner/internal/types/download.go b/runner/internal/types/download.go new file mode 100644 index 00000000..972eddab --- /dev/null +++ b/runner/internal/types/download.go @@ -0,0 +1,204 @@ +package types + +import ( + "encoding/json" + "os" + "path/filepath" + "time" +) + +// DownloadProgress tracks the download state for resumable downloads +type DownloadProgress struct { + ModelName string `json:"model_name"` + TotalSize int64 `json:"total_size"` + Downloaded int64 `json:"downloaded"` + Files map[string]*FileProgress `json:"files"` + LastModified time.Time `json:"last_modified"` + Version int `json:"version"` // For future compatibility +} + +// FileProgress tracks individual file download progress +type FileProgress struct { + FileName string `json:"file_name"` + TotalSize int64 `json:"total_size"` + Downloaded int64 `json:"downloaded"` + CompletedRange []CompletedRange `json:"completed_ranges"` + SHA256 string `json:"sha256,omitempty"` // Optional validation +} + +// CompletedRange represents a downloaded byte range +type CompletedRange struct { + Start int64 `json:"start"` + End int64 `json:"end"` // Exclusive +} + +// NewDownloadProgress creates a new download progress tracker +func NewDownloadProgress(modelName string, totalSize int64) *DownloadProgress { + return &DownloadProgress{ + ModelName: modelName, + TotalSize: totalSize, + Files: make(map[string]*FileProgress), + LastModified: time.Now(), + Version: 1, + } +} + +// AddFile adds a file to track +func (dp *DownloadProgress) AddFile(fileName string, size int64) { + dp.Files[fileName] = &FileProgress{ + FileName: fileName, + TotalSize: size, + CompletedRange: make([]CompletedRange, 0), + } +} + +// MarkRangeComplete marks a byte range as downloaded and merges adjacent ranges +func (fp *FileProgress) MarkRangeComplete(start, end int64) { + // Add the new range + newRange := CompletedRange{Start: start, End: end} + fp.CompletedRange = append(fp.CompletedRange, newRange) + + // Merge adjacent/overlapping ranges + fp.mergeRanges() + + // Update downloaded count + fp.Downloaded = fp.calculateDownloaded() +} + +// mergeRanges merges overlapping or adjacent ranges +func (fp *FileProgress) mergeRanges() { + if len(fp.CompletedRange) <= 1 { + return + } + + // Sort ranges by start position + for i := 0; i < len(fp.CompletedRange)-1; i++ { + for j := i + 1; j < len(fp.CompletedRange); j++ { + if fp.CompletedRange[i].Start > fp.CompletedRange[j].Start { + fp.CompletedRange[i], fp.CompletedRange[j] = fp.CompletedRange[j], fp.CompletedRange[i] + } + } + } + + // Merge overlapping ranges + merged := make([]CompletedRange, 0, len(fp.CompletedRange)) + current := fp.CompletedRange[0] + + for i := 1; i < len(fp.CompletedRange); i++ { + next := fp.CompletedRange[i] + if current.End >= next.Start { + // Merge overlapping ranges + current.End = max(current.End, next.End) + } else { + merged = append(merged, current) + current = next + } + } + merged = append(merged, current) + fp.CompletedRange = merged +} + +// calculateDownloaded calculates total downloaded bytes from ranges +func (fp *FileProgress) calculateDownloaded() int64 { + var total int64 + for _, r := range fp.CompletedRange { + total += r.End - r.Start + } + return total +} + +// IsComplete checks if file is fully downloaded +func (fp *FileProgress) IsComplete() bool { + return fp.Downloaded >= fp.TotalSize +} + +// GetMissingRanges returns ranges that still need to be downloaded +func (fp *FileProgress) GetMissingRanges(chunkSize int64) []CompletedRange { + if len(fp.CompletedRange) == 0 { + // Nothing downloaded yet + return []CompletedRange{{Start: 0, End: fp.TotalSize}} + } + + missing := make([]CompletedRange, 0) + + // Check gap before first range + if fp.CompletedRange[0].Start > 0 { + missing = append(missing, CompletedRange{ + Start: 0, + End: fp.CompletedRange[0].Start, + }) + } + + // Check gaps between ranges + for i := 0; i < len(fp.CompletedRange)-1; i++ { + gap := fp.CompletedRange[i+1].Start - fp.CompletedRange[i].End + if gap > 0 { + missing = append(missing, CompletedRange{ + Start: fp.CompletedRange[i].End, + End: fp.CompletedRange[i+1].Start, + }) + } + } + + // Check gap after last range + lastEnd := fp.CompletedRange[len(fp.CompletedRange)-1].End + if lastEnd < fp.TotalSize { + missing = append(missing, CompletedRange{ + Start: lastEnd, + End: fp.TotalSize, + }) + } + + return missing +} + +// Save saves progress to disk +func (dp *DownloadProgress) Save(outputPath string) error { + dp.LastModified = time.Now() + + progressFile := filepath.Join(outputPath, ".download_progress.json") + data, err := json.MarshalIndent(dp, "", " ") + if err != nil { + return err + } + + return os.WriteFile(progressFile, data, 0644) +} + +// LoadDownloadProgress loads progress from disk +func LoadDownloadProgress(outputPath string) (*DownloadProgress, error) { + progressFile := filepath.Join(outputPath, ".download_progress.json") + + data, err := os.ReadFile(progressFile) + if err != nil { + if os.IsNotExist(err) { + return nil, nil // No previous progress + } + return nil, err + } + + var progress DownloadProgress + if err := json.Unmarshal(data, &progress); err != nil { + return nil, err + } + + return &progress, nil +} + +// CleanupProgress removes progress file after successful download +func CleanupProgress(outputPath string) error { + progressFile := filepath.Join(outputPath, ".download_progress.json") + err := os.Remove(progressFile) + if err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +func max(a, b int64) int64 { + if a > b { + return a + } + return b +} + diff --git a/runner/internal/types/download_test.go b/runner/internal/types/download_test.go new file mode 100644 index 00000000..ddb41d01 --- /dev/null +++ b/runner/internal/types/download_test.go @@ -0,0 +1,239 @@ +package types + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDownloadProgress_NewAndSave(t *testing.T) { + tmpDir := t.TempDir() + + progress := NewDownloadProgress("test-model", 1000000) + progress.AddFile("model.gguf", 1000000) + + if err := progress.Save(tmpDir); err != nil { + t.Fatalf("Failed to save progress: %v", err) + } + + // Verify file exists + progressFile := filepath.Join(tmpDir, ".download_progress.json") + if _, err := os.Stat(progressFile); os.IsNotExist(err) { + t.Fatal("Progress file was not created") + } +} + +func TestDownloadProgress_LoadAndResume(t *testing.T) { + tmpDir := t.TempDir() + + // Create and save progress + progress := NewDownloadProgress("test-model", 1000000) + progress.AddFile("model.gguf", 1000000) + progress.Files["model.gguf"].MarkRangeComplete(0, 500000) + + if err := progress.Save(tmpDir); err != nil { + t.Fatalf("Failed to save progress: %v", err) + } + + // Load progress + loaded, err := LoadDownloadProgress(tmpDir) + if err != nil { + t.Fatalf("Failed to load progress: %v", err) + } + + if loaded == nil { + t.Fatal("Loaded progress is nil") + } + + if loaded.ModelName != "test-model" { + t.Errorf("Expected model name 'test-model', got '%s'", loaded.ModelName) + } + + if loaded.Files["model.gguf"].Downloaded != 500000 { + t.Errorf("Expected 500000 bytes downloaded, got %d", loaded.Files["model.gguf"].Downloaded) + } +} + +func TestFileProgress_MarkRangeComplete(t *testing.T) { + fp := &FileProgress{ + FileName: "test.bin", + TotalSize: 1000, + CompletedRange: make([]CompletedRange, 0), + } + + // Mark first range + fp.MarkRangeComplete(0, 100) + if fp.Downloaded != 100 { + t.Errorf("Expected 100 bytes downloaded, got %d", fp.Downloaded) + } + + // Mark second range + fp.MarkRangeComplete(200, 300) + if fp.Downloaded != 200 { + t.Errorf("Expected 200 bytes downloaded, got %d", fp.Downloaded) + } + + // Mark overlapping range (should merge) + fp.MarkRangeComplete(100, 200) + if fp.Downloaded != 300 { + t.Errorf("Expected 300 bytes downloaded after merge, got %d", fp.Downloaded) + } + + // Should have merged into single range + if len(fp.CompletedRange) != 1 { + t.Errorf("Expected 1 range after merge, got %d", len(fp.CompletedRange)) + } + + if fp.CompletedRange[0].Start != 0 || fp.CompletedRange[0].End != 300 { + t.Errorf("Expected range 0-300, got %d-%d", fp.CompletedRange[0].Start, fp.CompletedRange[0].End) + } +} + +func TestFileProgress_GetMissingRanges(t *testing.T) { + fp := &FileProgress{ + FileName: "test.bin", + TotalSize: 1000, + CompletedRange: make([]CompletedRange, 0), + } + + // No downloads yet + missing := fp.GetMissingRanges(100) + if len(missing) != 1 { + t.Fatalf("Expected 1 missing range, got %d", len(missing)) + } + if missing[0].Start != 0 || missing[0].End != 1000 { + t.Errorf("Expected missing range 0-1000, got %d-%d", missing[0].Start, missing[0].End) + } + + // Download middle section + fp.MarkRangeComplete(400, 600) + missing = fp.GetMissingRanges(100) + + if len(missing) != 2 { + t.Fatalf("Expected 2 missing ranges, got %d", len(missing)) + } + + // First gap: 0-400 + if missing[0].Start != 0 || missing[0].End != 400 { + t.Errorf("Expected first gap 0-400, got %d-%d", missing[0].Start, missing[0].End) + } + + // Second gap: 600-1000 + if missing[1].Start != 600 || missing[1].End != 1000 { + t.Errorf("Expected second gap 600-1000, got %d-%d", missing[1].Start, missing[1].End) + } +} + +func TestFileProgress_IsComplete(t *testing.T) { + fp := &FileProgress{ + FileName: "test.bin", + TotalSize: 1000, + CompletedRange: make([]CompletedRange, 0), + } + + if fp.IsComplete() { + t.Error("File should not be complete initially") + } + + // Download entire file + fp.MarkRangeComplete(0, 1000) + + if !fp.IsComplete() { + t.Error("File should be complete after downloading entire range") + } + + if fp.Downloaded != 1000 { + t.Errorf("Expected 1000 bytes downloaded, got %d", fp.Downloaded) + } +} + +func TestFileProgress_MergeAdjacentRanges(t *testing.T) { + fp := &FileProgress{ + FileName: "test.bin", + TotalSize: 1000, + CompletedRange: make([]CompletedRange, 0), + } + + // Download in reverse order + fp.MarkRangeComplete(600, 800) + fp.MarkRangeComplete(400, 600) + fp.MarkRangeComplete(200, 400) + fp.MarkRangeComplete(0, 200) + + // All ranges should be merged into one + if len(fp.CompletedRange) != 1 { + t.Errorf("Expected 1 merged range, got %d", len(fp.CompletedRange)) + } + + if fp.CompletedRange[0].Start != 0 || fp.CompletedRange[0].End != 800 { + t.Errorf("Expected merged range 0-800, got %d-%d", fp.CompletedRange[0].Start, fp.CompletedRange[0].End) + } + + if fp.Downloaded != 800 { + t.Errorf("Expected 800 bytes downloaded, got %d", fp.Downloaded) + } +} + +func TestCleanupProgress(t *testing.T) { + tmpDir := t.TempDir() + + // Create progress file + progress := NewDownloadProgress("test-model", 1000000) + if err := progress.Save(tmpDir); err != nil { + t.Fatalf("Failed to save progress: %v", err) + } + + // Verify it exists + progressFile := filepath.Join(tmpDir, ".download_progress.json") + if _, err := os.Stat(progressFile); os.IsNotExist(err) { + t.Fatal("Progress file should exist before cleanup") + } + + // Cleanup + if err := CleanupProgress(tmpDir); err != nil { + t.Fatalf("Failed to cleanup progress: %v", err) + } + + // Verify it's gone + if _, err := os.Stat(progressFile); !os.IsNotExist(err) { + t.Error("Progress file should be deleted after cleanup") + } +} + +func TestDownloadProgress_MultipleFiles(t *testing.T) { + tmpDir := t.TempDir() + + progress := NewDownloadProgress("test-model", 3000000) + progress.AddFile("model1.gguf", 1000000) + progress.AddFile("model2.gguf", 1000000) + progress.AddFile("model3.gguf", 1000000) + + // Download parts of each file + progress.Files["model1.gguf"].MarkRangeComplete(0, 500000) + progress.Files["model2.gguf"].MarkRangeComplete(0, 750000) + progress.Files["model3.gguf"].MarkRangeComplete(0, 1000000) // Complete + + // Save and reload + if err := progress.Save(tmpDir); err != nil { + t.Fatalf("Failed to save progress: %v", err) + } + + loaded, err := LoadDownloadProgress(tmpDir) + if err != nil { + t.Fatalf("Failed to load progress: %v", err) + } + + // Verify all files + if loaded.Files["model1.gguf"].Downloaded != 500000 { + t.Errorf("model1 expected 500000, got %d", loaded.Files["model1.gguf"].Downloaded) + } + + if loaded.Files["model2.gguf"].Downloaded != 750000 { + t.Errorf("model2 expected 750000, got %d", loaded.Files["model2.gguf"].Downloaded) + } + + if !loaded.Files["model3.gguf"].IsComplete() { + t.Error("model3 should be complete") + } +} +