Skip to content

Commit 50b556e

Browse files
authored
fix(api): Change the getToken to async (#2856)
1 parent c8ab069 commit 50b556e

File tree

10 files changed

+143
-65
lines changed

10 files changed

+143
-65
lines changed

Amplify/Categories/API/Operation/RetryableGraphQLOperation.swift

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public protocol RetryableGraphQLOperationBehavior: Operation, DefaultLogger {
2121
/// GraphQLOperation concrete type
2222
associatedtype OperationType: AnyGraphQLOperation
2323

24-
typealias RequestFactory = () -> GraphQLRequest<Payload>
24+
typealias RequestFactory = (@escaping (GraphQLRequest<Payload>) -> Void) -> Void
2525
typealias OperationFactory = (GraphQLRequest<Payload>, @escaping OperationResultListener) -> OperationType
2626
typealias OperationResultListener = OperationType.ResultListener
2727

@@ -64,7 +64,10 @@ extension RetryableGraphQLOperationBehavior {
6464
let wrappedResultListener: OperationResultListener = { result in
6565
if case let .failure(error) = result, self.shouldRetry(error: error as? APIError) {
6666
self.log.debug("\(error)")
67-
self.start(request: self.requestFactory())
67+
self.requestFactory { [weak self] request in
68+
self?.start(request: request)
69+
}
70+
6871
return
6972
}
7073

@@ -95,7 +98,7 @@ public final class RetryableGraphQLOperation<Payload: Decodable>: Operation, Ret
9598
public var resultListener: OperationResultListener
9699
public var operationFactory: OperationFactory
97100

98-
public init(requestFactory: @escaping () -> GraphQLRequest<Payload>,
101+
public init(requestFactory: @escaping RetryableGraphQLOperation<Payload>.RequestFactory,
99102
maxRetries: Int,
100103
resultListener: @escaping OperationResultListener,
101104
_ operationFactory: @escaping OperationFactory) {
@@ -106,7 +109,9 @@ public final class RetryableGraphQLOperation<Payload: Decodable>: Operation, Ret
106109
self.resultListener = resultListener
107110
}
108111
public override func main() {
109-
start(request: requestFactory())
112+
requestFactory { [weak self] request in
113+
self?.start(request: request)
114+
}
110115
}
111116

112117
public override func cancel() {
@@ -154,7 +159,9 @@ public final class RetryableGraphQLSubscriptionOperation<Payload: Decodable>: Op
154159
self.resultListener = resultListener
155160
}
156161
public override func main() {
157-
start(request: requestFactory())
162+
requestFactory { [weak self] request in
163+
self?.start(request: request)
164+
}
158165
}
159166

160167
public override func cancel() {

Amplify/Core/Support/Optional+Extension.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
// SPDX-License-Identifier: Apache-2.0
66
//
77

8-
98
import Foundation
109

1110
extension Optional {

AmplifyPlugins/API/AWSAPICategoryPlugin/Operation/AWSRESTOperation.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,4 @@ final public class AWSRESTOperation: AmplifyOperation<
125125
task.resume()
126126
}
127127
}
128+

AmplifyPlugins/API/Podfile.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,4 @@ SPEC CHECKSUMS:
116116

117117
PODFILE CHECKSUM: 5170578806036f2ba018abb8868d56e448fb0ada
118118

119-
COCOAPODS: 1.11.3
119+
COCOAPODS: 1.12.0

AmplifyPlugins/DataStore/AWSDataStoreCategoryPlugin/Sync/InitialSync/InitialSyncOperation.swift

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,16 @@ final class InitialSyncOperation: AsynchronousOperation {
140140
}
141141

142142
var authTypes = authModeStrategy.authTypesFor(schema: modelSchema,
143-
operation: .read)
144-
145-
RetryableGraphQLOperation(requestFactory: {
146-
GraphQLRequest<SyncQueryResult>.syncQuery(modelSchema: self.modelSchema,
147-
where: queryPredicate,
148-
limit: limit,
149-
nextToken: nextToken,
150-
lastSync: lastSyncTime,
151-
authType: authTypes.next())
143+
operation: .read)
144+
145+
RetryableGraphQLOperation(requestFactory: { completion in
146+
completion(GraphQLRequest<SyncQueryResult>.syncQuery(modelSchema: self.modelSchema,
147+
where: queryPredicate,
148+
limit: limit,
149+
nextToken: nextToken,
150+
lastSync: lastSyncTime,
151+
authType: authTypes.next()))
152+
152153
},
153154
maxRetries: authTypes.count,
154155
resultListener: completionListener, { nextRequest, wrappedCompletionListener in

AmplifyPlugins/DataStore/AWSDataStoreCategoryPlugin/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisher.swift

Lines changed: 98 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -196,38 +196,97 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
196196
}
197197

198198
// swiftlint:disable:next function_parameter_count
199-
static func makeAPIRequest(for modelSchema: ModelSchema,
200-
subscriptionType: GraphQLSubscriptionType,
201-
api: APICategoryGraphQLBehavior,
202-
auth: AuthCategoryBehavior?,
203-
authType: AWSAuthorizationType?,
204-
awsAuthService: AWSAuthServiceBehavior) -> GraphQLRequest<Payload> {
205-
let request: GraphQLRequest<Payload>
206-
if modelSchema.hasAuthenticationRules,
207-
auth != nil,
208-
case .success(let tokenString) = awsAuthService.getToken(),
209-
case .success(let claims) = awsAuthService.getTokenClaims(tokenString: tokenString) {
210-
request = GraphQLRequest<Payload>.subscription(to: modelSchema,
211-
subscriptionType: subscriptionType,
212-
claims: claims,
213-
authType: authType)
214-
} else if modelSchema.hasAuthenticationRules,
215-
let oidcAuthProvider = hasOIDCAuthProviderAvailable(api: api),
216-
case .success(let tokenString) = oidcAuthProvider.getLatestAuthToken(),
217-
case .success(let claims) = awsAuthService.getTokenClaims(tokenString: tokenString) {
218-
request = GraphQLRequest<Payload>.subscription(to: modelSchema,
219-
subscriptionType: subscriptionType,
220-
claims: claims,
221-
authType: authType)
199+
static func makeAPIRequest(
200+
for modelSchema: ModelSchema,
201+
subscriptionType: GraphQLSubscriptionType,
202+
api: APICategoryGraphQLBehavior,
203+
auth: AuthCategoryBehavior?,
204+
authType: AWSAuthorizationType?,
205+
awsAuthService: AWSAuthServiceBehavior,
206+
completion: @escaping (GraphQLRequest<Payload>) -> Void) {
207+
208+
let requestWithOutClaims = GraphQLRequest<Payload>.subscription(
209+
to: modelSchema,
210+
subscriptionType: subscriptionType,
211+
authType: authType)
212+
213+
guard modelSchema.hasAuthenticationRules else {
214+
completion(requestWithOutClaims)
215+
return
216+
}
217+
218+
getClaims(api: api,
219+
auth: auth,
220+
awsAuthService: awsAuthService) { claims in
221+
222+
guard let claims = claims else {
223+
completion(requestWithOutClaims)
224+
return
225+
}
226+
let request = GraphQLRequest<Payload>.subscription(
227+
to: modelSchema,
228+
subscriptionType: subscriptionType,
229+
claims: claims,
230+
authType: authType)
231+
completion(request)
232+
return
233+
}
234+
235+
}
236+
237+
static func getClaims(api: APICategoryGraphQLBehavior,
238+
auth: AuthCategoryBehavior?,
239+
awsAuthService: AWSAuthServiceBehavior,
240+
completion: @escaping ([String: AnyObject]?) -> Void) {
241+
if auth != nil {
242+
getClaimsFromUserPool(awsAuthService: awsAuthService) { claims in
243+
if let claims = claims {
244+
completion(claims)
245+
} else {
246+
getClaimsFromOIDCProvider(
247+
api: api,
248+
awsAuthService: awsAuthService,
249+
completion: completion)
250+
}
251+
}
222252
} else {
223-
request = GraphQLRequest<Payload>.subscription(to: modelSchema,
224-
subscriptionType: subscriptionType,
225-
authType: authType)
253+
getClaimsFromOIDCProvider(
254+
api: api,
255+
awsAuthService: awsAuthService,
256+
completion: completion)
226257
}
227258

228-
return request
229259
}
230260

261+
static func getClaimsFromUserPool(
262+
awsAuthService: AWSAuthServiceBehavior,
263+
completion: @escaping ([String: AnyObject]?) -> Void) {
264+
265+
awsAuthService.getUserPoolAccessToken { result in
266+
if case .success(let tokenString) = result,
267+
case .success(let claims) = awsAuthService.getTokenClaims(tokenString: tokenString) {
268+
completion(claims)
269+
} else {
270+
completion(nil)
271+
}
272+
}
273+
}
274+
275+
static func getClaimsFromOIDCProvider(
276+
api: APICategoryGraphQLBehavior,
277+
awsAuthService: AWSAuthServiceBehavior,
278+
completion: @escaping ([String: AnyObject]?) -> Void) {
279+
280+
guard let oidcAuthProvider = hasOIDCAuthProviderAvailable(api: api),
281+
case .success(let tokenString) = oidcAuthProvider.getLatestAuthToken(),
282+
case .success(let claims) = awsAuthService.getTokenClaims(tokenString: tokenString)
283+
else {
284+
completion(nil)
285+
return
286+
}
287+
completion(claims)
288+
}
289+
231290
static func hasOIDCAuthProviderAvailable(api: APICategoryGraphQLBehavior) -> AmplifyOIDCAuthProvider? {
232291
if let apiPlugin = api as? APICategoryAuthProviderFactoryBehavior,
233292
let oidcAuthProvider = apiPlugin.apiAuthProviderFactory().oidcAuthProvider() {
@@ -292,16 +351,20 @@ extension IncomingAsyncSubscriptionEventPublisher {
292351
api: APICategoryGraphQLBehavior,
293352
auth: AuthCategoryBehavior?,
294353
awsAuthService: AWSAuthServiceBehavior,
295-
authTypeProvider: AWSAuthorizationTypeIterator) -> RetryableGraphQLOperation<Payload>.RequestFactory {
354+
authTypeProvider: AWSAuthorizationTypeIterator)
355+
-> RetryableGraphQLOperation<Payload>.RequestFactory {
356+
296357
// swiftlint:disable:previous line_length
297358
var authTypes = authTypeProvider
298-
return {
299-
return IncomingAsyncSubscriptionEventPublisher.makeAPIRequest(for: modelSchema,
300-
subscriptionType: subscriptionType,
301-
api: api,
302-
auth: auth,
303-
authType: authTypes.next(),
304-
awsAuthService: awsAuthService)
359+
return { completion in
360+
return IncomingAsyncSubscriptionEventPublisher.makeAPIRequest(
361+
for: modelSchema,
362+
subscriptionType: subscriptionType,
363+
api: api,
364+
auth: auth,
365+
authType: authTypes.next(),
366+
awsAuthService: awsAuthService,
367+
completion: completion)
305368
}
306369
}
307370
}

AmplifyPlugins/DataStore/Podfile.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,4 @@ SPEC CHECKSUMS:
127127

128128
PODFILE CHECKSUM: 0bab7193bebdf470839514f327440893b0d26090
129129

130-
COCOAPODS: 1.11.3
130+
COCOAPODS: 1.12.0

AmplifyTests/CategoryTests/API/RetryableGraphQLOperationTests.swift

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@ class RetryableGraphQLOperationTests: XCTestCase {
2929
resultExpectation.fulfill()
3030
}
3131

32-
let requestFactory: RequestFactory = {
32+
let requestFactory: RequestFactory = { completion in
3333
requestFactoryExpectation.fulfill()
34-
return self.makeTestRequest()
35-
34+
self.makeTestRequestAsync(completion: completion)
3635
}
3736

3837
let operation = RetryableGraphQLOperation<Payload>(requestFactory: requestFactory,
@@ -69,10 +68,9 @@ class RetryableGraphQLOperationTests: XCTestCase {
6968
resultExpectation.fulfill()
7069
}
7170

72-
let requestFactory: RequestFactory = {
71+
let requestFactory: RequestFactory = { completion in
7372
requestFactoryExpectation.fulfill()
74-
return self.makeTestRequest()
75-
73+
completion(self.makeTestRequest())
7674
}
7775

7876
let operation = RetryableGraphQLOperation<Payload>(requestFactory: requestFactory,
@@ -103,10 +101,9 @@ class RetryableGraphQLOperationTests: XCTestCase {
103101
resultExpectation.fulfill()
104102
}
105103

106-
let requestFactory: RequestFactory = {
104+
let requestFactory: RequestFactory = { completion in
107105
requestFactoryExpectation.fulfill()
108-
return self.makeTestRequest()
109-
106+
completion(self.makeTestRequest())
110107
}
111108

112109
let operation = RetryableGraphQLOperation<Payload>(requestFactory: requestFactory,
@@ -133,6 +130,16 @@ extension RetryableGraphQLOperationTests {
133130
responseType: Payload.self)
134131
}
135132

133+
private func makeTestRequestAsync(completion: @escaping (GraphQLRequest<Payload>) -> Void ) {
134+
DispatchQueue.global().asyncAfter(deadline: .now() + 2) {
135+
let request = GraphQLRequest<Payload>(apiName: self.testApiName,
136+
document: "",
137+
responseType: Payload.self)
138+
completion(request)
139+
}
140+
141+
}
142+
136143
private func makeTestOperation() -> GraphQLOperation<Payload> {
137144
let requestOptions = GraphQLOperationRequest<Payload>.Options(pluginOptions: nil)
138145
let operationRequest = GraphQLOperationRequest<Payload>(apiName: testApiName,

AmplifyTests/CoreTests/Optional+ExtensionTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,6 @@ class OptionalExtensionTests: XCTestCase {
5757

5858
}
5959

60-
fileprivate struct TestRuntimeError: Error, Equatable {
60+
private struct TestRuntimeError: Error, Equatable {
6161
let id = UUID()
6262
}

Podfile.lock

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
PODS:
2-
- AWSCore (2.30.1)
2+
- AWSCore (2.30.4)
33
- CwlCatchException (2.1.1):
44
- CwlCatchExceptionSupport (~> 2.1.1)
55
- CwlCatchExceptionSupport (2.1.1)
@@ -39,7 +39,7 @@ CHECKOUT OPTIONS:
3939
:tag: 2.1.0
4040

4141
SPEC CHECKSUMS:
42-
AWSCore: 493e49f8118e04fa57d927ceb117ba24a9b5ca02
42+
AWSCore: 19b8233fe2d0ed3ccf5cff833a615814282cdc90
4343
CwlCatchException: 86760545af2a490a23e964d76d7c77442dbce79b
4444
CwlCatchExceptionSupport: a004322095d7101b945442c86adc7cec0650f676
4545
CwlMachBadInstructionHandler: aa1fe9f2d08b29507c150d099434b2890247e7f8
@@ -50,4 +50,4 @@ SPEC CHECKSUMS:
5050

5151
PODFILE CHECKSUM: 5e20e56b8ef40444b018a3736b7b726ff9772f00
5252

53-
COCOAPODS: 1.11.3
53+
COCOAPODS: 1.12.0

0 commit comments

Comments
 (0)