Skip to content

Commit 5497c02

Browse files
committed
chore(CoreMLPredictionsPlugin): Performing simulator work on CPU and surfacing errors
1 parent 159fca7 commit 5497c02

File tree

5 files changed

+76
-24
lines changed

5 files changed

+76
-24
lines changed

AmplifyPlugins/Predictions/CoreMLPredictionsPlugin/CoreMLPredictionsPlugin+ClientBehavior.swift

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ extension CoreMLPredictionsPlugin {
4343
.detectLabels(.moderation, _):
4444
throw predictionsError
4545
case let .detectText(lift):
46-
guard let result = coreMLVisionAdapter.detectText(image) else {
46+
guard let result = try coreMLVisionAdapter.detectText(image) else {
4747
let errorDescription = CoreMLPluginErrorString.detectTextNoResult.errorDescription
4848
let recovery = CoreMLPluginErrorString.detectTextNoResult.recoverySuggestion
4949
let predictionsError = PredictionsError.service(
@@ -53,7 +53,7 @@ extension CoreMLPredictionsPlugin {
5353
}
5454
return lift.outputSpecificToGeneric(result)
5555
case let .detectLabels(_, lift):
56-
guard let result = coreMLVisionAdapter.detectLabels(image) else {
56+
guard let result = try coreMLVisionAdapter.detectLabels(image) else {
5757
let errorDescription = CoreMLPluginErrorString.detectLabelsNoResult.errorDescription
5858
let recovery = CoreMLPluginErrorString.detectLabelsNoResult.recoverySuggestion
5959
let predictionsError = PredictionsError.service(
@@ -90,14 +90,18 @@ extension CoreMLPredictionsPlugin {
9090
)
9191
let stream = AsyncThrowingStream<Predictions.Convert.SpeechToText.Result, Error> { continuation in
9292
Task {
93-
let result = try await coreMLSpeech.getTranscription(
94-
request.speechToText
95-
)
96-
continuation.yield(
97-
.init(transcription: result.bestTranscription.formattedString)
98-
)
99-
if result.isFinal {
100-
continuation.finish()
93+
do {
94+
let result = try await coreMLSpeech.getTranscription(
95+
request.speechToText
96+
)
97+
continuation.yield(
98+
.init(transcription: result.bestTranscription.formattedString)
99+
)
100+
if result.isFinal {
101+
continuation.finish()
102+
}
103+
} catch {
104+
continuation.yield(with: .failure(error))
101105
}
102106
}
103107
}

AmplifyPlugins/Predictions/CoreMLPredictionsPlugin/Dependency/CoreMLSpeechAdapter.swift

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,24 @@ class CoreMLSpeechAdapter: CoreMLSpeechBehavior {
1212
func getTranscription(_ audioData: URL) async throws -> SFSpeechRecognitionResult {
1313
let request = SFSpeechURLRecognitionRequest(url: audioData)
1414
request.requiresOnDeviceRecognition = true
15-
let recognizer = SFSpeechRecognizer()
15+
guard let recognizer = SFSpeechRecognizer() else {
16+
throw PredictionsError.client(
17+
.init(
18+
description: "CoreML Service is not configured",
19+
recoverySuggestion: "Ensure that dictation is enabled on your device."
20+
)
21+
)
22+
}
1623

1724
let result = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<SFSpeechRecognitionResult, Error>) in
18-
recognizer?.recognitionTask(
25+
recognizer.recognitionTask(
1926
with: request,
20-
resultHandler: { result, _ in
27+
resultHandler: { (result, error) in
28+
if let error = error {
29+
continuation.resume(with: .failure(error))
30+
return
31+
}
32+
2133
guard let result = result else {
2234
continuation.resume(with: .failure(
2335
PredictionsError.client(

AmplifyPlugins/Predictions/CoreMLPredictionsPlugin/Dependency/CoreMLVisionAdapter.swift

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ import Vision
1010

1111
class CoreMLVisionAdapter: CoreMLVisionBehavior {
1212

13-
public func detectLabels(_ imageURL: URL) -> Predictions.Identify.Labels.Result? {
13+
func detectLabels(_ imageURL: URL) throws -> Predictions.Identify.Labels.Result? {
1414
var labelsResult = [Predictions.Label]()
1515
let handler = VNImageRequestHandler(url: imageURL, options: [:])
1616
let request = VNClassifyImageRequest()
17-
try? handler.perform([request])
17+
#if targetEnvironment(simulator)
18+
request.usesCPUOnly = true
19+
#endif
20+
try handler.perform([request])
1821
guard let observations = request.results else { return nil }
1922

2023
let categories = observations.filter { $0.hasMinimumRecall(0.01, forPrecision: 0.9) }
@@ -26,11 +29,14 @@ class CoreMLVisionAdapter: CoreMLVisionBehavior {
2629
return Predictions.Identify.Labels.Result(labels: labelsResult)
2730
}
2831

29-
public func detectText(_ imageURL: URL) -> Predictions.Identify.Text.Result? {
32+
public func detectText(_ imageURL: URL) throws -> Predictions.Identify.Text.Result? {
3033
let handler = VNImageRequestHandler(url: imageURL, options: [:])
3134
let request = VNRecognizeTextRequest()
35+
#if targetEnvironment(simulator)
36+
request.usesCPUOnly = true
37+
#endif
3238
request.recognitionLevel = .accurate
33-
try? handler.perform([request])
39+
try handler.perform([request])
3440
guard let observations = request.results else { return nil }
3541

3642
var identifiedLines = [Predictions.IdentifiedLine]()
@@ -62,10 +68,13 @@ class CoreMLVisionAdapter: CoreMLVisionBehavior {
6268
)
6369
}
6470

65-
func detectEntities(_ imageURL: URL) -> Predictions.Identify.Entities.Result? {
71+
func detectEntities(_ imageURL: URL) throws -> Predictions.Identify.Entities.Result? {
6672
let handler = VNImageRequestHandler(url: imageURL, options: [:])
6773
let faceLandmarksRequest = VNDetectFaceLandmarksRequest()
68-
try? handler.perform([faceLandmarksRequest])
74+
#if targetEnvironment(simulator)
75+
faceLandmarksRequest.usesCPUOnly = true
76+
#endif
77+
try handler.perform([faceLandmarksRequest])
6978
guard let observations = faceLandmarksRequest.results else { return nil }
7079

7180
var entities: [Predictions.Entity] = []

AmplifyPlugins/Predictions/CoreMLPredictionsPlugin/Dependency/CoreMLVisionBehavior.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Foundation
99
import Amplify
1010

1111
protocol CoreMLVisionBehavior: AnyObject {
12-
func detectLabels(_ imageURL: URL) -> Predictions.Identify.Labels.Result?
13-
func detectText(_ imageURL: URL) -> Predictions.Identify.Text.Result?
14-
func detectEntities(_ imageURL: URL) -> Predictions.Identify.Entities.Result?
12+
func detectLabels(_ imageURL: URL) throws -> Predictions.Identify.Labels.Result?
13+
func detectText(_ imageURL: URL) throws -> Predictions.Identify.Text.Result?
14+
func detectEntities(_ imageURL: URL) throws -> Predictions.Identify.Entities.Result?
1515
}

AmplifyPlugins/Predictions/Tests/PredictionsHostApp/CoreMLPredictionsPluginIntegrationTests/CoreMLPredictionsPluginIntegrationTest.swift

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,44 @@
88
import XCTest
99
import Amplify
1010
import XCTest
11+
import AVFoundation
1112

1213
class CoreMLPredictionsPluginIntegrationTest: AWSPredictionsPluginTestBase {
14+
1315
func testIdentify() async throws {
1416
let testBundle = Bundle(for: type(of: self))
1517
let url = try XCTUnwrap(testBundle.url(forResource: "people", withExtension: "jpg"))
1618

17-
let result = try await Amplify.Predictions.identify(
19+
let result: Predictions.Identify.Labels.Result = try await Amplify.Predictions.identify(
1820
.labels(type: .all),
1921
in: url
2022
)
2123

22-
XCTAssertNotNil(result, "Result should contain value")
24+
XCTAssertEqual(result.labels.count, 0, String(describing: result))
25+
XCTAssertNil(result.unsafeContent, String(describing: result))
26+
}
27+
28+
func testConvertSpeechToText() async throws {
29+
let testBundle = Bundle(for: type(of: self))
30+
let url = try XCTUnwrap(testBundle.url(forResource: "audio", withExtension: "wav"))
31+
32+
let options = Predictions.Convert.SpeechToText.Options(
33+
defaultNetworkPolicy: .auto,
34+
language: .usEnglish
35+
)
36+
37+
let result = try await Amplify.Predictions.convert(
38+
.speechToText(url: url), options: options
39+
)
40+
let responses = result.map(\.transcription)
41+
42+
do {
43+
for try await response in responses {
44+
XCTFail("Expecting failure but got: \(response)")
45+
}
46+
} catch let error as NSError {
47+
XCTAssertEqual(error.code, 201)
48+
XCTAssertEqual(error.localizedDescription, "Siri and Dictation are disabled")
49+
}
2350
}
2451
}

0 commit comments

Comments
 (0)