Skip to content

Commit 5d54b5d

Browse files
authored
feat(predictions): add web socket retry for clock skew (#3816)
* feat(predictions): add web socket retry for clock skew * address review comments
1 parent 611368c commit 5d54b5d

File tree

3 files changed

+105
-34
lines changed

3 files changed

+105
-34
lines changed

AmplifyPlugins/Predictions/AWSPredictionsPlugin/Dependency/AWSTranscribeStreamingAdapter.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,17 @@ class AWSTranscribeStreamingAdapter: AWSTranscribeStreamingBehavior {
131131
continuation.yield(transcribedPayload)
132132
let isPartial = transcribedPayload.transcript?.results?.map(\.isPartial) ?? []
133133
let shouldContinue = isPartial.allSatisfy { $0 }
134-
return shouldContinue
134+
return shouldContinue ? .continueToReceive : .stopAndInvalidateSession
135135
} catch {
136-
return true
136+
return .continueToReceive
137137
}
138138
case .success(.string):
139-
return true
139+
return .continueToReceive
140140
case .failure(let error):
141141
continuation.finish(throwing: error)
142-
return false
142+
return .stopAndInvalidateSession
143143
@unknown default:
144-
return true
144+
return .continueToReceive
145145
}
146146
}
147147
}

AmplifyPlugins/Predictions/AWSPredictionsPlugin/Liveness/Service/FaceLivenessSession.swift

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ public final class FaceLivenessSession: LivenessService {
1717
let baseURL: URL
1818
var serverEventListeners: [LivenessEventKind.Server: (FaceLivenessSession.SessionConfiguration) -> Void] = [:]
1919
var onComplete: (ServerDisconnection) -> Void = { _ in }
20+
var serverDate: Date?
21+
var savedURLForReconnect: URL?
22+
var connectingState: ConnectingState = .normal
23+
24+
enum ConnectingState {
25+
case normal
26+
case reconnect
27+
}
2028

2129
private let livenessServiceDispatchQueue = DispatchQueue(
2230
label: "com.amazon.aws.amplify.liveness.service",
@@ -35,12 +43,16 @@ public final class FaceLivenessSession: LivenessService {
3543
self.websocket = websocket
3644

3745
websocket.onMessageReceived { [weak self] result in
38-
self?.receive(result: result) ?? false
46+
self?.receive(result: result) ?? .stopAndInvalidateSession
3947
}
4048

4149
websocket.onSocketClosed { [weak self] closeCode in
4250
self?.onComplete(.unexpectedClosure(closeCode))
4351
}
52+
53+
websocket.onServerDateReceived { [weak self] serverDate in
54+
self?.serverDate = serverDate
55+
}
4456
}
4557

4658
public var onServiceException: (FaceLivenessSessionError) -> Void = { _ in }
@@ -75,6 +87,7 @@ public final class FaceLivenessSession: LivenessService {
7587
guard let url = components?.url
7688
else { throw FaceLivenessSessionError.invalidURL }
7789

90+
savedURLForReconnect = url
7891
let signedConnectionURL = signer.sign(url: url)
7992
websocket.open(url: signedConnectionURL)
8093
}
@@ -93,17 +106,22 @@ public final class FaceLivenessSession: LivenessService {
93106
]
94107
)
95108

96-
let eventDate = eventDate()
109+
let dateForSigning: Date
110+
if let serverDate = serverDate {
111+
dateForSigning = serverDate
112+
} else {
113+
dateForSigning = eventDate()
114+
}
97115

98116
let signedPayload = self.signer.signWithPreviousSignature(
99117
payload: encodedPayload,
100-
dateHeader: (key: ":date", value: eventDate)
118+
dateHeader: (key: ":date", value: dateForSigning)
101119
)
102120

103121
let encodedEvent = self.eventStreamEncoder.encode(
104122
payload: encodedPayload,
105123
headers: [
106-
":date": .timestamp(eventDate),
124+
":date": .timestamp(dateForSigning),
107125
":chunk-signature": .data(signedPayload)
108126
]
109127
)
@@ -115,7 +133,7 @@ public final class FaceLivenessSession: LivenessService {
115133
}
116134
}
117135

118-
private func fallbackDecoding(_ message: EventStream.Message) -> Bool {
136+
private func fallbackDecoding(_ message: EventStream.Message) -> WebSocketSession.WebSocketMessageResult {
119137
// We only care about two events above.
120138
// Just in case the header value changes (it shouldn't)
121139
// We'll try to decode each of these events
@@ -124,12 +142,12 @@ public final class FaceLivenessSession: LivenessService {
124142
self.serverEventListeners[.challenge]?(sessionConfiguration)
125143
} else if (try? JSONDecoder().decode(DisconnectEvent.self, from: message.payload)) != nil {
126144
onComplete(.disconnectionEvent)
127-
return false
145+
return .stopAndInvalidateSession
128146
}
129-
return true
147+
return .continueToReceive
130148
}
131149

132-
private func receive(result: Result<URLSessionWebSocketTask.Message, Error>) -> Bool {
150+
private func receive(result: Result<URLSessionWebSocketTask.Message, Error>) -> WebSocketSession.WebSocketMessageResult {
133151
switch result {
134152
case .success(.data(let data)):
135153
do {
@@ -145,28 +163,41 @@ public final class FaceLivenessSession: LivenessService {
145163
)
146164
let sessionConfiguration = sessionConfiguration(from: payload)
147165
serverEventListeners[.challenge]?(sessionConfiguration)
148-
return true
166+
return .continueToReceive
149167
case .disconnect:
150168
// :event-type DisconnectionEvent
151169
onComplete(.disconnectionEvent)
152-
return false
170+
return .stopAndInvalidateSession
153171
default:
154-
return true
172+
return .continueToReceive
155173
}
156174
} else if let exceptionType = message.headers.first(where: { $0.name == ":exception-type" }) {
157175
let exceptionEvent = LivenessEventKind.Exception(rawValue: exceptionType.value)
158-
onServiceException(.init(event: exceptionEvent))
159-
return false
176+
Amplify.log.verbose("\(#function): Received exception: \(exceptionEvent)")
177+
guard exceptionEvent == .invalidSignature,
178+
connectingState == .normal,
179+
let savedURLForReconnect = savedURLForReconnect,
180+
let serverDate = serverDate else {
181+
onServiceException(.init(event: exceptionEvent))
182+
return .stopAndInvalidateSession
183+
}
184+
185+
connectingState = .reconnect
186+
let signedConnectionURL = signer.sign(
187+
url: savedURLForReconnect,
188+
date: { serverDate }
189+
)
190+
return .invalidateSessionAndRetry(url: signedConnectionURL)
160191
} else {
161192
return fallbackDecoding(message)
162193
}
163194
} catch {
164-
return false
195+
return .stopAndInvalidateSession
165196
}
166197
case .success:
167-
return true
198+
return .continueToReceive
168199
case .failure:
169-
return false
200+
return .stopAndInvalidateSession
170201
}
171202
}
172203
}

AmplifyPlugins/Predictions/AWSPredictionsPlugin/Liveness/Service/WebSocketSession.swift

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
//
77

88
import Foundation
9+
import Amplify
910

1011
final class WebSocketSession {
1112
private let urlSessionWebSocketDelegate: Delegate
1213
private let session: URLSession
1314
private var task: URLSessionWebSocketTask?
14-
private var receiveMessage: ((Result<URLSessionWebSocketTask.Message, Error>) -> Bool)?
15+
private var receiveMessage: ((Result<URLSessionWebSocketTask.Message, Error>) -> WebSocketMessageResult)?
1516
private var onSocketClosed: ((URLSessionWebSocketTask.CloseCode) -> Void)?
17+
private var onServerDateReceived: ((Date?) -> Void)?
1618

1719
init() {
1820
self.urlSessionWebSocketDelegate = Delegate()
@@ -23,7 +25,7 @@ final class WebSocketSession {
2325
)
2426
}
2527

26-
func onMessageReceived(_ receive: @escaping (Result<URLSessionWebSocketTask.Message, Error>) -> Bool) {
28+
func onMessageReceived(_ receive: @escaping (Result<URLSessionWebSocketTask.Message, Error>) -> WebSocketMessageResult) {
2729
self.receiveMessage = receive
2830
}
2931

@@ -34,25 +36,32 @@ final class WebSocketSession {
3436
func onSocketOpened(_ onOpen: @escaping () -> Void) {
3537
urlSessionWebSocketDelegate.onOpen = onOpen
3638
}
39+
40+
func onServerDateReceived(_ onServerDateReceived: @escaping (Date?) -> Void) {
41+
urlSessionWebSocketDelegate.onServerDateReceived = onServerDateReceived
42+
}
3743

38-
func receive(shouldContinue: Bool) {
39-
guard shouldContinue else {
44+
func receive(result: WebSocketMessageResult) {
45+
switch result {
46+
case .continueToReceive:
47+
task?.receive(completionHandler: { [weak self] result in
48+
if let webSocketResult = self?.receiveMessage?(result) {
49+
self?.receive(result: webSocketResult)
50+
}
51+
})
52+
case .stopAndInvalidateSession:
53+
session.finishTasksAndInvalidate()
54+
case .invalidateSessionAndRetry(let url):
4055
session.finishTasksAndInvalidate()
41-
return
56+
open(url: url)
4257
}
43-
44-
task?.receive(completionHandler: { [weak self] result in
45-
if let shouldContinue = self?.receiveMessage?(result) {
46-
self?.receive(shouldContinue: shouldContinue)
47-
}
48-
})
4958
}
5059

5160
func open(url: URL) {
5261
var request = URLRequest(url: url)
5362
request.setValue("no-store", forHTTPHeaderField: "Cache-Control")
5463
task = session.webSocketTask(with: request)
55-
receive(shouldContinue: true)
64+
receive(result: .continueToReceive)
5665
task?.resume()
5766
}
5867

@@ -77,10 +86,12 @@ final class WebSocketSession {
7786
)
7887
}
7988

80-
final class Delegate: NSObject, URLSessionWebSocketDelegate {
89+
final class Delegate: NSObject, URLSessionWebSocketDelegate, URLSessionTaskDelegate {
8190
var onClose: (URLSessionWebSocketTask.CloseCode) -> Void = { _ in }
8291
var onOpen: () -> Void = {}
92+
var onServerDateReceived: (Date?) -> Void = { _ in }
8393

94+
// MARK: - URLSessionWebSocketDelegate methods
8495
func urlSession(
8596
_ session: URLSession,
8697
webSocketTask: URLSessionWebSocketTask,
@@ -97,5 +108,34 @@ final class WebSocketSession {
97108
) {
98109
onClose(closeCode)
99110
}
111+
112+
// MARK: - URLSessionTaskDelegate methods
113+
func urlSession(_ session: URLSession,
114+
task: URLSessionTask,
115+
didFinishCollecting metrics: URLSessionTaskMetrics
116+
) {
117+
guard let httpResponse = metrics.transactionMetrics.first?.response as? HTTPURLResponse,
118+
let dateString = httpResponse.value(forHTTPHeaderField: "Date") else {
119+
Amplify.log.verbose("\(#function): Couldn't find Date header in URLSession metrics")
120+
onServerDateReceived(nil)
121+
return
122+
}
123+
124+
let dateFormatter = DateFormatter()
125+
dateFormatter.dateFormat = "EEE, d MMM yyyy HH:mm:ss z"
126+
guard let serverDate = dateFormatter.date(from: dateString) else {
127+
Amplify.log.verbose("\(#function): Error parsing Date header in expected format")
128+
onServerDateReceived(nil)
129+
return
130+
}
131+
132+
onServerDateReceived(serverDate)
133+
}
134+
}
135+
136+
enum WebSocketMessageResult {
137+
case continueToReceive
138+
case stopAndInvalidateSession
139+
case invalidateSessionAndRetry(url: URL)
100140
}
101141
}

0 commit comments

Comments
 (0)