diff --git a/runner/internal/model_hub/model_hub.go b/runner/internal/model_hub/model_hub.go index 855d1c85..cabab647 100644 --- a/runner/internal/model_hub/model_hub.go +++ b/runner/internal/model_hub/model_hub.go @@ -11,6 +11,7 @@ import ( "reflect" "sync" "sync/atomic" + "time" "github.com/NexaAI/nexa-sdk/runner/internal/types" "github.com/bytedance/sonic" @@ -143,16 +144,38 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo defer close(errCh) defer close(resCh) - var downloaded int64 var totalSize int64 for _, f := range files { totalSize += f.Size } + // Load existing download progress if available + progress, err := types.LoadDownloadProgress(outputPath) + if err != nil { + slog.Warn("Failed to load download progress, starting fresh", "error", err) + progress = nil + } + + // Initialize progress if not exists + if progress == nil { + progress = types.NewDownloadProgress(modelName, totalSize) + for _, f := range files { + progress.AddFile(f.Name, f.Size) + } + slog.Info("Starting fresh download", "total_size", totalSize) + } else { + slog.Info("Resuming download", "already_downloaded", progress.Downloaded, "total_size", totalSize) + } + + var downloaded int64 = progress.Downloaded + // create tasks tasks := make(chan downloadTask, maxConcurrency) nctx, cancel := context.WithCancel(ctx) + // Mutex for progress updates + var progressMutex sync.Mutex + var wg1 sync.WaitGroup wg1.Add(1) go func() { @@ -165,26 +188,67 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo return } - // create task - task := downloadTask{ - OutputPath: outputPath, - ModelName: modelName, - FileName: f.Name, + // Get file progress + fileProgress := progress.Files[f.Name] + if fileProgress == nil { + fileProgress = &types.FileProgress{ + FileName: f.Name, + TotalSize: f.Size, + CompletedRange: make([]types.CompletedRange, 0), + } + progress.Files[f.Name] = fileProgress + } + + // Skip if file already complete + if fileProgress.IsComplete() { + slog.Info("File already downloaded, skipping", "file", f.Name) + continue } - // enqueue tasks + // Get missing ranges chunkSize := max(minChunkSize, f.Size/128) - slog.Info("Download file", "name", f.Name, "size", f.Size, "chunkSize", chunkSize) - for task.Offset = 0; task.Offset < f.Size; task.Offset += chunkSize { - task.Limit = min(chunkSize, f.Size-task.Offset) + missingRanges := fileProgress.GetMissingRanges(chunkSize) + + slog.Info("Download file", "name", f.Name, "size", f.Size, "chunkSize", chunkSize, "missingRanges", len(missingRanges)) + + // Create tasks for missing ranges only + for _, missingRange := range missingRanges { + for offset := missingRange.Start; offset < missingRange.End; offset += chunkSize { + task := downloadTask{ + OutputPath: outputPath, + ModelName: modelName, + FileName: f.Name, + Offset: offset, + Limit: min(chunkSize, missingRange.End-offset), + } + + // send chunk + select { + case tasks <- task: + case <-nctx.Done(): + slog.Warn("download canceled", "error", nctx.Err()) + return + } + } + } + } + }() - // send chunk + // Periodic progress saver + progressTicker := time.NewTicker(5 * time.Second) + defer progressTicker.Stop() + + saveChan := make(chan struct{}, 1) + go func() { + for { + select { + case <-progressTicker.C: select { - case tasks <- task: - case <-nctx.Done(): - slog.Warn("download canceled", "error", nctx.Err()) - return + case saveChan <- struct{}{}: + default: } + case <-nctx.Done(): + return } } }() @@ -205,10 +269,30 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo return } + // Update progress + progressMutex.Lock() + fileProgress := progress.Files[task.FileName] + fileProgress.MarkRangeComplete(task.Offset, task.Offset+task.Limit) + progress.Downloaded = calculateTotalDownloaded(progress) + progressMutex.Unlock() + resCh <- types.DownloadInfo{ TotalDownloaded: atomic.AddInt64(&downloaded, task.Limit), TotalSize: totalSize, } + + // Try to save progress periodically + select { + case <-saveChan: + progressMutex.Lock() + if err := progress.Save(outputPath); err != nil { + slog.Warn("Failed to save progress", "error", err) + } else { + slog.Debug("Progress saved", "downloaded", progress.Downloaded) + } + progressMutex.Unlock() + default: + } } }() } @@ -216,12 +300,38 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo wg1.Wait() close(tasks) wg2.Wait() + + // Final progress save + progressMutex.Lock() + if err := progress.Save(outputPath); err != nil { + slog.Warn("Failed to save final progress", "error", err) + } + progressMutex.Unlock() + + // Cleanup progress file on successful completion + if progress.Downloaded >= totalSize { + if err := types.CleanupProgress(outputPath); err != nil { + slog.Warn("Failed to cleanup progress file", "error", err) + } else { + slog.Info("Download completed, progress file cleaned up") + } + } + cancel() }() return resCh, errCh } +// calculateTotalDownloaded sums up downloaded bytes across all files +func calculateTotalDownloaded(progress *types.DownloadProgress) int64 { + var total int64 + for _, fp := range progress.Files { + total += fp.Downloaded + } + return total +} + func getHub(ctx context.Context, modelName string) (ModelHub, error) { for _, h := range hubs { if err := h.CheckAvailable(ctx, modelName); err != nil { diff --git a/runner/internal/types/download.go b/runner/internal/types/download.go new file mode 100644 index 00000000..972eddab --- /dev/null +++ b/runner/internal/types/download.go @@ -0,0 +1,204 @@ +package types + +import ( + "encoding/json" + "os" + "path/filepath" + "time" +) + +// DownloadProgress tracks the download state for resumable downloads +type DownloadProgress struct { + ModelName string `json:"model_name"` + TotalSize int64 `json:"total_size"` + Downloaded int64 `json:"downloaded"` + Files map[string]*FileProgress `json:"files"` + LastModified time.Time `json:"last_modified"` + Version int `json:"version"` // For future compatibility +} + +// FileProgress tracks individual file download progress +type FileProgress struct { + FileName string `json:"file_name"` + TotalSize int64 `json:"total_size"` + Downloaded int64 `json:"downloaded"` + CompletedRange []CompletedRange `json:"completed_ranges"` + SHA256 string `json:"sha256,omitempty"` // Optional validation +} + +// CompletedRange represents a downloaded byte range +type CompletedRange struct { + Start int64 `json:"start"` + End int64 `json:"end"` // Exclusive +} + +// NewDownloadProgress creates a new download progress tracker +func NewDownloadProgress(modelName string, totalSize int64) *DownloadProgress { + return &DownloadProgress{ + ModelName: modelName, + TotalSize: totalSize, + Files: make(map[string]*FileProgress), + LastModified: time.Now(), + Version: 1, + } +} + +// AddFile adds a file to track +func (dp *DownloadProgress) AddFile(fileName string, size int64) { + dp.Files[fileName] = &FileProgress{ + FileName: fileName, + TotalSize: size, + CompletedRange: make([]CompletedRange, 0), + } +} + +// MarkRangeComplete marks a byte range as downloaded and merges adjacent ranges +func (fp *FileProgress) MarkRangeComplete(start, end int64) { + // Add the new range + newRange := CompletedRange{Start: start, End: end} + fp.CompletedRange = append(fp.CompletedRange, newRange) + + // Merge adjacent/overlapping ranges + fp.mergeRanges() + + // Update downloaded count + fp.Downloaded = fp.calculateDownloaded() +} + +// mergeRanges merges overlapping or adjacent ranges +func (fp *FileProgress) mergeRanges() { + if len(fp.CompletedRange) <= 1 { + return + } + + // Sort ranges by start position + for i := 0; i < len(fp.CompletedRange)-1; i++ { + for j := i + 1; j < len(fp.CompletedRange); j++ { + if fp.CompletedRange[i].Start > fp.CompletedRange[j].Start { + fp.CompletedRange[i], fp.CompletedRange[j] = fp.CompletedRange[j], fp.CompletedRange[i] + } + } + } + + // Merge overlapping ranges + merged := make([]CompletedRange, 0, len(fp.CompletedRange)) + current := fp.CompletedRange[0] + + for i := 1; i < len(fp.CompletedRange); i++ { + next := fp.CompletedRange[i] + if current.End >= next.Start { + // Merge overlapping ranges + current.End = max(current.End, next.End) + } else { + merged = append(merged, current) + current = next + } + } + merged = append(merged, current) + fp.CompletedRange = merged +} + +// calculateDownloaded calculates total downloaded bytes from ranges +func (fp *FileProgress) calculateDownloaded() int64 { + var total int64 + for _, r := range fp.CompletedRange { + total += r.End - r.Start + } + return total +} + +// IsComplete checks if file is fully downloaded +func (fp *FileProgress) IsComplete() bool { + return fp.Downloaded >= fp.TotalSize +} + +// GetMissingRanges returns ranges that still need to be downloaded +func (fp *FileProgress) GetMissingRanges(chunkSize int64) []CompletedRange { + if len(fp.CompletedRange) == 0 { + // Nothing downloaded yet + return []CompletedRange{{Start: 0, End: fp.TotalSize}} + } + + missing := make([]CompletedRange, 0) + + // Check gap before first range + if fp.CompletedRange[0].Start > 0 { + missing = append(missing, CompletedRange{ + Start: 0, + End: fp.CompletedRange[0].Start, + }) + } + + // Check gaps between ranges + for i := 0; i < len(fp.CompletedRange)-1; i++ { + gap := fp.CompletedRange[i+1].Start - fp.CompletedRange[i].End + if gap > 0 { + missing = append(missing, CompletedRange{ + Start: fp.CompletedRange[i].End, + End: fp.CompletedRange[i+1].Start, + }) + } + } + + // Check gap after last range + lastEnd := fp.CompletedRange[len(fp.CompletedRange)-1].End + if lastEnd < fp.TotalSize { + missing = append(missing, CompletedRange{ + Start: lastEnd, + End: fp.TotalSize, + }) + } + + return missing +} + +// Save saves progress to disk +func (dp *DownloadProgress) Save(outputPath string) error { + dp.LastModified = time.Now() + + progressFile := filepath.Join(outputPath, ".download_progress.json") + data, err := json.MarshalIndent(dp, "", " ") + if err != nil { + return err + } + + return os.WriteFile(progressFile, data, 0644) +} + +// LoadDownloadProgress loads progress from disk +func LoadDownloadProgress(outputPath string) (*DownloadProgress, error) { + progressFile := filepath.Join(outputPath, ".download_progress.json") + + data, err := os.ReadFile(progressFile) + if err != nil { + if os.IsNotExist(err) { + return nil, nil // No previous progress + } + return nil, err + } + + var progress DownloadProgress + if err := json.Unmarshal(data, &progress); err != nil { + return nil, err + } + + return &progress, nil +} + +// CleanupProgress removes progress file after successful download +func CleanupProgress(outputPath string) error { + progressFile := filepath.Join(outputPath, ".download_progress.json") + err := os.Remove(progressFile) + if err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +func max(a, b int64) int64 { + if a > b { + return a + } + return b +} + diff --git a/runner/internal/types/download_test.go b/runner/internal/types/download_test.go new file mode 100644 index 00000000..ddb41d01 --- /dev/null +++ b/runner/internal/types/download_test.go @@ -0,0 +1,239 @@ +package types + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDownloadProgress_NewAndSave(t *testing.T) { + tmpDir := t.TempDir() + + progress := NewDownloadProgress("test-model", 1000000) + progress.AddFile("model.gguf", 1000000) + + if err := progress.Save(tmpDir); err != nil { + t.Fatalf("Failed to save progress: %v", err) + } + + // Verify file exists + progressFile := filepath.Join(tmpDir, ".download_progress.json") + if _, err := os.Stat(progressFile); os.IsNotExist(err) { + t.Fatal("Progress file was not created") + } +} + +func TestDownloadProgress_LoadAndResume(t *testing.T) { + tmpDir := t.TempDir() + + // Create and save progress + progress := NewDownloadProgress("test-model", 1000000) + progress.AddFile("model.gguf", 1000000) + progress.Files["model.gguf"].MarkRangeComplete(0, 500000) + + if err := progress.Save(tmpDir); err != nil { + t.Fatalf("Failed to save progress: %v", err) + } + + // Load progress + loaded, err := LoadDownloadProgress(tmpDir) + if err != nil { + t.Fatalf("Failed to load progress: %v", err) + } + + if loaded == nil { + t.Fatal("Loaded progress is nil") + } + + if loaded.ModelName != "test-model" { + t.Errorf("Expected model name 'test-model', got '%s'", loaded.ModelName) + } + + if loaded.Files["model.gguf"].Downloaded != 500000 { + t.Errorf("Expected 500000 bytes downloaded, got %d", loaded.Files["model.gguf"].Downloaded) + } +} + +func TestFileProgress_MarkRangeComplete(t *testing.T) { + fp := &FileProgress{ + FileName: "test.bin", + TotalSize: 1000, + CompletedRange: make([]CompletedRange, 0), + } + + // Mark first range + fp.MarkRangeComplete(0, 100) + if fp.Downloaded != 100 { + t.Errorf("Expected 100 bytes downloaded, got %d", fp.Downloaded) + } + + // Mark second range + fp.MarkRangeComplete(200, 300) + if fp.Downloaded != 200 { + t.Errorf("Expected 200 bytes downloaded, got %d", fp.Downloaded) + } + + // Mark overlapping range (should merge) + fp.MarkRangeComplete(100, 200) + if fp.Downloaded != 300 { + t.Errorf("Expected 300 bytes downloaded after merge, got %d", fp.Downloaded) + } + + // Should have merged into single range + if len(fp.CompletedRange) != 1 { + t.Errorf("Expected 1 range after merge, got %d", len(fp.CompletedRange)) + } + + if fp.CompletedRange[0].Start != 0 || fp.CompletedRange[0].End != 300 { + t.Errorf("Expected range 0-300, got %d-%d", fp.CompletedRange[0].Start, fp.CompletedRange[0].End) + } +} + +func TestFileProgress_GetMissingRanges(t *testing.T) { + fp := &FileProgress{ + FileName: "test.bin", + TotalSize: 1000, + CompletedRange: make([]CompletedRange, 0), + } + + // No downloads yet + missing := fp.GetMissingRanges(100) + if len(missing) != 1 { + t.Fatalf("Expected 1 missing range, got %d", len(missing)) + } + if missing[0].Start != 0 || missing[0].End != 1000 { + t.Errorf("Expected missing range 0-1000, got %d-%d", missing[0].Start, missing[0].End) + } + + // Download middle section + fp.MarkRangeComplete(400, 600) + missing = fp.GetMissingRanges(100) + + if len(missing) != 2 { + t.Fatalf("Expected 2 missing ranges, got %d", len(missing)) + } + + // First gap: 0-400 + if missing[0].Start != 0 || missing[0].End != 400 { + t.Errorf("Expected first gap 0-400, got %d-%d", missing[0].Start, missing[0].End) + } + + // Second gap: 600-1000 + if missing[1].Start != 600 || missing[1].End != 1000 { + t.Errorf("Expected second gap 600-1000, got %d-%d", missing[1].Start, missing[1].End) + } +} + +func TestFileProgress_IsComplete(t *testing.T) { + fp := &FileProgress{ + FileName: "test.bin", + TotalSize: 1000, + CompletedRange: make([]CompletedRange, 0), + } + + if fp.IsComplete() { + t.Error("File should not be complete initially") + } + + // Download entire file + fp.MarkRangeComplete(0, 1000) + + if !fp.IsComplete() { + t.Error("File should be complete after downloading entire range") + } + + if fp.Downloaded != 1000 { + t.Errorf("Expected 1000 bytes downloaded, got %d", fp.Downloaded) + } +} + +func TestFileProgress_MergeAdjacentRanges(t *testing.T) { + fp := &FileProgress{ + FileName: "test.bin", + TotalSize: 1000, + CompletedRange: make([]CompletedRange, 0), + } + + // Download in reverse order + fp.MarkRangeComplete(600, 800) + fp.MarkRangeComplete(400, 600) + fp.MarkRangeComplete(200, 400) + fp.MarkRangeComplete(0, 200) + + // All ranges should be merged into one + if len(fp.CompletedRange) != 1 { + t.Errorf("Expected 1 merged range, got %d", len(fp.CompletedRange)) + } + + if fp.CompletedRange[0].Start != 0 || fp.CompletedRange[0].End != 800 { + t.Errorf("Expected merged range 0-800, got %d-%d", fp.CompletedRange[0].Start, fp.CompletedRange[0].End) + } + + if fp.Downloaded != 800 { + t.Errorf("Expected 800 bytes downloaded, got %d", fp.Downloaded) + } +} + +func TestCleanupProgress(t *testing.T) { + tmpDir := t.TempDir() + + // Create progress file + progress := NewDownloadProgress("test-model", 1000000) + if err := progress.Save(tmpDir); err != nil { + t.Fatalf("Failed to save progress: %v", err) + } + + // Verify it exists + progressFile := filepath.Join(tmpDir, ".download_progress.json") + if _, err := os.Stat(progressFile); os.IsNotExist(err) { + t.Fatal("Progress file should exist before cleanup") + } + + // Cleanup + if err := CleanupProgress(tmpDir); err != nil { + t.Fatalf("Failed to cleanup progress: %v", err) + } + + // Verify it's gone + if _, err := os.Stat(progressFile); !os.IsNotExist(err) { + t.Error("Progress file should be deleted after cleanup") + } +} + +func TestDownloadProgress_MultipleFiles(t *testing.T) { + tmpDir := t.TempDir() + + progress := NewDownloadProgress("test-model", 3000000) + progress.AddFile("model1.gguf", 1000000) + progress.AddFile("model2.gguf", 1000000) + progress.AddFile("model3.gguf", 1000000) + + // Download parts of each file + progress.Files["model1.gguf"].MarkRangeComplete(0, 500000) + progress.Files["model2.gguf"].MarkRangeComplete(0, 750000) + progress.Files["model3.gguf"].MarkRangeComplete(0, 1000000) // Complete + + // Save and reload + if err := progress.Save(tmpDir); err != nil { + t.Fatalf("Failed to save progress: %v", err) + } + + loaded, err := LoadDownloadProgress(tmpDir) + if err != nil { + t.Fatalf("Failed to load progress: %v", err) + } + + // Verify all files + if loaded.Files["model1.gguf"].Downloaded != 500000 { + t.Errorf("model1 expected 500000, got %d", loaded.Files["model1.gguf"].Downloaded) + } + + if loaded.Files["model2.gguf"].Downloaded != 750000 { + t.Errorf("model2 expected 750000, got %d", loaded.Files["model2.gguf"].Downloaded) + } + + if !loaded.Files["model3.gguf"].IsComplete() { + t.Error("model3 should be complete") + } +} +