@@ -23,46 +23,77 @@ struct TestUser {
23
23
let password : String
24
24
}
25
25
26
- class AuthRecorderInterceptor : URLRequestInterceptor {
27
- let awsAuthService : AWSAuthService = AWSAuthService ( )
28
- var consumedAuthTypes : Set < AWSAuthorizationType > = [ ]
29
- private let accessQueue = DispatchQueue ( label: " com.amazon.AuthRecorderInterceptor.consumedAuthTypes " )
30
-
31
- private func recordAuthType( _ authType: AWSAuthorizationType ) {
32
- accessQueue. async {
33
- self . consumedAuthTypes. insert ( authType)
34
- }
35
- }
26
+ class DataStoreAuthBaseTestURLSessionFactory : URLSessionBehaviorFactory {
27
+ static let testIdHeaderKey = " x-amplify-test "
36
28
37
- func intercept( _ request: URLRequest ) throws -> URLRequest {
38
- guard let headers = request. allHTTPHeaderFields else {
39
- fatalError ( " No headers found in request \( request) " )
40
- }
29
+ static let subject = PassthroughSubject < ( String , Set < AWSAuthorizationType > ) , Never > ( )
41
30
42
- let authHeaderValue = headers [ " Authorization " ]
43
- let apiKeyHeaderValue = headers [ " x-api-key " ]
31
+ class Sniffer : URLProtocol {
44
32
45
- if apiKeyHeaderValue != nil {
46
- recordAuthType ( . apiKey)
47
- }
33
+ override class func canInit( with request: URLRequest ) -> Bool {
34
+ guard let headers = request. allHTTPHeaderFields else {
35
+ fatalError ( " No headers found in request \( request) " )
36
+ }
37
+
38
+ guard let testId = headers [ DataStoreAuthBaseTestURLSessionFactory . testIdHeaderKey] else {
39
+ return false
40
+ }
41
+
42
+ var result : Set < AWSAuthorizationType > = [ ]
43
+ let authHeaderValue = headers [ " Authorization " ]
44
+ let apiKeyHeaderValue = headers [ " x-api-key " ]
45
+
46
+ if apiKeyHeaderValue != nil {
47
+ result. insert ( . apiKey)
48
+ }
49
+
50
+ if let authHeaderValue = authHeaderValue,
51
+ case let . success( claims) = AWSAuthService ( ) . getTokenClaims ( tokenString: authHeaderValue) ,
52
+ let cognitoIss = claims [ " iss " ] as? String , cognitoIss. contains ( " cognito " ) {
53
+ result. insert ( . amazonCognitoUserPools)
54
+ }
48
55
49
- if let authHeaderValue = authHeaderValue,
50
- case let . success( claims) = awsAuthService. getTokenClaims ( tokenString: authHeaderValue) ,
51
- let cognitoIss = claims [ " iss " ] as? String , cognitoIss. contains ( " cognito " ) {
52
- recordAuthType ( . amazonCognitoUserPools)
56
+ if let authHeaderValue = authHeaderValue,
57
+ authHeaderValue. starts ( with: " AWS4-HMAC-SHA256 " ) {
58
+ result. insert ( . awsIAM)
59
+ }
60
+
61
+ DataStoreAuthBaseTestURLSessionFactory . subject. send ( ( testId, result) )
62
+ return false
53
63
}
54
64
55
- if let authHeaderValue = authHeaderValue,
56
- authHeaderValue. starts ( with: " AWS4-HMAC-SHA256 " ) {
57
- recordAuthType ( . awsIAM)
65
+ }
66
+
67
+ class Interceptor : URLRequestInterceptor {
68
+ let testId : String ?
69
+
70
+ init ( testId: String ? ) {
71
+ self . testId = testId
58
72
}
59
73
60
- return request
74
+ func intercept( _ request: URLRequest ) async throws -> URLRequest {
75
+ if let testId {
76
+ var mutableRequest = request
77
+ mutableRequest. setValue ( testId, forHTTPHeaderField: DataStoreAuthBaseTestURLSessionFactory . testIdHeaderKey)
78
+ return mutableRequest
79
+ }
80
+ return request
81
+ }
61
82
}
62
83
63
- func reset( ) {
64
- consumedAuthTypes = [ ]
84
+ func makeSession( withDelegate delegate: URLSessionBehaviorDelegate ? ) -> URLSessionBehavior {
85
+ let urlSessionDelegate = delegate? . asURLSessionDelegate
86
+ let configuration = URLSessionConfiguration . default
87
+ configuration. tlsMinimumSupportedProtocolVersion = . TLSv12
88
+ configuration. tlsMaximumSupportedProtocolVersion = . TLSv13
89
+ configuration. protocolClasses? . insert ( Sniffer . self, at: 0 )
90
+
91
+ let session = URLSession ( configuration: configuration,
92
+ delegate: urlSessionDelegate,
93
+ delegateQueue: nil )
94
+ return AmplifyURLSession ( session: session)
65
95
}
96
+
66
97
}
67
98
68
99
class AWSDataStoreAuthBaseTest : XCTestCase {
@@ -71,7 +102,6 @@ class AWSDataStoreAuthBaseTest: XCTestCase {
71
102
var amplifyConfig : AmplifyConfiguration !
72
103
var user1 : TestUser ?
73
104
var user2 : TestUser ?
74
- var authRecorderInterceptor : AuthRecorderInterceptor !
75
105
76
106
override func setUp( ) {
77
107
continueAfterFailure = false
@@ -138,8 +168,6 @@ class AWSDataStoreAuthBaseTest: XCTestCase {
138
168
self . user1 = TestUser ( username: user1, password: passwordUser1)
139
169
self . user2 = TestUser ( username: user2, password: passwordUser2)
140
170
141
- authRecorderInterceptor = AuthRecorderInterceptor ( )
142
-
143
171
amplifyConfig = try TestConfigHelper . retrieveAmplifyConfiguration ( forResource: configFile)
144
172
145
173
} catch {
@@ -161,7 +189,8 @@ class AWSDataStoreAuthBaseTest: XCTestCase {
161
189
func setup(
162
190
withModels models: AmplifyModelRegistration ,
163
191
testType: DataStoreAuthTestType ,
164
- apiPluginFactory: ( ) -> AWSAPIPlugin = { AWSAPIPlugin ( sessionFactory: AmplifyURLSessionFactory ( ) ) }
192
+ testId: String ? = nil ,
193
+ apiPluginFactory: ( ) -> AWSAPIPlugin = { AWSAPIPlugin ( sessionFactory: DataStoreAuthBaseTestURLSessionFactory ( ) ) }
165
194
) async {
166
195
do {
167
196
setupCredentials ( forAuthStrategy: testType)
@@ -182,7 +211,10 @@ class AWSDataStoreAuthBaseTest: XCTestCase {
182
211
183
212
// register auth recorder interceptor
184
213
let apiName = try apiEndpointName ( )
185
- try apiPlugin. add ( interceptor: authRecorderInterceptor, for: apiName)
214
+ try apiPlugin. add (
215
+ interceptor: DataStoreAuthBaseTestURLSessionFactory . Interceptor ( testId: testId) ,
216
+ for: apiName
217
+ )
186
218
187
219
await signOut ( )
188
220
} catch {
@@ -486,13 +518,27 @@ extension AWSDataStoreAuthBaseTest {
486
518
await waitForExpectations ( [ expectations. mutationDelete, expectations. mutationDeleteProcessed] , timeout: 60 )
487
519
}
488
520
489
- func assertUsedAuthTypes( _ authTypes: [ AWSAuthorizationType ] ,
490
- file: StaticString = #file,
491
- line: UInt = #line) {
492
- XCTAssertEqual ( authRecorderInterceptor. consumedAuthTypes,
493
- Set ( authTypes) ,
494
- file: file,
495
- line: line)
521
+ func assertUsedAuthTypes(
522
+ testId: String ,
523
+ authTypes: [ AWSAuthorizationType ] ,
524
+ file: StaticString = #file,
525
+ line: UInt = #line
526
+ ) -> XCTestExpectation {
527
+ let expectation = expectation ( description: " Should have expected auth types " )
528
+ expectation. assertForOverFulfill = false
529
+ DataStoreAuthBaseTestURLSessionFactory . subject
530
+ . filter { $0. 0 == testId }
531
+ . map { $0. 1 }
532
+ . collect ( . byTime( DispatchQueue . global ( ) , . milliseconds( 3500 ) ) )
533
+ . sink {
534
+ let result = $0. reduce ( Set < AWSAuthorizationType > ( ) ) { partialResult, data in
535
+ partialResult. union ( data)
536
+ }
537
+ XCTAssertEqual ( result, Set ( authTypes) , file: file, line: line)
538
+ expectation. fulfill ( )
539
+ }
540
+ . store ( in: & requests)
541
+ return expectation
496
542
}
497
543
}
498
544
0 commit comments