Skip to content

Commit 7228fde

Browse files
lawmichaatierian
andauthored
fix(datastore): retry on subscription connection error (#2571)
* fix(datastore): retry on subscription connection error * duplicate code for taskRunner and unit tests * Update AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift Co-authored-by: Ian Saultz <[email protected]> Co-authored-by: Ian Saultz <[email protected]>
1 parent 2c53d42 commit 7228fde

File tree

6 files changed

+139
-8
lines changed

6 files changed

+139
-8
lines changed

AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ public class AWSGraphQLSubscriptionTaskRunner<R: Decodable>: InternalTaskRunner,
174174
return
175175
} else if case ConnectionProviderError.unauthorized = error {
176176
errorDescription += ": \(APIError.UnauthorizedMessageString)"
177+
} else if case ConnectionProviderError.connection = error {
178+
errorDescription += ": connection"
179+
let error = URLError(.networkConnectionLost)
180+
fail(APIError.networkError(errorDescription, nil, error))
181+
return
177182
}
178183

179184
fail(APIError.operationError(errorDescription, "", error))
@@ -361,8 +366,13 @@ final public class AWSGraphQLSubscriptionOperation<R: Decodable>: GraphQLSubscri
361366
return
362367
} else if case ConnectionProviderError.unauthorized = error {
363368
errorDescription += ": \(APIError.UnauthorizedMessageString)"
369+
} else if case ConnectionProviderError.connection = error {
370+
errorDescription += ": connection"
371+
let error = URLError(.networkConnectionLost)
372+
dispatch(result: .failure(APIError.networkError(errorDescription, nil, error)))
373+
finish()
374+
return
364375
}
365-
366376
dispatch(result: .failure(APIError.operationError(errorDescription, "", error)))
367377
finish()
368378
}

AmplifyPlugins/API/Tests/APIHostApp/APIHostApp.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeTaskTests.swift

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class GraphQLSubscribeTasksTests: OperationTestBase {
3737

3838
var connectionStateSink: AnyCancellable?
3939
var subscriptionDataSink: AnyCancellable?
40+
var expectedCompletionFailureError: APIError?
4041

4142
override func setUp() async throws {
4243
try await super.setUp()
@@ -109,7 +110,7 @@ class GraphQLSubscribeTasksTests: OperationTestBase {
109110
await waitForSubscriptionExpectations()
110111
}
111112

112-
func testConnectionError() async throws {
113+
func testConnectionErrorWithLimitExceeded() async throws {
113114
await receivedCompletionSuccess.setShouldTrigger(false)
114115
await receivedCompletionFailure.setShouldTrigger(true)
115116
await receivedStateValueConnecting.setShouldTrigger(true)
@@ -123,8 +124,65 @@ class GraphQLSubscribeTasksTests: OperationTestBase {
123124
await waitForExpectations([onSubscribeInvoked], timeout: 0.05)
124125

125126
subscriptionEventHandler(.connection(.connecting), subscriptionItem)
126-
subscriptionEventHandler(.failed("Error"), subscriptionItem)
127+
subscriptionEventHandler(.failed(ConnectionProviderError.limitExceeded(nil)), subscriptionItem)
128+
expectedCompletionFailureError = APIError.operationError("", "", ConnectionProviderError.limitExceeded(nil))
129+
await waitForSubscriptionExpectations()
130+
}
131+
132+
func testConnectionErrorWithSubscriptionError() async throws {
133+
await receivedCompletionSuccess.setShouldTrigger(false)
134+
await receivedCompletionFailure.setShouldTrigger(true)
135+
await receivedStateValueConnecting.setShouldTrigger(true)
136+
await receivedStateValueConnected.setShouldTrigger(false)
137+
await receivedStateValueDisconnected.setShouldTrigger(false)
138+
139+
await receivedDataValueSuccess.setShouldTrigger(false)
140+
await receivedDataValueError.setShouldTrigger(false)
141+
142+
try await subscribe()
143+
await waitForExpectations([onSubscribeInvoked], timeout: 0.05)
127144

145+
subscriptionEventHandler(.connection(.connecting), subscriptionItem)
146+
subscriptionEventHandler(.failed(ConnectionProviderError.subscription("", nil)), subscriptionItem)
147+
expectedCompletionFailureError = APIError.operationError("", "", ConnectionProviderError.subscription("", nil))
148+
await waitForSubscriptionExpectations()
149+
}
150+
151+
func testConnectionErrorWithConnectionUnauthorizedError() async throws {
152+
await receivedCompletionSuccess.setShouldTrigger(false)
153+
await receivedCompletionFailure.setShouldTrigger(true)
154+
await receivedStateValueConnecting.setShouldTrigger(true)
155+
await receivedStateValueConnected.setShouldTrigger(false)
156+
await receivedStateValueDisconnected.setShouldTrigger(false)
157+
158+
await receivedDataValueSuccess.setShouldTrigger(false)
159+
await receivedDataValueError.setShouldTrigger(false)
160+
161+
try await subscribe()
162+
await waitForExpectations([onSubscribeInvoked], timeout: 0.05)
163+
164+
subscriptionEventHandler(.connection(.connecting), subscriptionItem)
165+
subscriptionEventHandler(.failed(ConnectionProviderError.unauthorized), subscriptionItem)
166+
expectedCompletionFailureError = APIError.operationError("", "", ConnectionProviderError.unauthorized)
167+
await waitForSubscriptionExpectations()
168+
}
169+
170+
func testConnectionErrorWithConnectionProviderConnectionError() async throws {
171+
await receivedCompletionSuccess.setShouldTrigger(false)
172+
await receivedCompletionFailure.setShouldTrigger(true)
173+
await receivedStateValueConnecting.setShouldTrigger(true)
174+
await receivedStateValueConnected.setShouldTrigger(false)
175+
await receivedStateValueDisconnected.setShouldTrigger(false)
176+
177+
await receivedDataValueSuccess.setShouldTrigger(false)
178+
await receivedDataValueError.setShouldTrigger(false)
179+
180+
try await subscribe()
181+
await waitForExpectations([onSubscribeInvoked], timeout: 0.05)
182+
183+
subscriptionEventHandler(.connection(.connecting), subscriptionItem)
184+
subscriptionEventHandler(.failed(ConnectionProviderError.connection), subscriptionItem)
185+
expectedCompletionFailureError = APIError.networkError("", nil, URLError(.networkConnectionLost))
128186
await waitForSubscriptionExpectations()
129187
}
130188

@@ -270,8 +328,53 @@ class GraphQLSubscribeTasksTests: OperationTestBase {
270328

271329
await self.receivedCompletionSuccess.fulfill()
272330
} catch {
331+
if let apiError = error as? APIError,
332+
let expectedError = expectedCompletionFailureError {
333+
XCTAssertEqual(apiError, expectedError)
334+
}
335+
273336
await self.receivedCompletionFailure.fulfill()
274337
}
275338
}
276339
}
277340
}
341+
342+
extension APIError: Equatable {
343+
public static func == (lhs: APIError, rhs: APIError) -> Bool {
344+
switch (lhs, rhs) {
345+
case (.unknown, .unknown),
346+
(.invalidConfiguration, .invalidConfiguration),
347+
(.httpStatusError, .httpStatusError),
348+
(.pluginError, .pluginError):
349+
return true
350+
case (.operationError(_, _, let lhs), .operationError(_, _, let rhs)):
351+
if let lhs = lhs as? ConnectionProviderError, let rhs = rhs as? ConnectionProviderError {
352+
switch (lhs, rhs) {
353+
case (.connection, .connection),
354+
(.jsonParse, .jsonParse),
355+
(.limitExceeded, .limitExceeded),
356+
(.subscription, .subscription),
357+
(.unauthorized, .unauthorized),
358+
(.unknown, .unknown):
359+
return true
360+
default:
361+
return false
362+
}
363+
} else if lhs == nil && rhs == nil {
364+
return true
365+
} else {
366+
return false
367+
}
368+
case (.networkError(_, _, let lhs), .networkError(_, _, let rhs)):
369+
if let lhs = lhs as? URLError, let rhs = rhs as? URLError {
370+
return lhs.code == rhs.code
371+
} else if lhs == nil && rhs == nil {
372+
return true
373+
} else {
374+
return false
375+
}
376+
default:
377+
return false
378+
}
379+
}
380+
}

AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Sync/RemoteSyncEngine+Retryable.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ extension RemoteSyncEngine {
3535
urlErrorOptional = underlyingError
3636
} else if let urlError = error as? URLError {
3737
urlErrorOptional = urlError
38+
} else if let dataStoreError = error as? DataStoreError,
39+
case .api(let amplifyError, _) = dataStoreError,
40+
let apiError = amplifyError as? APIError,
41+
case .networkError(_, _, let error) = apiError,
42+
let urlError = error as? URLError {
43+
urlErrorOptional = urlError
3844
}
3945

4046
let advice = requestRetryablePolicy.retryRequestAdvice(urlError: urlErrorOptional,

AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Sync/RequestRetryablePolicy.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class RequestRetryablePolicy: RequestRetryable {
3939
.cannotFindHost,
4040
.timedOut,
4141
.dataNotAllowed,
42-
.cannotParseResponse:
42+
.cannotParseResponse,
43+
.networkConnectionLost:
4344
let waitMillis = retryDelayInMillseconds(for: attemptNumber)
4445
return RequestRetryAdvice(shouldRetry: true, retryInterval: .milliseconds(waitMillis))
4546
default:

AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/RequestRetryablePolicyTests.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,17 @@ class RequestRetryablePolicyTests: XCTestCase {
187187
XCTAssert(retryAdvice.shouldRetry)
188188
assertMilliseconds(retryAdvice.retryInterval, greaterThan: 200, lessThan: 300)
189189
}
190+
191+
func testNetworkConnectionLostError() {
192+
let retryableErrorCode = URLError.init(.networkConnectionLost)
193+
194+
let retryAdvice = retryPolicy.retryRequestAdvice(urlError: retryableErrorCode,
195+
httpURLResponse: nil,
196+
attemptNumber: 1)
197+
198+
XCTAssert(retryAdvice.shouldRetry)
199+
assertMilliseconds(retryAdvice.retryInterval, greaterThan: 200, lessThan: 300)
200+
}
190201

191202
func testHTTPTooManyRedirectsError() {
192203
let nonRetryableErrorCode = URLError.init(.httpTooManyRedirects)

0 commit comments

Comments
 (0)