Skip to content

Commit 63ab365

Browse files
committed
Parallelize downloads, skip HEAD requests for cached files, weight download progress by file size
1 parent 1789228 commit 63ab365

File tree

4 files changed

+387
-190
lines changed

4 files changed

+387
-190
lines changed

Sources/Hub/Hub.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public extension Hub {
7777
}
7878

7979
/// The type of repository on the Hugging Face Hub.
80-
enum RepoType: String, Codable {
80+
enum RepoType: String, Codable, Sendable {
8181
/// Model repositories containing machine learning models.
8282
case models
8383
/// Dataset repositories containing training and evaluation data.
@@ -90,7 +90,7 @@ public extension Hub {
9090
///
9191
/// A repository is identified by its unique ID and type, allowing access to
9292
/// different kinds of resources hosted on the Hub platform.
93-
struct Repo: Codable {
93+
struct Repo: Codable, Sendable {
9494
/// The unique identifier for the repository (e.g., "microsoft/DialoGPT-medium").
9595
public let id: String
9696
/// The type of repository (models, datasets, or spaces).

Sources/Hub/HubApi.swift

Lines changed: 133 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,17 @@ public extension HubApi {
172172
struct Sibling: Codable {
173173
/// The relative filename within the repository.
174174
let rfilename: String
175+
/// The size of the file in bytes (optional, may not be present for all files).
176+
let size: Int64?
175177
}
176178

177-
/// Response structure for repository file listings.
179+
/// Response structure for repository info including file listings.
178180
///
179-
/// Contains the list of files available in a repository,
181+
/// Contains the commit hash and list of files available in a repository,
180182
/// returned by the Hub API when querying repository contents.
181-
struct SiblingsResponse: Codable {
183+
struct RepoInfoResponse: Codable {
184+
/// The commit hash (SHA) for this revision.
185+
let sha: String
182186
/// Array of files in the repository.
183187
let siblings: [Sibling]
184188
}
@@ -257,23 +261,42 @@ public extension HubApi {
257261
/// - Returns: Array of matching filenames
258262
/// - Throws: HubClientError if the repository cannot be accessed or parsed
259263
func getFilenames(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [String] {
260-
// Read repo info and only parse "siblings"
261-
let url = URL(string: endpoint)!
264+
let (files, _) = try await getRepoInfo(from: repo, revision: revision, matching: globs)
265+
return files.map { $0.filename }
266+
}
267+
268+
/// File info returned from getRepoInfo, containing filename and optional size.
269+
struct RepoFileInfo {
270+
let filename: String
271+
let size: Int64?
272+
}
273+
274+
/// Retrieves repository info including filenames, sizes, and commit hash.
275+
///
276+
/// - Parameters:
277+
/// - repo: The repository to query
278+
/// - revision: Git revision (branch, tag, or commit hash)
279+
/// - globs: Optional glob patterns to filter files
280+
/// - Returns: Tuple of (matching file info, commit SHA)
281+
/// - Throws: HubClientError if the repository cannot be accessed or parsed
282+
func getRepoInfo(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> (files: [RepoFileInfo], sha: String) {
283+
var url = URL(string: endpoint)!
262284
.appending(path: "api")
263285
.appending(path: repo.type.rawValue)
264286
.appending(path: repo.id)
265287
.appending(path: "revision")
266288
.appending(component: revision) // Encode slashes (e.g., "pr/1" -> "pr%2F1")
289+
// blobs=true includes file sizes in the response, used for size-weighted progress reporting
290+
url.append(queryItems: [URLQueryItem(name: "blobs", value: "true")])
267291
let (data, _) = try await httpGet(for: url)
268-
let response = try JSONDecoder().decode(SiblingsResponse.self, from: data)
269-
let filenames = response.siblings.map { $0.rfilename }
270-
guard globs.count > 0 else { return filenames }
292+
let response = try JSONDecoder().decode(RepoInfoResponse.self, from: data)
293+
let allFiles = response.siblings.map { RepoFileInfo(filename: $0.rfilename, size: $0.size) }
271294

272-
var selected: Set<String> = []
273-
for glob in globs {
274-
selected = selected.union(filenames.matching(glob: glob))
275-
}
276-
return Array(selected)
295+
guard globs.count > 0 else { return (allFiles, response.sha) }
296+
297+
let matchingFilenames = Set(allFiles.map { $0.filename }.matching(globs: globs))
298+
let matchingFiles = allFiles.filter { matchingFilenames.contains($0.filename) }
299+
return (matchingFiles, response.sha)
277300
}
278301

279302
func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] {
@@ -500,12 +523,25 @@ public extension HubApi {
500523
}
501524

502525
/// Downloads the file with progress tracking.
503-
/// - Parameter progressHandler: Called with download progress (0.0-1.0) and speed in bytes/sec, if available.
526+
/// - Parameters:
527+
/// - knownCommitHash: If provided and matches local metadata, skips the HEAD request for faster cache hits.
528+
/// - progressHandler: Called with download progress (0.0-1.0) and speed in bytes/sec, if available.
504529
/// - Returns: Local file URL (uses cached file if commit hash matches).
505530
/// - Throws: ``EnvironmentError`` errors for file and metadata validation failures, ``Downloader.DownloadError`` errors during transfer, or ``CancellationError`` if the task is cancelled.
506531
@discardableResult
507-
func download(progressHandler: @escaping (Double, Double?) -> Void) async throws -> URL {
532+
func download(knownCommitHash: String? = nil, progressHandler: @escaping (Double, Double?) -> Void) async throws -> URL {
508533
let localMetadata = try hub.readDownloadMetadata(metadataPath: metadataDestination)
534+
535+
// Fast path: if we know the repo's commit hash and local file matches, skip HEAD request
536+
if let knownCommitHash,
537+
hub.isValidHash(hash: knownCommitHash, pattern: hub.commitHashPattern),
538+
downloaded,
539+
let localMetadata,
540+
localMetadata.commitHash == knownCommitHash
541+
{
542+
return destination
543+
}
544+
509545
let remoteMetadata = try await hub.getFileMetadata(url: source)
510546

511547
let localCommitHash = localMetadata?.commitHash ?? ""
@@ -592,7 +628,7 @@ public extension HubApi {
592628
}
593629

594630
@discardableResult
595-
func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in })
631+
func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping @Sendable (Progress) -> Void = { _ in })
596632
async throws -> URL
597633
{
598634
let repoDestination = localRepoLocation(repo)
@@ -634,77 +670,111 @@ public extension HubApi {
634670
return repoDestination
635671
}
636672

637-
let filenames = try await getFilenames(from: repo, revision: revision, matching: globs)
638-
let progress = Progress(totalUnitCount: Int64(filenames.count))
639-
for filename in filenames {
640-
let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1)
641-
let downloader = HubFileDownloader(
642-
hub: self,
643-
repo: repo,
644-
revision: revision,
645-
repoDestination: repoDestination,
646-
repoMetadataDestination: repoMetadataDestination,
647-
relativeFilename: filename,
648-
hfToken: hfToken,
649-
endpoint: endpoint,
650-
backgroundSession: useBackgroundSession
651-
)
673+
let (files, repoCommitHash) = try await getRepoInfo(from: repo, revision: revision, matching: globs)
652674

653-
try await downloader.download { fractionDownloaded, speed in
654-
fileProgress.completedUnitCount = Int64(100 * fractionDownloaded)
655-
if let speed {
656-
fileProgress.setUserInfoObject(speed, forKey: .throughputKey)
657-
progress.setUserInfoObject(speed, forKey: .throughputKey)
675+
// Size-weighted progress: total is sum of file sizes (bytes)
676+
let totalBytes = files.reduce(Int64(0)) { $0 + ($1.size ?? 1) }
677+
let progress = Progress(totalUnitCount: totalBytes)
678+
progressHandler(progress)
679+
680+
// Process files in parallel with limited concurrency
681+
let maxConcurrentDownloads = 8
682+
try await withThrowingTaskGroup(of: Void.self) { group in
683+
var nextIndex = 0
684+
685+
// Helper to add a download task for a file
686+
func addDownloadTask(index: Int, file: RepoFileInfo) {
687+
// Create child progress weighted by file size
688+
let fileSize = file.size ?? 1
689+
let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: fileSize)
690+
691+
let hub = self
692+
let relativeFilename = file.filename
693+
group.addTask {
694+
let downloader = HubFileDownloader(
695+
hub: hub,
696+
repo: repo,
697+
revision: revision,
698+
repoDestination: repoDestination,
699+
repoMetadataDestination: repoMetadataDestination,
700+
relativeFilename: relativeFilename,
701+
hfToken: hfToken,
702+
endpoint: endpoint,
703+
backgroundSession: useBackgroundSession
704+
)
705+
try await downloader.download(knownCommitHash: repoCommitHash) { fractionDownloaded, speed in
706+
let newCount = Int64(100 * fractionDownloaded)
707+
// Only update if progress increased (handles out-of-order updates)
708+
if newCount > fileProgress.completedUnitCount {
709+
fileProgress.completedUnitCount = newCount
710+
}
711+
if let speed {
712+
fileProgress.setUserInfoObject(speed, forKey: .throughputKey)
713+
progress.setUserInfoObject(speed, forKey: .throughputKey)
714+
}
715+
progressHandler(progress)
716+
}
717+
fileProgress.completedUnitCount = 100
718+
progressHandler(progress)
658719
}
659-
progressHandler(progress)
660720
}
661-
if Task.isCancelled {
662-
return repoDestination
721+
722+
// Start initial batch of tasks
723+
while nextIndex < min(maxConcurrentDownloads, files.count) {
724+
addDownloadTask(index: nextIndex, file: files[nextIndex])
725+
nextIndex += 1
663726
}
664727

665-
fileProgress.completedUnitCount = 100
728+
// Process completions and add new tasks as slots open up
729+
for try await _ in group {
730+
if nextIndex < files.count, !Task.isCancelled {
731+
addDownloadTask(index: nextIndex, file: files[nextIndex])
732+
nextIndex += 1
733+
}
734+
}
666735
}
667736

737+
// Final progress update to ensure 100% is reported
668738
progressHandler(progress)
669739
return repoDestination
670740
}
671741

672742
/// New overloads exposing speed directly in the snapshot progress handler
673-
@discardableResult func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
743+
@discardableResult func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping @Sendable (Progress, Double?) -> Void) async throws -> URL {
674744
try await snapshot(from: repo, revision: revision, matching: globs) { progress in
675745
let speed = progress.userInfo[.throughputKey] as? Double
676746
progressHandler(progress, speed)
677747
}
678748
}
679749

680750
@discardableResult
681-
func snapshot(from repoId: String, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
751+
func snapshot(from repoId: String, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping @Sendable (Progress) -> Void = { _ in }) async throws -> URL {
682752
try await snapshot(from: Repo(id: repoId), revision: revision, matching: globs, progressHandler: progressHandler)
683753
}
684754

685755
@discardableResult
686-
func snapshot(from repo: Repo, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
756+
func snapshot(from repo: Repo, revision: String = "main", matching glob: String, progressHandler: @escaping @Sendable (Progress) -> Void = { _ in }) async throws -> URL {
687757
try await snapshot(from: repo, revision: revision, matching: [glob], progressHandler: progressHandler)
688758
}
689759

690760
@discardableResult
691-
func snapshot(from repoId: String, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
761+
func snapshot(from repoId: String, revision: String = "main", matching glob: String, progressHandler: @escaping @Sendable (Progress) -> Void = { _ in }) async throws -> URL {
692762
try await snapshot(from: Repo(id: repoId), revision: revision, matching: [glob], progressHandler: progressHandler)
693763
}
694764

695765
/// Convenience overloads for other snapshot entry points with speed
696766
@discardableResult
697-
func snapshot(from repoId: String, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
767+
func snapshot(from repoId: String, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping @Sendable (Progress, Double?) -> Void) async throws -> URL {
698768
try await snapshot(from: Repo(id: repoId), revision: revision, matching: globs, progressHandler: progressHandler)
699769
}
700770

701771
@discardableResult
702-
func snapshot(from repo: Repo, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
772+
func snapshot(from repo: Repo, revision: String = "main", matching glob: String, progressHandler: @escaping @Sendable (Progress, Double?) -> Void) async throws -> URL {
703773
try await snapshot(from: repo, revision: revision, matching: [glob], progressHandler: progressHandler)
704774
}
705775

706776
@discardableResult
707-
func snapshot(from repoId: String, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
777+
func snapshot(from repoId: String, revision: String = "main", matching glob: String, progressHandler: @escaping @Sendable (Progress, Double?) -> Void) async throws -> URL {
708778
try await snapshot(from: Repo(id: repoId), revision: revision, matching: [glob], progressHandler: progressHandler)
709779
}
710780
}
@@ -940,38 +1010,38 @@ public extension Hub {
9401010
/// - progressHandler: Closure called with download progress updates
9411011
/// - Returns: URL to the local repository directory
9421012
/// - Throws: HubClientError if the download fails
943-
static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
1013+
static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping @Sendable (Progress) -> Void = { _ in }) async throws -> URL {
9441014
try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler)
9451015
}
9461016

947-
static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws
1017+
static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping @Sendable (Progress) -> Void = { _ in }) async throws
9481018
-> URL
9491019
{
9501020
try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler)
9511021
}
9521022

953-
static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
1023+
static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping @Sendable (Progress) -> Void = { _ in }) async throws -> URL {
9541024
try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler)
9551025
}
9561026

957-
static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
1027+
static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping @Sendable (Progress) -> Void = { _ in }) async throws -> URL {
9581028
try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler)
9591029
}
9601030

9611031
/// Overloads exposing speed via (Progress, Double?) where Double is bytes/sec
962-
static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
1032+
static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping @Sendable (Progress, Double?) -> Void) async throws -> URL {
9631033
try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler)
9641034
}
9651035

966-
static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
1036+
static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping @Sendable (Progress, Double?) -> Void) async throws -> URL {
9671037
try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler)
9681038
}
9691039

970-
static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
1040+
static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping @Sendable (Progress, Double?) -> Void) async throws -> URL {
9711041
try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler)
9721042
}
9731043

974-
static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
1044+
static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping @Sendable (Progress, Double?) -> Void) async throws -> URL {
9751045
try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler)
9761046
}
9771047

@@ -1009,6 +1079,14 @@ private extension [String] {
10091079
func matching(glob: String) -> [String] {
10101080
filter { fnmatch(glob, $0, 0) == 0 }
10111081
}
1082+
1083+
func matching(globs: [String]) -> [String] {
1084+
var selected: Set<String> = []
1085+
for glob in globs {
1086+
selected = selected.union(matching(glob: glob))
1087+
}
1088+
return Array(selected)
1089+
}
10121090
}
10131091

10141092
private extension FileManager {

0 commit comments

Comments
 (0)