Skip to content

Commit e6b6020

Browse files
authored
Add singleton-ish instance methods for model downloader (public API). (#7124)
* Restores ML Pods after M77. * Fix Package.swift * Refactor model info to be in-memory. * Add improved model download functionality w/ single set of download handlers per task. * Rename model download to better match URLSessionTask. * Rename model download to better match URLSessionTask. * Update unit test to read/write to user defaults. * Configure downloader with Firebase app. * Minor updates to model downloader and tests. * Singleton for model downloader instance, for default and custom app. * Manually add app lifecycle handling + TODOs for Firebase Components. * Replace options with app name as a property in ModelDownloader. * Update unit test to check model downloader instance creation + minor rename.
1 parent 3f95f59 commit e6b6020

File tree

6 files changed

+104
-36
lines changed

6 files changed

+104
-36
lines changed

FirebaseMLModelDownloader/Sources/ModelDownloadTask.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class DownloadHandlers {
3737

3838
/// Manager to handle model downloading device and storing downloaded model info to persistent storage.
3939
class ModelDownloadTask: NSObject {
40-
let app: FirebaseApp
40+
private let appName: String
4141
private(set) var modelInfo: ModelInfo
4242
private var downloadTask: URLSessionDownloadTask?
4343
private let downloadHandlers: DownloadHandlers
@@ -48,11 +48,11 @@ class ModelDownloadTask: NSObject {
4848
delegate: self,
4949
delegateQueue: nil)
5050

51-
init(app: FirebaseApp, modelInfo: ModelInfo,
51+
init(modelInfo: ModelInfo, appName: String,
5252
progressHandler: DownloadHandlers.ProgressHandler? = nil,
5353
completion: @escaping DownloadHandlers.Completion) {
54-
self.app = app
5554
self.modelInfo = modelInfo
55+
self.appName = appName
5656
downloadHandlers = DownloadHandlers(
5757
progressHandler: progressHandler,
5858
completion: completion
@@ -90,7 +90,7 @@ extension ModelDownloadTask: URLSessionDownloadDelegate {
9090
modelInfo.path = savedURL.absoluteString
9191
/// Write model to user defaults.
9292
do {
93-
try modelInfo.writeToDefaults(app: app, defaults: .firebaseMLDefaults)
93+
try modelInfo.writeToDefaults(.firebaseMLDefaults, appName: appName)
9494
} catch {
9595
downloadHandlers
9696
.completion(.failure(.internalError(description: error.localizedDescription)))
@@ -123,7 +123,7 @@ extension ModelDownloadTask: URLSessionDownloadDelegate {
123123
/// Extension to handle post-download operations.
124124
extension ModelDownloadTask {
125125
var downloadedModelFileName: String {
126-
return "fbml_model__\(app.name)__\(modelInfo.name)"
126+
return "fbml_model__\(appName)__\(modelInfo.name)"
127127
}
128128

129129
/// Build custom model object from model info.

FirebaseMLModelDownloader/Sources/ModelDownloader.swift

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
import Foundation
16+
import FirebaseCore
1617

1718
/// Possible errors with model downloading.
1819
public enum DownloadError: Error, Equatable {
@@ -49,7 +50,54 @@ public enum ModelDownloadType {
4950
}
5051

5152
/// Downloader to manage custom model downloads.
52-
public struct ModelDownloader {
53+
public class ModelDownloader {
54+
/// Name of the app associated with this instance of ModelDownloader.
55+
private let appName: String
56+
57+
/// Shared dictionary mapping app name to a specific instance of model downloader.
58+
// TODO: Switch to using Firebase components.
59+
private static var modelDownloaderDictionary: [String: ModelDownloader] = [:]
60+
61+
/// Private init for downloader.
62+
private init(app: FirebaseApp) {
63+
appName = app.name
64+
NotificationCenter.default.addObserver(
65+
self,
66+
selector: #selector(deleteModelDownloader),
67+
name: Notification.Name("FIRAppDeleteNotification"),
68+
object: nil
69+
)
70+
}
71+
72+
/// Handles app deletion notification.
73+
@objc private func deleteModelDownloader(notification: Notification) {
74+
if let userInfo = notification.userInfo,
75+
let appName = userInfo["FIRAppNameKey"] as? String {
76+
ModelDownloader.modelDownloaderDictionary.removeValue(forKey: appName)
77+
// TODO: Clean up user defaults
78+
// TODO: Clean up local instances of app
79+
}
80+
}
81+
82+
/// Model downloader with default app.
83+
public static func modelDownloader() -> ModelDownloader {
84+
guard let defaultApp = FirebaseApp.app() else {
85+
fatalError("Default Firebase app not configured.")
86+
}
87+
return modelDownloader(app: defaultApp)
88+
}
89+
90+
/// Model Downloader with custom app.
91+
public static func modelDownloader(app: FirebaseApp) -> ModelDownloader {
92+
if let downloader = modelDownloaderDictionary[app.name] {
93+
return downloader
94+
} else {
95+
let downloader = ModelDownloader(app: app)
96+
modelDownloaderDictionary[app.name] = downloader
97+
return downloader
98+
}
99+
}
100+
53101
/// Downloads a custom model to device or gets a custom model already on device, w/ optional handler for progress.
54102
public func getModel(name modelName: String, downloadType: ModelDownloadType,
55103
conditions: ModelDownloadConditions,

FirebaseMLModelDownloader/Sources/ModelInfo.swift

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,15 @@ struct ModelInfo {
4141
self.size = size
4242
}
4343

44-
init?(fromDefaults defaults: UserDefaults, name: String, app: FirebaseApp) {
44+
/// Get user defaults key prefix.
45+
private static func getUserDefaultsKeyPrefix(appName: String, modelName: String) -> String {
4546
let bundleID = Bundle.main.bundleIdentifier ?? ""
46-
let defaultsPrefix = "\(bundleID).\(app.name).\(name)"
47+
return "\(bundleID).\(appName).\(modelName)"
48+
}
49+
50+
// TODO: Move reading and writing to user defaults to a new file.
51+
init?(fromDefaults defaults: UserDefaults, modelName: String, appName: String) {
52+
let defaultsPrefix = ModelInfo.getUserDefaultsKeyPrefix(appName: appName, modelName: modelName)
4753
guard let downloadURL = defaults
4854
.value(forKey: "\(defaultsPrefix).model-download-url") as? String,
4955
let url = URL(string: downloadURL),
@@ -52,20 +58,19 @@ struct ModelInfo {
5258
let path = defaults.value(forKey: "\(defaultsPrefix).model-path") as? String else {
5359
return nil
5460
}
55-
self.name = name
61+
name = modelName
5662
self.downloadURL = url
5763
self.modelHash = modelHash
5864
self.size = size
5965
self.path = path
6066
}
6167

62-
func writeToDefaults(app: FirebaseApp, defaults: UserDefaults) throws {
68+
func writeToDefaults(_ defaults: UserDefaults, appName: String) throws {
6369
guard let modelPath = path else {
6470
throw DownloadedModelError
6571
.fileIOError(description: "Could not save model info to user defaults.")
6672
}
67-
let bundleID = Bundle.main.bundleIdentifier ?? ""
68-
let defaultsPrefix = "\(bundleID).\(app.name).\(name)"
73+
let defaultsPrefix = ModelInfo.getUserDefaultsKeyPrefix(appName: appName, modelName: name)
6974
defaults.setValue(downloadURL.absoluteString, forKey: "\(defaultsPrefix).model-download-url")
7075
defaults.setValue(modelHash, forKey: "\(defaultsPrefix).model-hash")
7176
defaults.setValue(size, forKey: "\(defaultsPrefix).model-size")

FirebaseMLModelDownloader/Sources/ModelInfoRetriever.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ extension ModelInfoResponse {
3434

3535
/// Model info retriever for a model from local user defaults or server.
3636
class ModelInfoRetriever: NSObject {
37-
/// Current Firebase app.
38-
var app: FirebaseApp
37+
/// Current Firebase app options
38+
private var options: FirebaseOptions
3939
/// Model info associated with model.
4040
var modelInfo: ModelInfo?
4141
/// Model name.
@@ -44,10 +44,10 @@ class ModelInfoRetriever: NSObject {
4444
var installations: Installations
4545

4646
/// Associate model info retriever with current Firebase app, and model name.
47-
init(app: FirebaseApp, modelName: String) {
48-
self.app = app
47+
init(modelName: String, options: FirebaseOptions, installations: Installations) {
4948
self.modelName = modelName
50-
installations = Installations.installations(app: app)
49+
self.options = options
50+
self.installations = installations
5151
}
5252
}
5353

@@ -70,8 +70,8 @@ extension ModelInfoRetriever {
7070

7171
/// Construct model fetch base URL.
7272
var modelInfoFetchURL: URL {
73-
let projectID = app.options.projectID ?? ""
74-
let apiKey = app.options.apiKey
73+
let projectID = options.projectID ?? ""
74+
let apiKey = options.apiKey
7575
var components = URLComponents()
7676
components.scheme = "https"
7777
components.host = "firebaseml.googleapis.com"

FirebaseMLModelDownloader/Tests/Integration/ModelDownloaderIntegrationTests.swift

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import XCTest
1616
@testable import FirebaseCore
17+
@testable import FirebaseInstallations
1718
@testable import FirebaseMLModelDownloader
1819

1920
extension UserDefaults {
@@ -47,8 +48,9 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
4748
}
4849
let testModelName = "image-classification"
4950
let modelInfoRetriever = ModelInfoRetriever(
50-
app: testApp,
51-
modelName: testModelName
51+
modelName: testModelName,
52+
options: testApp.options,
53+
installations: Installations.installations(app: testApp)
5254
)
5355
let expectation = self.expectation(description: "Wait for FIS auth token.")
5456
modelInfoRetriever.getAuthToken(completion: { result in
@@ -72,8 +74,9 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
7274
}
7375
let testModelName = "pose-detection"
7476
let modelInfoRetriever = ModelInfoRetriever(
75-
app: testApp,
76-
modelName: testModelName
77+
modelName: testModelName,
78+
options: testApp.options,
79+
installations: Installations.installations(app: testApp)
7780
)
7881
let downloadExpectation = expectation(description: "Wait for model info to download.")
7982
modelInfoRetriever.downloadModelInfo(completion: { error in
@@ -111,10 +114,10 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
111114
let functionName = #function.dropLast(2)
112115
let testModelName = "\(functionName)-test-model"
113116
let modelInfoRetriever = ModelInfoRetriever(
114-
app: testApp,
115-
modelName: testModelName
117+
modelName: testModelName,
118+
options: testApp.options,
119+
installations: Installations.installations(app: testApp)
116120
)
117-
118121
let urlString =
119122
"https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
120123
let url = URL(string: urlString)!
@@ -127,8 +130,7 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
127130
)
128131
let expectation = self.expectation(description: "Wait for model to download.")
129132
let modelDownloadManager = ModelDownloadTask(
130-
app: testApp,
131-
modelInfo: modelInfoRetriever.modelInfo!,
133+
modelInfo: modelInfoRetriever.modelInfo!, appName: testApp.name,
132134
progressHandler: { progress in
133135
XCTAssertNotNil(progress)
134136
}

FirebaseMLModelDownloader/Tests/Unit/ModelDownloaderUnitTests.swift

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import XCTest
1616
@testable import FirebaseCore
17+
@testable import FirebaseInstallations
1718
@testable import FirebaseMLModelDownloader
1819

1920
/// Mock options to configure default Firebase app.
@@ -73,21 +74,21 @@ final class ModelDownloaderUnitTests: XCTestCase {
7374
)
7475
// This fails because there is no model path.
7576
do {
76-
try modelInfo.writeToDefaults(app: testApp, defaults: .getTestInstance())
77+
try modelInfo.writeToDefaults(.getTestInstance(), appName: testApp.name)
7778
} catch {
7879
XCTAssertNotNil(error)
7980
}
8081
modelInfo.path = testModelPath
8182
// This shouldn't fail because model info object is now complete.
8283
do {
83-
try modelInfo.writeToDefaults(app: testApp, defaults: .getTestInstance())
84+
try modelInfo.writeToDefaults(.getTestInstance(), appName: testApp.name)
8485
} catch {
8586
XCTFail(error.localizedDescription)
8687
}
8788
guard let savedModelInfo = ModelInfo(
8889
fromDefaults: .getTestInstance(cleared: false),
89-
name: testModelName,
90-
app: testApp
90+
modelName: testModelName,
91+
appName: testApp.name
9192
) else {
9293
XCTFail("Model info not saved to user defaults.")
9394
return
@@ -106,8 +107,9 @@ final class ModelDownloaderUnitTests: XCTestCase {
106107
let functionName = #function
107108
let testModelName = "\(functionName)-test-model"
108109
let modelInfoRetriever = ModelInfoRetriever(
109-
app: testApp,
110-
modelName: testModelName
110+
modelName: testModelName,
111+
options: testApp.options,
112+
installations: Installations.installations(app: testApp)
111113
)
112114
let sampleResponse: String = """
113115
{
@@ -137,7 +139,18 @@ final class ModelDownloaderUnitTests: XCTestCase {
137139
// This is an example of a functional test case.
138140
// Use XCTAssert and related functions to verify your tests produce the correct
139141
// results.
140-
let modelDownloader = ModelDownloader()
142+
guard let testApp = FirebaseApp.app() else {
143+
XCTFail("Default app was not configured.")
144+
return
145+
}
146+
147+
let modelDownloader = ModelDownloader.modelDownloader()
148+
149+
let modelDownloaderWithApp = ModelDownloader.modelDownloader(app: testApp)
150+
151+
/// These should point to the same instance.
152+
XCTAssert(modelDownloader === modelDownloaderWithApp)
153+
141154
let conditions = ModelDownloadConditions()
142155

143156
// Download model w/ progress handler
@@ -161,7 +174,7 @@ final class ModelDownloaderUnitTests: XCTestCase {
161174
}
162175

163176
// Access array of downloaded models
164-
modelDownloader.listDownloadedModels { result in
177+
modelDownloaderWithApp.listDownloadedModels { result in
165178
switch result {
166179
case .success:
167180
// Pick model(s) for further use

0 commit comments

Comments
 (0)