Skip to content

Commit f6e61e3

Browse files
authored
Logical checkpoint of ModelDownloader implementation (#7259)
* Restores ML Pods after M77. * Fix Package.swift * Re-add catalyst to GHA workflow. * Includes the implementation of: 1) download types, 2) list models, 3) delete model. * Add protocol to handle writing to user defaults. * Add protocol to handle writing to user defaults. * Add empty init to ModelDownloadConditions. * Copyright year + other minor fixes. * TODO to use FirebaseApp internal init for test app. * Test if download progress is in range. * Remove not nil test for progress. * Remove not nil test for progress.
1 parent 11f55a1 commit f6e61e3

11 files changed

+522
-58
lines changed

FirebaseMLModelDownloader/Sources/CustomModel.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020 Google LLC
1+
// Copyright 2021 Google LLC
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright 2021 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
/// Protocol to write info to user defaults.
18+
protocol DownloaderUserDefaultsWriteable {
19+
func writeToDefaults(_ defaults: UserDefaults, appName: String)
20+
}

FirebaseMLModelDownloader/Sources/LocalModelInfo.swift

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020 Google LLC
1+
// Copyright 2021 Google LLC
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -33,12 +33,6 @@ class LocalModelInfo {
3333
/// Local path of the model.
3434
let path: String
3535

36-
/// Get user defaults key prefix.
37-
private static func getUserDefaultsKeyPrefix(appName: String, modelName: String) -> String {
38-
let bundleID = Bundle.main.bundleIdentifier ?? ""
39-
return "\(bundleID).\(appName).\(modelName)"
40-
}
41-
4236
init(name: String, downloadURL: URL, modelHash: String, size: Int, path: String) {
4337
self.name = name
4438
self.downloadURL = downloadURL
@@ -71,6 +65,15 @@ class LocalModelInfo {
7165
}
7266
self.init(name: name, downloadURL: url, modelHash: modelHash, size: size, path: path)
7367
}
68+
}
69+
70+
/// Extension to write local model info to user defaults.
71+
extension LocalModelInfo: DownloaderUserDefaultsWriteable {
72+
/// Get user defaults key prefix.
73+
private static func getUserDefaultsKeyPrefix(appName: String, modelName: String) -> String {
74+
let bundleID = Bundle.main.bundleIdentifier ?? ""
75+
return "\(bundleID).\(appName).\(modelName)"
76+
}
7477

7578
/// Write local model info to user defaults.
7679
func writeToDefaults(_ defaults: UserDefaults, appName: String) {

FirebaseMLModelDownloader/Sources/ModelDownloadConditions.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020 Google LLC
1+
// Copyright 2021 Google LLC
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -15,4 +15,8 @@
1515
import Foundation
1616

1717
/// Model download conditions.
18-
public struct ModelDownloadConditions {}
18+
// TODO: Implement model download conditions.
19+
public struct ModelDownloadConditions {
20+
// TODO: Intentionally left blank until implementation.
21+
public init() {}
22+
}

FirebaseMLModelDownloader/Sources/ModelDownloadTask.swift

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020 Google LLC
1+
// Copyright 2021 Google LLC
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -84,15 +84,21 @@ extension ModelDownloadTask: URLSessionDownloadDelegate {
8484
return "fbml_model__\(appName)__\(remoteModelInfo.name)"
8585
}
8686

87+
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
88+
// TODO: Handle model download url expiry and other download errors
89+
}
90+
8791
func urlSession(_ session: URLSession,
8892
downloadTask: URLSessionDownloadTask,
8993
didFinishDownloadingTo location: URL) {
9094
assert(downloadTask == self.downloadTask)
9195
downloadStatus = .completed
92-
let savedURL = ModelFileManager.modelsDirectory
93-
.appendingPathComponent(downloadedModelFileName)
96+
let modelFileURL = ModelFileManager.getDownloadedModelFilePath(
97+
appName: appName,
98+
modelName: remoteModelInfo.name
99+
)
94100
do {
95-
try ModelFileManager.moveFile(at: location, to: savedURL)
101+
try ModelFileManager.moveFile(at: location, to: modelFileURL)
96102
} catch let error as DownloadError {
97103
DispatchQueue.main.async {
98104
self.downloadHandlers
@@ -106,7 +112,7 @@ extension ModelDownloadTask: URLSessionDownloadDelegate {
106112
}
107113

108114
/// Generate local model info.
109-
let localModelInfo = LocalModelInfo(from: remoteModelInfo, path: savedURL.absoluteString)
115+
let localModelInfo = LocalModelInfo(from: remoteModelInfo, path: modelFileURL.absoluteString)
110116
/// Write model to user defaults.
111117
localModelInfo.writeToDefaults(defaults, appName: appName)
112118
/// Build model from model info.

FirebaseMLModelDownloader/Sources/ModelDownloader.swift

Lines changed: 199 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020 Google LLC
1+
// Copyright 2021 Google LLC
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
1414

1515
import Foundation
1616
import FirebaseCore
17+
import FirebaseInstallations
1718

1819
/// Possible errors with model downloading.
1920
public enum DownloadError: Error, Equatable {
@@ -37,6 +38,8 @@ public enum DownloadedModelError: Error, Equatable {
3738
case fileIOError(description: String)
3839
/// Model not found on device.
3940
case notFound
41+
/// Other errors with description.
42+
case internalError(description: String)
4043
}
4144

4245
/// Possible ways to get a custom model.
@@ -53,14 +56,24 @@ public enum ModelDownloadType {
5356
public class ModelDownloader {
5457
/// Name of the app associated with this instance of ModelDownloader.
5558
private let appName: String
59+
/// Current Firebase app options.
60+
private let options: FirebaseOptions
61+
/// Installations instance for current Firebase app.
62+
private let installations: Installations
63+
/// User defaults for model info.
64+
private let userDefaults: UserDefaults
5665

5766
/// Shared dictionary mapping app name to a specific instance of model downloader.
5867
// TODO: Switch to using Firebase components.
5968
private static var modelDownloaderDictionary: [String: ModelDownloader] = [:]
6069

6170
/// Private init for downloader.
62-
private init(app: FirebaseApp) {
71+
private init(app: FirebaseApp, defaults: UserDefaults = .firebaseMLDefaults) {
6372
appName = app.name
73+
options = app.options
74+
installations = Installations.installations(app: app)
75+
userDefaults = defaults
76+
6477
NotificationCenter.default.addObserver(
6578
self,
6679
selector: #selector(deleteModelDownloader),
@@ -103,36 +116,202 @@ public class ModelDownloader {
103116
conditions: ModelDownloadConditions,
104117
progressHandler: ((Float) -> Void)? = nil,
105118
completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
106-
// TODO: Model download
107-
let modelSize = Int()
108-
let modelPath = String()
109-
let modelHash = String()
119+
switch downloadType {
120+
case .localModel:
121+
if let localModel = getLocalModel(modelName: modelName) {
122+
DispatchQueue.main.async {
123+
completion(.success(localModel))
124+
}
125+
} else {
126+
getRemoteModel(
127+
modelName: modelName,
128+
progressHandler: progressHandler,
129+
completion: completion
130+
)
131+
}
132+
case .localModelUpdateInBackground:
133+
if let localModel = getLocalModel(modelName: modelName) {
134+
DispatchQueue.main.async {
135+
completion(.success(localModel))
136+
}
137+
DispatchQueue.global(qos: .utility).async { [weak self] in
138+
self?.getRemoteModel(
139+
modelName: modelName,
140+
progressHandler: nil,
141+
completion: { result in
142+
switch result {
143+
// TODO: Log outcome of background download
144+
case .success: break
145+
case .failure: break
146+
}
147+
}
148+
)
149+
}
150+
} else {
151+
getRemoteModel(
152+
modelName: modelName,
153+
progressHandler: progressHandler,
154+
completion: completion
155+
)
156+
}
110157

111-
let customModel = CustomModel(
112-
name: modelName,
113-
size: modelSize,
114-
path: modelPath,
115-
hash: modelHash
116-
)
117-
completion(.success(customModel))
118-
completion(.failure(.notFound))
158+
case .latestModel:
159+
getRemoteModel(
160+
modelName: modelName,
161+
progressHandler: progressHandler,
162+
completion: completion
163+
)
164+
}
119165
}
120166

121167
/// Gets all downloaded models.
122168
public func listDownloadedModels(completion: @escaping (Result<Set<CustomModel>,
123169
DownloadedModelError>) -> Void) {
124-
let customModels = Set<CustomModel>()
125-
// TODO: List downloaded models
126-
completion(.success(customModels))
127-
completion(.failure(.notFound))
170+
do {
171+
let modelPaths = try ModelFileManager.contentsOfModelsDirectory()
172+
var customModels = Set<CustomModel>()
173+
for path in modelPaths {
174+
guard let modelName = ModelFileManager.getModelNameFromFilePath(path) else {
175+
completion(.failure(.internalError(description: "Invalid model file name.")))
176+
return
177+
}
178+
guard let modelInfo = getLocalModelInfo(modelName: modelName) else {
179+
completion(
180+
.failure(.internalError(description: "Failed to get model info for model file."))
181+
)
182+
return
183+
}
184+
guard modelInfo.path == path.absoluteString else {
185+
completion(
186+
.failure(.internalError(description: "Outdated model paths in local storage."))
187+
)
188+
return
189+
}
190+
let model = CustomModel(localModelInfo: modelInfo)
191+
customModels.insert(model)
192+
}
193+
completion(.success(customModels))
194+
} catch let error as DownloadedModelError {
195+
completion(.failure(error))
196+
} catch {
197+
completion(.failure(.internalError(description: error.localizedDescription)))
198+
}
128199
}
129200

130201
/// Deletes a custom model from device.
131202
public func deleteDownloadedModel(name modelName: String,
132203
completion: @escaping (Result<Void, DownloadedModelError>)
133204
-> Void) {
134205
// TODO: Delete previously downloaded model
135-
completion(.success(()))
136-
completion(.failure(.notFound))
206+
guard let localModelInfo = getLocalModelInfo(modelName: modelName),
207+
let localPath = URL(string: localModelInfo.path)
208+
else {
209+
completion(.failure(.notFound))
210+
return
211+
}
212+
do {
213+
try ModelFileManager.removeFile(at: localPath)
214+
completion(.success(()))
215+
} catch let error as DownloadedModelError {
216+
completion(.failure(error))
217+
} catch {
218+
completion(.failure(.internalError(description: error.localizedDescription)))
219+
}
220+
}
221+
}
222+
223+
extension ModelDownloader {
224+
/// Return local model info only if the model info is available and the corresponding model file is already on device.
225+
private func getLocalModelInfo(modelName: String) -> LocalModelInfo? {
226+
guard let localModelInfo = LocalModelInfo(
227+
fromDefaults: userDefaults,
228+
name: modelName,
229+
appName: appName
230+
) else {
231+
return nil
232+
}
233+
/// There is local model info on device, but no model file at the expected path.
234+
guard let localPath = URL(string: localModelInfo.path),
235+
ModelFileManager.isFileReachable(at: localPath) else {
236+
// TODO: Consider deleting local model info in user defaults.
237+
return nil
238+
}
239+
return localModelInfo
240+
}
241+
242+
/// Get model saved on device if available.
243+
private func getLocalModel(modelName: String) -> CustomModel? {
244+
guard let localModelInfo = getLocalModelInfo(modelName: modelName) else { return nil }
245+
let model = CustomModel(localModelInfo: localModelInfo)
246+
return model
247+
}
248+
249+
/// Download and get model from server, unless the latest model is already available on device.
250+
private func getRemoteModel(modelName: String,
251+
progressHandler: ((Float) -> Void)? = nil,
252+
completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
253+
let localModelInfo = getLocalModelInfo(modelName: modelName)
254+
let modelInfoRetriever = ModelInfoRetriever(
255+
modelName: modelName,
256+
options: options,
257+
installations: installations,
258+
appName: appName,
259+
localModelInfo: localModelInfo
260+
)
261+
modelInfoRetriever.downloadModelInfo { result in
262+
switch result {
263+
case let .success(downloadModelInfoResult):
264+
switch downloadModelInfoResult {
265+
/// New model info was downloaded from server.
266+
case let .modelInfo(remoteModelInfo):
267+
let downloadTask = ModelDownloadTask(
268+
remoteModelInfo: remoteModelInfo,
269+
appName: self.appName,
270+
defaults: self.userDefaults,
271+
progressHandler: progressHandler,
272+
completion: completion
273+
)
274+
downloadTask.resumeModelDownload()
275+
/// Local model info is the latest model info.
276+
case .notModified:
277+
guard let localModel = self.getLocalModel(modelName: modelName) else {
278+
DispatchQueue.main.async {
279+
/// This can only happen if either local model info or the model file was suddenly wiped out in the middle of model info request and server response
280+
// TODO: Consider handling: if model file is deleted after local model info is retrieved but before model info network call
281+
completion(
282+
.failure(
283+
.internalError(description: "Model unavailable due to deleted local model info.")
284+
)
285+
)
286+
}
287+
return
288+
}
289+
290+
DispatchQueue.main.async {
291+
completion(.success(localModel))
292+
}
293+
}
294+
case let .failure(downloadError):
295+
DispatchQueue.main.async {
296+
completion(.failure(downloadError))
297+
}
298+
}
299+
}
300+
}
301+
}
302+
303+
/// Model downloader extension for testing.
304+
extension ModelDownloader {
305+
/// Model downloader instance for testing.
306+
// TODO: Consider using protocols
307+
static func modelDownloaderWithDefaults(_ defaults: UserDefaults,
308+
app: FirebaseApp) -> ModelDownloader {
309+
if let downloader = modelDownloaderDictionary[app.name] {
310+
return downloader
311+
} else {
312+
let downloader = ModelDownloader(app: app, defaults: defaults)
313+
modelDownloaderDictionary[app.name] = downloader
314+
return downloader
315+
}
137316
}
138317
}

0 commit comments

Comments
 (0)