Skip to content

Commit 1a1a91a

Browse files
authored
Refactor modelInfo into localModelInfo and remoteModelInfo (#7163)
* Restores ML Pods after M77. * Fix Package.swift * Refactor model info to be in-memory. * Add improved model download functionality w/ single set of download handlers per task. * Rename model download to better match URLSessionTask. * Rename model download to better match URLSessionTask. * Update unit test to read/write to user defaults. * Configure downloader with Firebase app. * Minor updates to model downloader and tests. * Singleton for model downloader instance, for default and custom app. * Download types WIP - local model. * Download types WIP. * Manually add app lifecycle handling + TODOs for Firebase Components. * Replace options with app name as a property in ModelDownloader. * Download types WIP. * Refactor model info. * Add convenience inits + other minor fixes. * Explicit enum for model info result.
1 parent ae854d8 commit 1a1a91a

File tree

8 files changed

+279
-279
lines changed

8 files changed

+279
-279
lines changed

FirebaseMLModelDownloader/Sources/CustomModel.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,21 @@ public struct CustomModel: Hashable {
2424
public let path: String
2525
/// Hash for the model, used for model verification.
2626
public let hash: String
27+
28+
init(name: String, size: Int, path: String, hash: String) {
29+
self.name = name
30+
self.size = size
31+
self.path = path
32+
self.hash = hash
33+
}
34+
35+
/// Convenience init to create model from local model info.
36+
init(localModelInfo: LocalModelInfo) {
37+
self.init(
38+
name: localModelInfo.name,
39+
size: localModelInfo.size,
40+
path: localModelInfo.path,
41+
hash: localModelInfo.modelHash
42+
)
43+
}
2744
}

FirebaseMLModelDownloader/Sources/ModelInfo.swift renamed to FirebaseMLModelDownloader/Sources/LocalModelInfo.swift

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
import Foundation
1616
import FirebaseCore
1717

18-
/// Model info object with details about pending or downloaded model.
19-
struct ModelInfo {
18+
/// Model info object with details about downloaded and locally available model.
19+
// TODO: Can this be backed by user defaults property wrappers?
20+
class LocalModelInfo {
2021
/// Model name.
2122
let name: String
2223

23-
// TODO: revisit UserDefaultsBacked
2424
/// Download URL for the model file, as returned by server.
2525
let downloadURL: URL
2626

@@ -31,25 +31,36 @@ struct ModelInfo {
3131
let size: Int
3232

3333
/// Local path of the model.
34-
var path: String?
34+
let path: String
3535

36-
/// Initialize model info and create user default keys.
37-
init(name: String, downloadURL: URL, modelHash: String, size: Int) {
36+
/// Get user defaults key prefix.
37+
private static func getUserDefaultsKeyPrefix(appName: String, modelName: String) -> String {
38+
let bundleID = Bundle.main.bundleIdentifier ?? ""
39+
return "\(bundleID).\(appName).\(modelName)"
40+
}
41+
42+
init(name: String, downloadURL: URL, modelHash: String, size: Int, path: String) {
3843
self.name = name
3944
self.downloadURL = downloadURL
4045
self.modelHash = modelHash
4146
self.size = size
47+
self.path = path
4248
}
4349

44-
/// Get user defaults key prefix.
45-
private static func getUserDefaultsKeyPrefix(appName: String, modelName: String) -> String {
46-
let bundleID = Bundle.main.bundleIdentifier ?? ""
47-
return "\(bundleID).\(appName).\(modelName)"
50+
/// Convenience init to create local model info from remotely downloaded model info and a local model path.
51+
convenience init(from remoteModelInfo: RemoteModelInfo, path: String) {
52+
self.init(
53+
name: remoteModelInfo.name,
54+
downloadURL: remoteModelInfo.downloadURL,
55+
modelHash: remoteModelInfo.modelHash,
56+
size: remoteModelInfo.size,
57+
path: path
58+
)
4859
}
4960

50-
// TODO: Move reading and writing to user defaults to a new file.
51-
init?(fromDefaults defaults: UserDefaults, modelName: String, appName: String) {
52-
let defaultsPrefix = ModelInfo.getUserDefaultsKeyPrefix(appName: appName, modelName: modelName)
61+
/// Convenience init to create local model info from stored info in user defaults.
62+
convenience init?(fromDefaults defaults: UserDefaults, name: String, appName: String) {
63+
let defaultsPrefix = LocalModelInfo.getUserDefaultsKeyPrefix(appName: appName, modelName: name)
5364
guard let downloadURL = defaults
5465
.value(forKey: "\(defaultsPrefix).model-download-url") as? String,
5566
let url = URL(string: downloadURL),
@@ -58,22 +69,25 @@ struct ModelInfo {
5869
let path = defaults.value(forKey: "\(defaultsPrefix).model-path") as? String else {
5970
return nil
6071
}
61-
name = modelName
62-
self.downloadURL = url
63-
self.modelHash = modelHash
64-
self.size = size
65-
self.path = path
72+
self.init(name: name, downloadURL: url, modelHash: modelHash, size: size, path: path)
6673
}
6774

68-
func writeToDefaults(_ defaults: UserDefaults, appName: String) throws {
69-
guard let modelPath = path else {
70-
throw DownloadedModelError
71-
.fileIOError(description: "Could not save model info to user defaults.")
72-
}
73-
let defaultsPrefix = ModelInfo.getUserDefaultsKeyPrefix(appName: appName, modelName: name)
75+
/// Write local model info to user defaults.
76+
func writeToDefaults(_ defaults: UserDefaults, appName: String) {
77+
let defaultsPrefix = LocalModelInfo.getUserDefaultsKeyPrefix(appName: appName, modelName: name)
7478
defaults.setValue(downloadURL.absoluteString, forKey: "\(defaultsPrefix).model-download-url")
7579
defaults.setValue(modelHash, forKey: "\(defaultsPrefix).model-hash")
7680
defaults.setValue(size, forKey: "\(defaultsPrefix).model-size")
77-
defaults.setValue(modelPath, forKey: "\(defaultsPrefix).model-path")
81+
defaults.setValue(path, forKey: "\(defaultsPrefix).model-path")
82+
}
83+
}
84+
85+
/// Named user defaults for FirebaseML.
86+
extension UserDefaults {
87+
static var firebaseMLDefaults: UserDefaults {
88+
let suiteName = "com.google.firebase.ml"
89+
// TODO: reconsider force unwrapping
90+
let defaults = UserDefaults(suiteName: suiteName)!
91+
return defaults
7892
}
7993
}

FirebaseMLModelDownloader/Sources/ModelDownloadTask.swift

Lines changed: 37 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import Foundation
1616
import FirebaseCore
1717

18+
/// Possible states of model downloading.
1819
enum DownloadStatus {
1920
case notStarted
2021
case inProgress
@@ -37,22 +38,29 @@ class DownloadHandlers {
3738

3839
/// Manager to handle model downloading device and storing downloaded model info to persistent storage.
3940
class ModelDownloadTask: NSObject {
41+
/// Name of the app associated with this instance of ModelDownloadTask.
4042
private let appName: String
41-
private(set) var modelInfo: ModelInfo
43+
/// Model info downloaded from server.
44+
private(set) var remoteModelInfo: RemoteModelInfo
45+
/// User defaults to which local model info should ultimately be written.
46+
private let defaults: UserDefaults
47+
/// Task to handle model file download.
4248
private var downloadTask: URLSessionDownloadTask?
49+
/// Progress and completion handlers associated with this model download task.
4350
private let downloadHandlers: DownloadHandlers
44-
51+
/// Keeps track of download associated with this model download task.
4552
private(set) var downloadStatus: DownloadStatus = .notStarted
46-
53+
/// URLSession to handle model downloads.
4754
private lazy var downloadSession = URLSession(configuration: .ephemeral,
4855
delegate: self,
4956
delegateQueue: nil)
5057

51-
init(modelInfo: ModelInfo, appName: String,
58+
init(remoteModelInfo: RemoteModelInfo, appName: String, defaults: UserDefaults,
5259
progressHandler: DownloadHandlers.ProgressHandler? = nil,
5360
completion: @escaping DownloadHandlers.Completion) {
54-
self.modelInfo = modelInfo
61+
self.remoteModelInfo = remoteModelInfo
5562
self.appName = appName
63+
self.defaults = defaults
5664
downloadHandlers = DownloadHandlers(
5765
progressHandler: progressHandler,
5866
completion: completion
@@ -62,7 +70,7 @@ class ModelDownloadTask: NSObject {
6270
/// Asynchronously download model file to device.
6371
func resumeModelDownload() {
6472
guard downloadStatus == .notStarted else { return }
65-
let downloadTask = downloadSession.downloadTask(with: modelInfo.downloadURL)
73+
let downloadTask = downloadSession.downloadTask(with: remoteModelInfo.downloadURL)
6674
downloadTask.resume()
6775
downloadStatus = .inProgress
6876
self.downloadTask = downloadTask
@@ -71,6 +79,11 @@ class ModelDownloadTask: NSObject {
7179

7280
/// Extension to handle delegate methods.
7381
extension ModelDownloadTask: URLSessionDownloadDelegate {
82+
/// Name for model file stored on device.
83+
var downloadedModelFileName: String {
84+
return "fbml_model__\(appName)__\(remoteModelInfo.name)"
85+
}
86+
7487
func urlSession(_ session: URLSession,
7588
downloadTask: URLSessionDownloadTask,
7689
didFinishDownloadingTo location: URL) {
@@ -80,32 +93,27 @@ extension ModelDownloadTask: URLSessionDownloadDelegate {
8093
.appendingPathComponent(downloadedModelFileName)
8194
do {
8295
try ModelFileManager.moveFile(at: location, to: savedURL)
96+
} catch let error as DownloadError {
97+
DispatchQueue.main.async {
98+
self.downloadHandlers
99+
.completion(.failure(error))
100+
}
83101
} catch {
84-
downloadHandlers
85-
.completion(.failure(.internalError(description: error.localizedDescription)))
86-
return
102+
DispatchQueue.main.async {
103+
self.downloadHandlers
104+
.completion(.failure(.internalError(description: error.localizedDescription)))
105+
}
87106
}
88107

89-
/// Set path to local model.
90-
modelInfo.path = savedURL.absoluteString
108+
/// Generate local model info.
109+
let localModelInfo = LocalModelInfo(from: remoteModelInfo, path: savedURL.absoluteString)
91110
/// Write model to user defaults.
92-
do {
93-
try modelInfo.writeToDefaults(.firebaseMLDefaults, appName: appName)
94-
} catch {
95-
downloadHandlers
96-
.completion(.failure(.internalError(description: error.localizedDescription)))
97-
}
111+
localModelInfo.writeToDefaults(defaults, appName: appName)
98112
/// Build model from model info.
99-
guard let model = buildModel() else {
100-
downloadHandlers
101-
.completion(
102-
.failure(
103-
.internalError(description: "Could not create model due to incomplete model info.")
104-
)
105-
)
106-
return
113+
let model = CustomModel(localModelInfo: localModelInfo)
114+
DispatchQueue.main.async {
115+
self.downloadHandlers.completion(.success(model))
107116
}
108-
downloadHandlers.completion(.success(model))
109117
}
110118

111119
func urlSession(_ session: URLSession,
@@ -114,50 +122,11 @@ extension ModelDownloadTask: URLSessionDownloadDelegate {
114122
totalBytesWritten: Int64,
115123
totalBytesExpectedToWrite: Int64) {
116124
assert(downloadTask == self.downloadTask)
125+
/// Check if progress handler is unspecified.
117126
guard let progressHandler = downloadHandlers.progressHandler else { return }
118127
let calculatedProgress = Float(totalBytesWritten) / Float(totalBytesExpectedToWrite)
119-
progressHandler(calculatedProgress)
120-
}
121-
}
122-
123-
/// Extension to handle post-download operations.
124-
extension ModelDownloadTask {
125-
var downloadedModelFileName: String {
126-
return "fbml_model__\(appName)__\(modelInfo.name)"
127-
}
128-
129-
/// Build custom model object from model info.
130-
// TODO: Consider moving this to CustomModel as a convenience init
131-
func buildModel() -> CustomModel? {
132-
/// Build custom model only if the model file is already on device.
133-
guard let path = modelInfo.path else { return nil }
134-
let model = CustomModel(
135-
name: modelInfo.name,
136-
size: modelInfo.size,
137-
path: path,
138-
hash: modelInfo.modelHash
139-
)
140-
return model
141-
}
142-
143-
/// Get the local path to model on device.
144-
func getLocalModelPath(model: CustomModel) -> URL? {
145-
let fileURL: URL = ModelFileManager.modelsDirectory
146-
.appendingPathComponent(downloadedModelFileName)
147-
if ModelFileManager.isFileReachable(at: fileURL) {
148-
return fileURL
149-
} else {
150-
return nil
128+
DispatchQueue.main.async {
129+
progressHandler(calculatedProgress)
151130
}
152131
}
153132
}
154-
155-
/// Named user defaults for FirebaseML.
156-
extension UserDefaults {
157-
static var firebaseMLDefaults: UserDefaults {
158-
let suiteName = "com.google.firebase.ml"
159-
// TODO: reconsider force unwrapping
160-
let defaults = UserDefaults(suiteName: suiteName)!
161-
return defaults
162-
}
163-
}

FirebaseMLModelDownloader/Sources/ModelFileManager.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ enum ModelFileManager {
4444
do {
4545
try FileManager.default.removeItem(at: destinationURL)
4646
} catch {
47-
throw DownloadedModelError
48-
.fileIOError(description: "Could not replace existing model file.")
47+
throw DownloadError
48+
.internalError(description: "Could not replace existing model file.")
4949
}
5050
}
5151
do {
5252
try FileManager.default.moveItem(at: sourceURL, to: destinationURL)
5353
} catch {
54-
throw DownloadedModelError.fileIOError(description: "Unable to save model file.")
54+
throw DownloadError.internalError(description: "Unable to save model file.")
5555
}
5656
}
5757
}

0 commit comments

Comments
 (0)