Skip to content

Commit a8ce5b0

Browse files
authored
Handle expired download URL with support for retries (#7329)
* Restores ML Pods after M77. * Fix Package.swift * Re-add catalyst to GHA workflow. * Update error handling. * Add logging for background download failure. * Handle expired download URL. * Remove empty file. * Fix integration test. * Minor refactor.
1 parent 899b562 commit a8ce5b0

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

FirebaseMLModelDownloader/Sources/ModelDownloadTask.swift

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,31 @@ class ModelDownloadTask: NSObject {
5555
private lazy var downloadSession = URLSession(configuration: .ephemeral,
5656
delegate: self,
5757
delegateQueue: nil)
58+
/// Model info retriever in case of retries.
59+
private let modelInfoRetriever: ModelInfoRetriever
60+
/// Number of retries in case of model download URL expiry.
61+
private var numberOfRetries: Int = 1
5862
/// Telemetry logger.
5963
private let telemetryLogger: TelemetryLogger?
6064

6165
init(remoteModelInfo: RemoteModelInfo, appName: String, defaults: UserDefaults,
66+
modelInfoRetriever: ModelInfoRetriever,
6267
telemetryLogger: TelemetryLogger? = nil,
6368
progressHandler: DownloadHandlers.ProgressHandler? = nil,
6469
completion: @escaping DownloadHandlers.Completion) {
6570
self.remoteModelInfo = remoteModelInfo
6671
self.appName = appName
72+
self.modelInfoRetriever = modelInfoRetriever
6773
self.telemetryLogger = telemetryLogger
6874
self.defaults = defaults
6975
downloadHandlers = DownloadHandlers(
7076
progressHandler: progressHandler,
7177
completion: completion
7278
)
7379
}
80+
}
7481

82+
extension ModelDownloadTask {
7583
/// Asynchronously download model file to device.
7684
func resumeModelDownload() {
7785
guard downloadStatus == .notStarted else { return }
@@ -116,7 +124,42 @@ extension ModelDownloadTask: URLSessionDownloadDelegate {
116124
}
117125

118126
guard (200 ..< 299).contains(response.statusCode) else {
119-
// TODO: Handle download url expiry + retries.
127+
/// Possible failure due to download URL expiry.
128+
if response.statusCode == 400 {
129+
let currentDateTime = Date()
130+
/// Retry download if allowed.
131+
guard currentDateTime > remoteModelInfo.urlExpiryTime, numberOfRetries > 0 else {
132+
downloadStatus = .failed
133+
downloadHandlers
134+
.completion(.failure(.internalError(description: ModelDownloadTask.ErrorDescription
135+
.expiredModelInfo)))
136+
return
137+
}
138+
numberOfRetries -= 1
139+
modelInfoRetriever.downloadModelInfo { result in
140+
switch result {
141+
case let .success(downloadModelInfoResult):
142+
switch downloadModelInfoResult {
143+
/// New model info was downloaded from server.
144+
case let .modelInfo(remoteModelInfo):
145+
self.remoteModelInfo = remoteModelInfo
146+
self.resumeModelDownload()
147+
/// This should not ever be the case - model info cannot be unmodified within ModelDownloadTask.
148+
case .notModified:
149+
DispatchQueue.main.async {
150+
self.downloadHandlers
151+
.completion(.failure(.internalError(description: ModelDownloadTask
152+
.ErrorDescription.expiredModelInfo)))
153+
}
154+
}
155+
case let .failure(downloadError):
156+
self.downloadStatus = .failed
157+
DispatchQueue.main.async {
158+
self.downloadHandlers.completion(.failure(downloadError))
159+
}
160+
}
161+
}
162+
}
120163
return
121164
}
122165

@@ -196,5 +239,6 @@ extension ModelDownloadTask {
196239
"Could not get server response for model downloading."
197240
static let saveModel: StaticString =
198241
"Unable to save downloaded remote model file."
242+
static let expiredModelInfo = "Unable to update expired model info."
199243
}
200244
}

FirebaseMLModelDownloader/Sources/ModelDownloader.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import FirebaseCore
1717
import FirebaseInstallations
1818

1919
/// Possible errors with model downloading.
20-
public enum DownloadError: Error, Equatable {
20+
public enum DownloadError: Error {
2121
/// No model with this name found on server.
2222
case notFound
2323
/// Caller does not have necessary permissions for this operation.
@@ -33,7 +33,7 @@ public enum DownloadError: Error, Equatable {
3333
}
3434

3535
/// Possible errors with locating model on device.
36-
public enum DownloadedModelError: Error, Equatable {
36+
public enum DownloadedModelError: Error {
3737
/// File system error.
3838
case fileIOError(description: String)
3939
/// Model not found on device.
@@ -282,6 +282,7 @@ extension ModelDownloader {
282282
remoteModelInfo: remoteModelInfo,
283283
appName: self.appName,
284284
defaults: self.userDefaults,
285+
modelInfoRetriever: modelInfoRetriever,
285286
telemetryLogger: self.telemetryLogger,
286287
progressHandler: progressHandler,
287288
completion: completion

FirebaseMLModelDownloader/Tests/Integration/ModelDownloaderIntegrationTests.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,19 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
157157
urlExpiryTime: Date()
158158
)
159159

160+
let modelInfoRetriever = ModelInfoRetriever(
161+
modelName: testModelName,
162+
options: testApp.options,
163+
installations: Installations.installations(app: testApp),
164+
appName: testApp.name
165+
)
166+
160167
let expectation = self.expectation(description: "Wait for model to download.")
161168
let modelDownloadManager = ModelDownloadTask(
162169
remoteModelInfo: remoteModelInfo,
163170
appName: testApp.name,
164171
defaults: .createTestInstance(testName: #function),
172+
modelInfoRetriever: modelInfoRetriever,
165173
progressHandler: { progress in
166174
XCTAssertLessThanOrEqual(progress, 1)
167175
XCTAssertGreaterThanOrEqual(progress, 0)

0 commit comments

Comments
 (0)