@@ -196,38 +196,97 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
196
196
}
197
197
198
198
// swiftlint:disable:next function_parameter_count
199
- static func makeAPIRequest( for modelSchema: ModelSchema ,
200
- subscriptionType: GraphQLSubscriptionType ,
201
- api: APICategoryGraphQLBehavior ,
202
- auth: AuthCategoryBehavior ? ,
203
- authType: AWSAuthorizationType ? ,
204
- awsAuthService: AWSAuthServiceBehavior ) -> GraphQLRequest < Payload > {
205
- let request : GraphQLRequest < Payload >
206
- if modelSchema. hasAuthenticationRules,
207
- auth != nil ,
208
- case . success( let tokenString) = awsAuthService. getToken ( ) ,
209
- case . success( let claims) = awsAuthService. getTokenClaims ( tokenString: tokenString) {
210
- request = GraphQLRequest< Payload> . subscription( to: modelSchema,
211
- subscriptionType: subscriptionType,
212
- claims: claims,
213
- authType: authType)
214
- } else if modelSchema. hasAuthenticationRules,
215
- let oidcAuthProvider = hasOIDCAuthProviderAvailable ( api: api) ,
216
- case . success( let tokenString) = oidcAuthProvider. getLatestAuthToken ( ) ,
217
- case . success( let claims) = awsAuthService. getTokenClaims ( tokenString: tokenString) {
218
- request = GraphQLRequest< Payload> . subscription( to: modelSchema,
219
- subscriptionType: subscriptionType,
220
- claims: claims,
221
- authType: authType)
199
+ static func makeAPIRequest(
200
+ for modelSchema: ModelSchema ,
201
+ subscriptionType: GraphQLSubscriptionType ,
202
+ api: APICategoryGraphQLBehavior ,
203
+ auth: AuthCategoryBehavior ? ,
204
+ authType: AWSAuthorizationType ? ,
205
+ awsAuthService: AWSAuthServiceBehavior ,
206
+ completion: @escaping ( GraphQLRequest < Payload > ) -> Void ) {
207
+
208
+ let requestWithOutClaims = GraphQLRequest< Payload> . subscription(
209
+ to: modelSchema,
210
+ subscriptionType: subscriptionType,
211
+ authType: authType)
212
+
213
+ guard modelSchema. hasAuthenticationRules else {
214
+ completion ( requestWithOutClaims)
215
+ return
216
+ }
217
+
218
+ getClaims ( api: api,
219
+ auth: auth,
220
+ awsAuthService: awsAuthService) { claims in
221
+
222
+ guard let claims = claims else {
223
+ completion ( requestWithOutClaims)
224
+ return
225
+ }
226
+ let request = GraphQLRequest< Payload> . subscription(
227
+ to: modelSchema,
228
+ subscriptionType: subscriptionType,
229
+ claims: claims,
230
+ authType: authType)
231
+ completion ( request)
232
+ return
233
+ }
234
+
235
+ }
236
+
237
+ static func getClaims( api: APICategoryGraphQLBehavior ,
238
+ auth: AuthCategoryBehavior ? ,
239
+ awsAuthService: AWSAuthServiceBehavior ,
240
+ completion: @escaping ( [ String : AnyObject ] ? ) -> Void ) {
241
+ if auth != nil {
242
+ getClaimsFromUserPool ( awsAuthService: awsAuthService) { claims in
243
+ if let claims = claims {
244
+ completion ( claims)
245
+ } else {
246
+ getClaimsFromOIDCProvider (
247
+ api: api,
248
+ awsAuthService: awsAuthService,
249
+ completion: completion)
250
+ }
251
+ }
222
252
} else {
223
- request = GraphQLRequest< Payload> . subscription( to: modelSchema,
224
- subscriptionType: subscriptionType,
225
- authType: authType)
253
+ getClaimsFromOIDCProvider (
254
+ api: api,
255
+ awsAuthService: awsAuthService,
256
+ completion: completion)
226
257
}
227
258
228
- return request
229
259
}
230
260
261
+ static func getClaimsFromUserPool(
262
+ awsAuthService: AWSAuthServiceBehavior ,
263
+ completion: @escaping ( [ String : AnyObject ] ? ) -> Void ) {
264
+
265
+ awsAuthService. getUserPoolAccessToken { result in
266
+ if case . success( let tokenString) = result,
267
+ case . success( let claims) = awsAuthService. getTokenClaims ( tokenString: tokenString) {
268
+ completion ( claims)
269
+ } else {
270
+ completion ( nil )
271
+ }
272
+ }
273
+ }
274
+
275
+ static func getClaimsFromOIDCProvider(
276
+ api: APICategoryGraphQLBehavior ,
277
+ awsAuthService: AWSAuthServiceBehavior ,
278
+ completion: @escaping ( [ String : AnyObject ] ? ) -> Void ) {
279
+
280
+ guard let oidcAuthProvider = hasOIDCAuthProviderAvailable ( api: api) ,
281
+ case . success( let tokenString) = oidcAuthProvider. getLatestAuthToken ( ) ,
282
+ case . success( let claims) = awsAuthService. getTokenClaims ( tokenString: tokenString)
283
+ else {
284
+ completion ( nil )
285
+ return
286
+ }
287
+ completion ( claims)
288
+ }
289
+
231
290
static func hasOIDCAuthProviderAvailable( api: APICategoryGraphQLBehavior ) -> AmplifyOIDCAuthProvider ? {
232
291
if let apiPlugin = api as? APICategoryAuthProviderFactoryBehavior ,
233
292
let oidcAuthProvider = apiPlugin. apiAuthProviderFactory ( ) . oidcAuthProvider ( ) {
@@ -292,16 +351,20 @@ extension IncomingAsyncSubscriptionEventPublisher {
292
351
api: APICategoryGraphQLBehavior ,
293
352
auth: AuthCategoryBehavior ? ,
294
353
awsAuthService: AWSAuthServiceBehavior ,
295
- authTypeProvider: AWSAuthorizationTypeIterator ) -> RetryableGraphQLOperation < Payload > . RequestFactory {
354
+ authTypeProvider: AWSAuthorizationTypeIterator )
355
+ -> RetryableGraphQLOperation < Payload > . RequestFactory {
356
+
296
357
// swiftlint:disable:previous line_length
297
358
var authTypes = authTypeProvider
298
- return {
299
- return IncomingAsyncSubscriptionEventPublisher . makeAPIRequest ( for: modelSchema,
300
- subscriptionType: subscriptionType,
301
- api: api,
302
- auth: auth,
303
- authType: authTypes. next ( ) ,
304
- awsAuthService: awsAuthService)
359
+ return { completion in
360
+ return IncomingAsyncSubscriptionEventPublisher . makeAPIRequest (
361
+ for: modelSchema,
362
+ subscriptionType: subscriptionType,
363
+ api: api,
364
+ auth: auth,
365
+ authType: authTypes. next ( ) ,
366
+ awsAuthService: awsAuthService,
367
+ completion: completion)
305
368
}
306
369
}
307
370
}
0 commit comments