diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..08891d83f --- /dev/null +++ b/.editorconfig @@ -0,0 +1,8 @@ +root = true + +[*] +indent_style = space +indent_size = 4 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true \ No newline at end of file diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 000000000..e29eb8464 --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,14 @@ +changelog: + categories: + - title: SemVer Major + labels: + - ⚠️ semver/major + - title: SemVer Minor + labels: + - 🆕 semver/minor + - title: SemVer Patch + labels: + - 🔨 semver/patch + - title: Other Changes + labels: + - semver/none diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 000000000..f63d89a3e --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,18 @@ +name: Main + +on: + push: + branches: [main] + schedule: + - cron: "0 8,20 * * *" + +jobs: + unit-tests: + name: Unit tests + uses: apple/swift-nio/.github/workflows/unit_tests.yml@main + with: + linux_5_9_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" + linux_5_10_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" + linux_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_nightly_6_1_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml new file mode 100644 index 000000000..a4fde6207 --- /dev/null +++ b/.github/workflows/pull_request.yml @@ -0,0 +1,25 @@ +name: PR + +on: + pull_request: + types: [opened, reopened, synchronize] + +jobs: + soundness: + name: Soundness + uses: swiftlang/github-workflows/.github/workflows/soundness.yml@main + with: + license_header_check_project_name: "AsyncHTTPClient" + unit-tests: + name: Unit tests + uses: apple/swift-nio/.github/workflows/unit_tests.yml@main + with: + linux_5_9_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" + linux_5_10_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" + linux_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_nightly_6_1_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + + cxx-interop: + name: Cxx interop + uses: apple/swift-nio/.github/workflows/cxx_interop.yml@main diff --git a/.github/workflows/pull_request_label.yml b/.github/workflows/pull_request_label.yml new file mode 100644 index 000000000..8fd47c13f --- /dev/null +++ b/.github/workflows/pull_request_label.yml @@ -0,0 +1,18 @@ +name: PR label + +on: + pull_request: + types: [labeled, unlabeled, opened, reopened, synchronize] + +jobs: + semver-label-check: + name: Semantic version label check + runs-on: ubuntu-latest + timeout-minutes: 1 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Check for Semantic Version label + uses: apple/swift-nio/.github/actions/pull_request_semver_label_checker@main diff --git a/.licenseignore b/.licenseignore new file mode 100644 index 000000000..edceaab62 --- /dev/null +++ b/.licenseignore @@ -0,0 +1,37 @@ +.gitignore +**/.gitignore +.licenseignore +.gitattributes +.git-blame-ignore-revs +.mailfilter +.mailmap +.spi.yml +.swift-format +.editorconfig +.github/* +*.md +*.txt +*.yml +*.yaml +*.json +Package.swift +**/Package.swift +Package@-*.swift +**/Package@-*.swift +Package.resolved +**/Package.resolved +Makefile +*.modulemap +**/*.modulemap +**/*.docc/* +*.xcprivacy +**/*.xcprivacy +*.symlink +**/*.symlink +Dockerfile +**/Dockerfile +.dockerignore +Snippets/* +dev/git.commit.template +.unacceptablelanguageignore +Tests/AsyncHTTPClientTests/Resources/*.pem diff --git a/.spi.yml b/.spi.yml new file mode 100644 index 000000000..795484b35 --- /dev/null +++ b/.spi.yml @@ -0,0 +1,4 @@ +version: 1 +builder: + configs: + - documentation_targets: [AsyncHTTPClient] diff --git a/.swift-format b/.swift-format new file mode 100644 index 000000000..7e8ae7391 --- /dev/null +++ b/.swift-format @@ -0,0 +1,68 @@ +{ + "version" : 1, + "indentation" : { + "spaces" : 4 + }, + "tabWidth" : 4, + "fileScopedDeclarationPrivacy" : { + "accessLevel" : "private" + }, + "spacesAroundRangeFormationOperators" : false, + "indentConditionalCompilationBlocks" : false, + "indentSwitchCaseLabels" : false, + "lineBreakAroundMultilineExpressionChainComponents" : false, + "lineBreakBeforeControlFlowKeywords" : false, + "lineBreakBeforeEachArgument" : true, + "lineBreakBeforeEachGenericRequirement" : true, + "lineLength" : 120, + "maximumBlankLines" : 1, + "respectsExistingLineBreaks" : true, + "prioritizeKeepingFunctionOutputTogether" : true, + "noAssignmentInExpressions" : { + "allowedFunctions" : [ + "XCTAssertNoThrow", + "XCTAssertThrowsError" + ] + }, + "rules" : { + "AllPublicDeclarationsHaveDocumentation" : false, + "AlwaysUseLiteralForEmptyCollectionInit" : false, + "AlwaysUseLowerCamelCase" : false, + "AmbiguousTrailingClosureOverload" : true, + "BeginDocumentationCommentWithOneLineSummary" : false, + "DoNotUseSemicolons" : true, + "DontRepeatTypeInStaticProperties" : true, + "FileScopedDeclarationPrivacy" : true, + "FullyIndirectEnum" : true, + "GroupNumericLiterals" : true, + "IdentifiersMustBeASCII" : true, + "NeverForceUnwrap" : false, + "NeverUseForceTry" : false, + "NeverUseImplicitlyUnwrappedOptionals" : false, + "NoAccessLevelOnExtensionDeclaration" : true, + "NoAssignmentInExpressions" : true, + "NoBlockComments" : true, + "NoCasesWithOnlyFallthrough" : true, + "NoEmptyTrailingClosureParentheses" : true, + "NoLabelsInCasePatterns" : true, + "NoLeadingUnderscores" : false, + "NoParensAroundConditions" : true, + "NoVoidReturnOnFunctionSignature" : true, + "OmitExplicitReturns" : true, + "OneCasePerLine" : true, + "OneVariableDeclarationPerLine" : true, + "OnlyOneTrailingClosureArgument" : true, + "OrderedImports" : true, + "ReplaceForEachWithForLoop" : true, + "ReturnVoidInsteadOfEmptyTuple" : true, + "UseEarlyExits" : false, + "UseExplicitNilCheckInConditions" : false, + "UseLetInEveryBoundCaseVariable" : false, + "UseShorthandTypeNames" : true, + "UseSingleLinePropertyGetter" : false, + "UseSynthesizedInitializer" : false, + "UseTripleSlashForDocumentationComments" : true, + "UseWhereClausesInForLoops" : false, + "ValidateDocumentationComments" : false + } +} diff --git a/.swiftformat b/.swiftformat deleted file mode 100644 index dac9cd4d3..000000000 --- a/.swiftformat +++ /dev/null @@ -1,23 +0,0 @@ -# file options - ---swiftversion 5.2 ---exclude .build - -# format options - ---self insert ---patternlet inline ---ranges nospace ---stripunusedargs unnamed-only ---ifdef no-indent ---extensionacl on-declarations ---disable typeSugar # https://github.com/nicklockwood/SwiftFormat/issues/636 ---disable andOperator ---disable wrapMultilineStatementBraces ---disable enumNamespaces ---disable redundantExtensionACL ---disable redundantReturn ---disable preferKeyPath ---disable sortedSwitchCases - -# rules diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index c7a248828..76501d7d6 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,55 +1,5 @@ # Code of Conduct -To be a truly great community, AsyncHTTPClient needs to welcome developers from all walks of life, -with different backgrounds, and with a wide range of experience. A diverse and friendly -community will have more great ideas, more unique perspectives, and produce more great -code. We will work diligently to make the AsyncHTTPClient community welcoming to everyone. -To give clarity of what is expected of our members, AsyncHTTPClient has adopted the code of conduct -defined by [contributor-covenant.org](https://www.contributor-covenant.org). This document is used across many open source -communities, and we think it articulates our values well. The full text is copied below: +The code of conduct for this project can be found at https://swift.org/code-of-conduct. -### Contributor Code of Conduct v1.3 -As contributors and maintainers of this project, and in the interest of fostering an open and -welcoming community, we pledge to respect all people who contribute through reporting -issues, posting feature requests, updating documentation, submitting pull requests or patches, -and other activities. - -We are committed to making participation in this project a harassment-free experience for -everyone, regardless of level of experience, gender, gender identity and expression, sexual -orientation, disability, personal appearance, body size, race, ethnicity, age, religion, or -nationality. - -Examples of unacceptable behavior by participants include: -- The use of sexualized language or imagery -- Personal attacks -- Trolling or insulting/derogatory comments -- Public or private harassment -- Publishing other’s private information, such as physical or electronic addresses, without explicit permission -- Other unethical or unprofessional conduct - -Project maintainers have the right and responsibility to remove, edit, or reject comments, -commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of -Conduct, or to ban temporarily or permanently any contributor for other behaviors that they -deem inappropriate, threatening, offensive, or harmful. - -By adopting this Code of Conduct, project maintainers commit themselves to fairly and -consistently applying these principles to every aspect of managing this project. Project -maintainers who do not follow or enforce the Code of Conduct may be permanently removed -from the project team. - -This code of conduct applies both within project spaces and in public spaces when an -individual is representing the project or its community. - -Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by -contacting a project maintainer at [conduct@swiftserver.group](mailto:conduct@swiftserver.group). All complaints will be reviewed and -investigated and will result in a response that is deemed necessary and appropriate to the -circumstances. Maintainers are obligated to maintain confidentiality with regard to the reporter -of an incident. - -*This policy is adapted from the Contributor Code of Conduct [version 1.3.0](https://contributor-covenant.org/version/1/3/0/).* - -### Reporting -A working group of community members is committed to promptly addressing any [reported issues](mailto:conduct@swiftserver.group). -Working group members are volunteers appointed by the project lead, with a -preference for individuals with varied backgrounds and perspectives. Membership is expected -to change regularly, and may grow or shrink. + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6382fcd4d..dddcb3ba4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -57,7 +57,7 @@ A good AsyncHTTPClient patch is: 3. Documented, adding API documentation as needed to cover new functions and properties. 4. Accompanied by a great commit message, using our commit message template. -*Note* as of version 1.5.0 AsyncHTTPClient requires Swift 5.2. Earlier versions support as far back as Swift 5.0. +*Note* as of version 1.10.0 AsyncHTTPClient requires Swift 5.4. Earlier versions support as far back as Swift 5.0. ### Commit Message Template @@ -65,10 +65,10 @@ We require that your commit messages match our template. The easiest way to do t git config commit.template dev/git.commit.template -### Make sure Tests work on Linux -AsyncHTTPClient uses XCTest to run tests on both macOS and Linux. While the macOS version of XCTest is able to use the Objective-C runtime to discover tests at execution time, the Linux version is not. -For this reason, whenever you add new tests **you have to run a script** that generates the hooks needed to run those tests on Linux, or our CI will complain that the tests are not all present on Linux. To do this, merely execute `ruby ./scripts/generate_linux_tests.rb` at the root of the package and check the changes it made. +### Run CI checks locally + +You can run the Github Actions workflows locally using [act](https://github.com/nektos/act). For detailed steps on how to do this please see [https://github.com/swiftlang/github-workflows?tab=readme-ov-file#running-workflows-locally](https://github.com/swiftlang/github-workflows?tab=readme-ov-file#running-workflows-locally). ## How to contribute your work diff --git a/Examples/GetHTML/GetHTML.swift b/Examples/GetHTML/GetHTML.swift index dfefa922b..ca3bacbea 100644 --- a/Examples/GetHTML/GetHTML.swift +++ b/Examples/GetHTML/GetHTML.swift @@ -18,12 +18,12 @@ import NIOCore @main struct GetHTML { static func main() async throws { - let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) do { let request = HTTPClientRequest(url: "https://apple.com") let response = try await httpClient.execute(request, timeout: .seconds(30)) print("HTTP head", response) - let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB + let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB print(String(buffer: body)) } catch { print("request failed:", error) diff --git a/Examples/GetJSON/GetJSON.swift b/Examples/GetJSON/GetJSON.swift index ae58ffeaa..1af7a5144 100644 --- a/Examples/GetJSON/GetJSON.swift +++ b/Examples/GetJSON/GetJSON.swift @@ -33,12 +33,12 @@ struct Comic: Codable { @main struct GetJSON { static func main() async throws { - let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) do { let request = HTTPClientRequest(url: "https://xkcd.com/info.0.json") let response = try await httpClient.execute(request, timeout: .seconds(30)) print("HTTP head", response) - let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB + let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB // we use an overload defined in `NIOFoundationCompat` for `decode(_:from:)` to // efficiently decode from a `ByteBuffer` let comic = try JSONDecoder().decode(Comic.self, from: body) diff --git a/Examples/Package.swift b/Examples/Package.swift index 696092cba..9986b17b5 100644 --- a/Examples/Package.swift +++ b/Examples/Package.swift @@ -43,7 +43,8 @@ let package = Package( dependencies: [ .product(name: "AsyncHTTPClient", package: "async-http-client"), .product(name: "NIOCore", package: "swift-nio"), - ], path: "GetHTML" + ], + path: "GetHTML" ), .executableTarget( name: "GetJSON", @@ -51,14 +52,16 @@ let package = Package( .product(name: "AsyncHTTPClient", package: "async-http-client"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOFoundationCompat", package: "swift-nio"), - ], path: "GetJSON" + ], + path: "GetJSON" ), .executableTarget( name: "StreamingByteCounter", dependencies: [ .product(name: "AsyncHTTPClient", package: "async-http-client"), .product(name: "NIOCore", package: "swift-nio"), - ], path: "StreamingByteCounter" + ], + path: "StreamingByteCounter" ), ] ) diff --git a/Examples/StreamingByteCounter/StreamingByteCounter.swift b/Examples/StreamingByteCounter/StreamingByteCounter.swift index dc340d14b..ecfb48776 100644 --- a/Examples/StreamingByteCounter/StreamingByteCounter.swift +++ b/Examples/StreamingByteCounter/StreamingByteCounter.swift @@ -18,7 +18,7 @@ import NIOCore @main struct StreamingByteCounter { static func main() async throws { - let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) do { let request = HTTPClientRequest(url: "https://apple.com") let response = try await httpClient.execute(request, timeout: .seconds(30)) diff --git a/NOTICE.txt b/NOTICE.txt index 095a11740..86a969171 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -50,13 +50,13 @@ This product contains a derivation of the Tony Stone's 'process_test_files.rb'. * https://www.apache.org/licenses/LICENSE-2.0 * HOMEPAGE: * https://github.com/tonystone/build-tools/commit/6c417b7569df24597a48a9aa7b505b636e8f73a1 - * https://github.com/tonystone/build-tools/blob/master/source/xctest_tool.rb + * https://github.com/tonystone/build-tools/blob/cf3440f43bde2053430285b4ed0709c865892eb5/source/xctest_tool.rb --- This product contains a derivation of Fabian Fett's 'Base64.swift'. * LICENSE (Apache License 2.0): - * https://github.com/fabianfett/swift-base64-kit/blob/master/LICENSE + * https://github.com/swift-extras/swift-extras-base64/blob/b8af49699d59ad065b801715a5009619100245ca/LICENSE * HOMEPAGE: * https://github.com/fabianfett/swift-base64-kit diff --git a/Package.swift b/Package.swift index e4dcc717c..8bec2bd55 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.2 +// swift-tools-version:5.8 //===----------------------------------------------------------------------===// // // This source file is part of the AsyncHTTPClient open source project @@ -18,40 +18,50 @@ import PackageDescription let package = Package( name: "async-http-client", products: [ - .library(name: "AsyncHTTPClient", targets: ["AsyncHTTPClient"]), + .library(name: "AsyncHTTPClient", targets: ["AsyncHTTPClient"]) ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.38.0"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.1"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.78.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.27.1"), .package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.19.0"), - .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.10.0"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"), - .package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"), + .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.13.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.4.4"), + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), + .package(url: "https://github.com/apple/swift-algorithms.git", from: "1.0.0"), ], targets: [ - .target(name: "CAsyncHTTPClient"), + .target( + name: "CAsyncHTTPClient", + cSettings: [ + .define("_GNU_SOURCE") + ] + ), .target( name: "AsyncHTTPClient", dependencies: [ .target(name: "CAsyncHTTPClient"), .product(name: "NIO", package: "swift-nio"), + .product(name: "NIOTLS", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), - .product(name: "NIOFoundationCompat", package: "swift-nio"), .product(name: "NIOHTTP2", package: "swift-nio-http2"), .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOHTTPCompression", package: "swift-nio-extras"), .product(name: "NIOSOCKS", package: "swift-nio-extras"), .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), .product(name: "Logging", package: "swift-log"), + .product(name: "Atomics", package: "swift-atomics"), + .product(name: "Algorithms", package: "swift-algorithms"), ] ), .testTarget( name: "AsyncHTTPClientTests", dependencies: [ .target(name: "AsyncHTTPClient"), + .product(name: "NIOTLS", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), .product(name: "NIOEmbedded", package: "swift-nio"), @@ -61,7 +71,26 @@ let package = Package( .product(name: "NIOHTTP2", package: "swift-nio-http2"), .product(name: "NIOSOCKS", package: "swift-nio-extras"), .product(name: "Logging", package: "swift-log"), + .product(name: "Atomics", package: "swift-atomics"), + .product(name: "Algorithms", package: "swift-algorithms"), + ], + resources: [ + .copy("Resources/self_signed_cert.pem"), + .copy("Resources/self_signed_key.pem"), + .copy("Resources/example.com.cert.pem"), + .copy("Resources/example.com.private-key.pem"), ] ), ] ) + +// --- STANDARD CROSS-REPO SETTINGS DO NOT EDIT --- // +for target in package.targets { + if target.type != .plugin { + var settings = target.swiftSettings ?? [] + // https://github.com/swiftlang/swift-evolution/blob/main/proposals/0444-member-import-visibility.md + settings.append(.enableUpcomingFeature("MemberImportVisibility")) + target.swiftSettings = settings + } +} +// --- END: STANDARD CROSS-REPO SETTINGS DO NOT EDIT --- // diff --git a/README.md b/README.md index 6dad76de3..871eb910b 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,15 @@ # AsyncHTTPClient -This package provides simple HTTP Client library built on top of SwiftNIO. +This package provides an HTTP Client library built on top of SwiftNIO. This library provides the following: -- First class support for Swift Concurrency (since version 1.9.0) +- First class support for Swift Concurrency - Asynchronous and non-blocking request methods - Simple follow-redirects (cookie headers are dropped) - Streaming body download - TLS support -- Automatic HTTP/2 over HTTPS (since version 1.7.0) +- Automatic HTTP/2 over HTTPS - Cookie parsing (but not storage) ---- - -**NOTE**: You will need [Xcode 13.2](https://apps.apple.com/gb/app/xcode/id497799835?mt=12) or [Swift 5.5.2](https://swift.org/download/#swift-552) to try out `AsyncHTTPClient`s new async/await APIs. - ---- - ## Getting Started #### Adding the dependency @@ -33,18 +27,12 @@ and `AsyncHTTPClient` dependency to your target: The code snippet below illustrates how to make a simple GET request to a remote server. -Please note that the example will spawn a new `EventLoopGroup` which will _create fresh threads_ which is a very costly operation. In a real-world application that uses [SwiftNIO](https://github.com/apple/swift-nio) for other parts of your application (for example a web server), please prefer `eventLoopGroupProvider: .shared(myExistingEventLoopGroup)` to share the `EventLoopGroup` used by AsyncHTTPClient with other parts of your application. - -If your application does not use SwiftNIO yet, it is acceptable to use `eventLoopGroupProvider: .createNew` but please make sure to share the returned `HTTPClient` instance throughout your whole application. Do not create a large number of `HTTPClient` instances with `eventLoopGroupProvider: .createNew`, this is very wasteful and might exhaust the resources of your program. - ```swift import AsyncHTTPClient -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) - /// MARK: - Using Swift Concurrency let request = HTTPClientRequest(url: "https://apple.com/") -let response = try await httpClient.execute(request, timeout: .seconds(30)) +let response = try await HTTPClient.shared.execute(request, timeout: .seconds(30)) print("HTTP head", response) if response.status == .ok { let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB @@ -55,7 +43,7 @@ if response.status == .ok { /// MARK: - Using SwiftNIO EventLoopFuture -httpClient.get(url: "https://apple.com/").whenComplete { result in +HTTPClient.shared.get(url: "https://apple.com/").whenComplete { result in switch result { case .failure(let error): // process error @@ -69,7 +57,8 @@ httpClient.get(url: "https://apple.com/").whenComplete { result in } ``` -You should always shut down `HTTPClient` instances you created using `try httpClient.syncShutdown()`. Please note that you must not call `httpClient.syncShutdown` before all requests of the HTTP client have finished, or else the in-flight requests will likely fail because their network connections are interrupted. +If you create your own `HTTPClient` instances, you should shut them down using `httpClient.shutdown()` when you're done using them. Failing to do so will leak resources. + Please note that you must not call `httpClient.shutdown` before all requests of the HTTP client have finished, or else the in-flight requests will likely fail because their network connections are interrupted. ### async/await examples @@ -84,14 +73,13 @@ The default HTTP Method is `GET`. In case you need to have more control over the ```swift import AsyncHTTPClient -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) do { var request = HTTPClientRequest(url: "https://apple.com/") request.method = .POST request.headers.add(name: "User-Agent", value: "Swift HTTPClient") request.body = .bytes(ByteBuffer(string: "some data")) - let response = try await httpClient.execute(request, timeout: .seconds(30)) + let response = try await HTTPClient.shared.execute(request, timeout: .seconds(30)) if response.status == .ok { // handle response } else { @@ -100,8 +88,6 @@ do { } catch { // handle error } -// it's important to shutdown the httpClient after all requests are done, even if one failed -try await httpClient.shutdown() ``` #### Using SwiftNIO EventLoopFuture @@ -109,16 +95,11 @@ try await httpClient.shutdown() ```swift import AsyncHTTPClient -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) -defer { - try? httpClient.syncShutdown() -} - var request = try HTTPClient.Request(url: "https://apple.com/", method: .POST) request.headers.add(name: "User-Agent", value: "Swift HTTPClient") request.body = .string("some-body") -httpClient.execute(request: request).whenComplete { result in +HTTPClient.shared.execute(request: request).whenComplete { result in switch result { case .failure(let error): // process error @@ -133,9 +114,11 @@ httpClient.execute(request: request).whenComplete { result in ``` ### Redirects following -Enable follow-redirects behavior using the client configuration: + +The globally shared instance `HTTPClient.shared` follows redirects by default. If you create your own `HTTPClient`, you can enable the follow-redirects behavior using the client configuration: + ```swift -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, +let httpClient = HTTPClient(eventLoopGroupProvider: .singleton, configuration: HTTPClient.Configuration(followRedirects: true)) ``` @@ -143,7 +126,7 @@ let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, Timeouts (connect and read) can also be set using the client configuration: ```swift let timeout = HTTPClient.Configuration.Timeout(connect: .seconds(1), read: .seconds(1)) -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, +let httpClient = HTTPClient(eventLoopGroupProvider: .singleton, configuration: HTTPClient.Configuration(timeout: timeout)) ``` or on a per-request basis: @@ -152,15 +135,14 @@ httpClient.execute(request: request, deadline: .now() + .milliseconds(1)) ``` ### Streaming -When dealing with larger amount of data, it's critical to stream the response body instead of aggregating in-memory. +When dealing with larger amount of data, it's critical to stream the response body instead of aggregating in-memory. The following example demonstrates how to count the number of bytes in a streaming response body: #### Using Swift Concurrency ```swift -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) do { let request = HTTPClientRequest(url: "https://apple.com/") - let response = try await httpClient.execute(request, timeout: .seconds(30)) + let response = try await HTTPClient.shared.execute(request, timeout: .seconds(30)) print("HTTP head", response) // if defined, the content-length headers announces the size of the body @@ -172,7 +154,7 @@ do { for try await buffer in response.body { // for this example, we are just interested in the size of the fragment receivedBytes += buffer.readableBytes - + if let expectedBytes = expectedBytes { // if the body size is known, we calculate a progress indicator let progress = Double(receivedBytes) / Double(expectedBytes) @@ -181,10 +163,8 @@ do { } print("did receive \(receivedBytes) bytes") } catch { - print("request failed:", error) + print("request failed:", error) } -// it is important to shutdown the httpClient after all requests are done, even if one failed -try await httpClient.shutdown() ``` #### Using HTTPClientResponseDelegate and SwiftNIO EventLoopFuture @@ -211,17 +191,17 @@ class CountingDelegate: HTTPClientResponseDelegate { } func didReceiveHead( - task: HTTPClient.Task, + task: HTTPClient.Task, _ head: HTTPResponseHead ) -> EventLoopFuture { - // this is executed when we receive HTTP response head part of the request - // (it contains response code and headers), called once in case backpressure + // this is executed when we receive HTTP response head part of the request + // (it contains response code and headers), called once in case backpressure // is needed, all reads will be paused until returned future is resolved return task.eventLoop.makeSucceededFuture(()) } func didReceiveBodyPart( - task: HTTPClient.Task, + task: HTTPClient.Task, _ buffer: ByteBuffer ) -> EventLoopFuture { // this is executed when we receive parts of the response body, could be called zero or more times @@ -244,7 +224,7 @@ class CountingDelegate: HTTPClientResponseDelegate { let request = try HTTPClient.Request(url: "https://apple.com/") let delegate = CountingDelegate() -httpClient.execute(request: request, delegate: delegate).futureResult.whenSuccess { count in +HTTPClient.shared.execute(request: request, delegate: delegate).futureResult.whenSuccess { count in print(count) } ``` @@ -257,7 +237,6 @@ asynchronously, while reporting the download progress at the same time, like in example: ```swift -let client = HTTPClient(eventLoopGroupProvider: .createNew) let request = try HTTPClient.Request( url: "https://swift.org/builds/development/ubuntu1804/latest-build.yml" ) @@ -269,7 +248,7 @@ let delegate = try FileDownloadDelegate(path: "/tmp/latest-build.yml", reportPro print("Downloaded \($0.receivedBytes) bytes so far") }) -client.execute(request: request, delegate: delegate).futureResult +HTTPClient.shared.execute(request: request, delegate: delegate).futureResult .whenSuccess { progress in if let totalBytes = progress.totalBytes { print("Final total bytes count: \(totalBytes)") @@ -281,21 +260,19 @@ client.execute(request: request, delegate: delegate).futureResult ### Unix Domain Socket Paths Connecting to servers bound to socket paths is easy: ```swift -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) -httpClient.execute( - .GET, - socketPath: "/tmp/myServer.socket", +HTTPClient.shared.execute( + .GET, + socketPath: "/tmp/myServer.socket", urlPath: "/path/to/resource" ).whenComplete (...) ``` Connecting over TLS to a unix domain socket path is possible as well: ```swift -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) -httpClient.execute( - .POST, - secureSocketPath: "/tmp/myServer.socket", - urlPath: "/path/to/resource", +HTTPClient.shared.execute( + .POST, + secureSocketPath: "/tmp/myServer.socket", + urlPath: "/path/to/resource", body: .string("hello") ).whenComplete (...) ``` @@ -303,11 +280,11 @@ httpClient.execute( Direct URLs can easily be constructed to be executed in other scenarios: ```swift let socketPathBasedURL = URL( - httpURLWithSocketPath: "/tmp/myServer.socket", + httpURLWithSocketPath: "/tmp/myServer.socket", uri: "/path/to/resource" ) let secureSocketPathBasedURL = URL( - httpsURLWithSocketPath: "/tmp/myServer.socket", + httpsURLWithSocketPath: "/tmp/myServer.socket", uri: "/path/to/resource" ) ``` @@ -318,7 +295,7 @@ The exclusive use of HTTP/1 is possible by setting `httpVersion` to `.http1Only` var configuration = HTTPClient.Configuration() configuration.httpVersion = .http1Only let client = HTTPClient( - eventLoopGroupProvider: .createNew, + eventLoopGroupProvider: .singleton, configuration: configuration ) ``` @@ -326,3 +303,17 @@ let client = HTTPClient( ## Security Please have a look at [SECURITY.md](SECURITY.md) for AsyncHTTPClient's security process. + +## Supported Versions + +The most recent versions of AsyncHTTPClient support Swift 5.6 and newer. The minimum Swift version supported by AsyncHTTPClient releases are detailed below: + +AsyncHTTPClient | Minimum Swift Version +--------------------|---------------------- +`1.0.0 ..< 1.5.0` | 5.0 +`1.5.0 ..< 1.10.0` | 5.2 +`1.10.0 ..< 1.13.0` | 5.4 +`1.13.0 ..< 1.18.0` | 5.5.2 +`1.18.0 ..< 1.20.0` | 5.6 +`1.20.0 ..< 1.21.0` | 5.7 +`1.21.0 ...` | 5.8 diff --git a/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift new file mode 100644 index 000000000..8f6b32bd2 --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@usableFromInline +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +struct AnyAsyncSequence: Sendable, AsyncSequence { + @usableFromInline typealias AsyncIteratorNextCallback = () async throws -> Element? + + @usableFromInline struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline let nextCallback: AsyncIteratorNextCallback + + @inlinable init(nextCallback: @escaping AsyncIteratorNextCallback) { + self.nextCallback = nextCallback + } + + @inlinable mutating func next() async throws -> Element? { + try await self.nextCallback() + } + } + + @usableFromInline var makeAsyncIteratorCallback: @Sendable () -> AsyncIteratorNextCallback + + @inlinable init( + _ asyncSequence: SequenceOfBytes + ) where SequenceOfBytes: AsyncSequence & Sendable, SequenceOfBytes.Element == Element { + self.makeAsyncIteratorCallback = { + var iterator = asyncSequence.makeAsyncIterator() + return { + try await iterator.next() + } + } + } + + @inlinable func makeAsyncIterator() -> AsyncIterator { + .init(nextCallback: self.makeAsyncIteratorCallback()) + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequenceProucerDelete.swift b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequenceProucerDelete.swift new file mode 100644 index 000000000..1e35df7f2 --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequenceProucerDelete.swift @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +@usableFromInline +struct AnyAsyncSequenceProducerDelegate: NIOAsyncSequenceProducerDelegate { + @usableFromInline + var delegate: NIOAsyncSequenceProducerDelegate + + @inlinable + init(_ delegate: Delegate) { + self.delegate = delegate + } + + @inlinable + func produceMore() { + self.delegate.produceMore() + } + + @inlinable + func didTerminate() { + self.delegate.didTerminate() + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/AsyncLazySequence.swift b/Sources/AsyncHTTPClient/AsyncAwait/AsyncLazySequence.swift new file mode 100644 index 000000000..fe37dd5e7 --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/AsyncLazySequence.swift @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +@usableFromInline +struct AsyncLazySequence: AsyncSequence { + @usableFromInline typealias Element = Base.Element + @usableFromInline struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline var iterator: Base.Iterator + @inlinable init(iterator: Base.Iterator) { + self.iterator = iterator + } + + @inlinable mutating func next() async throws -> Base.Element? { + self.iterator.next() + } + } + + @usableFromInline var base: Base + + @inlinable init(base: Base) { + self.base = base + } + + @inlinable func makeAsyncIterator() -> AsyncIterator { + .init(iterator: self.base.makeIterator()) + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension AsyncLazySequence: Sendable where Base: Sendable {} +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension AsyncLazySequence.AsyncIterator: Sendable where Base.Iterator: Sendable {} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension Sequence { + /// Turns `self` into an `AsyncSequence` by vending each element of `self` asynchronously. + @inlinable var async: AsyncLazySequence { + .init(base: self) + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift index 043ad510b..3c3a6030c 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift @@ -12,12 +12,12 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) -import struct Foundation.URL import Logging import NIOCore import NIOHTTP1 +import struct Foundation.URL + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClient { /// Execute arbitrary HTTP requests. @@ -26,6 +26,10 @@ extension HTTPClient { /// - request: HTTP request to execute. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. + /// + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. + /// /// - Returns: The response to the request. Note that the `body` of the response may not yet have been fully received. public func execute( _ request: HTTPClientRequest, @@ -51,6 +55,10 @@ extension HTTPClient { /// - request: HTTP request to execute. /// - timeout: time the the request has to complete. /// - logger: The logger to use for this request. + /// + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. + /// /// - Returns: The response to the request. Note that the `body` of the response may not yet have been fully received. public func execute( _ request: HTTPClientRequest, @@ -67,6 +75,8 @@ extension HTTPClient { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClient { + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. private func executeAndFollowRedirectsIfNeeded( _ request: HTTPClientRequest, deadline: NIODeadline, @@ -78,19 +88,21 @@ extension HTTPClient { // this loop is there to follow potential redirects while true { - let preparedRequest = try HTTPClientRequest.Prepared(currentRequest) - let response = try await executeCancellable(preparedRequest, deadline: deadline, logger: logger) + let preparedRequest = try HTTPClientRequest.Prepared(currentRequest, dnsOverride: configuration.dnsOverride) + let response = try await self.executeCancellable(preparedRequest, deadline: deadline, logger: logger) guard var redirectState = currentRedirectState else { // a `nil` redirectState means we should not follow redirects return response } - guard let redirectURL = response.headers.extractRedirectTarget( - status: response.status, - originalURL: preparedRequest.url, - originalScheme: preparedRequest.poolKey.scheme - ) else { + guard + let redirectURL = response.headers.extractRedirectTarget( + status: response.status, + originalURL: preparedRequest.url, + originalScheme: preparedRequest.poolKey.scheme + ) + else { // response does not want a redirect return response } @@ -114,6 +126,8 @@ extension HTTPClient { } } + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. private func executeCancellable( _ request: HTTPClientRequest.Prepared, deadline: NIODeadline, @@ -121,31 +135,35 @@ extension HTTPClient { ) async throws -> HTTPClientResponse { let cancelHandler = TransactionCancelHandler() - return try await withTaskCancellationHandler(operation: { () async throws -> HTTPClientResponse in - let eventLoop = self.eventLoopGroup.any() - let deadlineTask = eventLoop.scheduleTask(deadline: deadline) { - cancelHandler.cancel(reason: .deadlineExceeded) + return try await withTaskCancellationHandler( + operation: { () async throws -> HTTPClientResponse in + let eventLoop = self.eventLoopGroup.any() + let deadlineTask = eventLoop.scheduleTask(deadline: deadline) { + cancelHandler.cancel(reason: .deadlineExceeded) + } + defer { + deadlineTask.cancel() + } + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) -> Void in + let transaction = Transaction( + request: request, + requestOptions: .fromClientConfiguration(self.configuration), + logger: logger, + connectionDeadline: .now() + (self.configuration.timeout.connectionCreationTimeout), + preferredEventLoop: eventLoop, + responseContinuation: continuation + ) + + cancelHandler.registerTransaction(transaction) + + self.poolManager.executeRequest(transaction) + } + }, + onCancel: { + cancelHandler.cancel(reason: .taskCanceled) } - defer { - deadlineTask.cancel() - } - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) -> Void in - let transaction = Transaction( - request: request, - requestOptions: .init(idleReadTimeout: nil), - logger: logger, - connectionDeadline: deadline, - preferredEventLoop: eventLoop, - responseContinuation: continuation - ) - - cancelHandler.registerTransaction(transaction) - - self.poolManager.executeRequest(transaction) - } - }, onCancel: { - cancelHandler.cancel(reason: .taskCanceled) - }) + ) } } @@ -215,5 +233,3 @@ private actor TransactionCancelHandler { } } } - -#endif diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+shutdown.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+shutdown.swift index 4e7090dbf..43020c3e5 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+shutdown.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+shutdown.swift @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) +import NIOCore extension HTTPClient { /// Shuts down the client and `EventLoopGroup` if it was created by the client. @@ -30,5 +30,3 @@ extension HTTPClient { } } } - -#endif diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift index de09df5b8..d4eeae03e 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift @@ -12,25 +12,41 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) -import struct Foundation.URL +import NIOCore import NIOHTTP1 +import NIOSSL + +import struct Foundation.URL @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest { struct Prepared { + enum Body { + case asyncSequence( + length: RequestBodyLength, + nextBodyPart: (ByteBufferAllocator) async throws -> ByteBuffer? + ) + case sequence( + length: RequestBodyLength, + canBeConsumedMultipleTimes: Bool, + makeCompleteBody: (ByteBufferAllocator) -> ByteBuffer + ) + case byteBuffer(ByteBuffer) + } + var url: URL var poolKey: ConnectionPool.Key var requestFramingMetadata: RequestFramingMetadata var head: HTTPRequestHead var body: Body? + var tlsConfiguration: TLSConfiguration? } } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest.Prepared { - init(_ request: HTTPClientRequest) throws { - guard let url = URL(string: request.url) else { + init(_ request: HTTPClientRequest, dnsOverride: [String: String] = [:]) throws { + guard !request.url.isEmpty, let url = URL(string: request.url) else { throw HTTPClientError.invalidURL } @@ -45,7 +61,7 @@ extension HTTPClientRequest.Prepared { self.init( url: url, - poolKey: .init(url: deconstructedURL, tlsConfiguration: nil), + poolKey: .init(url: deconstructedURL, tlsConfiguration: request.tlsConfiguration, dnsOverride: dnsOverride), requestFramingMetadata: metadata, head: .init( version: .http1_1, @@ -53,11 +69,30 @@ extension HTTPClientRequest.Prepared { uri: deconstructedURL.uri, headers: headers ), - body: request.body + body: request.body.map { .init($0) }, + tlsConfiguration: request.tlsConfiguration ) } } +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientRequest.Prepared.Body { + init(_ body: HTTPClientRequest.Body) { + switch body.mode { + case .asyncSequence(let length, let makeAsyncIterator): + self = .asyncSequence(length: length, nextBodyPart: makeAsyncIterator()) + case .sequence(let length, let canBeConsumedMultipleTimes, let makeCompleteBody): + self = .sequence( + length: length, + canBeConsumedMultipleTimes: canBeConsumedMultipleTimes, + makeCompleteBody: makeCompleteBody + ) + case .byteBuffer(let byteBuffer): + self = .byteBuffer(byteBuffer) + } + } +} + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension RequestBodyLength { init(_ body: HTTPClientRequest.Body?) { @@ -65,7 +100,7 @@ extension RequestBodyLength { case .none: self = .known(0) case .byteBuffer(let buffer): - self = .known(buffer.readableBytes) + self = .known(Int64(buffer.readableBytes)) case .sequence(let length, _, _), .asyncSequence(let length, _): self = length } @@ -94,5 +129,3 @@ extension HTTPClientRequest { return newRequest } } - -#endif diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+auth.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+auth.swift new file mode 100644 index 000000000..106a8f76b --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+auth.swift @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2024 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Foundation + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientRequest { + /// Set basic auth for a request. + /// + /// - parameters: + /// - username: the username to authenticate with + /// - password: authentication password associated with the username + public mutating func setBasicAuth(username: String, password: String) { + self.headers.setBasicAuth(username: username, password: password) + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift index cfab828a0..f07a2ed41 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift @@ -12,33 +12,85 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) +import Algorithms import NIOCore import NIOHTTP1 +import NIOSSL +@usableFromInline +let bagOfBytesToByteBufferConversionChunkSize = 1024 * 1024 * 4 + +#if arch(arm) || arch(i386) +// on 32-bit platforms we can't make use of a whole UInt32.max (as it doesn't fit in an Int) +@usableFromInline +let byteBufferMaxSize = Int.max +#else +// on 64-bit platforms we're good +@usableFromInline +let byteBufferMaxSize = Int(UInt32.max) +#endif + +/// A representation of an HTTP request for the Swift Concurrency HTTPClient API. +/// +/// This object is similar to ``HTTPClient/Request``, but used for the Swift Concurrency API. +/// +/// - note: For many ``HTTPClientRequest/body-swift.property`` configurations, this type is _not_ a value type +/// (https://github.com/swift-server/async-http-client/issues/708). @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -public struct HTTPClientRequest { +public struct HTTPClientRequest: Sendable { + /// The request URL, including scheme, hostname, and optionally port. public var url: String + + /// The request method. public var method: HTTPMethod + + /// The request headers. public var headers: HTTPHeaders + /// The request body, if any. public var body: Body? + /// Request-specific TLS configuration, defaults to no request-specific TLS configuration. + public var tlsConfiguration: TLSConfiguration? + public init(url: String) { self.url = url self.method = .GET self.headers = .init() self.body = .none + self.tlsConfiguration = nil } } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest { - public struct Body { + /// An HTTP request body. + /// + /// This object encapsulates the difference between streamed HTTP request bodies and those bodies that + /// are already entirely in memory. + public struct Body: Sendable { @usableFromInline - internal enum Mode { - case asyncSequence(length: RequestBodyLength, (ByteBufferAllocator) async throws -> ByteBuffer?) - case sequence(length: RequestBodyLength, canBeConsumedMultipleTimes: Bool, (ByteBufferAllocator) -> ByteBuffer) + internal enum Mode: Sendable { + /// - parameters: + /// - length: complete body length. + /// If `length` is `.known`, `nextBodyPart` is not allowed to produce more bytes than `length` defines. + /// - makeAsyncIterator: Creates a new async iterator under the hood and returns a function which will call `next()` on it. + /// The returned function then produce the next body buffer asynchronously. + /// We use a closure as an abstraction instead of an existential to enable specialization. + case asyncSequence( + length: RequestBodyLength, + makeAsyncIterator: @Sendable () -> ((ByteBufferAllocator) async throws -> ByteBuffer?) + ) + /// - parameters: + /// - length: complete body length. + /// If `length` is `.known`, `nextBodyPart` is not allowed to produce more bytes than `length` defines. + /// - canBeConsumedMultipleTimes: if `makeBody` can be called multiple times and returns the same result. + /// - makeCompleteBody: function to produce the complete body. + case sequence( + length: RequestBodyLength, + canBeConsumedMultipleTimes: Bool, + makeCompleteBody: @Sendable (ByteBufferAllocator) -> ByteBuffer + ) case byteBuffer(ByteBuffer) } @@ -54,91 +106,233 @@ extension HTTPClientRequest { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest.Body { + /// Create an ``HTTPClientRequest/Body-swift.struct`` from a `ByteBuffer`. + /// + /// - parameter byteBuffer: The bytes of the body. public static func bytes(_ byteBuffer: ByteBuffer) -> Self { self.init(.byteBuffer(byteBuffer)) } + /// Create an ``HTTPClientRequest/Body-swift.struct`` from a `RandomAccessCollection` of bytes. + /// + /// This construction will flatten the `bytes` into a `ByteBuffer` in chunks of ~4MB. + /// As a result, the peak memory usage of this construction will be a small multiple of ~4MB. + /// The construction of the `ByteBuffer` will be delayed until it's needed. + /// + /// - parameter bytes: The bytes of the request body. @inlinable - public static func bytes( + @preconcurrency + public static func bytes( _ bytes: Bytes ) -> Self where Bytes.Element == UInt8 { - self.init(.sequence( - length: .known(bytes.count), - canBeConsumedMultipleTimes: true - ) { allocator in - if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) { - // fastpath - return buffer - } - // potentially really slow path - return allocator.buffer(bytes: bytes) - }) + self.bytes(bytes, length: .known(Int64(bytes.count))) } + /// Create an ``HTTPClientRequest/Body-swift.struct`` from a `Sequence` of bytes. + /// + /// This construction will flatten the bytes into a `ByteBuffer`. As a result, the peak memory + /// usage of this construction will be double the size of the original collection. The construction + /// of the `ByteBuffer` will be delayed until it's needed. + /// + /// Unlike ``bytes(_:)-1uns7``, this construction does not assume that the body can be replayed. As a result, + /// if a redirect is encountered that would need us to replay the request body, the redirect will instead + /// not be followed. Prefer ``bytes(_:)-1uns7`` wherever possible. + /// + /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths + /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. + /// + /// - parameters: + /// - bytes: The bytes of the request body. + /// - length: The length of the request body. @inlinable - public static func bytes( + @preconcurrency + public static func bytes( _ bytes: Bytes, length: Length ) -> Self where Bytes.Element == UInt8 { - self.init(.sequence( - length: length.storage, - canBeConsumedMultipleTimes: false - ) { allocator in - if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) { - // fastpath - return buffer + Self._bytes( + bytes, + length: length, + bagOfBytesToByteBufferConversionChunkSize: bagOfBytesToByteBufferConversionChunkSize, + byteBufferMaxSize: byteBufferMaxSize + ) + } + + /// internal method to test chunking + @inlinable + @preconcurrency + static func _bytes( + _ bytes: Bytes, + length: Length, + bagOfBytesToByteBufferConversionChunkSize: Int, + byteBufferMaxSize: Int + ) -> Self where Bytes.Element == UInt8 { + // fast path + let body: Self? = bytes.withContiguousStorageIfAvailable { bufferPointer -> Self in + // `some Sequence` is special as it can't be efficiently chunked lazily. + // Therefore we need to do the chunking eagerly if it implements the fast path withContiguousStorageIfAvailable + // If we do it eagerly, it doesn't make sense to do a bunch of small chunks, so we only chunk if it exceeds + // the maximum size of a ByteBuffer. + if bufferPointer.count <= byteBufferMaxSize { + let buffer = ByteBuffer(bytes: bufferPointer) + return Self( + .sequence( + length: length.storage, + canBeConsumedMultipleTimes: true, + makeCompleteBody: { _ in buffer } + ) + ) + } else { + // we need to copy `bufferPointer` eagerly as the pointer is only valid during the call to `withContiguousStorageIfAvailable` + let buffers: [ByteBuffer] = bufferPointer.chunks(ofCount: byteBufferMaxSize).map { + ByteBuffer(bytes: $0) + } + return Self( + .asyncSequence( + length: length.storage, + makeAsyncIterator: { + var iterator = buffers.makeIterator() + return { _ in + iterator.next() + } + } + ) + ) } - // potentially really slow path - return allocator.buffer(bytes: bytes) - }) + } + if let body = body { + return body + } + + // slow path + return Self( + .asyncSequence( + length: length.storage + ) { + var iterator = bytes.makeIterator() + return { allocator in + var buffer = allocator.buffer(capacity: bagOfBytesToByteBufferConversionChunkSize) + while buffer.writableBytes > 0, let byte = iterator.next() { + buffer.writeInteger(byte) + } + if buffer.readableBytes > 0 { + return buffer + } + return nil + } + } + ) } + /// Create an ``HTTPClientRequest/Body-swift.struct`` from a `Collection` of bytes. + /// + /// This construction will flatten the `bytes` into a `ByteBuffer` in chunks of ~4MB. + /// As a result, the peak memory usage of this construction will be a small multiple of ~4MB. + /// The construction of the `ByteBuffer` will be delayed until it's needed. + /// + /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths + /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. + /// + /// - parameters: + /// - bytes: The bytes of the request body. + /// - length: The length of the request body. @inlinable - public static func bytes( + @preconcurrency + public static func bytes( _ bytes: Bytes, length: Length ) -> Self where Bytes.Element == UInt8 { - self.init(.sequence( - length: length.storage, - canBeConsumedMultipleTimes: true - ) { allocator in - if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) { - // fastpath - return buffer - } - // potentially really slow path - return allocator.buffer(bytes: bytes) - }) + if bytes.count <= bagOfBytesToByteBufferConversionChunkSize { + return self.init( + .sequence( + length: length.storage, + canBeConsumedMultipleTimes: true + ) { allocator in + allocator.buffer(bytes: bytes) + } + ) + } else { + return self.init( + .asyncSequence( + length: length.storage, + makeAsyncIterator: { + var iterator = bytes.chunks(ofCount: bagOfBytesToByteBufferConversionChunkSize).makeIterator() + return { allocator in + guard let chunk = iterator.next() else { + return nil + } + return allocator.buffer(bytes: chunk) + } + } + ) + ) + } } + /// Create an ``HTTPClientRequest/Body-swift.struct`` from an `AsyncSequence` of `ByteBuffer`s. + /// + /// This construction will stream the upload one `ByteBuffer` at a time. + /// + /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths + /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. + /// + /// - parameters: + /// - sequenceOfBytes: The bytes of the request body. + /// - length: The length of the request body. @inlinable - public static func stream( + @preconcurrency + public static func stream( _ sequenceOfBytes: SequenceOfBytes, length: Length ) -> Self where SequenceOfBytes.Element == ByteBuffer { - var iterator = sequenceOfBytes.makeAsyncIterator() - let body = self.init(.asyncSequence(length: length.storage) { _ -> ByteBuffer? in - try await iterator.next() - }) + let body = self.init( + .asyncSequence(length: length.storage) { + var iterator = sequenceOfBytes.makeAsyncIterator() + return { _ -> ByteBuffer? in + try await iterator.next() + } + } + ) return body } + /// Create an ``HTTPClientRequest/Body-swift.struct`` from an `AsyncSequence` of bytes. + /// + /// This construction will consume 4MB chunks from the `Bytes` and send them at once. This optimizes for + /// `AsyncSequence`s where larger chunks are buffered up and available without actually suspending, such + /// as those provided by `FileHandle`. + /// + /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths + /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. + /// + /// - parameters: + /// - bytes: The bytes of the request body. + /// - length: The length of the request body. @inlinable - public static func stream( + @preconcurrency + public static func stream( _ bytes: Bytes, length: Length ) -> Self where Bytes.Element == UInt8 { - var iterator = bytes.makeAsyncIterator() - let body = self.init(.asyncSequence(length: length.storage) { allocator -> ByteBuffer? in - var buffer = allocator.buffer(capacity: 1024) // TODO: Magic number - while buffer.writableBytes > 0, let byte = try await iterator.next() { - buffer.writeInteger(byte) - } - if buffer.readableBytes > 0 { - return buffer + let body = self.init( + .asyncSequence(length: length.storage) { + var iterator = bytes.makeAsyncIterator() + return { allocator -> ByteBuffer? in + var buffer = allocator.buffer(capacity: bagOfBytesToByteBufferConversionChunkSize) + while buffer.writableBytes > 0, let byte = try await iterator.next() { + buffer.writeInteger(byte) + } + if buffer.readableBytes > 0 { + return buffer + } + return nil + } } - return nil - }) + ) return body } } @@ -157,11 +351,19 @@ extension Optional where Wrapped == HTTPClientRequest.Body { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest.Body { - public struct Length { - /// size of the request body is not known before starting the request + /// The length of a HTTP request body. + public struct Length: Sendable { + /// The size of the request body is not known before starting the request public static let unknown: Self = .init(storage: .unknown) - /// size of the request body is fixed and exactly `count` bytes + + /// The size of the request body is known and exactly `count` bytes + @available(*, deprecated, message: "Use `known(_ count: Int64)` with an explicit Int64 argument instead") public static func known(_ count: Int) -> Self { + .init(storage: .known(Int64(count))) + } + + /// The size of the request body is known and exactly `count` bytes + public static func known(_ count: Int64) -> Self { .init(storage: .known(count)) } @@ -170,4 +372,52 @@ extension HTTPClientRequest.Body { } } -#endif +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientRequest.Body: AsyncSequence { + public typealias Element = ByteBuffer + + @inlinable + public func makeAsyncIterator() -> AsyncIterator { + switch self.mode { + case .asyncSequence(_, let makeAsyncIterator): + return .init(storage: .makeNext(makeAsyncIterator())) + case .sequence(_, _, let makeCompleteBody): + return .init(storage: .byteBuffer(makeCompleteBody(AsyncIterator.allocator))) + case .byteBuffer(let byteBuffer): + return .init(storage: .byteBuffer(byteBuffer)) + } + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientRequest.Body { + public struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline + static let allocator = ByteBufferAllocator() + + @usableFromInline + enum Storage { + case byteBuffer(ByteBuffer?) + case makeNext((ByteBufferAllocator) async throws -> ByteBuffer?) + } + + @usableFromInline + var storage: Storage + + @inlinable + init(storage: Storage) { + self.storage = storage + } + + @inlinable + public mutating func next() async throws -> ByteBuffer? { + switch self.storage { + case .byteBuffer(let buffer): + self.storage = .byteBuffer(nil) + return buffer + case .makeNext(let makeNext): + return try await makeNext(Self.allocator) + } + } + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift index 52f03089b..832eb7b41 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift @@ -12,105 +12,208 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) import NIOCore import NIOHTTP1 +/// A representation of an HTTP response for the Swift Concurrency HTTPClient API. +/// +/// This object is similar to ``HTTPClient/Response``, but used for the Swift Concurrency API. @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -public struct HTTPClientResponse { +public struct HTTPClientResponse: Sendable { + /// The HTTP version on which the response was received. public var version: HTTPVersion + + /// The HTTP status for this response. public var status: HTTPResponseStatus + + /// The HTTP headers of this response. public var headers: HTTPHeaders - public var body: Body - public struct Body { - private let bag: Transaction - private let reference: ResponseRef + /// The body of this HTTP response. + public var body: Body - fileprivate init(_ transaction: Transaction) { - self.bag = transaction - self.reference = ResponseRef(transaction: transaction) - } + @inlinable public init( + version: HTTPVersion = .http1_1, + status: HTTPResponseStatus = .ok, + headers: HTTPHeaders = [:], + body: Body = Body() + ) { + self.version = version + self.status = status + self.headers = headers + self.body = body } init( - bag: Transaction, + requestMethod: HTTPMethod, version: HTTPVersion, status: HTTPResponseStatus, - headers: HTTPHeaders + headers: HTTPHeaders, + body: TransactionBody ) { - self.body = Body(bag) - self.version = version - self.status = status - self.headers = headers + self.init( + version: version, + status: status, + headers: headers, + body: .init( + .transaction( + body, + expectedContentLength: HTTPClientResponse.expectedContentLength( + requestMethod: requestMethod, + headers: headers, + status: status + ) + ) + ) + ) } } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension HTTPClientResponse.Body: AsyncSequence { - public typealias Element = AsyncIterator.Element +extension HTTPClientResponse { + /// A representation of the response body for an HTTP response. + /// + /// The body is streamed as an `AsyncSequence` of `ByteBuffer`, where each `ByteBuffer` contains + /// an arbitrarily large chunk of data. The boundaries between `ByteBuffer` objects in the sequence + /// are entirely synthetic and have no semantic meaning. + public struct Body: AsyncSequence, Sendable { + public typealias Element = ByteBuffer + public struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline var storage: Storage.AsyncIterator + + @inlinable init(storage: Storage.AsyncIterator) { + self.storage = storage + } + + @inlinable public mutating func next() async throws -> ByteBuffer? { + try await self.storage.next() + } + } - public struct AsyncIterator: AsyncIteratorProtocol { - private let stream: IteratorStream + @usableFromInline var storage: Storage - fileprivate init(stream: IteratorStream) { - self.stream = stream + @inlinable public func makeAsyncIterator() -> AsyncIterator { + .init(storage: self.storage.makeAsyncIterator()) } - public mutating func next() async throws -> ByteBuffer? { - try await self.stream.next() + @inlinable init(storage: Storage) { + self.storage = storage + } + + /// Accumulates `Body` of `ByteBuffer`s into a single `ByteBuffer`. + /// - Parameters: + /// - maxBytes: The maximum number of bytes this method is allowed to accumulate + /// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`. + /// - Returns: the number of bytes collected over time + @inlinable public func collect(upTo maxBytes: Int) async throws -> ByteBuffer { + switch self.storage { + case .transaction(_, let expectedContentLength): + if let contentLength = expectedContentLength { + if contentLength > maxBytes { + throw NIOTooManyBytesError(maxBytes: maxBytes) + } + } + case .anyAsyncSequence: + break + } + + /// calling collect function within here in order to ensure the correct nested type + func collect(_ body: Body, maxBytes: Int) async throws -> ByteBuffer + where Body.Element == ByteBuffer { + try await body.collect(upTo: maxBytes) + } + return try await collect(self, maxBytes: maxBytes) } } +} - public func makeAsyncIterator() -> AsyncIterator { - AsyncIterator(stream: IteratorStream(bag: self.bag)) +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientResponse { + static func expectedContentLength( + requestMethod: HTTPMethod, + headers: HTTPHeaders, + status: HTTPResponseStatus + ) -> Int? { + if status == .notModified { + return 0 + } else if requestMethod == .HEAD { + return 0 + } else { + let contentLength = headers["content-length"].first.flatMap { Int($0, radix: 10) } + return contentLength + } } } +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +@usableFromInline +typealias TransactionBody = NIOThrowingAsyncSequenceProducer< + ByteBuffer, + Error, + NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, + AnyAsyncSequenceProducerDelegate +> + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientResponse.Body { - /// The purpose of this object is to inform the transaction about the response body being deinitialized. - /// If the users has not called `makeAsyncIterator` on the body, before it is deinited, the http - /// request needs to be cancelled. - fileprivate class ResponseRef { - private let transaction: Transaction - - init(transaction: Transaction) { - self.transaction = transaction - } + @usableFromInline enum Storage: Sendable { + case transaction(TransactionBody, expectedContentLength: Int?) + case anyAsyncSequence(AnyAsyncSequence) + } +} - deinit { - self.transaction.responseBodyDeinited() +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientResponse.Body.Storage: AsyncSequence { + @usableFromInline typealias Element = ByteBuffer + + @inlinable func makeAsyncIterator() -> AsyncIterator { + switch self { + case .transaction(let transaction, _): + return .transaction(transaction.makeAsyncIterator()) + case .anyAsyncSequence(let anyAsyncSequence): + return .anyAsyncSequence(anyAsyncSequence.makeAsyncIterator()) } } } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension HTTPClientResponse.Body { - internal class IteratorStream { - struct ID: Hashable { - private let objectID: ObjectIdentifier +extension HTTPClientResponse.Body.Storage { + @usableFromInline enum AsyncIterator { + case transaction(TransactionBody.AsyncIterator) + case anyAsyncSequence(AnyAsyncSequence.AsyncIterator) + } +} - init(_ object: IteratorStream) { - self.objectID = ObjectIdentifier(object) - } +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientResponse.Body.Storage.AsyncIterator: AsyncIteratorProtocol { + @inlinable mutating func next() async throws -> ByteBuffer? { + switch self { + case .transaction(let iterator): + return try await iterator.next() + case .anyAsyncSequence(var iterator): + defer { self = .anyAsyncSequence(iterator) } + return try await iterator.next() } + } +} - private var id: ID { ID(self) } - private let bag: Transaction +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientResponse.Body { + @inlinable init(_ storage: Storage) { + self.storage = storage + } - init(bag: Transaction) { - self.bag = bag - } + public init() { + self = .stream(EmptyCollection().async) + } - deinit { - self.bag.responseBodyIteratorDeinited(streamID: self.id) - } + @inlinable public static func stream( + _ sequenceOfBytes: SequenceOfBytes + ) -> Self where SequenceOfBytes: AsyncSequence & Sendable, SequenceOfBytes.Element == ByteBuffer { + Self(storage: .anyAsyncSequence(AnyAsyncSequence(sequenceOfBytes.singleIteratorPrecondition))) + } - func next() async throws -> ByteBuffer? { - try await self.bag.nextResponsePart(streamID: self.id) - } + public static func bytes(_ byteBuffer: ByteBuffer) -> Self { + .stream(CollectionOfOne(byteBuffer).async) } } - -#endif diff --git a/Sources/AsyncHTTPClient/AsyncAwait/SingleIteratorPrecondition.swift b/Sources/AsyncHTTPClient/AsyncAwait/SingleIteratorPrecondition.swift new file mode 100644 index 000000000..04034db2d --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/SingleIteratorPrecondition.swift @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Atomics + +/// Makes sure that a consumer of this `AsyncSequence` only calls `makeAsyncIterator()` at most once. +/// If `makeAsyncIterator()` is called multiple times, the program crashes. +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +@usableFromInline struct SingleIteratorPrecondition: AsyncSequence { + @usableFromInline let base: Base + @usableFromInline let didCreateIterator: ManagedAtomic = .init(false) + @usableFromInline typealias Element = Base.Element + @inlinable init(base: Base) { + self.base = base + } + + @inlinable func makeAsyncIterator() -> Base.AsyncIterator { + precondition( + self.didCreateIterator.exchange(true, ordering: .relaxed) == false, + "makeAsyncIterator() is only allowed to be called at most once." + ) + return self.base.makeAsyncIterator() + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension SingleIteratorPrecondition: @unchecked Sendable where Base: Sendable {} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension AsyncSequence { + @inlinable var singleIteratorPrecondition: SingleIteratorPrecondition { + .init(base: self) + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift index dea1093db..6cf0dbc07 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift @@ -11,13 +11,14 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) + import Logging import NIOCore import NIOHTTP1 @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension Transaction { + @usableFromInline struct StateMachine { struct ExecutionContext { let executor: HTTPRequestExecutor @@ -28,32 +29,24 @@ extension Transaction { private enum State { case initialized(CheckedContinuation) case queued(CheckedContinuation, HTTPRequestScheduler) + case deadlineExceededWhileQueued(CheckedContinuation) case executing(ExecutionContext, RequestStreamState, ResponseStreamState) - case finished(error: Error?, HTTPClientResponse.Body.IteratorStream.ID?) + case finished(error: Error?) } - fileprivate enum RequestStreamState { + fileprivate enum RequestStreamState: Sendable { case requestHeadSent case producing case paused(continuation: CheckedContinuation?) case finished } - fileprivate enum ResponseStreamState { - enum Next { - case askExecutorForMore - case error(Error) - case endOfFile - } - - // Waiting for response head. Valid transitions to: waitingForStream. + fileprivate enum ResponseStreamState: Sendable { + // Waiting for response head. Valid transitions to: streamingBody. case waitingForResponseHead - // We are waiting for the user to create a response body iterator and to call next on - // it for the first time. - case waitingForResponseIterator(CircularBuffer, next: Next) - case buffering(HTTPClientResponse.Body.IteratorStream.ID, CircularBuffer, next: Next) - case waitingForRemote(HTTPClientResponse.Body.IteratorStream.ID, CheckedContinuation) - case finished(HTTPClientResponse.Body.IteratorStream.ID, CheckedContinuation) + // streaming response body. Valid transitions to: finished. + case streamingBody(TransactionBody.Source) + case finished } private var state: State @@ -89,9 +82,20 @@ extension Transaction { enum FailAction { case none /// fail response before head received. scheduler and executor are exclusive here. - case failResponseHead(CheckedContinuation, Error, HTTPRequestScheduler?, HTTPRequestExecutor?, bodyStreamContinuation: CheckedContinuation?) + case failResponseHead( + CheckedContinuation, + Error, + HTTPRequestScheduler?, + HTTPRequestExecutor?, + bodyStreamContinuation: CheckedContinuation? + ) /// fail response after response head received. fail the response stream (aka call to `next()`) - case failResponseStream(CheckedContinuation, Error, HTTPRequestExecutor, bodyStreamContinuation: CheckedContinuation?) + case failResponseStream( + TransactionBody.Source, + Error, + HTTPRequestExecutor, + bodyStreamContinuation: CheckedContinuation? + ) case failRequestStreamContinuation(CheckedContinuation, Error) } @@ -99,78 +103,66 @@ extension Transaction { mutating func fail(_ error: Error) -> FailAction { switch self.state { case .initialized(let continuation): - self.state = .finished(error: error, nil) + self.state = .finished(error: error) return .failResponseHead(continuation, error, nil, nil, bodyStreamContinuation: nil) case .queued(let continuation, let scheduler): - self.state = .finished(error: error, nil) + self.state = .finished(error: error) return .failResponseHead(continuation, error, scheduler, nil, bodyStreamContinuation: nil) - + case .deadlineExceededWhileQueued(let continuation): + let realError: Error = { + if (error as? HTTPClientError) == .cancelled { + /// if we just get a `HTTPClientError.cancelled` we can use the original cancellation reason + /// to give a more descriptive error to the user. + return HTTPClientError.deadlineExceeded + } else { + /// otherwise we already had an intermediate connection error which we should present to the user instead + return error + } + }() + + self.state = .finished(error: realError) + return .failResponseHead(continuation, realError, nil, nil, bodyStreamContinuation: nil) case .executing(let context, let requestStreamState, .waitingForResponseHead): switch requestStreamState { case .paused(continuation: .some(let continuation)): - self.state = .finished(error: error, nil) - return .failResponseHead(context.continuation, error, nil, context.executor, bodyStreamContinuation: continuation) + self.state = .finished(error: error) + return .failResponseHead( + context.continuation, + error, + nil, + context.executor, + bodyStreamContinuation: continuation + ) case .requestHeadSent, .finished, .producing, .paused(continuation: .none): - self.state = .finished(error: error, nil) - return .failResponseHead(context.continuation, error, nil, context.executor, bodyStreamContinuation: nil) - } - - case .executing(let context, let requestStreamState, .waitingForResponseIterator(let buffer, next: .askExecutorForMore)), - .executing(let context, let requestStreamState, .waitingForResponseIterator(let buffer, next: .endOfFile)): - switch requestStreamState { - case .paused(.some(let continuation)): - self.state = .executing(context, .finished, .waitingForResponseIterator(buffer, next: .error(error))) - return .failRequestStreamContinuation(continuation, error) - - case .requestHeadSent, .producing, .paused(continuation: .none), .finished: - self.state = .executing(context, .finished, .waitingForResponseIterator(buffer, next: .error(error))) - return .none - } - - case .executing(let context, let requestStreamState, .buffering(let streamID, let buffer, next: .askExecutorForMore)), - .executing(let context, let requestStreamState, .buffering(let streamID, let buffer, next: .endOfFile)): - switch requestStreamState { - case .paused(continuation: .some(let continuation)): - self.state = .executing(context, .finished, .buffering(streamID, buffer, next: .error(error))) - return .failRequestStreamContinuation(continuation, error) - - case .requestHeadSent, .paused(continuation: .none), .producing, .finished: - self.state = .executing(context, .finished, .buffering(streamID, buffer, next: .error(error))) - return .none + self.state = .finished(error: error) + return .failResponseHead( + context.continuation, + error, + nil, + context.executor, + bodyStreamContinuation: nil + ) } - case .executing(let context, let requestStreamState, .waitingForRemote(let streamID, let continuation)): - // We are in response streaming. The response stream is waiting for the next bytes - // from the server. We can fail the call to `next` immediately. + case .executing(let context, let requestStreamState, .streamingBody(let source)): + self.state = .finished(error: error) switch requestStreamState { - case .paused(continuation: .some(let bodyStreamContinuation)): - self.state = .finished(error: error, streamID) - return .failResponseStream(continuation, error, context.executor, bodyStreamContinuation: bodyStreamContinuation) - - case .requestHeadSent, .paused(continuation: .none), .producing, .finished: - self.state = .finished(error: error, streamID) - return .failResponseStream(continuation, error, context.executor, bodyStreamContinuation: nil) + case .paused(let bodyStreamContinuation): + return .failResponseStream( + source, + error, + context.executor, + bodyStreamContinuation: bodyStreamContinuation + ) + case .finished, .producing, .requestHeadSent: + return .failResponseStream(source, error, context.executor, bodyStreamContinuation: nil) } - case .finished(error: _, _), - .executing(_, _, .waitingForResponseIterator(_, next: .error)), - .executing(_, _, .buffering(_, _, next: .error)): - // The request has already failed, succeeded, or the users is not interested in the - // response. There is no more way to reach the user code. Just drop the error. + case .finished(error: _), + .executing(_, _, .finished): return .none - - case .executing(let context, let requestStreamState, .finished(let streamID, let continuation)): - switch requestStreamState { - case .paused(continuation: .some(let bodyStreamContinuation)): - self.state = .finished(error: error, streamID) - return .failResponseStream(continuation, error, context.executor, bodyStreamContinuation: bodyStreamContinuation) - - case .requestHeadSent, .paused(continuation: .none), .producing, .finished: - self.state = .finished(error: error, streamID) - return .failResponseStream(continuation, error, context.executor, bodyStreamContinuation: nil) - } } } @@ -178,6 +170,7 @@ extension Transaction { enum StartExecutionAction { case cancel(HTTPRequestExecutor) + case cancelAndFail(HTTPRequestExecutor, CheckedContinuation, with: Error) case none } @@ -191,13 +184,16 @@ extension Transaction { ) self.state = .executing(context, .requestHeadSent, .waitingForResponseHead) return .none + case .deadlineExceededWhileQueued(let continuation): + let error = HTTPClientError.deadlineExceeded + self.state = .finished(error: error) + return .cancelAndFail(executor, continuation, with: error) - case .finished(error: .some, .none): + case .finished(error: .some): return .cancel(executor) case .executing, - .finished(error: .none, _), - .finished(error: .some, .some): + .finished(error: .none): preconditionFailure("Invalid state: \(self.state)") } } @@ -210,8 +206,10 @@ extension Transaction { mutating func resumeRequestBodyStream() -> ResumeProducingAction { switch self.state { - case .initialized, .queued: - preconditionFailure("Received a resumeBodyRequest on a request, that isn't executing. Invalid state: \(self.state)") + case .initialized, .queued, .deadlineExceededWhileQueued: + preconditionFailure( + "Received a resumeBodyRequest on a request, that isn't executing. Invalid state: \(self.state)" + ) case .executing(let context, .requestHeadSent, let responseState): // the request can start to send its body. @@ -219,7 +217,9 @@ extension Transaction { return .startStream(context.allocator) case .executing(_, .producing, _): - preconditionFailure("Received a resumeBodyRequest on a request, that is producing. Invalid state: \(self.state)") + preconditionFailure( + "Received a resumeBodyRequest on a request, that is producing. Invalid state: \(self.state)" + ) case .executing(let context, .paused(.none), let responseState): // request stream is currently paused, but there is no write waiting. We don't need @@ -245,16 +245,17 @@ extension Transaction { mutating func pauseRequestBodyStream() { switch self.state { case .initialized, - .queued, - .executing(_, .requestHeadSent, _): + .queued, + .deadlineExceededWhileQueued, + .executing(_, .requestHeadSent, _): preconditionFailure("A request stream can only be resumed, if the request was started") case .executing(let context, .producing, let responseSteam): self.state = .executing(context, .paused(continuation: nil), responseSteam) case .executing(_, .paused, _), - .executing(_, .finished, _), - .finished: + .executing(_, .finished, _), + .finished: // the channels writability changed to paused after we have already forwarded all // request bytes. Can be ignored. break @@ -270,9 +271,12 @@ extension Transaction { func writeNextRequestPart() -> NextWriteAction { switch self.state { case .initialized, - .queued, - .executing(_, .requestHeadSent, _): - preconditionFailure("A request stream can only produce, if the request was started. Invalid state: \(self.state)") + .queued, + .deadlineExceededWhileQueued, + .executing(_, .requestHeadSent, _): + preconditionFailure( + "A request stream can only produce, if the request was started. Invalid state: \(self.state)" + ) case .executing(let context, .producing, _): // We are currently producing the request body. The executors channel is writable. @@ -290,7 +294,9 @@ extension Transaction { return .writeAndWait(context.executor) case .executing(_, .paused(continuation: .some), _): - preconditionFailure("A write continuation already exists, but we tried to set another one. Invalid state: \(self.state)") + preconditionFailure( + "A write continuation already exists, but we tried to set another one. Invalid state: \(self.state)" + ) case .finished, .executing(_, .finished, _): return .fail @@ -300,10 +306,13 @@ extension Transaction { mutating func waitForRequestBodyDemand(continuation: CheckedContinuation) { switch self.state { case .initialized, - .queued, - .executing(_, .requestHeadSent, _), - .executing(_, .finished, _): - preconditionFailure("A request stream can only produce, if the request was started. Invalid state: \(self.state)") + .queued, + .deadlineExceededWhileQueued, + .executing(_, .requestHeadSent, _), + .executing(_, .finished, _): + preconditionFailure( + "A request stream can only produce, if the request was started. Invalid state: \(self.state)" + ) case .executing(_, .producing, _): preconditionFailure() @@ -324,36 +333,38 @@ extension Transaction { } enum FinishAction { - // forward the notice that the request stream has finished. If finalContinuation is not - // nil, succeed the continuation with nil to signal the requests end. - case forwardStreamFinished(HTTPRequestExecutor, finalContinuation: CheckedContinuation?) + // forward the notice that the request stream has finished. + case forwardStreamFinished(HTTPRequestExecutor) case none } mutating func finishRequestBodyStream() -> FinishAction { switch self.state { case .initialized, - .queued, - .executing(_, .finished, _): + .queued, + .deadlineExceededWhileQueued, + .executing(_, .finished, _): preconditionFailure("Invalid state: \(self.state)") case .executing(_, .paused(continuation: .some), _): - preconditionFailure("Received a request body end, while having a registered back-pressure continuation. Invalid state: \(self.state)") + preconditionFailure( + "Received a request body end, while having a registered back-pressure continuation. Invalid state: \(self.state)" + ) case .executing(let context, .producing, let responseState), - .executing(let context, .paused(continuation: .none), let responseState), - .executing(let context, .requestHeadSent, let responseState): + .executing(let context, .paused(continuation: .none), let responseState), + .executing(let context, .requestHeadSent, let responseState): switch responseState { - case .finished(let registeredStreamID, let continuation): + case .finished: // if the response stream has already finished before the request, we must succeed // the final continuation. - self.state = .finished(error: nil, registeredStreamID) - return .forwardStreamFinished(context.executor, finalContinuation: continuation) + self.state = .finished(error: nil) + return .forwardStreamFinished(context.executor) - case .waitingForResponseHead, .waitingForResponseIterator, .waitingForRemote, .buffering: + case .waitingForResponseHead, .streamingBody: self.state = .executing(context, .finished, responseState) - return .forwardStreamFinished(context.executor, finalContinuation: nil) + return .forwardStreamFinished(context.executor) } case .finished: @@ -364,319 +375,125 @@ extension Transaction { // MARK: - Response - enum ReceiveResponseHeadAction { - case succeedResponseHead(HTTPResponseHead, CheckedContinuation) + case succeedResponseHead(TransactionBody, CheckedContinuation) case none } - mutating func receiveResponseHead(_ head: HTTPResponseHead) -> ReceiveResponseHeadAction { + mutating func receiveResponseHead( + _ head: HTTPResponseHead, + delegate: Delegate + ) -> ReceiveResponseHeadAction { switch self.state { case .initialized, - .queued, - .executing(_, _, .waitingForResponseIterator), - .executing(_, _, .buffering), - .executing(_, _, .waitingForRemote): - preconditionFailure("How can we receive a response, if the request hasn't started yet.") + .queued, + .deadlineExceededWhileQueued, + .executing(_, _, .streamingBody), + .executing(_, _, .finished): + preconditionFailure("invalid state \(self.state)") case .executing(let context, let requestState, .waitingForResponseHead): // The response head was received. Next we will wait for the consumer to create a // response body stream. - self.state = .executing(context, requestState, .waitingForResponseIterator(.init(), next: .askExecutorForMore)) - return .succeedResponseHead(head, context.continuation) + let body = TransactionBody.makeSequence( + backPressureStrategy: .init(lowWatermark: 1, highWatermark: 1), + finishOnDeinit: true, + delegate: AnyAsyncSequenceProducerDelegate(delegate) + ) - case .finished(error: .some, _): + self.state = .executing(context, requestState, .streamingBody(body.source)) + return .succeedResponseHead(body.sequence, context.continuation) + + case .finished(error: .some): // If the request failed before, we don't need to do anything in response to // receiving the response head. return .none - case .executing(_, _, .finished), - .finished(error: .none, _): + case .finished(error: .none): preconditionFailure("How can the request be finished without error, before receiving response head?") } } - enum ReceiveResponsePartAction { + enum ProduceMoreAction { case none - case succeedContinuation(CheckedContinuation, ByteBuffer) + case requestMoreResponseBodyParts(HTTPRequestExecutor) } - mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ReceiveResponsePartAction { + mutating func produceMore() -> ProduceMoreAction { switch self.state { - case .initialized, .queued: - preconditionFailure("Received a response body part, but request hasn't started yet. Invalid state: \(self.state)") - - case .executing(_, _, .waitingForResponseHead): - preconditionFailure("If we receive a response body, we must have received a head before") - - case .executing(let context, let requestState, .buffering(let streamID, var currentBuffer, next: let next)): - guard case .askExecutorForMore = next else { - preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") - } - - if currentBuffer.isEmpty { - currentBuffer = buffer - } else { - currentBuffer.append(contentsOf: buffer) - } - self.state = .executing(context, requestState, .buffering(streamID, currentBuffer, next: next)) - return .none - - case .executing(let executor, let requestState, .waitingForResponseIterator(var currentBuffer, next: let next)): - guard case .askExecutorForMore = next else { - preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") - } - - if currentBuffer.isEmpty { - currentBuffer = buffer - } else { - currentBuffer.append(contentsOf: buffer) - } - self.state = .executing(executor, requestState, .waitingForResponseIterator(currentBuffer, next: next)) - return .none - - case .executing(let executor, let requestState, .waitingForRemote(let streamID, let continuation)): - var buffer = buffer - let first = buffer.removeFirst() - self.state = .executing(executor, requestState, .buffering(streamID, buffer, next: .askExecutorForMore)) - return .succeedContinuation(continuation, first) - - case .finished: - // the request failed or was cancelled before, we can ignore further data + case .initialized, + .queued, + .deadlineExceededWhileQueued, + .executing(_, _, .waitingForResponseHead): + preconditionFailure("invalid state \(self.state)") + + case .executing(let context, _, .streamingBody): + return .requestMoreResponseBodyParts(context.executor) + case .finished, + .executing(_, _, .finished): return .none - - case .executing(_, _, .finished): - preconditionFailure("Received response end. Must not receive further body parts after that. Invalid state: \(self.state)") } } - enum ResponseBodyDeinitedAction { - case cancel(HTTPRequestExecutor) + enum ReceiveResponsePartAction { case none + case yieldResponseBodyParts(TransactionBody.Source, CircularBuffer, HTTPRequestExecutor) } - mutating func responseBodyDeinited() -> ResponseBodyDeinitedAction { + mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ReceiveResponsePartAction { switch self.state { - case .initialized, - .queued, - .executing(_, _, .waitingForResponseHead): - preconditionFailure("Got notice about a deinited response, before we even received a response. Invalid state: \(self.state)") + case .initialized, .queued, .deadlineExceededWhileQueued: + preconditionFailure( + "Received a response body part, but request hasn't started yet. Invalid state: \(self.state)" + ) - case .executing(_, _, .waitingForResponseIterator(_, next: .endOfFile)): - self.state = .finished(error: nil, nil) - return .none + case .executing(_, _, .waitingForResponseHead): + preconditionFailure("If we receive a response body, we must have received a head before") - case .executing(let context, _, .waitingForResponseIterator(_, next: .askExecutorForMore)): - self.state = .finished(error: nil, nil) - return .cancel(context.executor) - - case .executing(_, _, .waitingForResponseIterator(_, next: .error(let error))): - self.state = .finished(error: error, nil) - return .none + case .executing(let context, _, .streamingBody(let source)): + return .yieldResponseBodyParts(source, buffer, context.executor) case .finished: - // body was released after the response was consumed - return .none - - case .executing(_, _, .buffering), - .executing(_, _, .waitingForRemote), - .executing(_, _, .finished): - // user is consuming the stream with an iterator - return .none - } - } - - mutating func responseBodyIteratorDeinited(streamID: HTTPClientResponse.Body.IteratorStream.ID) -> FailAction { - switch self.state { - case .initialized, .queued, .executing(_, _, .waitingForResponseHead): - preconditionFailure("Got notice about a deinited response body iterator, before we even received a response. Invalid state: \(self.state)") - - case .executing(_, _, .buffering(let registeredStreamID, _, next: _)), - .executing(_, _, .waitingForRemote(let registeredStreamID, _)): - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - return self.fail(HTTPClientError.cancelled) - - case .executing(_, _, .waitingForResponseIterator), - .executing(_, _, .finished), - .finished: - // the iterator went out of memory after the request was done. nothing to do. + // the request failed or was cancelled before, we can ignore further data return .none - } - } - - enum ConsumeAction { - case succeedContinuation(CheckedContinuation, ByteBuffer?) - case failContinuation(CheckedContinuation, Error) - case askExecutorForMore(HTTPRequestExecutor) - case none - } - - mutating func consumeNextResponsePart( - streamID: HTTPClientResponse.Body.IteratorStream.ID, - continuation: CheckedContinuation - ) -> ConsumeAction { - switch self.state { - case .initialized, - .queued, - .executing(_, _, .waitingForResponseHead): - preconditionFailure("If we receive a response body, we must have received a head before") case .executing(_, _, .finished): - preconditionFailure("This is an invalid state at this point. We are waiting for the request stream to finish to succeed the response stream. By sending a fi") - - case .executing(let context, let requestState, .waitingForResponseIterator(var buffer, next: .askExecutorForMore)): - if buffer.isEmpty { - self.state = .executing(context, requestState, .waitingForRemote(streamID, continuation)) - return .askExecutorForMore(context.executor) - } else { - let toReturn = buffer.removeFirst() - self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .askExecutorForMore)) - return .succeedContinuation(continuation, toReturn) - } - - case .executing(_, _, .waitingForResponseIterator(_, next: .error(let error))): - self.state = .finished(error: error, streamID) - return .failContinuation(continuation, error) - - case .executing(_, _, .waitingForResponseIterator(let buffer, next: .endOfFile)) where buffer.isEmpty: - self.state = .finished(error: nil, streamID) - return .succeedContinuation(continuation, nil) - - case .executing(let context, let requestState, .waitingForResponseIterator(var buffer, next: .endOfFile)): - assert(!buffer.isEmpty) - let toReturn = buffer.removeFirst() - self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .endOfFile)) - return .succeedContinuation(continuation, toReturn) - - case .executing(let context, let requestState, .buffering(let registeredStreamID, var buffer, next: .askExecutorForMore)): - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - - if buffer.isEmpty { - self.state = .executing(context, requestState, .waitingForRemote(streamID, continuation)) - return .askExecutorForMore(context.executor) - } else { - let toReturn = buffer.removeFirst() - self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .askExecutorForMore)) - return .succeedContinuation(continuation, toReturn) - } - - case .executing(_, _, .buffering(let registeredStreamID, _, next: .error(let error))): - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - self.state = .finished(error: error, registeredStreamID) - return .failContinuation(continuation, error) - - case .executing(_, _, .buffering(let registeredStreamID, let buffer, next: .endOfFile)) where buffer.isEmpty: - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - self.state = .finished(error: nil, registeredStreamID) - return .succeedContinuation(continuation, nil) - - case .executing(let context, let requestState, .buffering(let registeredStreamID, var buffer, next: .endOfFile)): - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - if let toReturn = buffer.popFirst() { - // As long as we have bytes in the local store, we can hand them to the user. - self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .endOfFile)) - return .succeedContinuation(continuation, toReturn) - } - - switch requestState { - case .requestHeadSent, .paused, .producing: - // if the request isn't finished yet, we don't succeed the final response stream - // continuation. We will succeed it once the request has been fully send. - self.state = .executing(context, requestState, .finished(streamID, continuation)) - return .none - case .finished: - // if the request is finished, we can succeed the final continuation. - self.state = .finished(error: nil, streamID) - return .succeedContinuation(continuation, nil) - } - - case .executing(_, _, .waitingForRemote(let registeredStreamID, _)): - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - preconditionFailure("A body response continuation from this iterator already exists! Queuing calls to `next()` is not supported.") - - case .finished(error: .some(let error), let registeredStreamID): - if let registeredStreamID = registeredStreamID { - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - } else { - self.state = .finished(error: error, streamID) - } - return .failContinuation(continuation, error) - - case .finished(error: .none, let registeredStreamID): - if let registeredStreamID = registeredStreamID { - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - } else { - self.state = .finished(error: .none, streamID) - } - - return .succeedContinuation(continuation, nil) - } - } - - private func verifyStreamIDIsEqual( - registered: HTTPClientResponse.Body.IteratorStream.ID, - this: HTTPClientResponse.Body.IteratorStream.ID, - file: StaticString = #file, - line: UInt = #line - ) { - if registered != this { preconditionFailure( - "Tried to use a second iterator on response body stream. Multiple iterators are not supported.", - file: file, line: line + "Received response end. Must not receive further body parts after that. Invalid state: \(self.state)" ) } } enum ReceiveResponseEndAction { - case succeedContinuation(CheckedContinuation, ByteBuffer) - case finishResponseStream(CheckedContinuation) + case finishResponseStream(TransactionBody.Source, finalBody: CircularBuffer?) case none } mutating func succeedRequest(_ newChunks: CircularBuffer?) -> ReceiveResponseEndAction { switch self.state { case .initialized, - .queued, - .executing(_, _, .waitingForResponseHead): - preconditionFailure("Received no response head, but received a response end. Invalid state: \(self.state)") - - case .executing(let context, let requestState, .waitingForResponseIterator(var buffer, next: .askExecutorForMore)): - if let newChunks = newChunks, !newChunks.isEmpty { - buffer.append(contentsOf: newChunks) - } - self.state = .executing(context, requestState, .waitingForResponseIterator(buffer, next: .endOfFile)) - return .none - - case .executing(let context, let requestState, .waitingForRemote(let streamID, let continuation)): - if var newChunks = newChunks, !newChunks.isEmpty { - let first = newChunks.removeFirst() - self.state = .executing(context, requestState, .buffering(streamID, newChunks, next: .endOfFile)) - return .succeedContinuation(continuation, first) - } - - self.state = .finished(error: nil, streamID) - return .finishResponseStream(continuation) - - case .executing(let context, let requestState, .buffering(let streamID, var buffer, next: .askExecutorForMore)): - if let newChunks = newChunks, !newChunks.isEmpty { - buffer.append(contentsOf: newChunks) - } - self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .endOfFile)) - return .none + .queued, + .deadlineExceededWhileQueued, + .executing(_, _, .waitingForResponseHead): + preconditionFailure( + "Received no response head, but received a response end. Invalid state: \(self.state)" + ) + case .executing(let context, let requestState, .streamingBody(let source)): + self.state = .executing(context, requestState, .finished) + return .finishResponseStream(source, finalBody: newChunks) case .finished: // the request failed or was cancelled before, we can ignore all events return .none - - case .executing(_, _, .waitingForResponseIterator(_, next: .error)), - .executing(_, _, .waitingForResponseIterator(_, next: .endOfFile)), - .executing(_, _, .buffering(_, _, next: .error)), - .executing(_, _, .buffering(_, _, next: .endOfFile)), - .executing(_, _, .finished(_, _)): - preconditionFailure("Already received an eof or error before. Must not receive further events. Invalid state: \(self.state)") + case .executing(_, _, .finished): + preconditionFailure( + "Already received an eof or error before. Must not receive further events. Invalid state: \(self.state)" + ) } } enum DeadlineExceededAction { case none + case cancelSchedulerOnly(scheduler: HTTPRequestScheduler) /// fail response before head received. scheduler and executor are exclusive here. case cancel( requestContinuation: CheckedContinuation, @@ -690,7 +507,7 @@ extension Transaction { let error = HTTPClientError.deadlineExceeded switch self.state { case .initialized(let continuation): - self.state = .finished(error: error, nil) + self.state = .finished(error: error) return .cancel( requestContinuation: continuation, scheduler: nil, @@ -699,18 +516,16 @@ extension Transaction { ) case .queued(let continuation, let scheduler): - self.state = .finished(error: error, nil) - return .cancel( - requestContinuation: continuation, - scheduler: scheduler, - executor: nil, - bodyStreamContinuation: nil + self.state = .deadlineExceededWhileQueued(continuation) + return .cancelSchedulerOnly( + scheduler: scheduler ) - + case .deadlineExceededWhileQueued: + return .none case .executing(let context, let requestStreamState, .waitingForResponseHead): switch requestStreamState { case .paused(continuation: .some(let continuation)): - self.state = .finished(error: error, nil) + self.state = .finished(error: error) return .cancel( requestContinuation: context.continuation, scheduler: nil, @@ -718,7 +533,7 @@ extension Transaction { bodyStreamContinuation: continuation ) case .requestHeadSent, .finished, .producing, .paused(continuation: .none): - self.state = .finished(error: error, nil) + self.state = .finished(error: error) return .cancel( requestContinuation: context.continuation, scheduler: nil, @@ -737,5 +552,3 @@ extension Transaction { } } } - -#endif diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift index c2ce52eeb..408ebeeb6 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) import Logging import NIOConcurrencyHelpers import NIOCore @@ -20,7 +19,11 @@ import NIOHTTP1 import NIOSSL @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -final class Transaction: @unchecked Sendable { +@usableFromInline +final class Transaction: + // until NIOLockedValueBox learns `sending` because StateMachine cannot be Sendable + @unchecked Sendable +{ let logger: Logger let request: HTTPClientRequest.Prepared @@ -29,8 +32,7 @@ final class Transaction: @unchecked Sendable { let preferredEventLoop: EventLoop let requestOptions: RequestOptions - private let stateLock = Lock() - private var state: StateMachine + private let state: NIOLockedValueBox init( request: HTTPClientRequest.Prepared, @@ -45,11 +47,11 @@ final class Transaction: @unchecked Sendable { self.logger = logger self.connectionDeadline = connectionDeadline self.preferredEventLoop = preferredEventLoop - self.state = StateMachine(responseContinuation) + self.state = NIOLockedValueBox(StateMachine(responseContinuation)) } func cancel() { - self.fail(HTTPClientError.cancelled) + self.fail(CancellationError()) } // MARK: Request body helpers @@ -57,13 +59,13 @@ final class Transaction: @unchecked Sendable { private func writeOnceAndOneTimeOnly(byteBuffer: ByteBuffer) { // This method is synchronously invoked after sending the request head. For this reason we // can make a number of assumptions, how the state machine will react. - let writeAction = self.stateLock.withLock { - self.state.writeNextRequestPart() + let writeAction = self.state.withLockedValue { state in + state.writeNextRequestPart() } switch writeAction { case .writeAndWait(let executor), .writeAndContinue(let executor): - executor.writeRequestBodyPart(.byteBuffer(byteBuffer), request: self) + executor.writeRequestBodyPart(.byteBuffer(byteBuffer), request: self, promise: nil) case .fail: // an error/cancellation has happened. we don't need to continue here @@ -100,30 +102,33 @@ final class Transaction: @unchecked Sendable { struct BreakTheWriteLoopError: Swift.Error {} + // FIXME: Refactor this to not use `self.state.unsafe`. private func writeRequestBodyPart(_ part: ByteBuffer) async throws { - self.stateLock.lock() - switch self.state.writeNextRequestPart() { + self.state.unsafe.lock() + switch self.state.unsafe.withValueAssumingLockIsAcquired({ state in state.writeNextRequestPart() }) { case .writeAndContinue(let executor): - self.stateLock.unlock() - executor.writeRequestBodyPart(.byteBuffer(part), request: self) + self.state.unsafe.unlock() + executor.writeRequestBodyPart(.byteBuffer(part), request: self, promise: nil) case .writeAndWait(let executor): try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - self.state.waitForRequestBodyDemand(continuation: continuation) - self.stateLock.unlock() + self.state.unsafe.withValueAssumingLockIsAcquired({ state in + state.waitForRequestBodyDemand(continuation: continuation) + }) + self.state.unsafe.unlock() - executor.writeRequestBodyPart(.byteBuffer(part), request: self) + executor.writeRequestBodyPart(.byteBuffer(part), request: self, promise: nil) } case .fail: - self.stateLock.unlock() + self.state.unsafe.unlock() throw BreakTheWriteLoopError() } } private func requestBodyStreamFinished() { - let finishAction = self.stateLock.withLock { - self.state.finishRequestBodyStream() + let finishAction = self.state.withLockedValue { state in + state.finishRequestBodyStream() } switch finishAction { @@ -131,9 +136,8 @@ final class Transaction: @unchecked Sendable { // an error/cancellation has happened. nothing to do. break - case .forwardStreamFinished(let executor, let succeedContinuation): - executor.finishRequestBodyStream(self) - succeedContinuation?.resume(returning: nil) + case .forwardStreamFinished(let executor): + executor.finishRequestBodyStream(self, promise: nil) } return } @@ -148,12 +152,12 @@ final class Transaction: @unchecked Sendable { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension Transaction: HTTPSchedulableRequest { var poolKey: ConnectionPool.Key { self.request.poolKey } - var tlsConfiguration: TLSConfiguration? { return nil } - var requiredEventLoop: EventLoop? { return nil } + var tlsConfiguration: TLSConfiguration? { self.request.tlsConfiguration } + var requiredEventLoop: EventLoop? { nil } func requestWasQueued(_ scheduler: HTTPRequestScheduler) { - self.stateLock.withLock { - self.state.requestWasQueued(scheduler) + self.state.withLockedValue { state in + state.requestWasQueued(scheduler) } } } @@ -167,14 +171,16 @@ extension Transaction: HTTPExecutableRequest { // MARK: Request func willExecuteRequest(_ executor: HTTPRequestExecutor) { - let action = self.stateLock.withLock { - self.state.willExecuteRequest(executor) + let action = self.state.withLockedValue { state in + state.willExecuteRequest(executor) } switch action { case .cancel(let executor): executor.cancelRequest(self) - + case .cancelAndFail(let executor, let continuation, with: let error): + executor.cancelRequest(self) + continuation.resume(throwing: error) case .none: break } @@ -183,8 +189,8 @@ extension Transaction: HTTPExecutableRequest { func requestHeadSent() {} func resumeRequestBodyStream() { - let action = self.stateLock.withLock { - self.state.resumeRequestBodyStream() + let action = self.state.withLockedValue { state in + state.resumeRequestBodyStream() } switch action { @@ -192,7 +198,7 @@ extension Transaction: HTTPExecutableRequest { break case .startStream(let allocator): - switch self.request.body?.mode { + switch self.request.body { case .asyncSequence(_, let next): // it is safe to call this async here. it dispatches... self.continueRequestBodyStream(allocator, next: next) @@ -214,62 +220,70 @@ extension Transaction: HTTPExecutableRequest { } func pauseRequestBodyStream() { - self.stateLock.withLock { - self.state.pauseRequestBodyStream() + self.state.withLockedValue { state in + state.pauseRequestBodyStream() } } // MARK: Response func receiveResponseHead(_ head: HTTPResponseHead) { - let action = self.stateLock.withLock { - self.state.receiveResponseHead(head) + let action = self.state.withLockedValue { state in + state.receiveResponseHead(head, delegate: self) } switch action { case .none: break - case .succeedResponseHead(let head, let continuation): - let asyncResponse = HTTPClientResponse( - bag: self, + case .succeedResponseHead(let body, let continuation): + let response = HTTPClientResponse( + requestMethod: self.requestHead.method, version: head.version, status: head.status, - headers: head.headers + headers: head.headers, + body: body ) - continuation.resume(returning: asyncResponse) + continuation.resume(returning: response) } } func receiveResponseBodyParts(_ buffer: CircularBuffer) { - let action = self.stateLock.withLock { - self.state.receiveResponseBodyParts(buffer) + let action = self.state.withLockedValue { state in + state.receiveResponseBodyParts(buffer) } switch action { case .none: break - case .succeedContinuation(let continuation, let bytes): - continuation.resume(returning: bytes) + case .yieldResponseBodyParts(let source, let responseBodyParts, let executer): + switch source.yield(contentsOf: responseBodyParts) { + case .dropped, .stopProducing: + break + case .produceMore: + executer.demandResponseBodyStream(self) + } } } func succeedRequest(_ buffer: CircularBuffer?) { - let succeedAction = self.stateLock.withLock { - self.state.succeedRequest(buffer) + let succeedAction = self.state.withLockedValue { state in + state.succeedRequest(buffer) } switch succeedAction { - case .finishResponseStream(let continuation): - continuation.resume(returning: nil) - case .succeedContinuation(let continuation, let byteBuffer): - continuation.resume(returning: byteBuffer) + case .finishResponseStream(let source, let finalResponse): + if let finalResponse = finalResponse { + _ = source.yield(contentsOf: finalResponse) + } + source.finish() + case .none: break } } func fail(_ error: Error) { - let action = self.stateLock.withLock { - self.state.fail(error) + let action = self.state.withLockedValue { state in + state.fail(error) } self.performFailAction(action) } @@ -282,12 +296,12 @@ extension Transaction: HTTPExecutableRequest { case .failResponseHead(let continuation, let error, let scheduler, let executor, let bodyStreamContinuation): continuation.resume(throwing: error) bodyStreamContinuation?.resume(throwing: error) - scheduler?.cancelRequest(self) // NOTE: scheduler and executor are exclusive here + scheduler?.cancelRequest(self) // NOTE: scheduler and executor are exclusive here executor?.cancelRequest(self) - case .failResponseStream(let continuation, let error, let executor, let bodyStreamContinuation): - continuation.resume(throwing: error) - bodyStreamContinuation?.resume(throwing: error) + case .failResponseStream(let source, let error, let executor, let requestBodyStreamContinuation): + source.finish(error) + requestBodyStreamContinuation?.resume(throwing: error) executor.cancelRequest(self) case .failRequestStreamContinuation(let bodyStreamContinuation, let error): @@ -296,8 +310,8 @@ extension Transaction: HTTPExecutableRequest { } func deadlineExceeded() { - let action = self.stateLock.withLock { - self.state.deadlineExceeded() + let action = self.state.withLockedValue { state in + state.deadlineExceeded() } self.performDeadlineExceededAction(action) } @@ -309,7 +323,8 @@ extension Transaction: HTTPExecutableRequest { scheduler?.cancelRequest(self) executor?.cancelRequest(self) bodyStreamContinuation?.resume(throwing: HTTPClientError.deadlineExceeded) - + case .cancelSchedulerOnly(let scheduler): + scheduler.cancelRequest(self) case .none: break } @@ -317,46 +332,22 @@ extension Transaction: HTTPExecutableRequest { } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension Transaction { - func responseBodyDeinited() { - let deinitedAction = self.stateLock.withLock { - self.state.responseBodyDeinited() +extension Transaction: NIOAsyncSequenceProducerDelegate { + @usableFromInline + func produceMore() { + let action = self.state.withLockedValue { state in + state.produceMore() } - - switch deinitedAction { - case .cancel(let executor): - executor.cancelRequest(self) + switch action { case .none: break + case .requestMoreResponseBodyParts(let executer): + executer.demandResponseBodyStream(self) } } - func nextResponsePart(streamID: HTTPClientResponse.Body.IteratorStream.ID) async throws -> ByteBuffer? { - try await withCheckedThrowingContinuation { continuation in - let action = self.stateLock.withLock { - self.state.consumeNextResponsePart(streamID: streamID, continuation: continuation) - } - switch action { - case .succeedContinuation(let continuation, let result): - continuation.resume(returning: result) - - case .failContinuation(let continuation, let error): - continuation.resume(throwing: error) - - case .askExecutorForMore(let executor): - executor.demandResponseBodyStream(self) - - case .none: - return - } - } - } - - func responseBodyIteratorDeinited(streamID: HTTPClientResponse.Body.IteratorStream.ID) { - let action = self.stateLock.withLock { - self.state.responseBodyIteratorDeinited(streamID: streamID) - } - self.performFailAction(action) + @usableFromInline + func didTerminate() { + self.fail(HTTPClientError.cancelled) } } -#endif diff --git a/Sources/AsyncHTTPClient/Base64.swift b/Sources/AsyncHTTPClient/Base64.swift index dbbf742ab..3162e7251 100644 --- a/Sources/AsyncHTTPClient/Base64.swift +++ b/Sources/AsyncHTTPClient/Base64.swift @@ -19,156 +19,156 @@ extension String { - /// Base64 encode a collection of UInt8 to a string, without the use of Foundation. - @inlinable - init(base64Encoding bytes: Buffer) - where Buffer.Element == UInt8 - { - self = Base64.encode(bytes: bytes) - } + /// Base64 encode a collection of UInt8 to a string, without the use of Foundation. + @inlinable + init(base64Encoding bytes: Buffer) + where Buffer.Element == UInt8 { + self = Base64.encode(bytes: bytes) + } } +// swift-format-ignore: DontRepeatTypeInStaticProperties @usableFromInline internal struct Base64 { - @inlinable - static func encode(bytes: Buffer) - -> String where Buffer.Element == UInt8 - { - guard !bytes.isEmpty else { - return "" - } - // In Base64, 3 bytes become 4 output characters, and we pad to the - // nearest multiple of four. - let base64StringLength = ((bytes.count + 2) / 3) * 4 - let alphabet = Base64.encodeBase64 - - return String(customUnsafeUninitializedCapacity: base64StringLength) { backingStorage in - var input = bytes.makeIterator() - var offset = 0 - while let firstByte = input.next() { - let secondByte = input.next() - let thirdByte = input.next() - - backingStorage[offset] = Base64.encode(alphabet: alphabet, firstByte: firstByte) - backingStorage[offset + 1] = Base64.encode(alphabet: alphabet, firstByte: firstByte, secondByte: secondByte) - backingStorage[offset + 2] = Base64.encode(alphabet: alphabet, secondByte: secondByte, thirdByte: thirdByte) - backingStorage[offset + 3] = Base64.encode(alphabet: alphabet, thirdByte: thirdByte) - offset += 4 - } - return offset + @inlinable + static func encode( + bytes: Buffer + ) + -> String where Buffer.Element == UInt8 + { + guard !bytes.isEmpty else { + return "" + } + // In Base64, 3 bytes become 4 output characters, and we pad to the + // nearest multiple of four. + let base64StringLength = ((bytes.count + 2) / 3) * 4 + let alphabet = Base64.encodeBase64 + + return String(customUnsafeUninitializedCapacity: base64StringLength) { backingStorage in + var input = bytes.makeIterator() + var offset = 0 + while let firstByte = input.next() { + let secondByte = input.next() + let thirdByte = input.next() + + backingStorage[offset] = Base64.encode(alphabet: alphabet, firstByte: firstByte) + backingStorage[offset + 1] = Base64.encode( + alphabet: alphabet, + firstByte: firstByte, + secondByte: secondByte + ) + backingStorage[offset + 2] = Base64.encode( + alphabet: alphabet, + secondByte: secondByte, + thirdByte: thirdByte + ) + backingStorage[offset + 3] = Base64.encode(alphabet: alphabet, thirdByte: thirdByte) + offset += 4 + } + return offset + } } - } - - // MARK: Internal - - // The base64 unicode table. - @usableFromInline - static let encodeBase64: [UInt8] = [ - UInt8(ascii: "A"), UInt8(ascii: "B"), UInt8(ascii: "C"), UInt8(ascii: "D"), - UInt8(ascii: "E"), UInt8(ascii: "F"), UInt8(ascii: "G"), UInt8(ascii: "H"), - UInt8(ascii: "I"), UInt8(ascii: "J"), UInt8(ascii: "K"), UInt8(ascii: "L"), - UInt8(ascii: "M"), UInt8(ascii: "N"), UInt8(ascii: "O"), UInt8(ascii: "P"), - UInt8(ascii: "Q"), UInt8(ascii: "R"), UInt8(ascii: "S"), UInt8(ascii: "T"), - UInt8(ascii: "U"), UInt8(ascii: "V"), UInt8(ascii: "W"), UInt8(ascii: "X"), - UInt8(ascii: "Y"), UInt8(ascii: "Z"), UInt8(ascii: "a"), UInt8(ascii: "b"), - UInt8(ascii: "c"), UInt8(ascii: "d"), UInt8(ascii: "e"), UInt8(ascii: "f"), - UInt8(ascii: "g"), UInt8(ascii: "h"), UInt8(ascii: "i"), UInt8(ascii: "j"), - UInt8(ascii: "k"), UInt8(ascii: "l"), UInt8(ascii: "m"), UInt8(ascii: "n"), - UInt8(ascii: "o"), UInt8(ascii: "p"), UInt8(ascii: "q"), UInt8(ascii: "r"), - UInt8(ascii: "s"), UInt8(ascii: "t"), UInt8(ascii: "u"), UInt8(ascii: "v"), - UInt8(ascii: "w"), UInt8(ascii: "x"), UInt8(ascii: "y"), UInt8(ascii: "z"), - UInt8(ascii: "0"), UInt8(ascii: "1"), UInt8(ascii: "2"), UInt8(ascii: "3"), - UInt8(ascii: "4"), UInt8(ascii: "5"), UInt8(ascii: "6"), UInt8(ascii: "7"), - UInt8(ascii: "8"), UInt8(ascii: "9"), UInt8(ascii: "+"), UInt8(ascii: "/"), - ] - - static let encodePaddingCharacter: UInt8 = UInt8(ascii: "=") - - @usableFromInline - static func encode(alphabet: [UInt8], firstByte: UInt8) -> UInt8 { - let index = firstByte >> 2 - return alphabet[Int(index)] - } - - @usableFromInline - static func encode(alphabet: [UInt8], firstByte: UInt8, secondByte: UInt8?) -> UInt8 { - var index = (firstByte & 0b00000011) << 4 - if let secondByte = secondByte { - index += (secondByte & 0b11110000) >> 4 + + // MARK: Internal + + // The base64 unicode table. + @usableFromInline + static let encodeBase64: [UInt8] = [ + UInt8(ascii: "A"), UInt8(ascii: "B"), UInt8(ascii: "C"), UInt8(ascii: "D"), + UInt8(ascii: "E"), UInt8(ascii: "F"), UInt8(ascii: "G"), UInt8(ascii: "H"), + UInt8(ascii: "I"), UInt8(ascii: "J"), UInt8(ascii: "K"), UInt8(ascii: "L"), + UInt8(ascii: "M"), UInt8(ascii: "N"), UInt8(ascii: "O"), UInt8(ascii: "P"), + UInt8(ascii: "Q"), UInt8(ascii: "R"), UInt8(ascii: "S"), UInt8(ascii: "T"), + UInt8(ascii: "U"), UInt8(ascii: "V"), UInt8(ascii: "W"), UInt8(ascii: "X"), + UInt8(ascii: "Y"), UInt8(ascii: "Z"), UInt8(ascii: "a"), UInt8(ascii: "b"), + UInt8(ascii: "c"), UInt8(ascii: "d"), UInt8(ascii: "e"), UInt8(ascii: "f"), + UInt8(ascii: "g"), UInt8(ascii: "h"), UInt8(ascii: "i"), UInt8(ascii: "j"), + UInt8(ascii: "k"), UInt8(ascii: "l"), UInt8(ascii: "m"), UInt8(ascii: "n"), + UInt8(ascii: "o"), UInt8(ascii: "p"), UInt8(ascii: "q"), UInt8(ascii: "r"), + UInt8(ascii: "s"), UInt8(ascii: "t"), UInt8(ascii: "u"), UInt8(ascii: "v"), + UInt8(ascii: "w"), UInt8(ascii: "x"), UInt8(ascii: "y"), UInt8(ascii: "z"), + UInt8(ascii: "0"), UInt8(ascii: "1"), UInt8(ascii: "2"), UInt8(ascii: "3"), + UInt8(ascii: "4"), UInt8(ascii: "5"), UInt8(ascii: "6"), UInt8(ascii: "7"), + UInt8(ascii: "8"), UInt8(ascii: "9"), UInt8(ascii: "+"), UInt8(ascii: "/"), + ] + + static let encodePaddingCharacter: UInt8 = UInt8(ascii: "=") + + @usableFromInline + static func encode(alphabet: [UInt8], firstByte: UInt8) -> UInt8 { + let index = firstByte >> 2 + return alphabet[Int(index)] } - return alphabet[Int(index)] - } - - @usableFromInline - static func encode(alphabet: [UInt8], secondByte: UInt8?, thirdByte: UInt8?) -> UInt8 { - guard let secondByte = secondByte else { - // No second byte means we are just emitting padding. - return Base64.encodePaddingCharacter + + @usableFromInline + static func encode(alphabet: [UInt8], firstByte: UInt8, secondByte: UInt8?) -> UInt8 { + var index = (firstByte & 0b00000011) << 4 + if let secondByte = secondByte { + index += (secondByte & 0b11110000) >> 4 + } + return alphabet[Int(index)] } - var index = (secondByte & 0b00001111) << 2 - if let thirdByte = thirdByte { - index += (thirdByte & 0b11000000) >> 6 + + @usableFromInline + static func encode(alphabet: [UInt8], secondByte: UInt8?, thirdByte: UInt8?) -> UInt8 { + guard let secondByte = secondByte else { + // No second byte means we are just emitting padding. + return Base64.encodePaddingCharacter + } + var index = (secondByte & 0b00001111) << 2 + if let thirdByte = thirdByte { + index += (thirdByte & 0b11000000) >> 6 + } + return alphabet[Int(index)] } - return alphabet[Int(index)] - } - - @usableFromInline - static func encode(alphabet: [UInt8], thirdByte: UInt8?) -> UInt8 { - guard let thirdByte = thirdByte else { - // No third byte means just padding. - return Base64.encodePaddingCharacter + + @usableFromInline + static func encode(alphabet: [UInt8], thirdByte: UInt8?) -> UInt8 { + guard let thirdByte = thirdByte else { + // No third byte means just padding. + return Base64.encodePaddingCharacter + } + let index = thirdByte & 0b00111111 + return alphabet[Int(index)] } - let index = thirdByte & 0b00111111 - return alphabet[Int(index)] - } } extension String { - /// This is a backport of a proposed String initializer that will allow writing directly into an uninitialized String's backing memory. - /// - /// As this API does not exist prior to 5.3 on Linux, or on older Apple platforms, we fake it out with a pointer and accept the extra copy. - @inlinable - init(backportUnsafeUninitializedCapacity capacity: Int, - initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int) rethrows { - // The buffer will store zero terminated C string - let buffer = UnsafeMutableBufferPointer.allocate(capacity: capacity + 1) - defer { - buffer.deallocate() + /// This is a backport of a proposed String initializer that will allow writing directly into an uninitialized String's backing memory. + /// + /// As this API does not exist prior to 5.3 on Linux, or on older Apple platforms, we fake it out with a pointer and accept the extra copy. + @inlinable + init( + backportUnsafeUninitializedCapacity capacity: Int, + initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int + ) rethrows { + // The buffer will store zero terminated C string + let buffer = UnsafeMutableBufferPointer.allocate(capacity: capacity + 1) + defer { + buffer.deallocate() + } + + let initializedCount = try initializer(buffer) + precondition(initializedCount <= capacity, "Overran buffer in initializer!") + // add zero termination + buffer[initializedCount] = 0 + + self = String(cString: buffer.baseAddress!) } - - let initializedCount = try initializer(buffer) - precondition(initializedCount <= capacity, "Overran buffer in initializer!") - // add zero termination - buffer[initializedCount] = 0 - - self = String(cString: buffer.baseAddress!) - } } -// Frustratingly, Swift 5.3 shipped before the macOS 11 SDK did, so we cannot gate the availability of -// this declaration on having the 5.3 compiler. This has caused a number of build issues. While updating -// to newer Xcodes does work, we can save ourselves some hassle and just wait until 5.4 to get this -// enhancement on Apple platforms. -#if (compiler(>=5.3) && !(os(macOS) || os(iOS) || os(tvOS) || os(watchOS))) || compiler(>=5.4) extension String { - @inlinable - init(customUnsafeUninitializedCapacity capacity: Int, - initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int) rethrows { - if #available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *) { - try self.init(unsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) - } else { - try self.init(backportUnsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) + @inlinable + init( + customUnsafeUninitializedCapacity capacity: Int, + initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int + ) rethrows { + if #available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *) { + try self.init(unsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) + } else { + try self.init(backportUnsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) + } } - } -} -#else -extension String { - @inlinable - init(customUnsafeUninitializedCapacity capacity: Int, - initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int) rethrows { - try self.init(backportUnsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) - } } -#endif diff --git a/Sources/AsyncHTTPClient/BasicAuth.swift b/Sources/AsyncHTTPClient/BasicAuth.swift new file mode 100644 index 000000000..3e69f8277 --- /dev/null +++ b/Sources/AsyncHTTPClient/BasicAuth.swift @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2024 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Foundation +import NIOHTTP1 + +/// Generates base64 encoded username + password for http basic auth. +/// +/// - Parameters: +/// - username: the username to authenticate with +/// - password: authentication password associated with the username +/// - Returns: encoded credentials to use the Authorization: Basic http header. +func encodeBasicAuthCredentials(username: String, password: String) -> String { + var value = Data() + value.reserveCapacity(username.utf8.count + password.utf8.count + 1) + value.append(contentsOf: username.utf8) + value.append(UInt8(ascii: ":")) + value.append(contentsOf: password.utf8) + return value.base64EncodedString() +} + +extension HTTPHeaders { + /// Sets the basic auth header + mutating func setBasicAuth(username: String, password: String) { + let encoded = encodeBasicAuthCredentials(username: username, password: password) + self.replaceOrAdd(name: "Authorization", value: "Basic \(encoded)") + } +} diff --git a/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift b/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift index 58169f645..aca0ce235 100644 --- a/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift +++ b/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift @@ -27,6 +27,6 @@ struct BestEffortHashableTLSConfiguration: Hashable { } static func == (lhs: BestEffortHashableTLSConfiguration, rhs: BestEffortHashableTLSConfiguration) -> Bool { - return lhs.base.bestEffortEquals(rhs.base) + lhs.base.bestEffortEquals(rhs.base) } } diff --git a/Sources/AsyncHTTPClient/Configuration+BrowserLike.swift b/Sources/AsyncHTTPClient/Configuration+BrowserLike.swift new file mode 100644 index 000000000..5a0abdfad --- /dev/null +++ b/Sources/AsyncHTTPClient/Configuration+BrowserLike.swift @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import NIOCore +import NIOHTTPCompression +import NIOSSL + +// swift-format-ignore: DontRepeatTypeInStaticProperties +extension HTTPClient.Configuration { + /// The ``HTTPClient/Configuration`` for ``HTTPClient/shared`` which tries to mimic the platform's default or prevalent browser as closely as possible. + /// + /// Don't rely on specific values of this configuration as they're subject to change. You can rely on them being somewhat sensible though. + /// + /// - note: At present, this configuration is nowhere close to a real browser configuration but in case of disagreements we will choose values that match + /// the default browser as closely as possible. + /// + /// Platform's default/prevalent browsers that we're trying to match (these might change over time): + /// - macOS: Safari + /// - iOS: Safari + /// - Android: Google Chrome + /// - Linux (non-Android): Google Chrome + public static var singletonConfiguration: HTTPClient.Configuration { + // To start with, let's go with these values. Obtained from Firefox's config. + HTTPClient.Configuration( + certificateVerification: .fullVerification, + redirectConfiguration: .follow(max: 20, allowCycles: false), + timeout: Timeout(connect: .seconds(90), read: .seconds(90)), + connectionPool: .seconds(600), + proxy: nil, + ignoreUncleanSSLShutdown: false, + decompression: .enabled(limit: .ratio(25)), + backgroundActivityLogger: nil + ) + } +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool.swift index 0dac50e5f..b5b058c2e 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool.swift @@ -12,8 +12,32 @@ // //===----------------------------------------------------------------------===// +import CNIOLinux +import NIOCore import NIOSSL +#if canImport(Darwin) +import Darwin.C +#elseif canImport(Musl) +import Musl +#elseif canImport(Android) +import Android +#elseif os(Linux) || os(FreeBSD) +import Glibc +#else +#error("unsupported target operating system") +#endif + +extension String { + var isIPAddress: Bool { + var ipv4Address = in_addr() + var ipv6Address = in6_addr() + return self.withCString { host in + inet_pton(AF_INET, host, &ipv4Address) == 1 || inet_pton(AF_INET6, host, &ipv6Address) == 1 + } + } +} + enum ConnectionPool { /// Used by the `ConnectionPool` to index its `HTTP1ConnectionProvider`s /// @@ -24,15 +48,18 @@ enum ConnectionPool { var scheme: Scheme var connectionTarget: ConnectionTarget private var tlsConfiguration: BestEffortHashableTLSConfiguration? + var serverNameIndicatorOverride: String? init( scheme: Scheme, connectionTarget: ConnectionTarget, - tlsConfiguration: BestEffortHashableTLSConfiguration? = nil + tlsConfiguration: BestEffortHashableTLSConfiguration? = nil, + serverNameIndicatorOverride: String? ) { self.scheme = scheme self.connectionTarget = connectionTarget self.tlsConfiguration = tlsConfiguration + self.serverNameIndicatorOverride = serverNameIndicatorOverride } var description: String { @@ -43,31 +70,50 @@ enum ConnectionPool { switch self.connectionTarget { case .ipAddress(let serialization, let addr): hostDescription = "\(serialization):\(addr.port!)" - case .domain(let domain, port: let port): + case .domain(let domain, let port): hostDescription = "\(domain):\(port)" case .unixSocket(let socketPath): hostDescription = socketPath } - return "\(self.scheme)://\(hostDescription) TLS-hash: \(hash)" + return + "\(self.scheme)://\(hostDescription)\(self.serverNameIndicatorOverride.map { " SNI: \($0)" } ?? "") TLS-hash: \(hash) " } } } +extension DeconstructedURL { + func applyDNSOverride(_ dnsOverride: [String: String]) -> (ConnectionTarget, serverNameIndicatorOverride: String?) { + guard + let originalHost = self.connectionTarget.host, + let hostOverride = dnsOverride[originalHost] + else { + return (self.connectionTarget, nil) + } + return ( + .init(remoteHost: hostOverride, port: self.connectionTarget.port ?? self.scheme.defaultPort), + serverNameIndicatorOverride: originalHost.isIPAddress ? nil : originalHost + ) + } +} + extension ConnectionPool.Key { - init(url: DeconstructedURL, tlsConfiguration: TLSConfiguration?) { + init(url: DeconstructedURL, tlsConfiguration: TLSConfiguration?, dnsOverride: [String: String]) { + let (connectionTarget, serverNameIndicatorOverride) = url.applyDNSOverride(dnsOverride) self.init( scheme: url.scheme, - connectionTarget: url.connectionTarget, + connectionTarget: connectionTarget, tlsConfiguration: tlsConfiguration.map { BestEffortHashableTLSConfiguration(wrapping: $0) - } + }, + serverNameIndicatorOverride: serverNameIndicatorOverride ) } - init(_ request: HTTPClient.Request) { + init(_ request: HTTPClient.Request, dnsOverride: [String: String] = [:]) { self.init( url: request.deconstructedURL, - tlsConfiguration: request.tlsConfiguration + tlsConfiguration: request.tlsConfiguration, + dnsOverride: dnsOverride ) } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift index 7340a59ea..db7b7b7ef 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift @@ -42,7 +42,7 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand private var proxyEstablishedPromise: EventLoopPromise? var proxyEstablishedFuture: EventLoopFuture? { - return self.proxyEstablishedPromise?.futureResult + self.proxyEstablishedPromise?.futureResult } convenience init( @@ -53,10 +53,10 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand let targetHost: String let targetPort: Int switch target { - case .ipAddress(serialization: let serialization, address: let address): + case .ipAddress(let serialization, let address): targetHost = serialization targetPort = address.port! - case .domain(name: let domain, port: let port): + case .domain(name: let domain, let port): targetHost = domain targetPort = port case .unixSocket: @@ -70,10 +70,12 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand ) } - init(targetHost: String, - targetPort: Int, - proxyAuthorization: HTTPClient.Authorization?, - deadline: NIODeadline) { + init( + targetHost: String, + targetPort: Int, + proxyAuthorization: HTTPClient.Authorization?, + deadline: NIODeadline + ) { self.targetHost = targetHost self.targetPort = targetPort self.proxyAuthorization = proxyAuthorization @@ -155,6 +157,7 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand method: .CONNECT, uri: "\(self.targetHost):\(self.targetPort)" ) + head.headers.replaceOrAdd(name: "host", value: "\(self.targetHost)") if let authorization = self.proxyAuthorization { head.headers.replaceOrAdd(name: "proxy-authorization", value: authorization.headerValue) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift index 5a46f44a7..a98f97d4d 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift @@ -31,7 +31,7 @@ final class SOCKSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { private var socksEstablishedPromise: EventLoopPromise? var socksEstablishedFuture: EventLoopFuture? { - return self.socksEstablishedPromise?.futureResult + self.socksEstablishedPromise?.futureResult } private let deadline: NIODeadline diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift index aab26fda8..bebd0bcc7 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift @@ -31,7 +31,7 @@ final class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { private var tlsEstablishedPromise: EventLoopPromise? var tlsEstablishedFuture: EventLoopFuture? { - return self.tlsEstablishedPromise?.futureResult + self.tlsEstablishedPromise?.futureResult } private let deadline: NIODeadline? diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift index 15544dc4a..8203f07af 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift @@ -35,16 +35,24 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { didSet { if let newRequest = self.request { var requestLogger = newRequest.logger - requestLogger[metadataKey: "ahc-connection-id"] = "\(self.connection.id)" - requestLogger[metadataKey: "ahc-el"] = "\(self.connection.channel.eventLoop)" + requestLogger[metadataKey: "ahc-connection-id"] = self.connectionIdLoggerMetadata + requestLogger[metadataKey: "ahc-el"] = self.eventLoopDescription self.logger = requestLogger if let idleReadTimeout = newRequest.requestOptions.idleReadTimeout { self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) } + + if let idleWriteTimeout = newRequest.requestOptions.idleWriteTimeout { + self.idleWriteTimeoutStateMachine = .init( + timeAmount: idleWriteTimeout, + isWritabilityEnabled: self.channelContext?.channel.isWritable ?? false + ) + } } else { self.logger = self.backgroundLogger self.idleReadTimeoutStateMachine = nil + self.idleWriteTimeoutStateMachine = nil } } } @@ -52,22 +60,28 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { private var idleReadTimeoutStateMachine: IdleReadStateMachine? private var idleReadTimeoutTimer: Scheduled? + private var idleWriteTimeoutStateMachine: IdleWriteStateMachine? + private var idleWriteTimeoutTimer: Scheduled? + /// Cancelling a task in NIO does *not* guarantee that the task will not execute under certain race conditions. /// We therefore give each timer an ID and increase the ID every time we reset or cancel it. /// We check in the task if the timer ID has changed in the meantime and do not execute any action if has changed. private var currentIdleReadTimeoutTimerID: Int = 0 + private var currentIdleWriteTimeoutTimerID: Int = 0 private let backgroundLogger: Logger private var logger: Logger + private let eventLoop: EventLoop + private let eventLoopDescription: Logger.MetadataValue + private let connectionIdLoggerMetadata: Logger.MetadataValue - let connection: HTTP1Connection - let eventLoop: EventLoop - - init(connection: HTTP1Connection, eventLoop: EventLoop, logger: Logger) { - self.connection = connection + var onConnectionIdle: () -> Void = {} + init(eventLoop: EventLoop, backgroundLogger: Logger, connectionIdLoggerMetadata: Logger.MetadataValue) { self.eventLoop = eventLoop - self.backgroundLogger = logger - self.logger = self.backgroundLogger + self.eventLoopDescription = "\(eventLoop.description)" + self.backgroundLogger = backgroundLogger + self.logger = backgroundLogger + self.connectionIdLoggerMetadata = connectionIdLoggerMetadata } func handlerAdded(context: ChannelHandlerContext) { @@ -86,9 +100,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { // MARK: Channel Inbound Handler func channelActive(context: ChannelHandlerContext) { - self.logger.trace("Channel active", metadata: [ - "ahc-channel-writable": "\(context.channel.isWritable)", - ]) + self.logger.trace( + "Channel active", + metadata: [ + "ahc-channel-writable": "\(context.channel.isWritable)" + ] + ) let action = self.state.channelActive(isWritable: context.channel.isWritable) self.run(action, context: context) @@ -102,20 +119,31 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } func channelWritabilityChanged(context: ChannelHandlerContext) { - self.logger.trace("Channel writability changed", metadata: [ - "ahc-channel-writable": "\(context.channel.isWritable)", - ]) + self.logger.trace( + "Channel writability changed", + metadata: [ + "ahc-channel-writable": "\(context.channel.isWritable)" + ] + ) + + if let timeoutAction = self.idleWriteTimeoutStateMachine?.channelWritabilityChanged(context: context) { + self.runTimeoutAction(timeoutAction, context: context) + } let action = self.state.writabilityChanged(writable: context.channel.isWritable) self.run(action, context: context) + context.fireChannelWritabilityChanged() } func channelRead(context: ChannelHandlerContext, data: NIOAny) { let httpPart = self.unwrapInboundIn(data) - self.logger.trace("HTTP response part received", metadata: [ - "ahc-http-part": "\(httpPart)", - ]) + self.logger.trace( + "HTTP response part received", + metadata: [ + "ahc-http-part": "\(httpPart)" + ] + ) if let timeoutAction = self.idleReadTimeoutStateMachine?.channelRead(httpPart) { self.runTimeoutAction(timeoutAction, context: context) @@ -133,9 +161,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } func errorCaught(context: ChannelHandlerContext, error: Error) { - self.logger.trace("Channel error caught", metadata: [ - "ahc-error": "\(error)", - ]) + self.logger.trace( + "Channel error caught", + metadata: [ + "ahc-error": "\(error)" + ] + ) let action = self.state.errorHappened(error) self.run(action, context: context) @@ -149,6 +180,11 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.request = req self.logger.debug("Request was scheduled on connection") + + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + req.willExecuteRequest(self) let action = self.state.runNewRequest( @@ -182,39 +218,49 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { private func run(_ action: HTTP1ConnectionStateMachine.Action, context: ChannelHandlerContext) { switch action { - case .sendRequestHead(let head, startBody: let startBody): - if startBody { - context.write(self.wrapOutboundOut(.head(head)), promise: nil) - context.flush() - - self.request!.requestHeadSent() - self.request!.resumeRequestBodyStream() - } else { - context.write(self.wrapOutboundOut(.head(head)), promise: nil) - context.write(self.wrapOutboundOut(.end(nil)), promise: nil) - context.flush() - - self.request!.requestHeadSent() + case .sendRequestHead(let head, let sendEnd): + self.sendRequestHead(head, sendEnd: sendEnd, context: context) + case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet + self.request!.requestHeadSent() + if resumeRequestBodyStream, let request = self.request { + // The above request head send notification might lead the request to mark itself as + // cancelled, which in turn might pop the request of the handler. For this reason we + // must check if the request is still present here. + request.resumeRequestBodyStream() + } + if startIdleTimer { + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) } } + case .sendBodyPart(let part, let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: writePromise) - case .sendBodyPart(let part): - context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: nil) + case .sendRequestEnd(let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) } case .pauseRequestBodyStream: + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet self.request!.pauseRequestBodyStream() case .resumeRequestBodyStream: + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet self.request!.resumeRequestBodyStream() case .fireChannelActive: @@ -239,15 +285,25 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { break case .forwardResponseHead(let head, let pauseRequestBodyStream): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet self.request!.receiveResponseHead(head) - if pauseRequestBodyStream { - self.request!.pauseRequestBodyStream() + if pauseRequestBodyStream, let request = self.request { + // The above response head forward might lead the request to mark itself as + // cancelled, which in turn might pop the request of the handler. For this reason we + // must check if the request is still present here. + request.pauseRequestBodyStream() } case .forwardResponseBodyParts(let buffer): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet self.request!.receiveResponseBodyParts(buffer) case .succeedRequest(let finalAction, let buffer): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet + // The order here is very important... // We first nil our own task property! `taskCompleted` will potentially lead to // situations in which we get a new request right away. We should finish the task @@ -258,41 +314,82 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { let oldRequest = self.request! self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) switch finalAction { case .close: context.close(promise: nil) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + oldRequest.succeedRequest(buffer) + case .sendRequestEnd(let writePromise, let shouldClose): + let writePromise = writePromise ?? context.eventLoop.makePromise(of: Void.self) + // We need to defer succeeding the old request to avoid ordering issues + writePromise.futureResult.hop(to: context.eventLoop).whenComplete { result in + switch result { + case .success: + // If our final action was `sendRequestEnd`, that means we've already received + // the complete response. As a result, once we've uploaded all the body parts + // we need to tell the pool that the connection is idle or, if we were asked to + // close when we're done, send the close. Either way, we then succeed the request + if shouldClose { + context.close(promise: nil) + } else { + self.onConnectionIdle() + } + + oldRequest.succeedRequest(buffer) + case .failure(let error): + context.close(promise: nil) + oldRequest.fail(error) + } + } + + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) case .informConnectionIsIdle: - self.connection.taskCompleted() - case .none: - break + self.onConnectionIdle() + oldRequest.succeedRequest(buffer) } - oldRequest.succeedRequest(buffer) - case .failRequest(let error, let finalAction): // see comment in the `succeedRequest` case. let oldRequest = self.request! self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) switch finalAction { - case .close: + case .close(let writePromise): context.close(promise: nil) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + writePromise?.fail(error) + oldRequest.fail(error) + case .informConnectionIsIdle: - self.connection.taskCompleted() + self.onConnectionIdle() + oldRequest.fail(error) + + case .failWritePromise(let writePromise): + writePromise?.fail(error) + oldRequest.fail(error) + case .none: - break + oldRequest.fail(error) } - oldRequest.fail(error) + case .failSendBodyPart(let error, let writePromise), .failSendStreamFinished(let error, let writePromise): + writePromise?.fail(error) } } + private func sendRequestHead(_ head: HTTPRequestHead, sendEnd: Bool, context: ChannelHandlerContext) { + if sendEnd { + context.write(self.wrapOutboundOut(.head(head)), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } else { + context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) + } + self.run(self.state.headSent(), context: context) + } + private func runTimeoutAction(_ action: IdleReadStateMachine.Action, context: ChannelHandlerContext) { switch action { case .startIdleReadTimeoutTimer(let timeAmount): @@ -328,29 +425,70 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } } + private func runTimeoutAction(_ action: IdleWriteStateMachine.Action, context: ChannelHandlerContext) { + switch action { + case .startIdleWriteTimeoutTimer(let timeAmount): + assert(self.idleWriteTimeoutTimer == nil, "Expected there is no timeout timer so far.") + + let timerID = self.currentIdleWriteTimeoutTimerID + self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + guard self.currentIdleWriteTimeoutTimerID == timerID else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .resetIdleWriteTimeoutTimer(let timeAmount): + if let oldTimer = self.idleWriteTimeoutTimer { + oldTimer.cancel() + } + + self.currentIdleWriteTimeoutTimerID &+= 1 + let timerID = self.currentIdleWriteTimeoutTimerID + self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + guard self.currentIdleWriteTimeoutTimerID == timerID else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .clearIdleWriteTimeoutTimer: + if let oldTimer = self.idleWriteTimeoutTimer { + self.idleWriteTimeoutTimer = nil + self.currentIdleWriteTimeoutTimerID &+= 1 + oldTimer.cancel() + } + case .none: + break + } + } + // MARK: Private HTTPRequestExecutor - private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest) { + private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) + { guard self.request === request, let context = self.channelContext else { // Because the HTTPExecutableRequest may run in a different thread to our eventLoop, // calls from the HTTPExecutableRequest to our ChannelHandler may arrive here after // the request has been popped by the state machine or the ChannelHandler has been // removed from the Channel pipeline. This is a normal threading issue, noone has // screwed up. + promise?.fail(HTTPClientError.requestStreamCancelled) return } - let action = self.state.requestStreamPartReceived(data) + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + + let action = self.state.requestStreamPartReceived(data, promise: promise) self.run(action, context: context) } - private func finishRequestBodyStream0(_ request: HTTPExecutableRequest) { + private func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` + promise?.fail(HTTPClientError.requestStreamCancelled) return } - let action = self.state.requestStreamFinished() + let action = self.state.requestStreamFinished(promise: promise) self.run(action, context: context) } @@ -374,28 +512,35 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.logger.trace("Request was cancelled") + if let timeoutAction = self.idleWriteTimeoutStateMachine?.cancelRequest() { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.requestCancelled(closeConnection: true) self.run(action, context: context) } } +@available(*, unavailable) +extension HTTP1ClientChannelHandler: Sendable {} + extension HTTP1ClientChannelHandler: HTTPRequestExecutor { - func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest) { + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { if self.eventLoop.inEventLoop { - self.writeRequestBodyPart0(data, request: request) + self.writeRequestBodyPart0(data, request: request, promise: promise) } else { self.eventLoop.execute { - self.writeRequestBodyPart0(data, request: request) + self.writeRequestBodyPart0(data, request: request, promise: promise) } } } - func finishRequestBodyStream(_ request: HTTPExecutableRequest) { + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { if self.eventLoop.inEventLoop { - self.finishRequestBodyStream0(request) + self.finishRequestBodyStream0(request, promise: promise) } else { self.eventLoop.execute { - self.finishRequestBodyStream0(request) + self.finishRequestBodyStream0(request, promise: promise) } } } @@ -483,3 +628,90 @@ struct IdleReadStateMachine { } } } + +struct IdleWriteStateMachine { + enum Action { + case startIdleWriteTimeoutTimer(TimeAmount) + case resetIdleWriteTimeoutTimer(TimeAmount) + case clearIdleWriteTimeoutTimer + case none + } + + enum State { + case waitingForRequestEnd + case waitingForWritabilityEnabled + case requestEndSent + } + + private var state: State + private let timeAmount: TimeAmount + + init(timeAmount: TimeAmount, isWritabilityEnabled: Bool) { + self.timeAmount = timeAmount + if isWritabilityEnabled { + self.state = .waitingForRequestEnd + } else { + self.state = .waitingForWritabilityEnabled + } + } + + mutating func cancelRequest() -> Action { + switch self.state { + case .waitingForRequestEnd, .waitingForWritabilityEnabled: + self.state = .requestEndSent + return .clearIdleWriteTimeoutTimer + case .requestEndSent: + return .none + } + } + + mutating func write() -> Action { + switch self.state { + case .waitingForRequestEnd: + return .resetIdleWriteTimeoutTimer(self.timeAmount) + case .waitingForWritabilityEnabled: + return .none + case .requestEndSent: + preconditionFailure("If the request end has been sent, we can't write more data.") + } + } + + mutating func requestEndSent() -> Action { + switch self.state { + case .waitingForRequestEnd: + self.state = .requestEndSent + return .clearIdleWriteTimeoutTimer + case .waitingForWritabilityEnabled: + self.state = .requestEndSent + return .none + case .requestEndSent: + return .none + } + } + + mutating func channelWritabilityChanged(context: ChannelHandlerContext) -> Action { + if context.channel.isWritable { + switch self.state { + case .waitingForRequestEnd: + preconditionFailure("If waiting for more data, the channel was already writable.") + case .waitingForWritabilityEnabled: + self.state = .waitingForRequestEnd + return .startIdleWriteTimeoutTimer(self.timeAmount) + case .requestEndSent: + return .none + } + } else { + switch self.state { + case .waitingForRequestEnd: + self.state = .waitingForWritabilityEnabled + return .clearIdleWriteTimeoutTimer + case .waitingForWritabilityEnabled: + preconditionFailure( + "If the channel was writable before, then we should have been waiting for more data." + ) + case .requestEndSent: + return .none + } + } + } +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift index 173ac79e4..e0496f2e3 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift @@ -39,9 +39,11 @@ final class HTTP1Connection { let id: HTTPConnectionPool.Connection.ID - init(channel: Channel, - connectionID: HTTPConnectionPool.Connection.ID, - delegate: HTTP1ConnectionDelegate) { + init( + channel: Channel, + connectionID: HTTPConnectionPool.Connection.ID, + delegate: HTTP1ConnectionDelegate + ) { self.channel = channel self.id = connectionID self.delegate = delegate @@ -57,11 +59,11 @@ final class HTTP1Connection { channel: Channel, connectionID: HTTPConnectionPool.Connection.ID, delegate: HTTP1ConnectionDelegate, - configuration: HTTPClient.Configuration, + decompression: HTTPClient.Decompression, logger: Logger ) throws -> HTTP1Connection { let connection = HTTP1Connection(channel: channel, connectionID: connectionID, delegate: delegate) - try connection.start(configuration: configuration, logger: logger) + try connection.start(decompression: decompression, logger: logger) return connection } @@ -80,7 +82,7 @@ final class HTTP1Connection { } func close(promise: EventLoopPromise?) { - return self.channel.close(mode: .all, promise: promise) + self.channel.close(mode: .all, promise: promise) } func close() -> EventLoopFuture { @@ -101,7 +103,7 @@ final class HTTP1Connection { self.channel.write(request, promise: nil) } - private func start(configuration: HTTPClient.Configuration, logger: Logger) throws { + private func start(decompression: HTTPClient.Decompression, logger: Logger) throws { self.channel.eventLoop.assertInEventLoop() guard case .initialized = self.state else { @@ -127,16 +129,19 @@ final class HTTP1Connection { try sync.addHandler(requestEncoder) try sync.addHandler(ByteToMessageHandler(responseDecoder)) - if case .enabled(let limit) = configuration.decompression { + if case .enabled(let limit) = decompression { let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) try sync.addHandler(decompressHandler) } let channelHandler = HTTP1ClientChannelHandler( - connection: self, eventLoop: channel.eventLoop, - logger: logger + backgroundLogger: logger, + connectionIdLoggerMetadata: "\(self.id)" ) + channelHandler.onConnectionIdle = { + self.taskCompleted() + } try sync.addHandler(channelHandler) } catch { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift index 19825aec7..2cde1df3f 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift @@ -28,21 +28,44 @@ struct HTTP1ConnectionStateMachine { enum Action { /// A action to execute, when we consider a request "done". - enum FinalStreamAction { + enum FinalSuccessfulStreamAction { /// Close the connection case close /// If the server has replied, with a status of 200...300 before all data was sent, a request is considered succeeded, /// as soon as we wrote the request end onto the wire. - case sendRequestEnd + /// + /// The promise is an optional write promise. + /// + /// `shouldClose` records whether we have attached a Connection: close header to this request, and so the connection should + /// be terminated + case sendRequestEnd(EventLoopPromise?, shouldClose: Bool) /// Inform an observer that the connection has become idle case informConnectionIsIdle + } + + /// A action to execute, when we consider a request "done". + enum FinalFailedStreamAction { + /// Close the connection + /// + /// The promise is an optional write promise. + case close(EventLoopPromise?) + /// Inform an observer that the connection has become idle + case informConnectionIsIdle + /// Fail the write promise + case failWritePromise(EventLoopPromise?) /// Do nothing. case none } - case sendRequestHead(HTTPRequestHead, startBody: Bool) - case sendBodyPart(IOData) - case sendRequestEnd + case sendRequestHead(HTTPRequestHead, sendEnd: Bool) + case notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: Bool, + startIdleTimer: Bool + ) + case sendBodyPart(IOData, EventLoopPromise?) + case sendRequestEnd(EventLoopPromise?) + case failSendBodyPart(Error, EventLoopPromise?) + case failSendStreamFinished(Error, EventLoopPromise?) case pauseRequestBodyStream case resumeRequestBodyStream @@ -50,8 +73,8 @@ struct HTTP1ConnectionStateMachine { case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) case forwardResponseBodyParts(CircularBuffer) - case failRequest(Error, FinalStreamAction) - case succeedRequest(FinalStreamAction, CircularBuffer) + case failRequest(Error, FinalFailedStreamAction) + case succeedRequest(FinalSuccessfulStreamAction, CircularBuffer) case read case close @@ -83,14 +106,14 @@ struct HTTP1ConnectionStateMachine { return .wait case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } mutating func channelInactive() -> Action { switch self.state { case .initialized: - preconditionFailure("A channel that isn't active, must not become inactive") + fatalError("A channel that isn't active, must not become inactive") case .inRequest(var requestStateMachine, close: _): return self.avoidingStateMachineCoW { state -> Action in @@ -107,7 +130,7 @@ struct HTTP1ConnectionStateMachine { return .wait case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } @@ -117,7 +140,7 @@ struct HTTP1ConnectionStateMachine { self.state = .closed return .fireChannelError(error, closeConnection: false) - case .inRequest(var requestStateMachine, close: let close): + case .inRequest(var requestStateMachine, let close): return self.avoidingStateMachineCoW { state -> Action in let action = requestStateMachine.errorHappened(error) state = .inRequest(requestStateMachine, close: close) @@ -132,7 +155,7 @@ struct HTTP1ConnectionStateMachine { return .fireChannelError(error, closeConnection: false) case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } @@ -150,7 +173,7 @@ struct HTTP1ConnectionStateMachine { } case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } @@ -159,15 +182,15 @@ struct HTTP1ConnectionStateMachine { metadata: RequestFramingMetadata ) -> Action { switch self.state { - case .initialized, .closing, .inRequest: + case .initialized, .inRequest: // These states are unreachable as the connection pool state machine has put the // connection into these states. In other words the connection pool state machine must // be aware about these states before the connection itself. For this reason the // connection pool state machine must not send a new request to the connection, if the // connection is `.initialized`, `.closing` or `.inRequest` - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") - case .closed: + case .closing, .closed: // The remote may have closed the connection and the connection pool state machine // was not updated yet because of a race condition. New request vs. marking connection // as closed. @@ -185,29 +208,29 @@ struct HTTP1ConnectionStateMachine { return self.state.modify(with: action) case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } - mutating func requestStreamPartReceived(_ part: IOData) -> Action { + mutating func requestStreamPartReceived(_ part: IOData, promise: EventLoopPromise?) -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in - let action = requestStateMachine.requestStreamPartReceived(part) + let action = requestStateMachine.requestStreamPartReceived(part, promise: promise) state = .inRequest(requestStateMachine, close: close) return state.modify(with: action) } } - mutating func requestStreamFinished() -> Action { + mutating func requestStreamFinished(promise: EventLoopPromise?) -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in - let action = requestStateMachine.requestStreamFinished() + let action = requestStateMachine.requestStreamFinished(promise: promise) state = .inRequest(requestStateMachine, close: close) return state.modify(with: action) } @@ -216,7 +239,9 @@ struct HTTP1ConnectionStateMachine { mutating func requestCancelled(closeConnection: Bool) -> Action { switch self.state { case .initialized: - preconditionFailure("This event must only happen, if the connection is leased. During startup this is impossible. Invalid state: \(self.state)") + fatalError( + "This event must only happen, if the connection is leased. During startup this is impossible. Invalid state: \(self.state)" + ) case .idle: if closeConnection { @@ -226,7 +251,7 @@ struct HTTP1ConnectionStateMachine { return .wait } - case .inRequest(var requestStateMachine, close: let close): + case .inRequest(var requestStateMachine, let close): return self.avoidingStateMachineCoW { state -> Action in let action = requestStateMachine.requestCancelled() state = .inRequest(requestStateMachine, close: close || closeConnection) @@ -237,7 +262,7 @@ struct HTTP1ConnectionStateMachine { return .wait case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } @@ -246,7 +271,7 @@ struct HTTP1ConnectionStateMachine { mutating func read() -> Action { switch self.state { case .initialized: - preconditionFailure("Why should we read something, if we are not connected yet") + fatalError("Why should we read something, if we are not connected yet") case .idle: return .read case .inRequest(var requestStateMachine, let close): @@ -261,14 +286,14 @@ struct HTTP1ConnectionStateMachine { return .read case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } mutating func channelRead(_ part: HTTPClientResponsePart) -> Action { switch self.state { case .initialized, .idle: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") case .inRequest(var requestStateMachine, var close): return self.avoidingStateMachineCoW { state -> Action in @@ -287,7 +312,7 @@ struct HTTP1ConnectionStateMachine { return .wait case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } @@ -304,13 +329,13 @@ struct HTTP1ConnectionStateMachine { } case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } mutating func demandMoreResponseBodyParts() -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in @@ -322,7 +347,7 @@ struct HTTP1ConnectionStateMachine { mutating func idleReadTimeoutTriggered() -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in @@ -331,6 +356,29 @@ struct HTTP1ConnectionStateMachine { return state.modify(with: action) } } + + mutating func idleWriteTimeoutTriggered() -> Action { + guard case .inRequest(var requestStateMachine, let close) = self.state else { + return .wait + } + + return self.avoidingStateMachineCoW { state -> Action in + let action = requestStateMachine.idleWriteTimeoutTriggered() + state = .inRequest(requestStateMachine, close: close) + return state.modify(with: action) + } + } + + mutating func headSent() -> Action { + guard case .inRequest(var requestStateMachine, let close) = self.state else { + return .wait + } + return self.avoidingStateMachineCoW { state in + let action = requestStateMachine.headSent() + state = .inRequest(requestStateMachine, close: close) + return state.modify(with: action) + } + } } extension HTTP1ConnectionStateMachine { @@ -369,34 +417,41 @@ extension HTTP1ConnectionStateMachine { } extension HTTP1ConnectionStateMachine.State { - fileprivate mutating func modify(with action: HTTPRequestStateMachine.Action) -> HTTP1ConnectionStateMachine.Action { + fileprivate mutating func modify(with action: HTTPRequestStateMachine.Action) -> HTTP1ConnectionStateMachine.Action + { switch action { - case .sendRequestHead(let head, let startBody): - return .sendRequestHead(head, startBody: startBody) + case .sendRequestHead(let head, let sendEnd): + return .sendRequestHead(head, sendEnd: sendEnd) + case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer): + return .notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: resumeRequestBodyStream, + startIdleTimer: startIdleTimer + ) case .pauseRequestBodyStream: return .pauseRequestBodyStream case .resumeRequestBodyStream: return .resumeRequestBodyStream - case .sendBodyPart(let part): - return .sendBodyPart(part) - case .sendRequestEnd: - return .sendRequestEnd + case .sendBodyPart(let part, let writePromise): + return .sendBodyPart(part, writePromise) + case .sendRequestEnd(let writePromise): + return .sendRequestEnd(writePromise) case .forwardResponseHead(let head, let pauseRequestBodyStream): return .forwardResponseHead(head, pauseRequestBodyStream: pauseRequestBodyStream) case .forwardResponseBodyParts(let parts): return .forwardResponseBodyParts(parts) case .succeedRequest(let finalAction, let finalParts): guard case .inRequest(_, close: let close) = self else { - preconditionFailure("Invalid state: \(self)") + fatalError("Invalid state: \(self)") } - let newFinalAction: HTTP1ConnectionStateMachine.Action.FinalStreamAction + let newFinalAction: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction switch finalAction { case .close: self = .closing newFinalAction = .close - case .sendRequestEnd: - newFinalAction = .sendRequestEnd + case .sendRequestEnd(let writePromise): + self = .idle + newFinalAction = .sendRequestEnd(writePromise, shouldClose: close) case .none: self = .idle newFinalAction = close ? .close : .informConnectionIsIdle @@ -406,13 +461,16 @@ extension HTTP1ConnectionStateMachine.State { case .failRequest(let error, let finalAction): switch self { case .initialized: - preconditionFailure("Invalid state: \(self)") + fatalError("Invalid state: \(self)") case .idle: - preconditionFailure("How can we fail a task, if we are idle") - case .inRequest(_, close: let close): - if close || finalAction == .close { + fatalError("How can we fail a task, if we are idle") + case .inRequest(_, let close): + if case .close(let promise) = finalAction { + self = .closing + return .failRequest(error, .close(promise)) + } else if close { self = .closing - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) } else { self = .idle return .failRequest(error, .informConnectionIsIdle) @@ -425,7 +483,7 @@ extension HTTP1ConnectionStateMachine.State { return .failRequest(error, .none) case .modifying: - preconditionFailure("Invalid state: \(self)") + fatalError("Invalid state: \(self)") } case .read: @@ -433,6 +491,12 @@ extension HTTP1ConnectionStateMachine.State { case .wait: return .wait + + case .failSendBodyPart(let error, let writePromise): + return .failSendBodyPart(error, writePromise) + + case .failSendStreamFinished(let error, let writePromise): + return .failSendStreamFinished(error, writePromise) } } } @@ -444,14 +508,14 @@ extension HTTP1ConnectionStateMachine: CustomStringConvertible { return ".initialized" case .idle: return ".idle" - case .inRequest(let request, close: let close): + case .inRequest(let request, let close): return ".inRequest(\(request), closeAfterRequest: \(close))" case .closing: return ".closing" case .closed: return ".closed" case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift index 8b2a50738..61350dfd7 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -35,8 +35,16 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { private var request: HTTPExecutableRequest? { didSet { - if let newRequest = self.request, let idleReadTimeout = newRequest.requestOptions.idleReadTimeout { - self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) + if let newRequest = self.request { + if let idleReadTimeout = newRequest.requestOptions.idleReadTimeout { + self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) + } + if let idleWriteTimeout = newRequest.requestOptions.idleWriteTimeout { + self.idleWriteTimeoutStateMachine = .init( + timeAmount: idleWriteTimeout, + isWritabilityEnabled: self.channelContext?.channel.isWritable ?? false + ) + } } else { self.idleReadTimeoutStateMachine = nil } @@ -46,13 +54,24 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { private var idleReadTimeoutStateMachine: IdleReadStateMachine? private var idleReadTimeoutTimer: Scheduled? + private var idleWriteTimeoutStateMachine: IdleWriteStateMachine? + private var idleWriteTimeoutTimer: Scheduled? + + /// Cancelling a task in NIO does *not* guarantee that the task will not execute under certain race conditions. + /// We therefore give each timer an ID and increase the ID every time we reset or cancel it. + /// We check in the task if the timer ID has changed in the meantime and do not execute any action if has changed. + private var currentIdleReadTimeoutTimerID: Int = 0 + private var currentIdleWriteTimeoutTimerID: Int = 0 + init(eventLoop: EventLoop) { self.eventLoop = eventLoop } func handlerAdded(context: ChannelHandlerContext) { - assert(context.eventLoop === self.eventLoop, - "The handler must be added to a channel that runs on the eventLoop it was initialized with.") + assert( + context.eventLoop === self.eventLoop, + "The handler must be added to a channel that runs on the eventLoop it was initialized with." + ) self.channelContext = context let isWritable = context.channel.isActive && context.channel.isWritable @@ -77,6 +96,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } func channelWritabilityChanged(context: ChannelHandlerContext) { + if let timeoutAction = self.idleWriteTimeoutStateMachine?.channelWritabilityChanged(context: context) { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.writabilityChanged(writable: context.channel.isWritable) self.run(action, context: context) } @@ -110,6 +133,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // a single request. self.request = request + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + request.willExecuteRequest(self) let action = self.state.startRequest( @@ -140,22 +167,44 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { private func run(_ action: HTTPRequestStateMachine.Action, context: ChannelHandlerContext) { switch action { - case .sendRequestHead(let head, let startBody): - self.sendRequestHead(head, startBody: startBody, context: context) - + case .sendRequestHead(let head, let sendEnd): + self.sendRequestHead(head, sendEnd: sendEnd, context: context) + case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet + self.request!.requestHeadSent() + if resumeRequestBodyStream, let request = self.request { + // The above request head send notification might lead the request to mark itself as + // cancelled, which in turn might pop the request of the handler. For this reason we + // must check if the request is still present here. + request.resumeRequestBodyStream() + } + if startIdleTimer { + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } + + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) + } + } case .pauseRequestBodyStream: // We can force unwrap the request here, as we have just validated in the state machine, // that the request is neither failed nor finished yet self.request!.pauseRequestBodyStream() - case .sendBodyPart(let data): - context.writeAndFlush(self.wrapOutboundOut(.body(data)), promise: nil) + case .sendBodyPart(let data, let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.body(data)), promise: writePromise) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .sendRequestEnd(let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) + + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) } case .read: @@ -169,7 +218,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // that the request is neither failed nor finished yet self.request!.resumeRequestBodyStream() - case .forwardResponseHead(let head, pauseRequestBodyStream: let pauseRequestBodyStream): + case .forwardResponseHead(let head, let pauseRequestBodyStream): // We can force unwrap the request here, as we have just validated in the state machine, // that the request is neither failed nor finished yet self.request!.receiveResponseHead(head) @@ -185,17 +234,18 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // that the request is neither failed nor finished yet self.request!.receiveResponseBodyParts(parts) - case .failRequest(let error, _): + case .failRequest(let error, let finalAction): // We can force unwrap the request here, as we have just validated in the state machine, // that the request object is still present. self.request!.fail(error) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) // No matter the error reason, we must always make sure the h2 stream is closed. Only // once the h2 stream is closed, it is released from the h2 multiplexer. The // HTTPRequestStateMachine may signal finalAction: .none in the error case (as this is // the right result for HTTP/1). In the h2 case we MUST always close. - self.runFinalAction(.close, context: context) + self.runFailedFinalAction(finalAction, context: context, error: error) case .succeedRequest(let finalAction, let finalParts): // We can force unwrap the request here, as we have just validated in the state machine, @@ -203,44 +253,54 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.request!.succeedRequest(finalParts) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) - self.runFinalAction(finalAction, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) + self.runSuccessfulFinalAction(finalAction, context: context) + + case .failSendBodyPart(let error, let writePromise), .failSendStreamFinished(let error, let writePromise): + writePromise?.fail(error) } } - private func sendRequestHead(_ head: HTTPRequestHead, startBody: Bool, context: ChannelHandlerContext) { - if startBody { - context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) - - // The above write might trigger an error, which may lead to a call to `errorCaught`, - // which in turn, may fail the request and pop it from the handler. For this reason - // we must check if the request is still present here. - guard let request = self.request else { return } - request.requestHeadSent() - request.resumeRequestBodyStream() - } else { + private func sendRequestHead(_ head: HTTPRequestHead, sendEnd: Bool, context: ChannelHandlerContext) { + if sendEnd { context.write(self.wrapOutboundOut(.head(head)), promise: nil) context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() + } else { + context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) + } + self.run(self.state.headSent(), context: context) + } - // The above write might trigger an error, which may lead to a call to `errorCaught`, - // which in turn, may fail the request and pop it from the handler. For this reason - // we must check if the request is still present here. - guard let request = self.request else { return } - request.requestHeadSent() + private func runSuccessfulFinalAction( + _ action: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, + context: ChannelHandlerContext + ) { + switch action { + case .close, .none: + // The actions returned here come from an `HTTPRequestStateMachine` that assumes http/1.1 + // semantics. For this reason we can ignore the close here, since an h2 stream is closed + // after every request anyway. + break - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) - } + case .sendRequestEnd(let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) } } - private func runFinalAction(_ action: HTTPRequestStateMachine.Action.FinalStreamAction, context: ChannelHandlerContext) { - switch action { - case .close: - context.close(promise: nil) + private func runFailedFinalAction( + _ action: HTTPRequestStateMachine.Action.FinalFailedRequestAction, + context: ChannelHandlerContext, + error: Error + ) { + // We must close the http2 stream after the request has finished. Since the request failed, + // we have no idea what the h2 streams state was. To be on the save side, we explicitly close + // the h2 stream. This will break a reference cycle in HTTP2Connection. + context.close(promise: nil) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + switch action { + case .close(let writePromise): + writePromise?.fail(error) case .none: break @@ -252,8 +312,9 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { case .startIdleReadTimeoutTimer(let timeAmount): assert(self.idleReadTimeoutTimer == nil, "Expected there is no timeout timer so far.") + let timerID = self.currentIdleReadTimeoutTimerID self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { - guard self.idleReadTimeoutTimer != nil else { return } + guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) } @@ -263,14 +324,17 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { oldTimer.cancel() } + self.currentIdleReadTimeoutTimerID &+= 1 + let timerID = self.currentIdleReadTimeoutTimerID self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { - guard self.idleReadTimeoutTimer != nil else { return } + guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) } case .clearIdleReadTimeoutTimer: if let oldTimer = self.idleReadTimeoutTimer { self.idleReadTimeoutTimer = nil + self.currentIdleReadTimeoutTimerID &+= 1 oldTimer.cancel() } @@ -279,29 +343,69 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } } + private func runTimeoutAction(_ action: IdleWriteStateMachine.Action, context: ChannelHandlerContext) { + switch action { + case .startIdleWriteTimeoutTimer(let timeAmount): + assert(self.idleWriteTimeoutTimer == nil, "Expected there is no timeout timer so far.") + + let timerID = self.currentIdleWriteTimeoutTimerID + self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + guard self.currentIdleWriteTimeoutTimerID == timerID else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .resetIdleWriteTimeoutTimer(let timeAmount): + if let oldTimer = self.idleWriteTimeoutTimer { + oldTimer.cancel() + } + + self.currentIdleWriteTimeoutTimerID &+= 1 + let timerID = self.currentIdleWriteTimeoutTimerID + self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + guard self.currentIdleWriteTimeoutTimerID == timerID else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .clearIdleWriteTimeoutTimer: + if let oldTimer = self.idleWriteTimeoutTimer { + self.idleWriteTimeoutTimer = nil + self.currentIdleWriteTimeoutTimerID &+= 1 + oldTimer.cancel() + } + case .none: + break + } + } + // MARK: Private HTTPRequestExecutor - private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest) { + private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) + { guard self.request === request, let context = self.channelContext else { // Because the HTTPExecutableRequest may run in a different thread to our eventLoop, // calls from the HTTPExecutableRequest to our ChannelHandler may arrive here after // the request has been popped by the state machine or the ChannelHandler has been // removed from the Channel pipeline. This is a normal threading issue, noone has // screwed up. + promise?.fail(HTTPClientError.requestStreamCancelled) return } - let action = self.state.requestStreamPartReceived(data) + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + + let action = self.state.requestStreamPartReceived(data, promise: promise) self.run(action, context: context) } - private func finishRequestBodyStream0(_ request: HTTPExecutableRequest) { + private func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` return } - let action = self.state.requestStreamFinished() + let action = self.state.requestStreamFinished(promise: promise) self.run(action, context: context) } @@ -321,28 +425,35 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { return } + if let timeoutAction = self.idleWriteTimeoutStateMachine?.cancelRequest() { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.requestCancelled() self.run(action, context: context) } } +@available(*, unavailable) +extension HTTP2ClientRequestHandler: Sendable {} + extension HTTP2ClientRequestHandler: HTTPRequestExecutor { - func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest) { + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { if self.eventLoop.inEventLoop { - self.writeRequestBodyPart0(data, request: request) + self.writeRequestBodyPart0(data, request: request, promise: promise) } else { self.eventLoop.execute { - self.writeRequestBodyPart0(data, request: request) + self.writeRequestBodyPart0(data, request: request, promise: promise) } } } - func finishRequestBodyStream(_ request: HTTPExecutableRequest) { + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { if self.eventLoop.inEventLoop { - self.finishRequestBodyStream0(request) + self.finishRequestBodyStream0(request, promise: promise) } else { self.eventLoop.execute { - self.finishRequestBodyStream0(request) + self.finishRequestBodyStream0(request, promise: promise) } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift index 8eb189adc..5e4ae6e01 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift @@ -15,6 +15,7 @@ import Logging import NIOCore import NIOHTTP2 +import NIOHTTPCompression protocol HTTP2ConnectionDelegate { func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) @@ -28,6 +29,8 @@ struct HTTP2PushNotSupportedError: Error {} struct HTTP2ReceivedGoAwayBeforeSettingsError: Error {} final class HTTP2Connection { + internal static let defaultSettings = nioDefaultSettings + [HTTP2Setting(parameter: .enablePush, value: 0)] + let channel: Channel let multiplexer: HTTP2StreamMultiplexer let logger: Logger @@ -76,25 +79,33 @@ final class HTTP2Connection { /// We use this channel set to remember, which open streams we need to inform that /// we want to close the connection. The channels shall than cancel their currently running - /// request. + /// request. This property must only be accessed from the connections `EventLoop`. private var openStreams = Set() let id: HTTPConnectionPool.Connection.ID + let decompression: HTTPClient.Decompression + let maximumConnectionUses: Int? var closeFuture: EventLoopFuture { self.channel.closeFuture } - init(channel: Channel, - connectionID: HTTPConnectionPool.Connection.ID, - delegate: HTTP2ConnectionDelegate, - logger: Logger) { + init( + channel: Channel, + connectionID: HTTPConnectionPool.Connection.ID, + decompression: HTTPClient.Decompression, + maximumConnectionUses: Int?, + delegate: HTTP2ConnectionDelegate, + logger: Logger + ) { self.channel = channel self.id = connectionID + self.decompression = decompression + self.maximumConnectionUses = maximumConnectionUses self.logger = logger self.multiplexer = HTTP2StreamMultiplexer( mode: .client, channel: channel, - targetWindowSize: 8 * 1024 * 1024, // 8mb + targetWindowSize: 8 * 1024 * 1024, // 8mb outboundBufferSizeHighWatermark: 8196, outboundBufferSizeLowWatermark: 4092, inboundStreamInitializer: { channel -> EventLoopFuture in @@ -107,7 +118,7 @@ final class HTTP2Connection { deinit { guard case .closed = self.state else { - preconditionFailure("Connection must be closed, before we can deinit it") + preconditionFailure("Connection must be closed, before we can deinit it. Current state: \(self.state)") } } @@ -115,11 +126,19 @@ final class HTTP2Connection { channel: Channel, connectionID: HTTPConnectionPool.Connection.ID, delegate: HTTP2ConnectionDelegate, - configuration: HTTPClient.Configuration, + decompression: HTTPClient.Decompression, + maximumConnectionUses: Int?, logger: Logger ) -> EventLoopFuture<(HTTP2Connection, Int)> { - let connection = HTTP2Connection(channel: channel, connectionID: connectionID, delegate: delegate, logger: logger) - return connection.start().map { maxStreams in (connection, maxStreams) } + let connection = HTTP2Connection( + channel: channel, + connectionID: connectionID, + decompression: decompression, + maximumConnectionUses: maximumConnectionUses, + delegate: delegate, + logger: logger + ) + return connection._start0().map { maxStreams in (connection, maxStreams) } } func executeRequest(_ request: HTTPExecutableRequest) { @@ -145,7 +164,7 @@ final class HTTP2Connection { } func close(promise: EventLoopPromise?) { - return self.channel.close(mode: .all, promise: promise) + self.channel.close(mode: .all, promise: promise) } func close() -> EventLoopFuture { @@ -154,15 +173,23 @@ final class HTTP2Connection { return promise.futureResult } - private func start() -> EventLoopFuture { + func _start0() -> EventLoopFuture { self.channel.eventLoop.assertInEventLoop() let readyToAcceptConnectionsPromise = self.channel.eventLoop.makePromise(of: Int.self) self.state = .starting(readyToAcceptConnectionsPromise) self.channel.closeFuture.whenComplete { _ in - self.state = .closed - self.delegate.http2ConnectionClosed(self) + switch self.state { + case .initialized, .closed: + preconditionFailure("invalid state \(self.state)") + case .starting(let readyToAcceptConnectionsPromise): + self.state = .closed + readyToAcceptConnectionsPromise.fail(HTTPClientError.remoteConnectionClosed) + case .active, .closing: + self.state = .closed + self.delegate.http2ConnectionClosed(self) + } } do { @@ -173,8 +200,12 @@ final class HTTP2Connection { // can be scheduled on this connection. let sync = self.channel.pipeline.syncOperations - let http2Handler = NIOHTTP2Handler(mode: .client, initialSettings: nioDefaultSettings) - let idleHandler = HTTP2IdleHandler(delegate: self, logger: self.logger) + let http2Handler = NIOHTTP2Handler(mode: .client, initialSettings: Self.defaultSettings) + let idleHandler = HTTP2IdleHandler( + delegate: self, + logger: self.logger, + maximumConnectionUses: self.maximumConnectionUses + ) try sync.addHandler(http2Handler, position: .last) try sync.addHandler(idleHandler, position: .last) @@ -196,7 +227,8 @@ final class HTTP2Connection { case .active: let createStreamChannelPromise = self.channel.eventLoop.makePromise(of: Channel.self) - self.multiplexer.createStreamChannel(promise: createStreamChannelPromise) { channel -> EventLoopFuture in + self.multiplexer.createStreamChannel(promise: createStreamChannelPromise) { + channel -> EventLoopFuture in do { // the connection may have been asked to shutdown while we created the child. in // this @@ -208,9 +240,14 @@ final class HTTP2Connection { // We only support http/2 over an https connection – using the Application-Layer // Protocol Negotiation (ALPN). For this reason it is safe to fix this to `.https`. let translate = HTTP2FramePayloadToHTTP1ClientCodec(httpProtocol: .https) - let handler = HTTP2ClientRequestHandler(eventLoop: channel.eventLoop) - try channel.pipeline.syncOperations.addHandler(translate) + + if case .enabled(let limit) = self.decompression { + let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) + try channel.pipeline.syncOperations.addHandler(decompressHandler) + } + + let handler = HTTP2ClientRequestHandler(eventLoop: channel.eventLoop) try channel.pipeline.syncOperations.addHandler(handler) // We must add the new channel to the list of open channels BEFORE we write the @@ -218,7 +255,7 @@ final class HTTP2Connection { // before. let box = ChannelBox(channel) self.openStreams.insert(box) - self.channel.closeFuture.whenComplete { _ in + channel.closeFuture.whenComplete { _ in self.openStreams.remove(box) } @@ -243,16 +280,31 @@ final class HTTP2Connection { private func shutdown0() { self.channel.eventLoop.assertInEventLoop() - self.state = .closing + switch self.state { + case .active: + self.state = .closing + + // inform all open streams, that the currently running request should be cancelled. + for box in self.openStreams { + box.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) + } + + // inform the idle connection handler, that connection should be closed, once all streams + // are closed. + self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) + + case .closed, .closing: + // we are already closing/closed and we need to tolerate this + break - // inform all open streams, that the currently running request should be cancelled. - self.openStreams.forEach { box in - box.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) + case .initialized, .starting: + preconditionFailure("invalid state \(self.state)") } + } - // inform the idle connection handler, that connection should be closed, once all streams - // are closed. - self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) + func __forTesting_getStreamChannels() -> [Channel] { + self.channel.eventLoop.preconditionInEventLoop() + return self.openStreams.map { $0.channel } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift index 8978e1a86..64a151489 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift @@ -35,9 +35,10 @@ final class HTTP2IdleHandler: ChannelDuplexH let logger: Logger let delegate: Delegate - private var state: StateMachine = .init() + private var state: StateMachine - init(delegate: Delegate, logger: Logger) { + init(delegate: Delegate, logger: Logger, maximumConnectionUses: Int? = nil) { + self.state = StateMachine(maximumUses: maximumConnectionUses) self.delegate = delegate self.logger = logger } @@ -140,19 +141,23 @@ extension HTTP2IdleHandler { } enum State { - case initialized - case connected - case active(openStreams: Int, maxStreams: Int) + case initialized(maximumUses: Int?) + case connected(remainingUses: Int?) + case active(openStreams: Int, maxStreams: Int, remainingUses: Int?) case closing(openStreams: Int, maxStreams: Int) case closed } - var state: State = .initialized + var state: State + + init(maximumUses: Int?) { + self.state = .initialized(maximumUses: maximumUses) + } mutating func channelActive() { switch self.state { - case .initialized: - self.state = .connected + case .initialized(let maximumUses): + self.state = .connected(remainingUses: maximumUses) case .connected, .active, .closing, .closed: break @@ -168,44 +173,60 @@ extension HTTP2IdleHandler { mutating func settingsReceived(_ settings: HTTP2Settings) -> Action { switch self.state { - case .initialized, .closed: + case .initialized: preconditionFailure("Invalid state: \(self.state)") - case .connected: + case .connected(let remainingUses): // a settings frame might have multiple entries for `maxConcurrentStreams`. We are // only interested in the last value! If no `maxConcurrentStreams` is set, we assume // the http/2 default of 100. let maxStreams = settings.last(where: { $0.parameter == .maxConcurrentStreams })?.value ?? 100 - self.state = .active(openStreams: 0, maxStreams: maxStreams) + self.state = .active(openStreams: 0, maxStreams: maxStreams, remainingUses: remainingUses) return .notifyConnectionNewMaxStreamsSettings(maxStreams) - case .active(openStreams: let openStreams, maxStreams: let maxStreams): - if let newMaxStreams = settings.last(where: { $0.parameter == .maxConcurrentStreams })?.value, newMaxStreams != maxStreams { - self.state = .active(openStreams: openStreams, maxStreams: newMaxStreams) + case .active(let openStreams, let maxStreams, let remainingUses): + if let newMaxStreams = settings.last(where: { $0.parameter == .maxConcurrentStreams })?.value, + newMaxStreams != maxStreams + { + self.state = .active( + openStreams: openStreams, + maxStreams: newMaxStreams, + remainingUses: remainingUses + ) return .notifyConnectionNewMaxStreamsSettings(newMaxStreams) } return .nothing case .closing: return .nothing + + case .closed: + // We may receive a Settings frame after we have called connection close, because of + // packages being delivered from the incoming buffer. + return .nothing } } mutating func goAwayReceived() -> Action { switch self.state { - case .initialized, .closed: + case .initialized: preconditionFailure("Invalid state: \(self.state)") case .connected: self.state = .closing(openStreams: 0, maxStreams: 0) return .notifyConnectionGoAwayReceived(close: true) - case .active(let openStreams, let maxStreams): + case .active(let openStreams, let maxStreams, _): self.state = .closing(openStreams: openStreams, maxStreams: maxStreams) return .notifyConnectionGoAwayReceived(close: openStreams == 0) case .closing: return .notifyConnectionGoAwayReceived(close: false) + + case .closed: + // We may receive a GoAway frame after we have called connection close, because of + // packages being delivered from the incoming buffer. + return .nothing } } @@ -218,7 +239,7 @@ extension HTTP2IdleHandler { self.state = .closing(openStreams: 0, maxStreams: 0) return .close - case .active(let openStreams, let maxStreams): + case .active(let openStreams, let maxStreams, _): if openStreams == 0 { self.state = .closed return .close @@ -234,10 +255,22 @@ extension HTTP2IdleHandler { mutating func streamCreated() -> Action { switch self.state { - case .active(var openStreams, let maxStreams): + case .initialized, .connected: + preconditionFailure("Invalid state: \(self.state)") + + case .active(var openStreams, let maxStreams, let remainingUses): openStreams += 1 - self.state = .active(openStreams: openStreams, maxStreams: maxStreams) - return .nothing + let remainingUses = remainingUses.map { $0 - 1 } + self.state = .active(openStreams: openStreams, maxStreams: maxStreams, remainingUses: remainingUses) + + if remainingUses == 0 { + // Treat running out of connection uses as if we received a GOAWAY frame. This + // will notify the delegate (i.e. connection pool) that the connection can no + // longer be used. + return self.goAwayReceived() + } else { + return .nothing + } case .closing(var openStreams, let maxStreams): // A stream might be opened, while we are closing because of race conditions. For @@ -246,17 +279,22 @@ extension HTTP2IdleHandler { self.state = .closing(openStreams: openStreams, maxStreams: maxStreams) return .nothing - case .initialized, .connected, .closed: - preconditionFailure("Invalid state: \(self.state)") + case .closed: + // We may receive a events after we have called connection close, because of + // internal races. We should just ignore these cases. + return .nothing } } mutating func streamClosed() -> Action { switch self.state { - case .active(var openStreams, let maxStreams): + case .initialized, .connected: + preconditionFailure("Invalid state: \(self.state)") + + case .active(var openStreams, let maxStreams, let remainingUses): openStreams -= 1 assert(openStreams >= 0) - self.state = .active(openStreams: openStreams, maxStreams: maxStreams) + self.state = .active(openStreams: openStreams, maxStreams: maxStreams, remainingUses: remainingUses) return .notifyConnectionStreamClosed(currentlyAvailable: maxStreams - openStreams) case .closing(var openStreams, let maxStreams): @@ -269,8 +307,10 @@ extension HTTP2IdleHandler { self.state = .closing(openStreams: openStreams, maxStreams: maxStreams) return .nothing - case .initialized, .connected, .closed: - preconditionFailure("Invalid state: \(self.state)") + case .closed: + // We may receive a events after we have called connection close, because of + // internal races. We should just ignore these cases. + return .nothing } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift index 4a3338697..cb3ec0bf5 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift @@ -20,7 +20,9 @@ import NIOPosix import NIOSOCKS import NIOSSL import NIOTLS + #if canImport(Network) +import Network import NIOTransportServices #endif @@ -31,14 +33,17 @@ extension HTTPConnectionPool { let tlsConfiguration: TLSConfiguration let sslContextCache: SSLContextCache - init(key: ConnectionPool.Key, - tlsConfiguration: TLSConfiguration?, - clientConfiguration: HTTPClient.Configuration, - sslContextCache: SSLContextCache) { + init( + key: ConnectionPool.Key, + tlsConfiguration: TLSConfiguration?, + clientConfiguration: HTTPClient.Configuration, + sslContextCache: SSLContextCache + ) { self.key = key self.clientConfiguration = clientConfiguration self.sslContextCache = sslContextCache - self.tlsConfiguration = tlsConfiguration ?? clientConfiguration.tlsConfiguration ?? .makeClientConfiguration() + self.tlsConfiguration = + tlsConfiguration ?? clientConfiguration.tlsConfiguration ?? .makeClientConfiguration() } } } @@ -47,6 +52,7 @@ protocol HTTPConnectionRequester { func http1ConnectionCreated(_: HTTP1Connection) func http2ConnectionCreated(_: HTTP2Connection, maximumStreams: Int) func failedToCreateHTTPConnection(_: HTTPConnectionPool.Connection.ID, error: Error) + func waitingForConnectivity(_: HTTPConnectionPool.Connection.ID, error: Error) } extension HTTPConnectionPool.ConnectionFactory { @@ -62,7 +68,13 @@ extension HTTPConnectionPool.ConnectionFactory { var logger = logger logger[metadataKey: "ahc-connection-id"] = "\(connectionID)" - self.makeChannel(connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, logger: logger).whenComplete { result in + self.makeChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + logger: logger + ).whenComplete { result in switch result { case .success(.http1_1(let channel)): do { @@ -70,7 +82,7 @@ extension HTTPConnectionPool.ConnectionFactory { channel: channel, connectionID: connectionID, delegate: http1ConnectionDelegate, - configuration: self.clientConfiguration, + decompression: self.clientConfiguration.decompression, logger: logger ) requester.http1ConnectionCreated(connection) @@ -82,7 +94,8 @@ extension HTTPConnectionPool.ConnectionFactory { channel: channel, connectionID: connectionID, delegate: http2ConnectionDelegate, - configuration: self.clientConfiguration, + decompression: self.clientConfiguration.decompression, + maximumConnectionUses: self.clientConfiguration.maximumUsesPerConnection, logger: logger ).whenComplete { result in switch result { @@ -104,40 +117,8 @@ extension HTTPConnectionPool.ConnectionFactory { case http2(Channel) } - func makeHTTP1Channel( - connectionID: HTTPConnectionPool.Connection.ID, - deadline: NIODeadline, - eventLoop: EventLoop, - logger: Logger - ) -> EventLoopFuture { - self.makeChannel( - connectionID: connectionID, - deadline: deadline, - eventLoop: eventLoop, - logger: logger - ).flatMapThrowing { negotiated -> Channel in - - guard case .http1_1(let channel) = negotiated else { - preconditionFailure("Expected to create http/1.1 connections only for now") - } - - // add the http1.1 channel handlers - let syncOperations = channel.pipeline.syncOperations - try syncOperations.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes) - - switch self.clientConfiguration.decompression { - case .disabled: - () - case .enabled(let limit): - let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) - try syncOperations.addHandler(decompressHandler) - } - - return channel - } - } - - func makeChannel( + func makeChannel( + requester: Requester, connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, @@ -150,6 +131,7 @@ extension HTTPConnectionPool.ConnectionFactory { case .socks: channelFuture = self.makeSOCKSProxyChannel( proxy, + requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, @@ -158,6 +140,7 @@ extension HTTPConnectionPool.ConnectionFactory { case .http: channelFuture = self.makeHTTPProxyChannel( proxy, + requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, @@ -165,7 +148,13 @@ extension HTTPConnectionPool.ConnectionFactory { ) } } else { - channelFuture = self.makeNonProxiedChannel(deadline: deadline, eventLoop: eventLoop, logger: logger) + channelFuture = self.makeNonProxiedChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + logger: logger + ) } // let's map `ChannelError.connectTimeout` into a `HTTPClientError.connectTimeout` @@ -179,30 +168,55 @@ extension HTTPConnectionPool.ConnectionFactory { } } - private func makeNonProxiedChannel( + private func makeNonProxiedChannel( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, logger: Logger ) -> EventLoopFuture { switch self.key.scheme { case .http, .httpUnix, .unix: - return self.makePlainChannel(deadline: deadline, eventLoop: eventLoop).map { .http1_1($0) } + return self.makePlainChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop + ).map { .http1_1($0) } case .https, .httpsUnix: - return self.makeTLSChannel(deadline: deadline, eventLoop: eventLoop, logger: logger).flatMapThrowing { - channel, negotiated in + return self.makeTLSChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + logger: logger + ).flatMapThrowing { + channel, + negotiated in try self.matchALPNToHTTPVersion(negotiated, channel: channel) } } } - private func makePlainChannel(deadline: NIODeadline, eventLoop: EventLoop) -> EventLoopFuture { + private func makePlainChannel( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, + deadline: NIODeadline, + eventLoop: EventLoop + ) -> EventLoopFuture { precondition(!self.key.scheme.usesTLS, "Unexpected scheme") - return self.makePlainBootstrap(deadline: deadline, eventLoop: eventLoop).connect(target: self.key.connectionTarget) + return self.makePlainBootstrap( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop + ).connect(target: self.key.connectionTarget) } - private func makeHTTPProxyChannel( + private func makeHTTPProxyChannel( _ proxy: HTTPClient.Configuration.Proxy, + requester: Requester, connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, @@ -211,7 +225,12 @@ extension HTTPConnectionPool.ConnectionFactory { // A proxy connection starts with a plain text connection to the proxy server. After // the connection has been established with the proxy server, the connection might be // upgraded to TLS before we send our first request. - let bootstrap = self.makePlainBootstrap(deadline: deadline, eventLoop: eventLoop) + let bootstrap = self.makePlainBootstrap( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop + ) return bootstrap.connect(host: proxy.host, port: proxy.port).flatMap { channel in let encoder = HTTPRequestEncoder() let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)) @@ -231,20 +250,21 @@ extension HTTPConnectionPool.ConnectionFactory { // The proxyEstablishedFuture is set as soon as the HTTP1ProxyConnectHandler is in a // pipeline. It is created in HTTP1ProxyConnectHandler's handlerAdded method. - return proxyHandler.proxyEstablishedFuture!.flatMap { - channel.pipeline.removeHandler(proxyHandler).flatMap { - channel.pipeline.removeHandler(decoder).flatMap { - channel.pipeline.removeHandler(encoder) - } - } + return proxyHandler.proxyEstablishedFuture!.assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(proxyHandler).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(decoder).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(encoder) + }.nonisolated() + }.nonisolated() }.flatMap { self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) - } + }.nonisolated() } } - private func makeSOCKSProxyChannel( + private func makeSOCKSProxyChannel( _ proxy: HTTPClient.Configuration.Proxy, + requester: Requester, connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, @@ -253,7 +273,12 @@ extension HTTPConnectionPool.ConnectionFactory { // A proxy connection starts with a plain text connection to the proxy server. After // the connection has been established with the proxy server, the connection might be // upgraded to TLS before we send our first request. - let bootstrap = self.makePlainBootstrap(deadline: deadline, eventLoop: eventLoop) + let bootstrap = self.makePlainBootstrap( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop + ) return bootstrap.connect(host: proxy.host, port: proxy.port).flatMap { channel in let socksConnectHandler = SOCKSClientHandler(targetAddress: SOCKSAddress(self.key.connectionTarget)) let socksEventHandler = SOCKSEventsHandler(deadline: deadline) @@ -267,13 +292,13 @@ extension HTTPConnectionPool.ConnectionFactory { // The socksEstablishedFuture is set as soon as the SOCKSEventsHandler is in a // pipeline. It is created in SOCKSEventsHandler's handlerAdded method. - return socksEventHandler.socksEstablishedFuture!.flatMap { - channel.pipeline.removeHandler(socksEventHandler).flatMap { - channel.pipeline.removeHandler(socksConnectHandler) - } + return socksEventHandler.socksEstablishedFuture!.assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(socksEventHandler).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(socksConnectHandler) + }.nonisolated() }.flatMap { self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) - } + }.nonisolated() } } @@ -301,7 +326,7 @@ extension HTTPConnectionPool.ConnectionFactory { } let tlsEventHandler = TLSEventsHandler(deadline: deadline) - let sslServerHostname = self.key.connectionTarget.sslServerHostname + let sslServerHostname = self.key.serverNameIndicator let sslContextFuture = self.sslContextCache.sslContext( tlsConfiguration: tlsConfig, eventLoop: channel.eventLoop, @@ -331,14 +356,33 @@ extension HTTPConnectionPool.ConnectionFactory { } } - private func makePlainBootstrap(deadline: NIODeadline, eventLoop: EventLoop) -> NIOClientTCPBootstrapProtocol { + private func makePlainBootstrap( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, + deadline: NIODeadline, + eventLoop: EventLoop + ) -> NIOClientTCPBootstrapProtocol { #if canImport(Network) - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { - return tsBootstrap + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), + let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) + { + return + tsBootstrap + .channelOption( + NIOTSChannelOptions.waitForActivity, + value: self.clientConfiguration.networkFrameworkWaitForConnectivity + ) + .channelOption( + NIOTSChannelOptions.multipathServiceType, + value: self.clientConfiguration.enableMultipath ? .handover : .disabled + ) .connectTimeout(deadline - NIODeadline.now()) .channelInitializer { channel in do { try channel.pipeline.syncOperations.addHandler(HTTPClient.NWErrorHandler()) + try channel.pipeline.syncOperations.addHandler( + NWWaitingHandler(requester: requester, connectionID: connectionID) + ) return channel.eventLoop.makeSucceededVoidFuture() } catch { return channel.eventLoop.makeFailedFuture(error) @@ -348,32 +392,50 @@ extension HTTPConnectionPool.ConnectionFactory { #endif if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) { - return nioBootstrap + return + nioBootstrap .connectTimeout(deadline - NIODeadline.now()) + .enableMPTCP(clientConfiguration.enableMultipath) } preconditionFailure("No matching bootstrap found") } - private func makeTLSChannel(deadline: NIODeadline, eventLoop: EventLoop, logger: Logger) -> EventLoopFuture<(Channel, String?)> { + private func makeTLSChannel( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, + deadline: NIODeadline, + eventLoop: EventLoop, + logger: Logger + ) -> EventLoopFuture<(Channel, String?)> { precondition(self.key.scheme.usesTLS, "Unexpected scheme") let bootstrapFuture = self.makeTLSBootstrap( + requester: requester, + connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, logger: logger ) var channelFuture = bootstrapFuture.flatMap { bootstrap -> EventLoopFuture in - return bootstrap.connect(target: self.key.connectionTarget) + bootstrap.connect(target: self.key.connectionTarget) }.flatMap { channel -> EventLoopFuture<(Channel, String?)> in - // It is save to use `try!` here, since we are sure, that a `TLSEventsHandler` exists - // within the pipeline. It is added in `makeTLSBootstrap`. - let tlsEventHandler = try! channel.pipeline.syncOperations.handler(type: TLSEventsHandler.self) - - // The tlsEstablishedFuture is set as soon as the TLSEventsHandler is in a - // pipeline. It is created in TLSEventsHandler's handlerAdded method. - return tlsEventHandler.tlsEstablishedFuture!.flatMap { negotiated in - channel.pipeline.removeHandler(tlsEventHandler).map { (channel, negotiated) } + do { + // if the channel is closed before flatMap is executed, all ChannelHandler are removed + // and TLSEventsHandler is therefore not present either + let tlsEventHandler = try channel.pipeline.syncOperations.handler(type: TLSEventsHandler.self) + + // The tlsEstablishedFuture is set as soon as the TLSEventsHandler is in a + // pipeline. It is created in TLSEventsHandler's handlerAdded method. + return tlsEventHandler.tlsEstablishedFuture!.flatMap { negotiated in + channel.pipeline.removeHandler(tlsEventHandler).map { (channel, negotiated) } + } + } catch { + assert( + channel.isActive == false, + "if the channel is still active then TLSEventsHandler must be present but got error \(error)" + ) + return channel.eventLoop.makeFailedFuture(HTTPClientError.remoteConnectionClosed) } } @@ -387,8 +449,13 @@ extension HTTPConnectionPool.ConnectionFactory { return channelFuture } - private func makeTLSBootstrap(deadline: NIODeadline, eventLoop: EventLoop, logger: Logger) - -> EventLoopFuture { + private func makeTLSBootstrap( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, + deadline: NIODeadline, + eventLoop: EventLoop, + logger: Logger + ) -> EventLoopFuture { var tlsConfig = self.tlsConfiguration switch self.clientConfiguration.httpVersion.configuration { case .automatic: @@ -402,17 +469,33 @@ extension HTTPConnectionPool.ConnectionFactory { } #if canImport(Network) - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), + let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) + { // create NIOClientTCPBootstrap with NIOTS TLS provider - let bootstrapFuture = tlsConfig.getNWProtocolTLSOptions(on: eventLoop).map { + let bootstrapFuture = tlsConfig.getNWProtocolTLSOptions( + on: eventLoop, + serverNameIndicatorOverride: key.serverNameIndicatorOverride + ).map { options -> NIOClientTCPBootstrapProtocol in tsBootstrap + .channelOption( + NIOTSChannelOptions.waitForActivity, + value: self.clientConfiguration.networkFrameworkWaitForConnectivity + ) + .channelOption( + NIOTSChannelOptions.multipathServiceType, + value: self.clientConfiguration.enableMultipath ? .handover : .disabled + ) .connectTimeout(deadline - NIODeadline.now()) .tlsOptions(options) .channelInitializer { channel in do { try channel.pipeline.syncOperations.addHandler(HTTPClient.NWErrorHandler()) + try channel.pipeline.syncOperations.addHandler( + NWWaitingHandler(requester: requester, connectionID: connectionID) + ) // we don't need to set a TLS deadline for NIOTS connections, since the // TLS handshake is part of the TS connection bootstrap. If the TLS // handshake times out the complete connection creation will be failed. @@ -427,7 +510,6 @@ extension HTTPConnectionPool.ConnectionFactory { } #endif - let sslServerHostname = self.key.connectionTarget.sslServerHostname let sslContextFuture = sslContextCache.sslContext( tlsConfiguration: tlsConfig, eventLoop: eventLoop, @@ -436,13 +518,14 @@ extension HTTPConnectionPool.ConnectionFactory { let bootstrap = ClientBootstrap(group: eventLoop) .connectTimeout(deadline - NIODeadline.now()) + .enableMPTCP(clientConfiguration.enableMultipath) .channelInitializer { channel in sslContextFuture.flatMap { sslContext -> EventLoopFuture in do { let sync = channel.pipeline.syncOperations let sslHandler = try NIOSSLClientHandler( context: sslContext, - serverHostname: sslServerHostname + serverHostname: self.key.serverNameIndicator ) let tlsEventHandler = TLSEventsHandler(deadline: deadline) @@ -481,6 +564,12 @@ extension Scheme { } } +extension ConnectionPool.Key { + var serverNameIndicator: String? { + serverNameIndicatorOverride ?? connectionTarget.sslServerHostname + } +} + extension ConnectionTarget { fileprivate var sslServerHostname: String? { switch self { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift index 1a1760908..3fdf93752 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +import Atomics import Logging import NIOConcurrencyHelpers import NIOCore @@ -34,13 +35,15 @@ extension HTTPConnectionPool { private var state: State = .active private var _pools: [Key: HTTPConnectionPool] = [:] - private let lock = Lock() + private let lock = NIOLock() private let sslContextCache = SSLContextCache() - init(eventLoopGroup: EventLoopGroup, - configuration: HTTPClient.Configuration, - backgroundActivityLogger logger: Logger) { + init( + eventLoopGroup: EventLoopGroup, + configuration: HTTPClient.Configuration, + backgroundActivityLogger logger: Logger + ) { self.eventLoopGroup = eventLoopGroup self.configuration = configuration self.logger = logger @@ -117,7 +120,7 @@ extension HTTPConnectionPool { promise?.succeed(false) case .shutdown(let pools): - pools.values.forEach { pool in + for pool in pools.values { pool.shutdown() } } @@ -139,7 +142,9 @@ extension HTTPConnectionPool.Manager: HTTPConnectionPoolDelegate { case .shuttingDown(let promise, let soFarUnclean): guard self._pools.removeValue(forKey: pool.key) === pool else { - preconditionFailure("Expected that the pool was created by this manager and is known for this reason.") + preconditionFailure( + "Expected that the pool was created by this manager and is known for this reason." + ) } if self._pools.isEmpty { @@ -153,7 +158,7 @@ extension HTTPConnectionPool.Manager: HTTPConnectionPoolDelegate { } switch closeAction { - case .close(let promise, unclean: let unclean): + case .close(let promise, let unclean): promise?.succeed(unclean) case .wait: break @@ -162,17 +167,17 @@ extension HTTPConnectionPool.Manager: HTTPConnectionPoolDelegate { } extension HTTPConnectionPool.Connection.ID { - static var globalGenerator = Generator() + static let globalGenerator = Generator() struct Generator { - private let atomic: NIOAtomic + private let atomic: ManagedAtomic init() { - self.atomic = .makeAtomic(value: 0) + self.atomic = .init(0) } func next() -> Int { - return self.atomic.add(1) + self.atomic.loadThenWrappingIncrement(ordering: .relaxed) } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift index 764ad2093..eebe4d029 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift @@ -21,8 +21,11 @@ protocol HTTPConnectionPoolDelegate { func connectionPoolDidShutdown(_ pool: HTTPConnectionPool, unclean: Bool) } -final class HTTPConnectionPool { - private let stateLock = Lock() +final class HTTPConnectionPool: + // TODO: Refactor to use `NIOLockedValueBox` which will allow this to be checked + @unchecked Sendable +{ + private let stateLock = NIOLock() private var _state: StateMachine /// The connection idle timeout timers. Protected by the stateLock private var _idleTimer = [Connection.ID: Scheduled]() @@ -44,14 +47,16 @@ final class HTTPConnectionPool { let delegate: HTTPConnectionPoolDelegate - init(eventLoopGroup: EventLoopGroup, - sslContextCache: SSLContextCache, - tlsConfiguration: TLSConfiguration?, - clientConfiguration: HTTPClient.Configuration, - key: ConnectionPool.Key, - delegate: HTTPConnectionPoolDelegate, - idGenerator: Connection.ID.Generator, - backgroundActivityLogger logger: Logger) { + init( + eventLoopGroup: EventLoopGroup, + sslContextCache: SSLContextCache, + tlsConfiguration: TLSConfiguration?, + clientConfiguration: HTTPClient.Configuration, + key: ConnectionPool.Key, + delegate: HTTPConnectionPoolDelegate, + idGenerator: Connection.ID.Generator, + backgroundActivityLogger logger: Logger + ) { self.eventLoopGroup = eventLoopGroup self.connectionFactory = ConnectionFactory( key: key, @@ -70,7 +75,11 @@ final class HTTPConnectionPool { self._state = StateMachine( idGenerator: idGenerator, - maximumConcurrentHTTP1Connections: clientConfiguration.connectionPool.concurrentHTTP1ConnectionsPerHostSoftLimit + maximumConcurrentHTTP1Connections: clientConfiguration.connectionPool + .concurrentHTTP1ConnectionsPerHostSoftLimit, + retryConnectionEstablishment: clientConfiguration.connectionPool.retryConnectionEstablishment, + preferHTTP1: clientConfiguration.httpVersion == .http1Only, + maximumConnectionUses: clientConfiguration.maximumUsesPerConnection ) } @@ -147,9 +156,7 @@ final class HTTPConnectionPool { self.unlocked = Unlocked(connection: .none, request: .none) switch stateMachineAction.request { - case .cancelRequestTimeout(let requestID): - self.locked.request = .cancelRequestTimeout(requestID) - case .executeRequest(let request, let connection, cancelTimeout: let cancelTimeout): + case .executeRequest(let request, let connection, let cancelTimeout): if cancelTimeout { self.locked.request = .cancelRequestTimeout(request.id) } @@ -157,7 +164,7 @@ final class HTTPConnectionPool { case .executeRequestsAndCancelTimeouts(let requests, let connection): self.locked.request = .cancelRequestTimeouts(requests) self.unlocked.request = .executeRequests(requests, connection) - case .failRequest(let request, let error, cancelTimeout: let cancelTimeout): + case .failRequest(let request, let error, let cancelTimeout): if cancelTimeout { self.locked.request = .cancelRequestTimeout(request.id) } @@ -174,15 +181,15 @@ final class HTTPConnectionPool { switch stateMachineAction.connection { case .createConnection(let connectionID, on: let eventLoop): self.unlocked.connection = .createConnection(connectionID, on: eventLoop) - case .scheduleBackoffTimer(let connectionID, backoff: let backoff, on: let eventLoop): + case .scheduleBackoffTimer(let connectionID, let backoff, on: let eventLoop): self.locked.connection = .scheduleBackoffTimer(connectionID, backoff: backoff, on: eventLoop) case .scheduleTimeoutTimer(let connectionID, on: let eventLoop): self.locked.connection = .scheduleTimeoutTimer(connectionID, on: eventLoop) case .cancelTimeoutTimer(let connectionID): self.locked.connection = .cancelTimeoutTimer(connectionID) - case .closeConnection(let connection, isShutdown: let isShutdown): + case .closeConnection(let connection, let isShutdown): self.unlocked.connection = .closeConnection(connection, isShutdown: isShutdown) - case .cleanupConnections(var cleanupContext, isShutdown: let isShutdown): + case .cleanupConnections(var cleanupContext, let isShutdown): // self.locked.connection = .cancelBackoffTimers(cleanupContext.connectBackoff) cleanupContext.connectBackoff = [] @@ -220,7 +227,7 @@ final class HTTPConnectionPool { private func runLockedConnectionAction(_ action: Actions.ConnectionAction.Locked) { switch action { - case .scheduleBackoffTimer(let connectionID, backoff: let backoff, on: let eventLoop): + case .scheduleBackoffTimer(let connectionID, let backoff, on: let eventLoop): self.scheduleConnectionStartBackoffTimer(connectionID, backoff, on: eventLoop) case .scheduleTimeoutTimer(let connectionID, on: let eventLoop): @@ -248,7 +255,7 @@ final class HTTPConnectionPool { self.cancelRequestTimeout(requestID) case .cancelRequestTimeouts(let requests): - requests.forEach { self.cancelRequestTimeout($0.id) } + for request in requests { self.cancelRequestTimeout(request.id) } case .none: break @@ -265,10 +272,13 @@ final class HTTPConnectionPool { case .createConnection(let connectionID, let eventLoop): self.createConnection(connectionID, on: eventLoop) - case .closeConnection(let connection, isShutdown: let isShutdown): - self.logger.trace("close connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - ]) + case .closeConnection(let connection, let isShutdown): + self.logger.trace( + "close connection", + metadata: [ + "ahc-connection-id": "\(connection.id)" + ] + ) // we are not interested in the close promise... connection.close(promise: nil) @@ -277,7 +287,7 @@ final class HTTPConnectionPool { self.delegate.connectionPoolDidShutdown(self, unclean: unclean) } - case .cleanupConnections(let cleanupContext, isShutdown: let isShutdown): + case .cleanupConnections(let cleanupContext, let isShutdown): for connection in cleanupContext.close { connection.close(promise: nil) } @@ -314,13 +324,13 @@ final class HTTPConnectionPool { connection.executeRequest(request.req) case .executeRequests(let requests, let connection): - requests.forEach { connection.executeRequest($0.req) } + for request in requests { connection.executeRequest(request.req) } case .failRequest(let request, let error): request.req.fail(error) case .failRequests(let requests, let error): - requests.forEach { $0.req.fail(error) } + for request in requests { request.req.fail(error) } case .none: break @@ -328,9 +338,12 @@ final class HTTPConnectionPool { } private func createConnection(_ connectionID: Connection.ID, on eventLoop: EventLoop) { - self.logger.trace("Opening fresh connection", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Opening fresh connection", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) // Even though this function is called make it actually creates/establishes a connection. // TBD: Should we rename it? To what? self.connectionFactory.makeConnection( @@ -373,9 +386,12 @@ final class HTTPConnectionPool { } private func scheduleIdleTimerForConnection(_ connectionID: Connection.ID, on eventLoop: EventLoop) { - self.logger.trace("Schedule idle connection timeout timer", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Schedule idle connection timeout timer", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) let scheduled = eventLoop.scheduleTask(in: self.idleConnectionTimeout) { // there might be a race between a cancelTimer call and the triggering // of this scheduled task. both want to acquire the lock @@ -393,9 +409,12 @@ final class HTTPConnectionPool { } private func cancelIdleTimerForConnection(_ connectionID: Connection.ID) { - self.logger.trace("Cancel idle connection timeout timer", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Cancel idle connection timeout timer", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) guard let cancelTimer = self._idleTimer.removeValue(forKey: connectionID) else { preconditionFailure("Expected to have an idle timer for connection \(connectionID) at this point.") } @@ -407,9 +426,12 @@ final class HTTPConnectionPool { _ timeAmount: TimeAmount, on eventLoop: EventLoop ) { - self.logger.trace("Schedule connection creation backoff timer", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Schedule connection creation backoff timer", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) let scheduled = eventLoop.scheduleTask(in: timeAmount) { // there might be a race between a backoffTimer and the pool shutting down. @@ -438,53 +460,81 @@ final class HTTPConnectionPool { extension HTTPConnectionPool: HTTPConnectionRequester { func http1ConnectionCreated(_ connection: HTTP1Connection) { - self.logger.trace("successfully created connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/1.1", - ]) + self.logger.trace( + "successfully created connection", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/1.1", + ] + ) self.modifyStateAndRunActions { $0.newHTTP1ConnectionCreated(.http1_1(connection)) } } func http2ConnectionCreated(_ connection: HTTP2Connection, maximumStreams: Int) { - self.logger.trace("successfully created connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - "ahc-max-streams": "\(maximumStreams)", - ]) + self.logger.trace( + "successfully created connection", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/2", + "ahc-max-streams": "\(maximumStreams)", + ] + ) self.modifyStateAndRunActions { $0.newHTTP2ConnectionCreated(.http2(connection), maxConcurrentStreams: maximumStreams) } } func failedToCreateHTTPConnection(_ connectionID: HTTPConnectionPool.Connection.ID, error: Error) { - self.logger.debug("connection attempt failed", metadata: [ - "ahc-error": "\(error)", - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.debug( + "connection attempt failed", + metadata: [ + "ahc-error": "\(error)", + "ahc-connection-id": "\(connectionID)", + ] + ) self.modifyStateAndRunActions { $0.failedToCreateNewConnection(error, connectionID: connectionID) } } + + func waitingForConnectivity(_ connectionID: HTTPConnectionPool.Connection.ID, error: Error) { + self.logger.debug( + "waiting for connectivity", + metadata: [ + "ahc-error": "\(error)", + "ahc-connection-id": "\(connectionID)", + ] + ) + self.modifyStateAndRunActions { + $0.waitingForConnectivity(error, connectionID: connectionID) + } + } } extension HTTPConnectionPool: HTTP1ConnectionDelegate { func http1ConnectionClosed(_ connection: HTTP1Connection) { - self.logger.debug("connection closed", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/1.1", - ]) + self.logger.debug( + "connection closed", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/1.1", + ] + ) self.modifyStateAndRunActions { $0.http1ConnectionClosed(connection.id) } } func http1ConnectionReleased(_ connection: HTTP1Connection) { - self.logger.trace("releasing connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/1.1", - ]) + self.logger.trace( + "releasing connection", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/1.1", + ] + ) self.modifyStateAndRunActions { $0.http1ConnectionReleased(connection.id) } @@ -493,41 +543,53 @@ extension HTTPConnectionPool: HTTP1ConnectionDelegate { extension HTTPConnectionPool: HTTP2ConnectionDelegate { func http2Connection(_ connection: HTTP2Connection, newMaxStreamSetting: Int) { - self.logger.debug("new max stream setting", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - "ahc-max-streams": "\(newMaxStreamSetting)", - ]) + self.logger.debug( + "new max stream setting", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/2", + "ahc-max-streams": "\(newMaxStreamSetting)", + ] + ) self.modifyStateAndRunActions { $0.newHTTP2MaxConcurrentStreamsReceived(connection.id, newMaxStreams: newMaxStreamSetting) } } func http2ConnectionGoAwayReceived(_ connection: HTTP2Connection) { - self.logger.debug("connection go away received", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - ]) + self.logger.debug( + "connection go away received", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/2", + ] + ) self.modifyStateAndRunActions { $0.http2ConnectionGoAwayReceived(connection.id) } } func http2ConnectionClosed(_ connection: HTTP2Connection) { - self.logger.debug("connection closed", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - ]) + self.logger.debug( + "connection closed", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/2", + ] + ) self.modifyStateAndRunActions { $0.http2ConnectionClosed(connection.id) } } func http2ConnectionStreamClosed(_ connection: HTTP2Connection, availableStreams: Int) { - self.logger.trace("stream closed", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - ]) + self.logger.trace( + "stream closed", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/2", + ] + ) self.modifyStateAndRunActions { $0.http2ConnectionStreamClosed(connection.id) } @@ -631,7 +693,9 @@ extension HTTPConnectionPool { return lhsConn.id == rhsConn.id case (.http2(let lhsConn), .http2(let rhsConn)): return lhsConn.id == rhsConn.id - case (.__testOnly_connection(let lhsID, let lhsEventLoop), .__testOnly_connection(let rhsID, let rhsEventLoop)): + case ( + .__testOnly_connection(let lhsID, let lhsEventLoop), .__testOnly_connection(let rhsID, let rhsEventLoop) + ): return lhsID == rhsID && lhsEventLoop === rhsEventLoop default: return false @@ -712,7 +776,7 @@ struct EventLoopID: Hashable { } static func __testOnly_fakeID(_ id: Int) -> EventLoopID { - return EventLoopID(.__testOnly_fakeID(id)) + EventLoopID(.__testOnly_fakeID(id)) } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift index 2477e1154..e8c07e50f 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift @@ -132,7 +132,7 @@ import NIOSSL /// /// Use this handle to cancel the request, while it is waiting for a free connection, to execute the request. /// This protocol is only intended to be implemented by the `HTTPConnectionPool`. -protocol HTTPRequestScheduler { +protocol HTTPRequestScheduler: Sendable { /// Informs the task queuer that a request has been cancelled. func cancelRequest(_: HTTPSchedulableRequest) } @@ -180,12 +180,12 @@ protocol HTTPRequestExecutor { /// Writes a body part into the channel pipeline /// /// This method may be **called on any thread**. The executor needs to ensure thread safety. - func writeRequestBodyPart(_: IOData, request: HTTPExecutableRequest) + func writeRequestBodyPart(_: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) /// Signals that the request body stream has finished /// /// This method may be **called on any thread**. The executor needs to ensure thread safety. - func finishRequestBodyStream(_ task: HTTPExecutableRequest) + func finishRequestBodyStream(_ task: HTTPExecutableRequest, promise: EventLoopPromise?) /// Signals that more bytes from response body stream can be consumed. /// diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift index 90578bc87..5c5b893e0 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift @@ -104,8 +104,8 @@ extension HTTPRequestStateMachine { // forwarded to the user. case .waitingForRead, - .waitingForDemand, - .waitingForReadOrDemand: + .waitingForDemand, + .waitingForReadOrDemand: return nil case .modifying: @@ -174,8 +174,8 @@ extension HTTPRequestStateMachine { return (buffer, .none) case .waitingForReadOrDemand(let buffer), - .waitingForRead(let buffer), - .waitingForDemand(let buffer): + .waitingForRead(let buffer), + .waitingForDemand(let buffer): // Normally this code path should never be hit. However there is one way to trigger // this: // diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index fa520a865..e06389360 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -20,21 +20,24 @@ struct HTTPRequestStateMachine { fileprivate enum State { /// The initial state machine state. The only valid mutation is `start()`. The state will /// transitions to: - /// - `.waitForChannelToBecomeWritable` - /// - `.running(.streaming, .initialized)` (if the Channel is writable and if a request body is expected) - /// - `.running(.endSent, .initialized)` (if the Channel is writable and no request body is expected) + /// - `.waitForChannelToBecomeWritable` (if the channel becomes non writable while sending the header) + /// - `.sendingHead` if the channel is writable case initialized + /// Waiting for the channel to be writable. Valid transitions are: - /// - `.running(.streaming, .initialized)` (once the Channel is writable again and if a request body is expected) - /// - `.running(.endSent, .initialized)` (once the Channel is writable again and no request body is expected) + /// - `.running(.streaming, .waitingForHead)` (once the Channel is writable again and if a request body is expected) + /// - `.running(.endSent, .waitingForHead)` (once the Channel is writable again and no request body is expected) /// - `.failed` (if a connection error occurred) case waitForChannelToBecomeWritable(HTTPRequestHead, RequestFramingMetadata) + /// A request is on the wire. Valid transitions are: /// - `.finished` /// - `.failed` case running(RequestState, ResponseState) + /// The request has completed successfully case finished + /// The request has failed case failed(Error) @@ -55,7 +58,7 @@ struct HTTPRequestStateMachine { /// The request is streaming its request body. `expectedBodyLength` has a value, if the request header contained /// a `"content-length"` header field. If the request header contained a `"transfer-encoding" = "chunked"` /// header field, the `expectedBodyLength` is `nil`. - case streaming(expectedBodyLength: Int?, sentBodyBytes: Int, producer: ProducerControlState) + case streaming(expectedBodyLength: Int64?, sentBodyBytes: Int64, producer: ProducerControlState) /// The request has sent its request body and end. case endSent } @@ -70,21 +73,38 @@ struct HTTPRequestStateMachine { } enum Action { - /// A action to execute, when we consider a request "done". - enum FinalStreamAction { + /// A action to execute, when we consider a successful request "done". + enum FinalSuccessfulRequestAction { /// Close the connection case close /// If the server has replied, with a status of 200...300 before all data was sent, a request is considered succeeded, /// as soon as we wrote the request end onto the wire. - case sendRequestEnd + /// + /// The promise is an optional write promise. + case sendRequestEnd(EventLoopPromise?) + /// Do nothing. This is action is used, if the request failed, before we the request head was written onto the wire. + /// This might happen if the request is cancelled, or the request failed the soundness check. + case none + } + + /// A action to execute, when we consider a failed request "done". + enum FinalFailedRequestAction { + /// Close the connection + case close(EventLoopPromise?) /// Do nothing. This is action is used, if the request failed, before we the request head was written onto the wire. /// This might happen if the request is cancelled, or the request failed the soundness check. case none } - case sendRequestHead(HTTPRequestHead, startBody: Bool) - case sendBodyPart(IOData) - case sendRequestEnd + case sendRequestHead(HTTPRequestHead, sendEnd: Bool) + case notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: Bool, + startIdleTimer: Bool + ) + case sendBodyPart(IOData, EventLoopPromise?) + case sendRequestEnd(EventLoopPromise?) + case failSendBodyPart(Error, EventLoopPromise?) + case failSendStreamFinished(Error, EventLoopPromise?) case pauseRequestBodyStream case resumeRequestBodyStream @@ -92,8 +112,8 @@ struct HTTPRequestStateMachine { case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) case forwardResponseBodyParts(CircularBuffer) - case failRequest(Error, FinalStreamAction) - case succeedRequest(FinalStreamAction, CircularBuffer) + case failRequest(Error, FinalFailedRequestAction) + case succeedRequest(FinalSuccessfulRequestAction, CircularBuffer) case read case wait @@ -141,10 +161,10 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, - .running(.streaming(_, _, producer: .producing), _), - .running(.endSent, _), - .finished, - .failed: + .running(.streaming(_, _, producer: .producing), _), + .running(.endSent, _), + .finished, + .failed: return .wait case .waitForChannelToBecomeWritable(let head, let metadata): @@ -176,11 +196,11 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.streaming(_, _, producer: .paused), _), - .running(.endSent, _), - .finished, - .failed: + .waitForChannelToBecomeWritable, + .running(.streaming(_, _, producer: .paused), _), + .running(.endSent, _), + .finished, + .failed: return .wait case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .producing), let responseState): @@ -199,20 +219,24 @@ struct HTTPRequestStateMachine { mutating func errorHappened(_ error: Error) -> Action { if let error = error as? NIOSSLError, - error == .uncleanShutdown, - let action = self.handleNIOSSLUncleanShutdownError() { + error == .uncleanShutdown, + let action = self.handleNIOSSLUncleanShutdownError() + { return action } switch self.state { case .initialized: - preconditionFailure("After the state machine has been initialized, start must be called immediately. Thus this state is unreachable") + preconditionFailure( + "After the state machine has been initialized, start must be called immediately. Thus this state is unreachable" + ) case .waitForChannelToBecomeWritable: // the request failed, before it was sent onto the wire. self.state = .failed(error) return .failRequest(error, .none) + case .running: self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) case .finished, .failed: // ignore error @@ -226,14 +250,14 @@ struct HTTPRequestStateMachine { private mutating func handleNIOSSLUncleanShutdownError() -> Action? { switch self.state { case .running(.streaming, .waitingForHead), - .running(.endSent, .waitingForHead): + .running(.endSent, .waitingForHead): // if we received a NIOSSL.uncleanShutdown before we got an answer we should handle // this like a normal connection close. We will receive a call to channelInactive after // this error. return .wait case .running(.streaming, .receivingBody(let responseHead, _)), - .running(.endSent, .receivingBody(let responseHead, _)): + .running(.endSent, .receivingBody(let responseHead, _)): // This code is only reachable for request and responses, which we expect to have a body. // We depend on logic from the HTTPResponseDecoder here. The decoder will emit an // HTTPResponsePart.end right after the HTTPResponsePart.head, for every request with a @@ -242,7 +266,9 @@ struct HTTPRequestStateMachine { // For this reason we only need to check the "content-length" or "transfer-encoding" // headers here to determine if we are potentially in an EOF terminated response. - if responseHead.headers.contains(name: "content-length") || responseHead.headers.contains(name: "transfer-encoding") { + if responseHead.headers.contains(name: "content-length") + || responseHead.headers.contains(name: "transfer-encoding") + { // If we have already received the response head, the parser will ensure that we // receive a complete response, if the content-length or transfer-encoding header // was set. In this case we can ignore the NIOSSLError.uncleanShutdown. We will see @@ -254,19 +280,21 @@ struct HTTPRequestStateMachine { // we have received all necessary bytes. For this reason we forward the uncleanShutdown // error to the user. self.state = .failed(NIOSSLError.uncleanShutdown) - return .failRequest(NIOSSLError.uncleanShutdown, .close) + return .failRequest(NIOSSLError.uncleanShutdown, .close(nil)) case .waitForChannelToBecomeWritable, .running, .finished, .failed, .initialized, .modifying: return nil } } - mutating func requestStreamPartReceived(_ part: IOData) -> Action { + mutating func requestStreamPartReceived(_ part: IOData, promise: EventLoopPromise?) -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.endSent, _): - preconditionFailure("We must be in the request streaming phase, if we receive further body parts. Invalid state: \(self.state)") + .waitForChannelToBecomeWritable, + .running(.endSent, _): + preconditionFailure( + "We must be in the request streaming phase, if we receive further body parts. Invalid state: \(self.state)" + ) case .running(.streaming(_, _, let producerState), .receivingBody(let head, _)) where head.status.code >= 300: // If we have already received a response head with status >= 300, we won't send out any @@ -274,7 +302,7 @@ struct HTTPRequestStateMachine { // won't be interested. We expect that the producer has been informed to pause // producing. assert(producerState == .paused) - return .wait + return .failSendBodyPart(HTTPClientError.requestStreamCancelled, promise) case .running(.streaming(let expectedBodyLength, var sentBodyBytes, let producerState), let responseState): // We don't check the producer state here: @@ -287,13 +315,13 @@ struct HTTPRequestStateMachine { // pause. The reason for this is as follows: There might be thread synchronization // situations in which the producer might not have received the plea to pause yet. - if let expected = expectedBodyLength, sentBodyBytes + part.readableBytes > expected { + if let expected = expectedBodyLength, sentBodyBytes + Int64(part.readableBytes) > expected { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } - sentBodyBytes += part.readableBytes + sentBodyBytes += Int64(part.readableBytes) let requestState: RequestState = .streaming( expectedBodyLength: expectedBodyLength, @@ -303,10 +331,10 @@ struct HTTPRequestStateMachine { self.state = .running(requestState, responseState) - return .sendBodyPart(part) + return .sendBodyPart(part, promise) - case .failed: - return .wait + case .failed(let error): + return .failSendBodyPart(error, promise) case .finished: // A request may be finished, before we have send all parts. This might be the case if @@ -318,54 +346,59 @@ struct HTTPRequestStateMachine { // We may still receive something, here because of potential race conditions with the // producing thread. - return .wait + return .failSendBodyPart(HTTPClientError.requestStreamCancelled, promise) case .modifying: preconditionFailure("Invalid state: \(self.state)") } } - mutating func requestStreamFinished() -> Action { + mutating func requestStreamFinished(promise: EventLoopPromise?) -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.endSent, _): - preconditionFailure("A request body stream end is only expected if we are in state request streaming. Invalid state: \(self.state)") + .waitForChannelToBecomeWritable, + .running(.endSent, _): + preconditionFailure( + "A request body stream end is only expected if we are in state request streaming. Invalid state: \(self.state)" + ) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .waitingForHead): if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } self.state = .running(.endSent, .waitingForHead) - return .sendRequestEnd + return .sendRequestEnd(promise) - case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .receivingBody(let head, let streamState)): + case .running( + .streaming(let expectedBodyLength, let sentBodyBytes, _), + .receivingBody(let head, let streamState) + ): assert(head.status.code < 300) if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } self.state = .running(.endSent, .receivingBody(head, streamState)) - return .sendRequestEnd + return .sendRequestEnd(promise) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .endReceived): if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } self.state = .finished - return .succeedRequest(.sendRequestEnd, .init()) + return .succeedRequest(.sendRequestEnd(promise), .init()) - case .failed: - return .wait + case .failed(let error): + return .failSendStreamFinished(error, promise) case .finished: // A request may be finished, before we have send all parts. This might be the case if @@ -377,7 +410,7 @@ struct HTTPRequestStateMachine { // We may still receive something, here because of potential race conditions with the // producing thread. - return .wait + return .failSendStreamFinished(HTTPClientError.requestStreamCancelled, promise) case .modifying: preconditionFailure("Invalid state: \(self.state)") @@ -398,7 +431,7 @@ struct HTTPRequestStateMachine { case .running: let error = HTTPClientError.cancelled self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) case .finished: return .wait @@ -435,11 +468,11 @@ struct HTTPRequestStateMachine { mutating func read() -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(_, .waitingForHead), - .running(_, .endReceived), - .finished, - .failed: + .waitForChannelToBecomeWritable, + .running(_, .waitingForHead), + .running(_, .endReceived), + .finished, + .failed: // If we are not in the middle of streaming the response body, we always want to get // more data... return .read @@ -472,11 +505,11 @@ struct HTTPRequestStateMachine { mutating func channelReadComplete() -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(_, .waitingForHead), - .running(_, .endReceived), - .finished, - .failed: + .waitForChannelToBecomeWritable, + .running(_, .waitingForHead), + .running(_, .endReceived), + .finished, + .failed: return .wait case .running(let requestState, .receivingBody(let head, var streamState)): @@ -507,7 +540,9 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response head before sending a request head ourselves") + preconditionFailure( + "How can we receive a response head before sending a request head ourselves \(self.state)" + ) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), .waitingForHead): self.state = .running( @@ -525,7 +560,11 @@ struct HTTPRequestStateMachine { return .forwardResponseHead(head, pauseRequestBodyStream: true) } else { self.state = .running( - .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: .producing), + .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: .producing + ), .receivingBody(head, .init()) ) return .forwardResponseHead(head, pauseRequestBodyStream: false) @@ -536,7 +575,9 @@ struct HTTPRequestStateMachine { return .forwardResponseHead(head, pauseRequestBodyStream: false) case .running(_, .receivingBody), .running(_, .endReceived), .finished: - preconditionFailure("How can we successfully finish the request, before having received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we successfully finish the request, before having received a head. Invalid state: \(self.state)" + ) case .failed: return .wait @@ -548,10 +589,14 @@ struct HTTPRequestStateMachine { mutating func receivedHTTPResponseBodyPart(_ body: ByteBuffer) -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response head before completely sending a request head ourselves. Invalid state: \(self.state)" + ) case .running(_, .waitingForHead): - preconditionFailure("How can we receive a response body, if we haven't received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response body, if we haven't received a head. Invalid state: \(self.state)" + ) case .running(let requestState, .receivingBody(let head, var responseStreamState)): return self.avoidingStateMachineCoW { state -> Action in @@ -561,7 +606,9 @@ struct HTTPRequestStateMachine { } case .running(_, .endReceived), .finished: - preconditionFailure("How can we successfully finish the request, before having received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we successfully finish the request, before having received a head. Invalid state: \(self.state)" + ) case .failed: return .wait @@ -574,20 +621,31 @@ struct HTTPRequestStateMachine { private mutating func receivedHTTPResponseEnd() -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response end before completely sending a request head ourselves. Invalid state: \(self.state)" + ) case .running(_, .waitingForHead): - preconditionFailure("How can we receive a response end, if we haven't a received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response end, if we haven't a received a head. Invalid state: \(self.state)" + ) - case .running(.streaming(let expectedBodyLength, let sentBodyBytes, let producerState), .receivingBody(let head, var responseStreamState)) - where head.status.code < 300: + case .running( + .streaming(let expectedBodyLength, let sentBodyBytes, let producerState), + .receivingBody(let head, var responseStreamState) + ) + where head.status.code < 300: return self.avoidingStateMachineCoW { state -> Action in let (remainingBuffer, connectionAction) = responseStreamState.end() switch connectionAction { case .none: state = .running( - .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: producerState), + .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: producerState + ), .endReceived ) return .forwardResponseBodyParts(remainingBuffer) @@ -597,13 +655,16 @@ struct HTTPRequestStateMachine { // the request is still uploading, we will not be able to finish the upload. For // this reason we can fail the request here. state = .failed(HTTPClientError.remoteConnectionClosed) - return .failRequest(HTTPClientError.remoteConnectionClosed, .close) + return .failRequest(HTTPClientError.remoteConnectionClosed, .close(nil)) } } case .running(.streaming(_, _, let producerState), .receivingBody(let head, var responseStreamState)): assert(head.status.code >= 300) - assert(producerState == .paused, "Expected to have paused the request body stream, when the head was received. Invalid state: \(self.state)") + assert( + producerState == .paused, + "Expected to have paused the request body stream, when the head was received. Invalid state: \(self.state)" + ) return self.avoidingStateMachineCoW { state -> Action in // We can ignore the connectionAction from the responseStreamState, since the @@ -626,7 +687,9 @@ struct HTTPRequestStateMachine { } case .running(_, .endReceived), .finished: - preconditionFailure("How can we receive a response end, if another one was already received. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response end, if another one was already received. Invalid state: \(self.state)" + ) case .failed: return .wait @@ -639,9 +702,11 @@ struct HTTPRequestStateMachine { mutating func demandMoreResponseBodyParts() -> Action { switch self.state { case .initialized, - .running(_, .waitingForHead), - .waitForChannelToBecomeWritable: - preconditionFailure("The response is expected to only ask for more data after the response head was forwarded") + .running(_, .waitingForHead), + .waitForChannelToBecomeWritable: + preconditionFailure( + "The response is expected to only ask for more data after the response head was forwarded \(self.state)" + ) case .running(let requestState, .receivingBody(let head, var responseStreamState)): return self.avoidingStateMachineCoW { state -> Action in @@ -651,8 +716,8 @@ struct HTTPRequestStateMachine { } case .running(_, .endReceived), - .finished, - .failed: + .finished, + .failed: return .wait case .modifying: @@ -663,14 +728,16 @@ struct HTTPRequestStateMachine { mutating func idleReadTimeoutTriggered() -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.streaming, _): - preconditionFailure("We only schedule idle read timeouts after we have sent the complete request. Invalid state: \(self.state)") + .waitForChannelToBecomeWritable, + .running(.streaming, _): + preconditionFailure( + "We only schedule idle read timeouts after we have sent the complete request. Invalid state: \(self.state)" + ) case .running(.endSent, .waitingForHead), .running(.endSent, .receivingBody): let error = HTTPClientError.readTimeout self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) case .running(.endSent, .endReceived): preconditionFailure("Invalid state. This state should be: .finished") @@ -683,19 +750,84 @@ struct HTTPRequestStateMachine { } } + mutating func idleWriteTimeoutTriggered() -> Action { + switch self.state { + case .initialized, + .waitForChannelToBecomeWritable: + preconditionFailure( + "We only schedule idle write timeouts while the request is being sent. Invalid state: \(self.state)" + ) + + case .running(.streaming, _): + let error = HTTPClientError.writeTimeout + self.state = .failed(error) + return .failRequest(error, .close(nil)) + + case .running(.endSent, _): + preconditionFailure("Invalid state. This state should be: .finished") + + case .finished, .failed: + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + private mutating func startSendingRequest(head: HTTPRequestHead, metadata: RequestFramingMetadata) -> Action { - switch metadata.body { - case .stream: - self.state = .running(.streaming(expectedBodyLength: nil, sentBodyBytes: 0, producer: .producing), .waitingForHead) - return .sendRequestHead(head, startBody: true) - case .fixedSize(0): + let length = metadata.body.expectedLength + if length == 0 { // no body self.state = .running(.endSent, .waitingForHead) - return .sendRequestHead(head, startBody: false) - case .fixedSize(let length): - // length is greater than zero and we therefore have a body to send - self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .producing), .waitingForHead) - return .sendRequestHead(head, startBody: true) + return .sendRequestHead(head, sendEnd: true) + } else { + self.state = .running( + .streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .paused), + .waitingForHead + ) + return .sendRequestHead(head, sendEnd: false) + } + } + + mutating func headSent() -> Action { + switch self.state { + case .initialized, .waitForChannelToBecomeWritable, .finished: + preconditionFailure("Not a valid transition after `.sendingHeader`: \(self.state)") + + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), let responseState): + let startProducing = self.isChannelWritable && expectedBodyLength != sentBodyBytes + self.state = .running( + .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: startProducing ? .producing : .paused + ), + responseState + ) + return .notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: startProducing, + startIdleTimer: false + ) + case .running(.endSent, _): + return .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + case .running(.streaming(_, _, producer: .producing), _): + preconditionFailure( + "request body producing can not start before we have successfully send the header \(self.state)" + ) + case .failed: + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } +} + +extension RequestFramingMetadata.Body { + var expectedLength: Int64? { + switch self { + case .fixedSize(let length): return length + case .stream: return nil } } } @@ -754,7 +886,8 @@ extension HTTPRequestStateMachine: CustomStringConvertible { case .waitForChannelToBecomeWritable: return "HTTPRequestStateMachine(.waitForChannelToBecomeWritable, isWritable: \(self.isChannelWritable))" case .running(let requestState, let responseState): - return "HTTPRequestStateMachine(.running(request: \(requestState), response: \(responseState)), isWritable: \(self.isChannelWritable))" + return + "HTTPRequestStateMachine(.running(request: \(requestState), response: \(responseState)), isWritable: \(self.isChannelWritable))" case .finished: return "HTTPRequestStateMachine(.finished, isWritable: \(self.isChannelWritable))" case .failed(let error): @@ -768,7 +901,7 @@ extension HTTPRequestStateMachine: CustomStringConvertible { extension HTTPRequestStateMachine.RequestState: CustomStringConvertible { var description: String { switch self { - case .streaming(expectedBodyLength: let expected, let sent, producer: let producer): + case .streaming(expectedBodyLength: let expected, let sent, let producer): return ".streaming(sent: \(expected != nil ? String(expected!) : "-"), sent: \(sent), producer: \(producer)" case .endSent: return ".endSent" diff --git a/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift b/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift index 38d90e057..58ba694a7 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift @@ -12,11 +12,13 @@ // //===----------------------------------------------------------------------===// +import NIOCore + /// - Note: use `HTTPClientRequest.Body.Length` if you want to expose `RequestBodyLength` publicly @usableFromInline -internal enum RequestBodyLength: Hashable { +internal enum RequestBodyLength: Hashable, Sendable { /// size of the request body is not known before starting the request case unknown /// size of the request body is fixed and exactly `count` bytes - case known(_ count: Int) + case known(_ count: Int64) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift b/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift index 98080e364..033060a99 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift @@ -15,7 +15,7 @@ struct RequestFramingMetadata: Hashable { enum Body: Hashable { case stream - case fixedSize(Int) + case fixedSize(Int64) } var connectionClose: Bool diff --git a/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift b/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift index 2092498d8..903f962e5 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift @@ -17,16 +17,28 @@ import NIOCore struct RequestOptions { /// The maximal `TimeAmount` that is allowed to pass between `channelRead`s from the Channel. var idleReadTimeout: TimeAmount? + /// The maximal `TimeAmount` that is allowed to pass between `write`s into the Channel. + var idleWriteTimeout: TimeAmount? + /// DNS overrides. + var dnsOverride: [String: String] - init(idleReadTimeout: TimeAmount?) { + init( + idleReadTimeout: TimeAmount?, + idleWriteTimeout: TimeAmount?, + dnsOverride: [String: String] + ) { self.idleReadTimeout = idleReadTimeout + self.idleWriteTimeout = idleWriteTimeout + self.dnsOverride = dnsOverride } } extension RequestOptions { static func fromClientConfiguration(_ configuration: HTTPClient.Configuration) -> Self { RequestOptions( - idleReadTimeout: configuration.timeout.read + idleReadTimeout: configuration.timeout.read, + idleWriteTimeout: configuration.timeout.write, + dnsOverride: configuration.dnsOverride ) } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift index 4aec9f6fe..71d8f15f1 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift @@ -13,8 +13,13 @@ //===----------------------------------------------------------------------===// import NIOCore + #if canImport(Darwin) import func Darwin.pow +#elseif canImport(Musl) +import func Musl.pow +#elseif canImport(Android) +import func Android.pow #else import func Glibc.pow #endif diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift index cdbf02394..15138a141 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift @@ -19,15 +19,15 @@ extension HTTPConnectionPool { private struct HTTP1ConnectionState { enum State { /// the connection is creating a connection. Valid transitions are to: .backingOff, .idle, and .closed - case starting + case starting(maximumUses: Int?) /// the connection is waiting to retry the establishing a connection. Valid transitions are to: .closed. /// This means, the connection can be removed from the connections without cancelling external /// state. The connection state can then be replaced by a new one. case backingOff /// the connection is idle for a new request. Valid transitions to: .leased and .closed - case idle(Connection, since: NIODeadline) + case idle(Connection, since: NIODeadline, remainingUses: Int?) /// the connection is leased and running for a request. Valid transitions to: .idle and .closed - case leased(Connection) + case leased(Connection, remainingUses: Int?) /// the connection is closed. final state. case closed } @@ -36,10 +36,10 @@ extension HTTPConnectionPool { let connectionID: Connection.ID let eventLoop: EventLoop - init(connectionID: Connection.ID, eventLoop: EventLoop) { + init(connectionID: Connection.ID, eventLoop: EventLoop, maximumUses: Int?) { self.connectionID = connectionID self.eventLoop = eventLoop - self.state = .starting + self.state = .starting(maximumUses: maximumUses) } var isConnecting: Bool { @@ -69,6 +69,19 @@ extension HTTPConnectionPool { } } + var idleAndNoRemainingUses: Bool { + switch self.state { + case .idle(_, since: _, let remainingUses): + if let remainingUses = remainingUses { + return remainingUses <= 0 + } else { + return false + } + case .backingOff, .starting, .leased, .closed: + return false + } + } + var canOrWillBeAbleToExecuteRequests: Bool { switch self.state { case .leased, .backingOff, .idle, .starting: @@ -89,7 +102,7 @@ extension HTTPConnectionPool { var idleSince: NIODeadline? { switch self.state { - case .idle(_, since: let idleSince): + case .idle(_, since: let idleSince, _): return idleSince case .backingOff, .starting, .leased, .closed: return nil @@ -107,8 +120,8 @@ extension HTTPConnectionPool { mutating func connected(_ connection: Connection) { switch self.state { - case .starting: - self.state = .idle(connection, since: .now()) + case .starting(maximumUses: let maxUses): + self.state = .idle(connection, since: .now(), remainingUses: maxUses) case .backingOff, .idle, .leased, .closed: preconditionFailure("Invalid state: \(self.state)") } @@ -126,8 +139,8 @@ extension HTTPConnectionPool { mutating func lease() -> Connection { switch self.state { - case .idle(let connection, since: _): - self.state = .leased(connection) + case .idle(let connection, since: _, let remainingUses): + self.state = .leased(connection, remainingUses: remainingUses.map { $0 - 1 }) return connection case .backingOff, .starting, .leased, .closed: preconditionFailure("Invalid state: \(self.state)") @@ -136,8 +149,8 @@ extension HTTPConnectionPool { mutating func release() { switch self.state { - case .leased(let connection): - self.state = .idle(connection, since: .now()) + case .leased(let connection, let remainingUses): + self.state = .idle(connection, since: .now(), remainingUses: remainingUses) case .backingOff, .starting, .idle, .closed: preconditionFailure("Invalid state: \(self.state)") } @@ -145,7 +158,7 @@ extension HTTPConnectionPool { mutating func close() -> Connection { switch self.state { - case .idle(let connection, since: _): + case .idle(let connection, since: _, remainingUses: _): self.state = .closed return connection case .backingOff, .starting, .leased, .closed: @@ -188,14 +201,16 @@ extension HTTPConnectionPool { return .removeConnection case .starting: return .keepConnection - case .idle(let connection, since: _): + case .idle(let connection, since: _, remainingUses: _): context.close.append(connection) return .removeConnection - case .leased(let connection): + case .leased(let connection, remainingUses: _): context.cancel.append(connection) return .keepConnection case .closed: - preconditionFailure("Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)") + preconditionFailure( + "Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)" + ) } } @@ -212,14 +227,16 @@ extension HTTPConnectionPool { case .backingOff: context.backingOff.append((self.connectionID, self.eventLoop)) return .removeConnection - case .idle(let connection, since: _): + case .idle(let connection, since: _, remainingUses: _): // Idle connections can be removed right away context.close.append(connection) return .removeConnection case .leased: return .keepConnection case .closed: - preconditionFailure("Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)") + preconditionFailure( + "Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)" + ) } } } @@ -243,13 +260,16 @@ extension HTTPConnectionPool { /// The index after which you will find the connections for requests with `EventLoop` /// requirements in `connections`. private var overflowIndex: Array.Index + /// The number of times each connection can be used before it is closed and replaced. + private let maximumConnectionUses: Int? - init(maximumConcurrentConnections: Int, generator: Connection.ID.Generator) { + init(maximumConcurrentConnections: Int, generator: Connection.ID.Generator, maximumConnectionUses: Int?) { self.connections = [] - self.connections.reserveCapacity(maximumConcurrentConnections) + self.connections.reserveCapacity(min(maximumConcurrentConnections, 1024)) self.overflowIndex = self.connections.endIndex self.maximumConcurrentConnections = maximumConcurrentConnections self.generator = generator + self.maximumConnectionUses = maximumConnectionUses } var stats: Stats { @@ -291,7 +311,7 @@ extension HTTPConnectionPool { } private var maximumAdditionalGeneralPurposeConnections: Int { - self.maximumConcurrentConnections - (self.overflowIndex - 1) + self.maximumConcurrentConnections - (self.overflowIndex) } /// Is there at least one connection that is able to run requests @@ -300,7 +320,7 @@ extension HTTPConnectionPool { } func startingEventLoopConnections(on eventLoop: EventLoop) -> Int { - return self.connections[self.overflowIndex.. Connection.ID { precondition(self.canGrow) - let connection = HTTP1ConnectionState(connectionID: self.generator.next(), eventLoop: eventLoop) + let connection = HTTP1ConnectionState( + connectionID: self.generator.next(), + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) self.connections.insert(connection, at: self.overflowIndex) self.overflowIndex = self.connections.index(after: self.overflowIndex) return connection.connectionID } mutating func createNewOverflowConnection(on eventLoop: EventLoop) -> Connection.ID { - let connection = HTTP1ConnectionState(connectionID: self.generator.next(), eventLoop: eventLoop) + let connection = HTTP1ConnectionState( + connectionID: self.generator.next(), + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) self.connections.append(connection) return connection.connectionID } @@ -369,7 +399,10 @@ extension HTTPConnectionPool { guard let index = self.connections.firstIndex(where: { $0.connectionID == connection.id }) else { preconditionFailure("There is a new connection that we didn't request!") } - precondition(connection.eventLoop === self.connections[index].eventLoop, "Expected the new connection to be on EL") + precondition( + connection.eventLoop === self.connections[index].eventLoop, + "Expected the new connection to be on EL" + ) self.connections[index].connected(connection) let context = self.generateIdleConnectionContextForConnection(at: index) return (index, context) @@ -484,7 +517,8 @@ extension HTTPConnectionPool { precondition(self.connections[index].isClosed) let newConnection = HTTP1ConnectionState( connectionID: self.generator.next(), - eventLoop: self.connections[index].eventLoop + eventLoop: self.connections[index].eventLoop, + maximumUses: self.maximumConnectionUses ) self.connections[index] = newConnection @@ -562,7 +596,12 @@ extension HTTPConnectionPool { backingOff: [(Connection.ID, EventLoop)] ) { for (connectionID, eventLoop) in starting { - let newConnection = HTTP1ConnectionState(connectionID: connectionID, eventLoop: eventLoop) + let newConnection = HTTP1ConnectionState( + connectionID: connectionID, + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) + self.connections.insert(newConnection, at: self.overflowIndex) /// If we can grow, we mark the connection as a general purpose connection. /// Otherwise, it will be an overflow connection which is only used once for requests with a required event loop @@ -572,9 +611,14 @@ extension HTTPConnectionPool { } for (connectionID, eventLoop) in backingOff { - var backingOffConnection = HTTP1ConnectionState(connectionID: connectionID, eventLoop: eventLoop) + var backingOffConnection = HTTP1ConnectionState( + connectionID: connectionID, + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) // TODO: Maybe we want to add a static init for backing off connections to HTTP1ConnectionState backingOffConnection.failedToConnect() + self.connections.insert(backingOffConnection, at: self.overflowIndex) /// If we can grow, we mark the connection as a general purpose connection. /// Otherwise, it will be an overflow connection which is only used once for requests with a required event loop @@ -602,21 +646,23 @@ extension HTTPConnectionPool { ) -> [(Connection.ID, EventLoop)] { // create new connections for requests with a required event loop - // we may already start connections for those requests and do not want to start to many + // we may already start connections for those requests and do not want to start too many let startingRequiredEventLoopConnectionCount = Dictionary( self.connections[self.overflowIndex.. [(Connection.ID, EventLoop)] in // We need a connection for each queued request with a required event loop. // Therefore, we look how many request we have queued for a given `eventLoop` and // how many connections we are already starting on the given `eventLoop`. // If we have not enough, we will create additional connections to have at least // on connection per request. - let connectionsToStart = requestCount - startingRequiredEventLoopConnectionCount[eventLoop.id, default: 0] + let connectionsToStart = + requestCount - startingRequiredEventLoopConnectionCount[eventLoop.id, default: 0] return stride(from: 0, to: connectionsToStart, by: 1).lazy.map { _ in (self.createNewOverflowConnection(on: eventLoop), eventLoop) } @@ -631,7 +677,8 @@ extension HTTPConnectionPool { // event loop we will continue with the event loop with the second most queued requests // and so on and so forth. The `generalPurposeRequestCountGroupedByPreferredEventLoop` // array is already ordered so we can just iterate over it without sorting by request count. - let newGeneralPurposeConnections: [(Connection.ID, EventLoop)] = generalPurposeRequestCountGroupedByPreferredEventLoop + let newGeneralPurposeConnections: [(Connection.ID, EventLoop)] = + generalPurposeRequestCountGroupedByPreferredEventLoop // we do not want to allocated intermediate arrays. .lazy // we flatten the grouped list of event loops by lazily repeating the event loop @@ -690,7 +737,8 @@ extension HTTPConnectionPool { } else { use = .eventLoop(eventLoop) } - return IdleConnectionContext(eventLoop: eventLoop, use: use) + let hasNoRemainingUses = self.connections[index].idleAndNoRemainingUses + return IdleConnectionContext(eventLoop: eventLoop, use: use, shouldBeClosed: hasNoRemainingUses) } private func findIdleConnection(onPreferred preferredEL: EventLoop) -> Int? { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift index 6b3f7352e..09b1dc85e 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift @@ -17,6 +17,7 @@ import NIOCore extension HTTPConnectionPool { struct HTTP1StateMachine { typealias Action = HTTPConnectionPool.StateMachine.Action + typealias RequestAction = HTTPConnectionPool.StateMachine.RequestAction typealias ConnectionMigrationAction = HTTPConnectionPool.StateMachine.ConnectionMigrationAction typealias EstablishedAction = HTTPConnectionPool.StateMachine.EstablishedAction typealias EstablishedConnectionAction = HTTPConnectionPool.StateMachine.EstablishedConnectionAction @@ -29,16 +30,23 @@ extension HTTPConnectionPool { private(set) var requests: RequestQueue private(set) var lifecycleState: StateMachine.LifecycleState + /// The property was introduced to fail fast during testing. + /// Otherwise this should always be true and not turned off. + private let retryConnectionEstablishment: Bool init( idGenerator: Connection.ID.Generator, maximumConcurrentConnections: Int, + retryConnectionEstablishment: Bool, + maximumConnectionUses: Int?, lifecycleState: StateMachine.LifecycleState ) { self.connections = HTTP1Connections( maximumConcurrentConnections: maximumConcurrentConnections, - generator: idGenerator + generator: idGenerator, + maximumConnectionUses: maximumConnectionUses ) + self.retryConnectionEstablishment = retryConnectionEstablishment self.requests = RequestQueue() self.lifecycleState = lifecycleState @@ -72,7 +80,10 @@ extension HTTPConnectionPool { requests: RequestQueue ) -> ConnectionMigrationAction { precondition(self.connections.isEmpty, "expected an empty state machine but connections are not empty") - precondition(self.http2Connections == nil, "expected an empty state machine but http2Connections are not nil") + precondition( + self.http2Connections == nil, + "expected an empty state machine but http2Connections are not nil" + ) precondition(self.requests.isEmpty, "expected an empty state machine but requests are not empty") self.requests = requests @@ -92,7 +103,8 @@ extension HTTPConnectionPool { let createConnections = self.connections.createConnectionsAfterMigrationIfNeeded( requiredEventLoopOfPendingRequests: requests.requestCountGroupedByRequiredEventLoop(), - generalPurposeRequestCountGroupedByPreferredEventLoop: requests.generalPurposeRequestCountGroupedByPreferredEventLoop() + generalPurposeRequestCountGroupedByPreferredEventLoop: + requests.generalPurposeRequestCountGroupedByPreferredEventLoop() ) if !http2Connections.isEmpty { @@ -219,6 +231,19 @@ extension HTTPConnectionPool { switch self.lifecycleState { case .running: + guard self.retryConnectionEstablishment else { + guard let (index, _) = self.connections.failConnection(connectionID) else { + preconditionFailure( + "A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost." + ) + } + self.connections.removeConnection(at: index) + + return .init( + request: self.failAllRequests(reason: error), + connection: .none + ) + } // We don't care how many waiting requests we have at this point, we will schedule a // retry. More tasks, may appear until the backoff has completed. The final // decision about the retry will be made in `connectionCreationBackoffDone(_:)` @@ -241,6 +266,12 @@ extension HTTPConnectionPool { } } + mutating func waitingForConnectivity(_ error: Error, connectionID: Connection.ID) -> Action { + self.lastConnectFailure = error + + return .init(request: .none, connection: .none) + } + mutating func connectionCreationBackoffDone(_ connectionID: Connection.ID) -> Action { switch self.lifecycleState { case .running: @@ -270,7 +301,10 @@ extension HTTPConnectionPool { return .none } - precondition(self.lifecycleState == .running, "If we are shutting down, we must not have any idle connections") + precondition( + self.lifecycleState == .running, + "If we are shutting down, we must not have any idle connections" + ) return .init( request: .none, @@ -317,9 +351,11 @@ extension HTTPConnectionPool { mutating func cancelRequest(_ requestID: Request.ID) -> Action { // 1. check requests in queue - if self.requests.remove(requestID) != nil { + if let request = self.requests.remove(requestID) { + // Use the last connection error to let the user know why the request was never scheduled + let error = self.lastConnectFailure ?? HTTPClientError.cancelled return .init( - request: .cancelRequestTimeout(requestID), + request: .failRequest(request, error, cancelTimeout: true), connection: .none ) } @@ -372,11 +408,20 @@ extension HTTPConnectionPool { ) -> EstablishedAction { switch self.lifecycleState { case .running: - switch context.use { - case .generalPurpose: - return self.nextActionForIdleGeneralPurposeConnection(at: index, context: context) - case .eventLoop: - return self.nextActionForIdleEventLoopConnection(at: index, context: context) + // Close the connection if it's expired. + if context.shouldBeClosed { + let connection = self.connections.closeConnection(at: index) + return .init( + request: .none, + connection: .closeConnection(connection, isShutdown: .no) + ) + } else { + switch context.use { + case .generalPurpose: + return self.nextActionForIdleGeneralPurposeConnection(at: index, context: context) + case .eventLoop: + return self.nextActionForIdleEventLoopConnection(at: index, context: context) + } } case .shuttingDown(let unclean): assert(self.requests.isEmpty) @@ -515,9 +560,18 @@ extension HTTPConnectionPool { return .none } + private mutating func failAllRequests(reason error: Error) -> RequestAction { + let allRequests = self.requests.removeAll() + guard !allRequests.isEmpty else { + return .none + } + return .failRequestsAndCancelTimeouts(allRequests, error) + } + // MARK: HTTP2 - mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action { + mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action + { // The `http2Connections` are optional here: // Connections report events back to us, if they are in a shutdown that was // initiated by the state machine. For this reason this callback might be invoked @@ -619,6 +673,7 @@ extension HTTPConnectionPool.HTTP1StateMachine: CustomStringConvertible { let stats = self.connections.stats let queued = self.requests.count - return "connections: [connecting: \(stats.connecting) | backoff: \(stats.backingOff) | leased: \(stats.leased) | idle: \(stats.idle)], queued: \(queued)" + return + "connections: [connecting: \(stats.connecting) | backoff: \(stats.backingOff) | leased: \(stats.leased) | idle: \(stats.idle)], queued: \(queued)" } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift index 7aa504d03..dbb6b2d30 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift @@ -18,12 +18,12 @@ extension HTTPConnectionPool { private struct HTTP2ConnectionState { private enum State { /// the pool is establishing a connection. Valid transitions are to: .backingOff, .active and .closed - case starting + case starting(maximumUses: Int?) /// the connection is waiting to retry to establish a connection. Valid transitions are to .closed. /// From .closed a new connection state must be created for a retry. case backingOff /// the connection is active and is able to run requests. Valid transitions are to: .draining and .closed - case active(Connection, maxStreams: Int, usedStreams: Int, lastIdle: NIODeadline) + case active(Connection, maxStreams: Int, usedStreams: Int, lastIdle: NIODeadline, remainingUses: Int?) /// the connection is active and is running requests. No new requests must be scheduled. /// Valid transitions to: .draining and .closed case draining(Connection, maxStreams: Int, usedStreams: Int) @@ -71,8 +71,12 @@ extension HTTPConnectionPool { /// A request can be scheduled on the connection var isAvailable: Bool { switch self.state { - case .active(_, let maxStreams, let usedStreams, _): - return usedStreams < maxStreams + case .active(_, let maxStreams, let usedStreams, _, let remainingUses): + if let remainingUses = remainingUses { + return usedStreams < maxStreams && remainingUses > 0 + } else { + return usedStreams < maxStreams + } case .starting, .backingOff, .draining, .closed: return false } @@ -82,7 +86,7 @@ extension HTTPConnectionPool { /// Every idle connection is available, but not every available connection is idle. var isIdle: Bool { switch self.state { - case .active(_, _, let usedStreams, _): + case .active(_, _, let usedStreams, _, _): return usedStreams == 0 case .starting, .backingOff, .draining, .closed: return false @@ -112,9 +116,19 @@ extension HTTPConnectionPool { case .active, .draining, .backingOff, .closed: preconditionFailure("Invalid state: \(self.state)") - case .starting: - self.state = .active(conn, maxStreams: maxStreams, usedStreams: 0, lastIdle: .now()) - return maxStreams + case .starting(let maxUses): + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: 0, + lastIdle: .now(), + remainingUses: maxUses + ) + if let maxUses = maxUses { + return min(maxStreams, maxUses) + } else { + return maxStreams + } } } @@ -127,9 +141,20 @@ extension HTTPConnectionPool { case .starting, .backingOff, .closed: preconditionFailure("Invalid state for updating max concurrent streams: \(self.state)") - case .active(let conn, _, let usedStreams, let lastIdle): - self.state = .active(conn, maxStreams: maxStreams, usedStreams: usedStreams, lastIdle: lastIdle) - return max(maxStreams - usedStreams, 0) + case .active(let conn, _, let usedStreams, let lastIdle, let remainingUses): + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: usedStreams, + lastIdle: lastIdle, + remainingUses: remainingUses + ) + let availableStreams = max(maxStreams - usedStreams, 0) + if let remainingUses = remainingUses { + return min(remainingUses, availableStreams) + } else { + return availableStreams + } case .draining(let conn, _, let usedStreams): self.state = .draining(conn, maxStreams: maxStreams, usedStreams: usedStreams) @@ -142,7 +167,7 @@ extension HTTPConnectionPool { case .starting, .backingOff, .closed: preconditionFailure("Invalid state for draining a connection: \(self.state)") - case .active(let conn, let maxStreams, let usedStreams, _): + case .active(let conn, let maxStreams, let usedStreams, _, _): self.state = .draining(conn, maxStreams: maxStreams, usedStreams: usedStreams) return conn.eventLoop @@ -176,10 +201,20 @@ extension HTTPConnectionPool { case .starting, .backingOff, .draining, .closed: preconditionFailure("Invalid state for leasing a stream: \(self.state)") - case .active(let conn, let maxStreams, var usedStreams, let lastIdle): + case .active(let conn, let maxStreams, var usedStreams, let lastIdle, let remainingUses): usedStreams += count precondition(usedStreams <= maxStreams, "tried to lease a connection which is not available") - self.state = .active(conn, maxStreams: maxStreams, usedStreams: usedStreams, lastIdle: lastIdle) + precondition( + remainingUses.map { $0 >= count } ?? true, + "tried to lease streams from a connection which does not have enough remaining streams" + ) + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: usedStreams, + lastIdle: lastIdle, + remainingUses: remainingUses.map { $0 - count } + ) return conn } } @@ -191,14 +226,26 @@ extension HTTPConnectionPool { case .starting, .backingOff, .closed: preconditionFailure("Invalid state: \(self.state)") - case .active(let conn, let maxStreams, var usedStreams, var lastIdle): + case .active(let conn, let maxStreams, var usedStreams, var lastIdle, let remainingUses): precondition(usedStreams > 0, "we cannot release more streams than we have leased") usedStreams &-= 1 if usedStreams == 0 { lastIdle = .now() } - self.state = .active(conn, maxStreams: maxStreams, usedStreams: usedStreams, lastIdle: lastIdle) - return max(maxStreams &- usedStreams, 0) + + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: usedStreams, + lastIdle: lastIdle, + remainingUses: remainingUses + ) + let availableStreams = max(maxStreams &- usedStreams, 0) + if let remainingUses = remainingUses { + return min(availableStreams, remainingUses) + } else { + return availableStreams + } case .draining(let conn, let maxStreams, var usedStreams): precondition(usedStreams > 0, "we cannot release more streams than we have leased") @@ -210,7 +257,7 @@ extension HTTPConnectionPool { mutating func close() -> Connection { switch self.state { - case .active(let conn, _, 0, _): + case .active(let conn, _, 0, _, _): self.state = .closed return conn @@ -247,7 +294,7 @@ extension HTTPConnectionPool { context.connectBackoff.append(self.connectionID) return .removeConnection - case .active(let connection, _, let usedStreams, _): + case .active(let connection, _, let usedStreams, _, _): precondition(usedStreams >= 0) if usedStreams == 0 { context.close.append(connection) @@ -262,7 +309,9 @@ extension HTTPConnectionPool { return .keepConnection case .closed: - preconditionFailure("Unexpected state for cleanup: Did not expect to have closed connections in the state machine.") + preconditionFailure( + "Unexpected state for cleanup: Did not expect to have closed connections in the state machine." + ) } } @@ -274,7 +323,7 @@ extension HTTPConnectionPool { case .backingOff: stats.backingOffConnections &+= 1 - case .active(_, let maxStreams, let usedStreams, _): + case .active(_, let maxStreams, let usedStreams, _, _): stats.availableStreams += max(maxStreams - usedStreams, 0) stats.leasedStreams += usedStreams stats.availableConnections &+= 1 @@ -304,7 +353,7 @@ extension HTTPConnectionPool { context.starting.append((self.connectionID, self.eventLoop)) return .removeConnection - case .active(let connection, _, let usedStreams, _): + case .active(let connection, _, let usedStreams, _, _): precondition(usedStreams >= 0) if usedStreams == 0 { context.close.append(connection) @@ -321,14 +370,16 @@ extension HTTPConnectionPool { return .removeConnection case .closed: - preconditionFailure("Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)") + preconditionFailure( + "Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)" + ) } } - init(connectionID: Connection.ID, eventLoop: EventLoop) { + init(connectionID: Connection.ID, eventLoop: EventLoop, maximumUses: Int?) { self.connectionID = connectionID self.eventLoop = eventLoop - self.state = .starting + self.state = .starting(maximumUses: maximumUses) } } @@ -337,6 +388,8 @@ extension HTTPConnectionPool { private let generator: Connection.ID.Generator /// The connections states private var connections: [HTTP2ConnectionState] + /// The number of times each connection can be used before it is closed and replaced. + private let maximumConnectionUses: Int? var isEmpty: Bool { self.connections.isEmpty @@ -348,9 +401,10 @@ extension HTTPConnectionPool { } } - init(generator: Connection.ID.Generator) { + init(generator: Connection.ID.Generator, maximumConnectionUses: Int?) { self.generator = generator self.connections = [] + self.maximumConnectionUses = maximumConnectionUses } // MARK: Migration @@ -365,12 +419,20 @@ extension HTTPConnectionPool { backingOff: [(Connection.ID, EventLoop)] ) { for (connectionID, eventLoop) in starting { - let newConnection = HTTP2ConnectionState(connectionID: connectionID, eventLoop: eventLoop) + let newConnection = HTTP2ConnectionState( + connectionID: connectionID, + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) self.connections.append(newConnection) } for (connectionID, eventLoop) in backingOff { - var backingOffConnection = HTTP2ConnectionState(connectionID: connectionID, eventLoop: eventLoop) + var backingOffConnection = HTTP2ConnectionState( + connectionID: connectionID, + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) // TODO: Maybe we want to add a static init for backing off connections to HTTP2ConnectionState backingOffConnection.failedToConnect() self.connections.append(backingOffConnection) @@ -476,7 +538,11 @@ extension HTTPConnectionPool { "we should not create more than one connection per event loop" ) - let connection = HTTP2ConnectionState(connectionID: self.generator.next(), eventLoop: eventLoop) + let connection = HTTP2ConnectionState( + connectionID: self.generator.next(), + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) self.connections.append(connection) return connection.connectionID } @@ -489,11 +555,17 @@ extension HTTPConnectionPool { /// - Returns: An index and an ``EstablishedConnectionContext`` to determine the next action for the now idle connection. /// Call ``leaseStreams(at:count:)`` or ``closeConnection(at:)`` with the supplied index after /// this. - mutating func newHTTP2ConnectionEstablished(_ connection: Connection, maxConcurrentStreams: Int) -> (Int, EstablishedConnectionContext) { + mutating func newHTTP2ConnectionEstablished( + _ connection: Connection, + maxConcurrentStreams: Int + ) -> (Int, EstablishedConnectionContext) { guard let index = self.connections.firstIndex(where: { $0.connectionID == connection.id }) else { preconditionFailure("There is a new connection that we didn't request!") } - precondition(connection.eventLoop === self.connections[index].eventLoop, "Expected the new connection to be on EL") + precondition( + connection.eventLoop === self.connections[index].eventLoop, + "Expected the new connection to be on EL" + ) let availableStreams = self.connections[index].connected(connection, maxStreams: maxConcurrentStreams) let context = EstablishedConnectionContext( availableStreams: availableStreams, @@ -661,7 +733,8 @@ extension HTTPConnectionPool { precondition(self.connections[index].isClosed) let newConnection = HTTP2ConnectionState( connectionID: self.generator.next(), - eventLoop: self.connections[index].eventLoop + eventLoop: self.connections[index].eventLoop, + maximumUses: self.maximumConnectionUses ) self.connections[index] = newConnection diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift index d3e6fbdcd..2372cab4b 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift @@ -18,6 +18,7 @@ import NIOHTTP2 extension HTTPConnectionPool { struct HTTP2StateMachine { typealias Action = HTTPConnectionPool.StateMachine.Action + typealias RequestAction = HTTPConnectionPool.StateMachine.RequestAction typealias ConnectionMigrationAction = HTTPConnectionPool.StateMachine.ConnectionMigrationAction typealias EstablishedAction = HTTPConnectionPool.StateMachine.EstablishedAction typealias EstablishedConnectionAction = HTTPConnectionPool.StateMachine.EstablishedConnectionAction @@ -33,16 +34,25 @@ extension HTTPConnectionPool { private let idGenerator: Connection.ID.Generator private(set) var lifecycleState: StateMachine.LifecycleState + /// The property was introduced to fail fast during testing. + /// Otherwise this should always be true and not turned off. + private let retryConnectionEstablishment: Bool init( idGenerator: Connection.ID.Generator, - lifecycleState: StateMachine.LifecycleState + retryConnectionEstablishment: Bool, + lifecycleState: StateMachine.LifecycleState, + maximumConnectionUses: Int? ) { self.idGenerator = idGenerator self.requests = RequestQueue() - self.connections = HTTP2Connections(generator: idGenerator) + self.connections = HTTP2Connections( + generator: idGenerator, + maximumConnectionUses: maximumConnectionUses + ) self.lifecycleState = lifecycleState + self.retryConnectionEstablishment = retryConnectionEstablishment } mutating func migrateFromHTTP1( @@ -75,7 +85,10 @@ extension HTTPConnectionPool { requests: RequestQueue ) -> ConnectionMigrationAction { precondition(self.connections.isEmpty, "expected an empty state machine but connections are not empty") - precondition(self.http1Connections == nil, "expected an empty state machine but http1Connections are not nil") + precondition( + self.http1Connections == nil, + "expected an empty state machine but http1Connections are not nil" + ) precondition(self.requests.isEmpty, "expected an empty state machine but requests are not empty") self.requests = requests @@ -85,7 +98,7 @@ extension HTTPConnectionPool { self.connections = http2Connections } - var http1Connections = http1Connections // make http1Connections mutable + var http1Connections = http1Connections // make http1Connections mutable let context = http1Connections.migrateToHTTP2() self.connections.migrateFromHTTP1( starting: context.starting, @@ -207,7 +220,10 @@ extension HTTPConnectionPool { .init(self._newHTTP2ConnectionEstablished(connection, maxConcurrentStreams: maxConcurrentStreams)) } - private mutating func _newHTTP2ConnectionEstablished(_ connection: Connection, maxConcurrentStreams: Int) -> EstablishedAction { + private mutating func _newHTTP2ConnectionEstablished( + _ connection: Connection, + maxConcurrentStreams: Int + ) -> EstablishedAction { self.failedConsecutiveConnectionAttempts = 0 self.lastConnectFailure = nil if self.connections.hasActiveConnection(for: connection.eventLoop) { @@ -288,8 +304,14 @@ extension HTTPConnectionPool { } } - mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action { - guard let (index, context) = self.connections.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) else { + mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action + { + guard + let (index, context) = self.connections.newHTTP2MaxConcurrentStreamsReceived( + connectionID, + newMaxStreams: newMaxStreams + ) + else { // When a connection close is initiated by the connection pool, the connection will // still report further events (like newMaxConcurrentStreamsReceived) to the state // machine. In those cases we must ignore the event. @@ -333,15 +355,15 @@ extension HTTPConnectionPool { // we need to start a new on connection in two cases: let needGeneralPurposeConnection = // 1. if we have general purpose requests - !self.requests.isEmpty(for: nil) && + !self.requests.isEmpty(for: nil) // and no connection starting or active - !context.hasGeneralPurposeConnection + && !context.hasGeneralPurposeConnection let needRequiredEventLoopConnection = // 2. or if we have requests for a required event loop - !self.requests.isEmpty(for: eventLoop) && + !self.requests.isEmpty(for: eventLoop) // and no connection starting or active for the given event loop - !context.hasConnectionOnSpecifiedEventLoop + && !context.hasConnectionOnSpecifiedEventLoop guard needGeneralPurposeConnection || needRequiredEventLoopConnection else { // otherwise we can remove the connection @@ -349,7 +371,8 @@ extension HTTPConnectionPool { return .none } - let (newConnectionID, previousEventLoop) = self.connections.createNewConnectionByReplacingClosedConnection(at: index) + let (newConnectionID, previousEventLoop) = self.connections + .createNewConnectionByReplacingClosedConnection(at: index) precondition(previousEventLoop === eventLoop) return .init( @@ -401,9 +424,44 @@ extension HTTPConnectionPool { self.failedConsecutiveConnectionAttempts += 1 self.lastConnectFailure = error - let eventLoop = self.connections.backoffNextConnectionAttempt(connectionID) - let backoff = calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) - return .init(request: .none, connection: .scheduleBackoffTimer(connectionID, backoff: backoff, on: eventLoop)) + switch self.lifecycleState { + case .running: + guard self.retryConnectionEstablishment else { + guard let (index, _) = self.connections.failConnection(connectionID) else { + preconditionFailure( + "A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost." + ) + } + self.connections.removeConnection(at: index) + + return .init( + request: self.failAllRequests(reason: error), + connection: .none + ) + } + + let eventLoop = self.connections.backoffNextConnectionAttempt(connectionID) + let backoff = calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) + return .init( + request: .none, + connection: .scheduleBackoffTimer(connectionID, backoff: backoff, on: eventLoop) + ) + case .shuttingDown: + guard let (index, context) = self.connections.failConnection(connectionID) else { + preconditionFailure( + "A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost." + ) + } + return self.nextActionForFailedConnection(at: index, on: context.eventLoop) + case .shutDown: + preconditionFailure("If the pool is already shutdown, all connections must have been torn down.") + } + } + + mutating func waitingForConnectivity(_ error: Error, connectionID: Connection.ID) -> Action { + self.lastConnectFailure = error + + return .init(request: .none, connection: .none) } mutating func connectionCreationBackoffDone(_ connectionID: Connection.ID) -> Action { @@ -416,6 +474,14 @@ extension HTTPConnectionPool { return self.nextActionForFailedConnection(at: index, on: context.eventLoop) } + private mutating func failAllRequests(reason error: Error) -> RequestAction { + let allRequests = self.requests.removeAll() + guard !allRequests.isEmpty else { + return .none + } + return .failRequestsAndCancelTimeouts(allRequests, error) + } + mutating func timeoutRequest(_ requestID: Request.ID) -> Action { // 1. check requests in queue if let request = self.requests.remove(requestID) { @@ -439,9 +505,11 @@ extension HTTPConnectionPool { mutating func cancelRequest(_ requestID: Request.ID) -> Action { // 1. check requests in queue - if self.requests.remove(requestID) != nil { + if let request = self.requests.remove(requestID) { + // Use the last connection error to let the user know why the request was never scheduled + let error = self.lastConnectFailure ?? HTTPClientError.cancelled return .init( - request: .cancelRequestTimeout(requestID), + request: .failRequest(request, error, cancelTimeout: true), connection: .none ) } @@ -459,7 +527,10 @@ extension HTTPConnectionPool { return .none } - precondition(self.lifecycleState == .running, "If we are shutting down, we must not have any idle connections") + precondition( + self.lifecycleState == .running, + "If we are shutting down, we must not have any idle connections" + ) return .init( request: .none, @@ -512,7 +583,10 @@ extension HTTPConnectionPool { case .shuttingDown(let unclean): if self.connections.isEmpty { // if the http2connections are empty as well, there are no more connections. Shutdown completed. - return .init(request: .none, connection: .closeConnection(connection, isShutdown: .yes(unclean: unclean))) + return .init( + request: .none, + connection: .closeConnection(connection, isShutdown: .yes(unclean: unclean)) + ) } else { return .init(request: .none, connection: .closeConnection(connection, isShutdown: .no)) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift index 4d912633c..6dfd4223e 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift @@ -61,7 +61,6 @@ extension HTTPConnectionPool { case failRequestsAndCancelTimeouts([Request], Error) case scheduleRequestTimeout(for: Request, on: EventLoop) - case cancelRequestTimeout(Request.ID) case none } @@ -97,24 +96,52 @@ extension HTTPConnectionPool { let idGenerator: Connection.ID.Generator let maximumConcurrentHTTP1Connections: Int - - init(idGenerator: Connection.ID.Generator, maximumConcurrentHTTP1Connections: Int) { + /// The property was introduced to fail fast during testing. + /// Otherwise this should always be true and not turned off. + private let retryConnectionEstablishment: Bool + let maximumConnectionUses: Int? + + init( + idGenerator: Connection.ID.Generator, + maximumConcurrentHTTP1Connections: Int, + retryConnectionEstablishment: Bool, + preferHTTP1: Bool, + maximumConnectionUses: Int? + ) { self.maximumConcurrentHTTP1Connections = maximumConcurrentHTTP1Connections + self.retryConnectionEstablishment = retryConnectionEstablishment self.idGenerator = idGenerator - let http1State = HTTP1StateMachine( - idGenerator: idGenerator, - maximumConcurrentConnections: maximumConcurrentHTTP1Connections, - lifecycleState: .running - ) - self.state = .http1(http1State) + self.maximumConnectionUses = maximumConnectionUses + + if preferHTTP1 { + let http1State = HTTP1StateMachine( + idGenerator: idGenerator, + maximumConcurrentConnections: maximumConcurrentHTTP1Connections, + retryConnectionEstablishment: retryConnectionEstablishment, + maximumConnectionUses: maximumConnectionUses, + lifecycleState: .running + ) + self.state = .http1(http1State) + } else { + let http2State = HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: retryConnectionEstablishment, + lifecycleState: .running, + maximumConnectionUses: maximumConnectionUses + ) + self.state = .http2(http2State) + } } mutating func executeRequest(_ request: Request) -> Action { - self.state.modify(http1: { http1 in - http1.executeRequest(request) - }, http2: { http2 in - http2.executeRequest(request) - }) + self.state.modify( + http1: { http1 in + http1.executeRequest(request) + }, + http2: { http2 in + http2.executeRequest(request) + } + ) } mutating func newHTTP1ConnectionCreated(_ connection: Connection) -> Action { @@ -128,6 +155,8 @@ extension HTTPConnectionPool { var http1StateMachine = HTTP1StateMachine( idGenerator: self.idGenerator, maximumConcurrentConnections: self.maximumConcurrentHTTP1Connections, + retryConnectionEstablishment: self.retryConnectionEstablishment, + maximumConnectionUses: self.maximumConnectionUses, lifecycleState: http2StateMachine.lifecycleState ) @@ -148,7 +177,9 @@ extension HTTPConnectionPool { var http2StateMachine = HTTP2StateMachine( idGenerator: self.idGenerator, - lifecycleState: http1StateMachine.lifecycleState + retryConnectionEstablishment: self.retryConnectionEstablishment, + lifecycleState: http1StateMachine.lifecycleState, + maximumConnectionUses: self.maximumConnectionUses ) let migrationAction = http2StateMachine.migrateFromHTTP1( http1Connections: http1StateMachine.connections, @@ -171,52 +202,82 @@ extension HTTPConnectionPool { } } - mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action { - self.state.modify(http1: { http1 in - http1.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) - }, http2: { http2 in - http2.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) - }) + mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action + { + self.state.modify( + http1: { http1 in + http1.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) + }, + http2: { http2 in + http2.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) + } + ) } mutating func http2ConnectionGoAwayReceived(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http2ConnectionGoAwayReceived(connectionID) - }, http2: { http2 in - http2.http2ConnectionGoAwayReceived(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http2ConnectionGoAwayReceived(connectionID) + }, + http2: { http2 in + http2.http2ConnectionGoAwayReceived(connectionID) + } + ) } mutating func http2ConnectionClosed(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http2ConnectionClosed(connectionID) - }, http2: { http2 in - http2.http2ConnectionClosed(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http2ConnectionClosed(connectionID) + }, + http2: { http2 in + http2.http2ConnectionClosed(connectionID) + } + ) } mutating func http2ConnectionStreamClosed(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http2ConnectionStreamClosed(connectionID) - }, http2: { http2 in - http2.http2ConnectionStreamClosed(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http2ConnectionStreamClosed(connectionID) + }, + http2: { http2 in + http2.http2ConnectionStreamClosed(connectionID) + } + ) } mutating func failedToCreateNewConnection(_ error: Error, connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.failedToCreateNewConnection(error, connectionID: connectionID) - }, http2: { http2 in - http2.failedToCreateNewConnection(error, connectionID: connectionID) - }) + self.state.modify( + http1: { http1 in + http1.failedToCreateNewConnection(error, connectionID: connectionID) + }, + http2: { http2 in + http2.failedToCreateNewConnection(error, connectionID: connectionID) + } + ) + } + + mutating func waitingForConnectivity(_ error: Error, connectionID: Connection.ID) -> Action { + self.state.modify( + http1: { http1 in + http1.waitingForConnectivity(error, connectionID: connectionID) + }, + http2: { http2 in + http2.waitingForConnectivity(error, connectionID: connectionID) + } + ) } mutating func connectionCreationBackoffDone(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.connectionCreationBackoffDone(connectionID) - }, http2: { http2 in - http2.connectionCreationBackoffDone(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.connectionCreationBackoffDone(connectionID) + }, + http2: { http2 in + http2.connectionCreationBackoffDone(connectionID) + } + ) } /// A request has timed out. @@ -225,11 +286,14 @@ extension HTTPConnectionPool { /// request, but don't need to cancel the timer (it already triggered). If a request is cancelled /// we don't need to fail it but we need to cancel its timeout timer. mutating func timeoutRequest(_ requestID: Request.ID) -> Action { - self.state.modify(http1: { http1 in - http1.timeoutRequest(requestID) - }, http2: { http2 in - http2.timeoutRequest(requestID) - }) + self.state.modify( + http1: { http1 in + http1.timeoutRequest(requestID) + }, + http2: { http2 in + http2.timeoutRequest(requestID) + } + ) } /// A request was cancelled. @@ -238,44 +302,59 @@ extension HTTPConnectionPool { /// need to cancel its timeout timer. If a request times out, we need to fail the request, but don't /// need to cancel the timer (it already triggered). mutating func cancelRequest(_ requestID: Request.ID) -> Action { - self.state.modify(http1: { http1 in - http1.cancelRequest(requestID) - }, http2: { http2 in - http2.cancelRequest(requestID) - }) + self.state.modify( + http1: { http1 in + http1.cancelRequest(requestID) + }, + http2: { http2 in + http2.cancelRequest(requestID) + } + ) } mutating func connectionIdleTimeout(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.connectionIdleTimeout(connectionID) - }, http2: { http2 in - http2.connectionIdleTimeout(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.connectionIdleTimeout(connectionID) + }, + http2: { http2 in + http2.connectionIdleTimeout(connectionID) + } + ) } /// A connection has been closed mutating func http1ConnectionClosed(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http1ConnectionClosed(connectionID) - }, http2: { http2 in - http2.http1ConnectionClosed(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http1ConnectionClosed(connectionID) + }, + http2: { http2 in + http2.http1ConnectionClosed(connectionID) + } + ) } mutating func http1ConnectionReleased(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http1ConnectionReleased(connectionID) - }, http2: { http2 in - http2.http1ConnectionReleased(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http1ConnectionReleased(connectionID) + }, + http2: { http2 in + http2.http1ConnectionReleased(connectionID) + } + ) } mutating func shutdown() -> Action { - return self.state.modify(http1: { http1 in - http1.shutdown() - }, http2: { http2 in - http2.shutdown() - }) + self.state.modify( + http1: { http1 in + http1.shutdown() + }, + http2: { http2 in + http2.shutdown() + } + ) } } } @@ -326,7 +405,10 @@ extension HTTPConnectionPool.StateMachine { enum EstablishedConnectionAction { case none case scheduleTimeoutTimer(HTTPConnectionPool.Connection.ID, on: EventLoop) - case closeConnection(HTTPConnectionPool.Connection, isShutdown: HTTPConnectionPool.StateMachine.ConnectionAction.IsShutdown) + case closeConnection( + HTTPConnectionPool.Connection, + isShutdown: HTTPConnectionPool.StateMachine.ConnectionAction.IsShutdown + ) } } @@ -367,8 +449,7 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction { case .closeConnection(let connection, let isShutdown): guard isShutdown == .no else { precondition( - migrationAction.closeConnections.isEmpty && - migrationAction.createConnections.isEmpty, + migrationAction.closeConnections.isEmpty && migrationAction.createConnections.isEmpty, "migration actions are not supported during shutdown" ) return .closeConnection(connection, isShutdown: isShutdown) diff --git a/Sources/AsyncHTTPClient/DeconstructedURL.swift b/Sources/AsyncHTTPClient/DeconstructedURL.swift index 020c17455..52042bce3 100644 --- a/Sources/AsyncHTTPClient/DeconstructedURL.swift +++ b/Sources/AsyncHTTPClient/DeconstructedURL.swift @@ -48,9 +48,16 @@ extension DeconstructedURL { switch scheme { case .http, .https: + #if !canImport(Darwin) && compiler(>=6.0) + guard let urlHost = url.host, !urlHost.isEmpty else { + throw HTTPClientError.emptyHost + } + let host = urlHost.trimIPv6Brackets() + #else guard let host = url.host, !host.isEmpty else { throw HTTPClientError.emptyHost } + #endif self.init( scheme: scheme, connectionTarget: .init(remoteHost: host, port: url.port ?? scheme.defaultPort), @@ -81,3 +88,26 @@ extension DeconstructedURL { } } } + +#if !canImport(Darwin) && compiler(>=6.0) +extension String { + @inlinable internal func trimIPv6Brackets() -> String { + var utf8View = self.utf8[...] + + var modified = false + if utf8View.first == UInt8(ascii: "[") { + utf8View = utf8View.dropFirst() + modified = true + } + if utf8View.last == UInt8(ascii: "]") { + utf8View = utf8View.dropLast() + modified = true + } + + if modified { + return String(Substring(utf8View)) + } + return self + } +} +#endif diff --git a/Sources/AsyncHTTPClient/Docs.docc/index.md b/Sources/AsyncHTTPClient/Docs.docc/index.md new file mode 100644 index 000000000..37033e043 --- /dev/null +++ b/Sources/AsyncHTTPClient/Docs.docc/index.md @@ -0,0 +1,325 @@ +# ``AsyncHTTPClient`` + +This package provides simple HTTP Client library built on top of SwiftNIO. + +## Overview + +This library provides the following: +- First class support for Swift Concurrency (since version 1.9.0) +- Asynchronous and non-blocking request methods +- Simple follow-redirects (cookie headers are dropped) +- Streaming body download +- TLS support +- Automatic HTTP/2 over HTTPS (since version 1.7.0) +- Cookie parsing (but not storage) + +### Getting Started + +#### Adding the dependency + +Add the following entry in your Package.swift to start using HTTPClient: + +```swift +.package(url: "https://github.com/swift-server/async-http-client.git", from: "1.9.0") +``` +and `AsyncHTTPClient` dependency to your target: +```swift +.target(name: "MyApp", dependencies: [.product(name: "AsyncHTTPClient", package: "async-http-client")]), +``` + +#### Request-Response API + +The code snippet below illustrates how to make a simple GET request to a remote server. + +```swift +import AsyncHTTPClient + +/// MARK: - Using Swift Concurrency +let request = HTTPClientRequest(url: "https://apple.com/") +let response = try await httpClient.execute(request, timeout: .seconds(30)) +print("HTTP head", response) +if response.status == .ok { + let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB + // handle body +} else { + // handle remote error +} + + +/// MARK: - Using SwiftNIO EventLoopFuture +HTTPClient.shared.get(url: "https://apple.com/").whenComplete { result in + switch result { + case .failure(let error): + // process error + case .success(let response): + if response.status == .ok { + // handle response + } else { + // handle remote error + } + } +} +``` + +You should always shut down ``HTTPClient`` instances you created using ``HTTPClient/shutdown()-9gcpw``. Please note that you must not call ``HTTPClient/shutdown()-9gcpw`` before all requests of the HTTP client have finished, or else the in-flight requests will likely fail because their network connections are interrupted. + +#### async/await examples + +Examples for the async/await API can be found in the [`Examples` folder](https://github.com/swift-server/async-http-client/tree/main/Examples) in the repository. + +### Usage guide + +The default HTTP Method is `GET`. In case you need to have more control over the method, or you want to add headers or body, use the ``HTTPClientRequest`` struct: + +#### Using Swift Concurrency + +```swift +import AsyncHTTPClient + +do { + var request = HTTPClientRequest(url: "https://apple.com/") + request.method = .POST + request.headers.add(name: "User-Agent", value: "Swift HTTPClient") + request.body = .bytes(ByteBuffer(string: "some data")) + + let response = try await HTTPClient.shared.execute(request, timeout: .seconds(30)) + if response.status == .ok { + // handle response + } else { + // handle remote error + } +} catch { + // handle error +} +``` + +#### Using SwiftNIO EventLoopFuture + +```swift +import AsyncHTTPClient + +var request = try HTTPClient.Request(url: "https://apple.com/", method: .POST) +request.headers.add(name: "User-Agent", value: "Swift HTTPClient") +request.body = .string("some-body") + +HTTPClient.shared.execute(request: request).whenComplete { result in + switch result { + case .failure(let error): + // process error + case .success(let response): + if response.status == .ok { + // handle response + } else { + // handle remote error + } + } +} +``` + +#### Redirects following +Enable follow-redirects behavior using the client configuration: +```swift +let httpClient = HTTPClient(eventLoopGroupProvider: .singleton, + configuration: HTTPClient.Configuration(followRedirects: true)) +``` + +#### Timeouts +Timeouts (connect and read) can also be set using the client configuration: +```swift +let timeout = HTTPClient.Configuration.Timeout(connect: .seconds(1), read: .seconds(1)) +let httpClient = HTTPClient(eventLoopGroupProvider: .singleton, + configuration: HTTPClient.Configuration(timeout: timeout)) +``` +or on a per-request basis: +```swift +httpClient.execute(request: request, deadline: .now() + .milliseconds(1)) +``` + +#### Streaming +When dealing with larger amount of data, it's critical to stream the response body instead of aggregating in-memory. +The following example demonstrates how to count the number of bytes in a streaming response body: + +##### Using Swift Concurrency +```swift +do { + let request = HTTPClientRequest(url: "https://apple.com/") + let response = try await HTTPClient.shared.execute(request, timeout: .seconds(30)) + print("HTTP head", response) + + // if defined, the content-length headers announces the size of the body + let expectedBytes = response.headers.first(name: "content-length").flatMap(Int.init) + + var receivedBytes = 0 + // asynchronously iterates over all body fragments + // this loop will automatically propagate backpressure correctly + for try await buffer in response.body { + // for this example, we are just interested in the size of the fragment + receivedBytes += buffer.readableBytes + + if let expectedBytes = expectedBytes { + // if the body size is known, we calculate a progress indicator + let progress = Double(receivedBytes) / Double(expectedBytes) + print("progress: \(Int(progress * 100))%") + } + } + print("did receive \(receivedBytes) bytes") +} catch { + print("request failed:", error) +} +``` + +##### Using HTTPClientResponseDelegate and SwiftNIO EventLoopFuture + +```swift +import NIOCore +import NIOHTTP1 + +class CountingDelegate: HTTPClientResponseDelegate { + typealias Response = Int + + var count = 0 + + func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { + // this is executed right after request head was sent, called once + } + + func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) { + // this is executed when request body part is sent, could be called zero or more times + } + + func didSendRequest(task: HTTPClient.Task) { + // this is executed when request is fully sent, called once + } + + func didReceiveHead( + task: HTTPClient.Task, + _ head: HTTPResponseHead + ) -> EventLoopFuture { + // this is executed when we receive HTTP response head part of the request + // (it contains response code and headers), called once in case backpressure + // is needed, all reads will be paused until returned future is resolved + return task.eventLoop.makeSucceededFuture(()) + } + + func didReceiveBodyPart( + task: HTTPClient.Task, + _ buffer: ByteBuffer + ) -> EventLoopFuture { + // this is executed when we receive parts of the response body, could be called zero or more times + count += buffer.readableBytes + // in case backpressure is needed, all reads will be paused until returned future is resolved + return task.eventLoop.makeSucceededFuture(()) + } + + func didFinishRequest(task: HTTPClient.Task) throws -> Int { + // this is called when the request is fully read, called once + // this is where you return a result or throw any errors you require to propagate to the client + return count + } + + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + // this is called when we receive any network-related error, called once + } +} + +let request = try HTTPClient.Request(url: "https://apple.com/") +let delegate = CountingDelegate() + +httpClient.execute(request: request, delegate: delegate).futureResult.whenSuccess { count in + print(count) +} +``` + +#### File downloads + +Based on the `HTTPClientResponseDelegate` example above you can build more complex delegates, +the built-in `FileDownloadDelegate` is one of them. It allows streaming the downloaded data +asynchronously, while reporting the download progress at the same time, like in the following +example: + +```swift +let request = try HTTPClient.Request( + url: "https://swift.org/builds/development/ubuntu1804/latest-build.yml" +) + +let delegate = try FileDownloadDelegate(path: "/tmp/latest-build.yml", reportProgress: { + if let totalBytes = $0.totalBytes { + print("Total bytes count: \(totalBytes)") + } + print("Downloaded \($0.receivedBytes) bytes so far") +}) + +HTTPClient.shared.execute(request: request, delegate: delegate).futureResult + .whenSuccess { progress in + if let totalBytes = progress.totalBytes { + print("Final total bytes count: \(totalBytes)") + } + print("Downloaded finished with \(progress.receivedBytes) bytes downloaded") + } +``` + +#### Unix Domain Socket Paths +Connecting to servers bound to socket paths is easy: +```swift +HTTPClient.shared.execute( + .GET, + socketPath: "/tmp/myServer.socket", + urlPath: "/path/to/resource" +).whenComplete (...) +``` + +Connecting over TLS to a unix domain socket path is possible as well: +```swift +HTTPClient.shared.execute( + .POST, + secureSocketPath: "/tmp/myServer.socket", + urlPath: "/path/to/resource", + body: .string("hello") +).whenComplete (...) +``` + +Direct URLs can easily be constructed to be executed in other scenarios: +```swift +let socketPathBasedURL = URL( + httpURLWithSocketPath: "/tmp/myServer.socket", + uri: "/path/to/resource" +) +let secureSocketPathBasedURL = URL( + httpsURLWithSocketPath: "/tmp/myServer.socket", + uri: "/path/to/resource" +) +``` + +#### Disabling HTTP/2 +The exclusive use of HTTP/1 is possible by setting ``HTTPClient/Configuration/httpVersion-swift.property`` to ``HTTPClient/Configuration/HTTPVersion-swift.struct/http1Only`` on the ``HTTPClient/Configuration``: +```swift +var configuration = HTTPClient.Configuration() +configuration.httpVersion = .http1Only +let client = HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: configuration +) +``` + +### Security + +AsyncHTTPClient's security process is documented on [GitHub](https://github.com/swift-server/async-http-client/blob/main/SECURITY.md). + +## Topics + +### HTTPClient + +- ``HTTPClient`` +- ``HTTPClientRequest`` +- ``HTTPClientResponse`` + +### HTTP Client Delegates + +- ``HTTPClientResponseDelegate`` +- ``ResponseAccumulator`` +- ``FileDownloadDelegate`` +- ``HTTPClientCopyingDelegate`` + +### Errors + +- ``HTTPClientError`` diff --git a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift index 6f046dce9..b21499843 100644 --- a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift +++ b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift @@ -20,7 +20,7 @@ import NIOPosix public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// The response type for this delegate: the total count of bytes as reported by the response /// "Content-Length" header (if available) and the count of bytes downloaded. - public struct Progress { + public struct Progress: Sendable { public var totalBytes: Int? public var receivedBytes: Int } @@ -30,17 +30,18 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { public typealias Response = Progress private let filePath: String - private let io: NonBlockingFileIO - private let reportHead: ((HTTPResponseHead) -> Void)? - private let reportProgress: ((Progress) -> Void)? + private(set) var fileIOThreadPool: NIOThreadPool? + private let reportHead: ((HTTPClient.Task, HTTPResponseHead) -> Void)? + private let reportProgress: ((HTTPClient.Task, Progress) -> Void)? private var fileHandleFuture: EventLoopFuture? private var writeFuture: EventLoopFuture? /// Initializes a new file download delegate. + /// /// - parameters: /// - path: Path to a file you'd like to write the download to. - /// - pool: A thread pool to use for asynchronous file I/O. + /// - pool: A thread pool to use for asynchronous file I/O. If nil, a shared thread pool will be used. Defaults to nil. /// - reportHead: A closure called when the response head is available. /// - reportProgress: A closure called when a body chunk has been downloaded, with /// the total byte count and download byte count passed to it as arguments. The callbacks @@ -48,26 +49,95 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// as controlled by `EventLoopPreference`. public init( path: String, - pool: NIOThreadPool = NIOThreadPool(numberOfThreads: 1), - reportHead: ((HTTPResponseHead) -> Void)? = nil, - reportProgress: ((Progress) -> Void)? = nil + pool: NIOThreadPool? = nil, + reportHead: ((HTTPClient.Task, HTTPResponseHead) -> Void)? = nil, + reportProgress: ((HTTPClient.Task, Progress) -> Void)? = nil ) throws { - pool.start() - self.io = NonBlockingFileIO(threadPool: pool) + if let pool = pool { + self.fileIOThreadPool = pool + } else { + // we should use the shared thread pool from the HTTPClient which + // we will get from the `HTTPClient.Task` + self.fileIOThreadPool = nil + } + self.filePath = path self.reportHead = reportHead self.reportProgress = reportProgress } + /// Initializes a new file download delegate. + /// + /// - parameters: + /// - path: Path to a file you'd like to write the download to. + /// - pool: A thread pool to use for asynchronous file I/O. + /// - reportHead: A closure called when the response head is available. + /// - reportProgress: A closure called when a body chunk has been downloaded, with + /// the total byte count and download byte count passed to it as arguments. The callbacks + /// will be invoked in the same threading context that the delegate itself is invoked, + /// as controlled by `EventLoopPreference`. + public convenience init( + path: String, + pool: NIOThreadPool, + reportHead: ((HTTPResponseHead) -> Void)? = nil, + reportProgress: ((Progress) -> Void)? = nil + ) throws { + try self.init( + path: path, + pool: .some(pool), + reportHead: reportHead.map { reportHead in + { _, head in + reportHead(head) + } + }, + reportProgress: reportProgress.map { reportProgress in + { _, head in + reportProgress(head) + } + } + ) + } + + /// Initializes a new file download delegate and uses the shared thread pool of the ``HTTPClient`` for file I/O. + /// + /// - parameters: + /// - path: Path to a file you'd like to write the download to. + /// - reportHead: A closure called when the response head is available. + /// - reportProgress: A closure called when a body chunk has been downloaded, with + /// the total byte count and download byte count passed to it as arguments. The callbacks + /// will be invoked in the same threading context that the delegate itself is invoked, + /// as controlled by `EventLoopPreference`. + public convenience init( + path: String, + reportHead: ((HTTPResponseHead) -> Void)? = nil, + reportProgress: ((Progress) -> Void)? = nil + ) throws { + try self.init( + path: path, + pool: nil, + reportHead: reportHead.map { reportHead in + { _, head in + reportHead(head) + } + }, + reportProgress: reportProgress.map { reportProgress in + { _, head in + reportProgress(head) + } + } + ) + } + public func didReceiveHead( task: HTTPClient.Task, _ head: HTTPResponseHead ) -> EventLoopFuture { - self.reportHead?(head) + self.reportHead?(task, head) if let totalBytesString = head.headers.first(name: "Content-Length"), - let totalBytes = Int(totalBytesString) { + let totalBytes = Int(totalBytesString) + { self.progress.totalBytes = totalBytes } @@ -78,24 +148,33 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { task: HTTPClient.Task, _ buffer: ByteBuffer ) -> EventLoopFuture { + let threadPool: NIOThreadPool = { + guard let pool = self.fileIOThreadPool else { + let pool = task.fileIOThreadPool + self.fileIOThreadPool = pool + return pool + } + return pool + }() + let io = NonBlockingFileIO(threadPool: threadPool) self.progress.receivedBytes += buffer.readableBytes - self.reportProgress?(self.progress) + self.reportProgress?(task, self.progress) let writeFuture: EventLoopFuture if let fileHandleFuture = self.fileHandleFuture { writeFuture = fileHandleFuture.flatMap { - self.io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) } } else { - let fileHandleFuture = self.io.openFile( - path: self.filePath, + let fileHandleFuture = io.openFile( + _deprecatedPath: self.filePath, mode: .write, flags: .allowFileCreation(), eventLoop: task.eventLoop ) self.fileHandleFuture = fileHandleFuture writeFuture = fileHandleFuture.flatMap { - self.io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) } } diff --git a/Sources/AsyncHTTPClient/FoundationExtensions.swift b/Sources/AsyncHTTPClient/FoundationExtensions.swift index 545da756b..452cb7b13 100644 --- a/Sources/AsyncHTTPClient/FoundationExtensions.swift +++ b/Sources/AsyncHTTPClient/FoundationExtensions.swift @@ -39,7 +39,16 @@ extension HTTPClient.Cookie { /// - maxAge: The cookie's age in seconds, defaults to nil. /// - httpOnly: Whether this cookie should be used by HTTP servers only, defaults to false. /// - secure: Whether this cookie should only be sent using secure channels, defaults to false. - public init(name: String, value: String, path: String = "/", domain: String? = nil, expires: Date? = nil, maxAge: Int? = nil, httpOnly: Bool = false, secure: Bool = false) { + public init( + name: String, + value: String, + path: String = "/", + domain: String? = nil, + expires: Date? = nil, + maxAge: Int? = nil, + httpOnly: Bool = false, + secure: Bool = false + ) { // FIXME: This should be failable and validate the inputs // (for example, checking that the strings are ASCII, path begins with "/", domain is not empty, etc). self.init( @@ -59,8 +68,8 @@ extension HTTPClient.Body { /// Create and stream body using `Data`. /// /// - parameters: - /// - bytes: Body `Data` representation. + /// - data: Body `Data` representation. public static func data(_ data: Data) -> HTTPClient.Body { - return self.bytes(data) + self.bytes(data) } } diff --git a/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift b/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift index 75fc28de4..ea272a137 100644 --- a/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift +++ b/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift @@ -12,17 +12,29 @@ // //===----------------------------------------------------------------------===// +import CAsyncHTTPClient +import NIOCore import NIOHTTP1 + +#if canImport(xlocale) +import xlocale +#elseif canImport(locale_h) +import locale_h +#endif + #if canImport(Darwin) import Darwin +#elseif canImport(Musl) +import Musl +#elseif canImport(Android) +import Android #elseif canImport(Glibc) import Glibc #endif -import CAsyncHTTPClient extension HTTPClient { /// A representation of an HTTP cookie. - public struct Cookie { + public struct Cookie: Sendable { /// The name of the cookie. public var name: String /// The cookie's string value. @@ -45,7 +57,6 @@ extension HTTPClient { /// - parameters: /// - header: String representation of the `Set-Cookie` response header. /// - defaultDomain: Default domain to use if cookie was sent without one. - /// - returns: nil if the header is invalid. public init?(header: String, defaultDomain: String) { // The parsing of "Set-Cookie" headers is defined by Section 5.2, RFC-6265: // https://datatracker.ietf.org/doc/html/rfc6265#section-5.2 @@ -126,7 +137,16 @@ extension HTTPClient { /// - maxAge: The cookie's age in seconds, defaults to nil. /// - httpOnly: Whether this cookie should be used by HTTP servers only, defaults to false. /// - secure: Whether this cookie should only be sent using secure channels, defaults to false. - internal init(name: String, value: String, path: String = "/", domain: String? = nil, expires_timestamp: Int64? = nil, maxAge: Int? = nil, httpOnly: Bool = false, secure: Bool = false) { + internal init( + name: String, + value: String, + path: String = "/", + domain: String? = nil, + expires_timestamp: Int64? = nil, + maxAge: Int? = nil, + httpOnly: Bool = false, + secure: Bool = false + ) { self.name = name self.value = value self.path = path @@ -142,7 +162,7 @@ extension HTTPClient { extension HTTPClient.Response { /// List of HTTP cookies returned by the server. public var cookies: [HTTPClient.Cookie] { - return self.headers["set-cookie"].compactMap { HTTPClient.Cookie(header: $0, defaultDomain: self.host) } + self.headers["set-cookie"].compactMap { HTTPClient.Cookie(header: $0, defaultDomain: self.host) } } } @@ -212,7 +232,8 @@ private func parseTimestamp(_ utf8: String.UTF8View.SubSequence, format: String) } private func parseCookieTime(_ timestampUTF8: String.UTF8View.SubSequence) -> Int64? { - if timestampUTF8.contains(where: { $0 < 0x20 /* Control characters */ || $0 == 0x7F /* DEL */ }) { + // 0x20: Control characters or 0x7F: DEL + if timestampUTF8.contains(where: { $0 < 0x20 || $0 == 0x7F }) { return nil } var timestampUTF8 = timestampUTF8 @@ -225,8 +246,8 @@ private func parseCookieTime(_ timestampUTF8: String.UTF8View.SubSequence) -> In } guard var timeComponents = parseTimestamp(timestampUTF8, format: "%a, %d %b %Y %H:%M:%S") - ?? parseTimestamp(timestampUTF8, format: "%a, %d-%b-%y %H:%M:%S") - ?? parseTimestamp(timestampUTF8, format: "%a %b %d %H:%M:%S %Y") + ?? parseTimestamp(timestampUTF8, format: "%a, %d-%b-%y %H:%M:%S") + ?? parseTimestamp(timestampUTF8, format: "%a %b %d %H:%M:%S %Y") else { return nil } diff --git a/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift b/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift index 4d2b9388f..e95c828ce 100644 --- a/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift +++ b/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift @@ -12,6 +12,8 @@ // //===----------------------------------------------------------------------===// +import NIOCore + extension HTTPClient.Configuration { /// Proxy server configuration /// Specifies the remote address of an HTTP proxy. @@ -23,7 +25,7 @@ extension HTTPClient.Configuration { /// If a `TLSConfiguration` is used in conjunction with `HTTPClient.Configuration.Proxy`, /// TLS will be established _after_ successful proxy, between your client /// and the destination server. - public struct Proxy { + public struct Proxy: Sendable, Hashable { enum ProxyType: Hashable { case http(HTTPClient.Authorization?) case socks @@ -36,7 +38,10 @@ extension HTTPClient.Configuration { /// Specifies Proxy server authorization. public var authorization: HTTPClient.Authorization? { set { - precondition(self.type == .http(self.authorization), "SOCKS authorization support is not yet implemented.") + precondition( + self.type == .http(self.authorization), + "SOCKS authorization support is not yet implemented." + ) self.type = .http(newValue) } @@ -58,7 +63,7 @@ extension HTTPClient.Configuration { /// - host: proxy server host. /// - port: proxy server port. public static func server(host: String, port: Int) -> Proxy { - return .init(host: host, port: port, type: .http(nil)) + .init(host: host, port: port, type: .http(nil)) } /// Create a HTTP proxy. @@ -68,7 +73,7 @@ extension HTTPClient.Configuration { /// - port: proxy server port. /// - authorization: proxy server authorization. public static func server(host: String, port: Int, authorization: HTTPClient.Authorization? = nil) -> Proxy { - return .init(host: host, port: port, type: .http(authorization)) + .init(host: host, port: port, type: .http(authorization)) } /// Create a SOCKSv5 proxy. @@ -76,7 +81,7 @@ extension HTTPClient.Configuration { /// - parameter port: The SOCKSv5 proxy port, defaults to 1080. /// - returns: A new instance of `Proxy` configured to connect to a `SOCKSv5` server. public static func socksServer(host: String, port: Int = 1080) -> Proxy { - return .init(host: host, port: port, type: .socks) + .init(host: host, port: port, type: .socks) } } } diff --git a/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift b/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift new file mode 100644 index 000000000..f7d471f10 --- /dev/null +++ b/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIO + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClient { + #if compiler(>=6.0) + /// Start & automatically shut down a new ``HTTPClient``. + /// + /// This method allows to start & automatically dispose of a ``HTTPClient`` following the principle of Structured Concurrency. + /// The ``HTTPClient`` is guaranteed to be shut down upon return, whether `body` throws or not. + /// + /// This may be particularly useful if you cannot use the shared singleton (``HTTPClient/shared``). + public static func withHTTPClient( + eventLoopGroup: any EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger? = nil, + isolation: isolated (any Actor)? = #isolation, + _ body: (HTTPClient) async throws -> Return + ) async throws -> Return { + let logger = (backgroundActivityLogger ?? HTTPClient.loggingDisabled) + let httpClient = HTTPClient( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: logger + ) + return try await asyncDo { + try await body(httpClient) + } finally: { _ in + try await httpClient.shutdown() + } + } + #else + /// Start & automatically shut down a new ``HTTPClient``. + /// + /// This method allows to start & automatically dispose of a ``HTTPClient`` following the principle of Structured Concurrency. + /// The ``HTTPClient`` is guaranteed to be shut down upon return, whether `body` throws or not. + /// + /// This may be particularly useful if you cannot use the shared singleton (``HTTPClient/shared``). + public static func withHTTPClient( + eventLoopGroup: any EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger? = nil, + _ body: (HTTPClient) async throws -> Return + ) async throws -> Return { + let logger = (backgroundActivityLogger ?? HTTPClient.loggingDisabled) + let httpClient = HTTPClient( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: logger + ) + return try await asyncDo { + try await body(httpClient) + } finally: { _ in + try await httpClient.shutdown() + } + } + #endif +} diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 9301094ef..f1655c7c5 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +import Atomics import Foundation import Logging import NIOConcurrencyHelpers @@ -25,7 +26,7 @@ import NIOTransportServices extension Logger { private func requestInfo(_ request: HTTPClient.Request) -> Logger.Metadata.Value { - return "\(request.method) \(request.url)" + "\(request.method) \(request.url)" } func attachingRequestInformation(_ request: HTTPClient.Request, requestID: Int) -> Logger { @@ -36,15 +37,14 @@ extension Logger { } } -let globalRequestID = NIOAtomic.makeAtomic(value: 0) +let globalRequestID = ManagedAtomic(0) /// HTTPClient class provides API for request execution. /// /// Example: /// /// ```swift -/// let client = HTTPClient(eventLoopGroupProvider: .createNew) -/// client.get(url: "https://swift.org", deadline: .now() + .seconds(1)).whenComplete { result in +/// HTTPClient.shared.get(url: "https://swift.org", deadline: .now() + .seconds(1)).whenComplete { result in /// switch result { /// case .failure(let error): /// // process error @@ -57,57 +57,110 @@ let globalRequestID = NIOAtomic.makeAtomic(value: 0) /// } /// } /// ``` -/// -/// It is important to close the client instance, for example in a defer statement, after use to cleanly shutdown the underlying NIO `EventLoopGroup`: -/// -/// ```swift -/// try client.syncShutdown() -/// ``` public class HTTPClient { + /// The `EventLoopGroup` in use by this ``HTTPClient``. + /// + /// All HTTP transactions will occur on loops owned by this group. public let eventLoopGroup: EventLoopGroup - let eventLoopGroupProvider: EventLoopGroupProvider let configuration: Configuration let poolManager: HTTPConnectionPool.Manager + + /// Shared thread pool used for file IO. It is lazily created on first access of ``Task/fileIOThreadPool``. + private var fileIOThreadPool: NIOThreadPool? + private let fileIOThreadPoolLock = NIOLock() + private var state: State - private let stateLock = Lock() + private let stateLock = NIOLock() + private let canBeShutDown: Bool - internal static let loggingDisabled = Logger(label: "AHC-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() }) + static let loggingDisabled = Logger(label: "AHC-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() }) - /// Create an `HTTPClient` with specified `EventLoopGroup` provider and configuration. + /// Create an ``HTTPClient`` with specified `EventLoopGroup` provider and configuration. /// /// - parameters: /// - eventLoopGroupProvider: Specify how `EventLoopGroup` will be created. /// - configuration: Client configuration. - public convenience init(eventLoopGroupProvider: EventLoopGroupProvider, - configuration: Configuration = Configuration()) { - self.init(eventLoopGroupProvider: eventLoopGroupProvider, - configuration: configuration, - backgroundActivityLogger: HTTPClient.loggingDisabled) + public convenience init( + eventLoopGroupProvider: EventLoopGroupProvider, + configuration: Configuration = Configuration() + ) { + self.init( + eventLoopGroupProvider: eventLoopGroupProvider, + configuration: configuration, + backgroundActivityLogger: HTTPClient.loggingDisabled + ) + } + + /// Create an ``HTTPClient`` with specified `EventLoopGroup` and configuration. + /// + /// - parameters: + /// - eventLoopGroup: Specify how `EventLoopGroup` will be created. + /// - configuration: Client configuration. + public convenience init( + eventLoopGroup: EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration() + ) { + self.init( + eventLoopGroupProvider: .shared(eventLoopGroup), + configuration: configuration, + backgroundActivityLogger: HTTPClient.loggingDisabled + ) } - /// Create an `HTTPClient` with specified `EventLoopGroup` provider and configuration. + /// Create an ``HTTPClient`` with specified `EventLoopGroup` provider and configuration. /// /// - parameters: /// - eventLoopGroupProvider: Specify how `EventLoopGroup` will be created. /// - configuration: Client configuration. - public required init(eventLoopGroupProvider: EventLoopGroupProvider, - configuration: Configuration = Configuration(), - backgroundActivityLogger: Logger) { - self.eventLoopGroupProvider = eventLoopGroupProvider - switch self.eventLoopGroupProvider { + /// - backgroundActivityLogger: The logger to use for background activity logs. + public convenience init( + eventLoopGroupProvider: EventLoopGroupProvider, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger + ) { + let eventLoopGroup: any EventLoopGroup + + switch eventLoopGroupProvider { case .shared(let group): - self.eventLoopGroup = group - case .createNew: - #if canImport(Network) - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { - self.eventLoopGroup = NIOTSEventLoopGroup() - } else { - self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - } - #else - self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - #endif + eventLoopGroup = group + default: // handle `.createNew` without a deprecation warning + eventLoopGroup = HTTPClient.defaultEventLoopGroup } + + self.init( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: backgroundActivityLogger + ) + } + + /// Create an ``HTTPClient`` with specified `EventLoopGroup` and configuration. + /// + /// - parameters: + /// - eventLoopGroup: The `EventLoopGroup` that the ``HTTPClient`` will use. + /// - configuration: Client configuration. + /// - backgroundActivityLogger: The `Logger` that will be used to log background any activity that's not associated with a request. + public convenience init( + eventLoopGroup: any EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger + ) { + self.init( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: backgroundActivityLogger, + canBeShutDown: true + ) + } + + internal required init( + eventLoopGroup: EventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger, + canBeShutDown: Bool + ) { + self.canBeShutDown = canBeShutDown + self.eventLoopGroup = eventLoopGroup self.configuration = configuration self.poolManager = HTTPConnectionPool.Manager( eventLoopGroup: self.eventLoopGroup, @@ -124,20 +177,27 @@ public class HTTPClient { case .shutDown: break case .shuttingDown: - preconditionFailure(""" - This state should be totally unreachable. While the HTTPClient is shutting down a \ - reference cycle should exist, that prevents it from deinit. - """) + preconditionFailure( + """ + This state should be totally unreachable. While the HTTPClient is shutting down a \ + reference cycle should exist, that prevents it from deinit. + """ + ) case .upAndRunning: - preconditionFailure(""" - Client not shut down before the deinit. Please call client.syncShutdown() when no \ - longer needed. Otherwise memory will leak. - """) + preconditionFailure( + """ + Client not shut down before the deinit. Please call client.shutdown() when no \ + longer needed. Otherwise memory will leak. + """ + ) } } } /// Shuts down the client and `EventLoopGroup` if it was created by the client. + /// + /// This method blocks the thread indefinitely, prefer using ``shutdown()-96ayw``. + @available(*, noasync, message: "syncShutdown() can block indefinitely, prefer shutdown()", renamed: "shutdown()") public func syncShutdown() throws { try self.syncShutdown(requiresCleanClose: false) } @@ -152,62 +212,68 @@ public class HTTPClient { /// throw the appropriate error if needed. For instance, if its internal connection pool has any non-released connections, /// this indicate shutdown was called too early before tasks were completed or explicitly canceled. /// In general, setting this parameter to `true` should make it easier and faster to catch related programming errors. - internal func syncShutdown(requiresCleanClose: Bool) throws { + func syncShutdown(requiresCleanClose: Bool) throws { if let eventLoop = MultiThreadedEventLoopGroup.currentEventLoop { - preconditionFailure(""" - BUG DETECTED: syncShutdown() must not be called when on an EventLoop. - Calling syncShutdown() on any EventLoop can lead to deadlocks. - Current eventLoop: \(eventLoop) - """) + preconditionFailure( + """ + BUG DETECTED: syncShutdown() must not be called when on an EventLoop. + Calling syncShutdown() on any EventLoop can lead to deadlocks. + Current eventLoop: \(eventLoop) + """ + ) } - let errorStorageLock = Lock() - var errorStorage: Error? + let errorStorage: NIOLockedValueBox = NIOLockedValueBox(nil) let continuation = DispatchWorkItem {} - self.shutdown(requiresCleanClose: requiresCleanClose, queue: DispatchQueue(label: "async-http-client.shutdown")) { error in + self.shutdown(requiresCleanClose: requiresCleanClose, queue: DispatchQueue(label: "async-http-client.shutdown")) + { error in if let error = error { - errorStorageLock.withLock { + errorStorage.withLockedValue { errorStorage in errorStorage = error } } continuation.perform() } continuation.wait() - try errorStorageLock.withLock { + try errorStorage.withLockedValue { errorStorage in if let error = errorStorage { throw error } } } - /// Shuts down the client and event loop gracefully. This function is clearly an outlier in that it uses a completion + /// Shuts down the client and event loop gracefully. + /// + /// This function is clearly an outlier in that it uses a completion /// callback instead of an EventLoopFuture. The reason for that is that NIO's EventLoopFutures will call back on an event loop. /// The virtue of this function is to shut the event loop down. To work around that we call back on a DispatchQueue /// instead. - public func shutdown(queue: DispatchQueue = .global(), _ callback: @escaping (Error?) -> Void) { + @preconcurrency public func shutdown( + queue: DispatchQueue = .global(), + _ callback: @Sendable @escaping (Error?) -> Void + ) { self.shutdown(requiresCleanClose: false, queue: queue, callback) } - private func shutdownEventLoop(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { - self.stateLock.withLock { - switch self.eventLoopGroupProvider { - case .shared: - self.state = .shutDown - queue.async { - callback(nil) - } - case .createNew: - switch self.state { - case .shuttingDown: - self.state = .shutDown - self.eventLoopGroup.shutdownGracefully(queue: queue, callback) - case .shutDown, .upAndRunning: - assertionFailure("The only valid state at this point is \(String(describing: State.shuttingDown))") - } + /// Shuts down the ``HTTPClient`` and releases its resources. + public func shutdown() -> EventLoopFuture { + let promise = self.eventLoopGroup.any().makePromise(of: Void.self) + self.shutdown(queue: .global()) { error in + if let error = error { + promise.fail(error) + } else { + promise.succeed(()) } } + return promise.futureResult } - private func shutdown(requiresCleanClose: Bool, queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { + private func shutdown(requiresCleanClose: Bool, queue: DispatchQueue, _ callback: @escaping ShutdownCallback) { + guard self.canBeShutDown else { + queue.async { + callback(HTTPClientError.shutdownUnsupported) + } + return + } do { try self.stateLock.withLock { guard case .upAndRunning = self.state else { @@ -227,7 +293,7 @@ public class HTTPClient { case .failure: preconditionFailure("Shutting down the connection pool must not fail, ever.") case .success(let unclean): - let (callback, uncleanError) = self.stateLock.withLock { () -> ((Error?) -> Void, Error?) in + let (callback, uncleanError) = self.stateLock.withLock { () -> (ShutdownCallback, Error?) in guard case .shuttingDown(let requiresClean, callback: let callback) = self.state else { preconditionFailure("Why did the pool manager shut down, if it was not instructed to") } @@ -235,22 +301,32 @@ public class HTTPClient { let error: Error? = (requiresClean && unclean) ? HTTPClientError.uncleanShutdown : nil return (callback, error) } - - self.shutdownEventLoop(queue: queue) { error in - let reportedError = error ?? uncleanError - callback(reportedError) + self.stateLock.withLock { + self.state = .shutDown + } + queue.async { + callback(uncleanError) } } } } + private func makeOrGetFileIOThreadPool() -> NIOThreadPool { + self.fileIOThreadPoolLock.withLock { + guard let fileIOThreadPool = self.fileIOThreadPool else { + return NIOThreadPool.singleton + } + return fileIOThreadPool + } + } + /// Execute `GET` request using specified URL. /// /// - parameters: /// - url: Remote URL. /// - deadline: Point in time by which the request must complete. public func get(url: String, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.get(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.get(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `GET` request using specified URL. @@ -260,7 +336,7 @@ public class HTTPClient { /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. public func get(url: String, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.GET, url: url, deadline: deadline, logger: logger) + self.execute(.GET, url: url, deadline: deadline, logger: logger) } /// Execute `POST` request using specified URL. @@ -270,7 +346,7 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. public func post(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.post(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.post(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `POST` request using specified URL. @@ -280,8 +356,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func post(url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.POST, url: url, body: body, deadline: deadline, logger: logger) + public func post( + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger + ) -> EventLoopFuture { + self.execute(.POST, url: url, body: body, deadline: deadline, logger: logger) } /// Execute `PATCH` request using specified URL. @@ -291,7 +372,7 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. public func patch(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.patch(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.patch(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `PATCH` request using specified URL. @@ -301,8 +382,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func patch(url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.PATCH, url: url, body: body, deadline: deadline, logger: logger) + public func patch( + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger + ) -> EventLoopFuture { + self.execute(.PATCH, url: url, body: body, deadline: deadline, logger: logger) } /// Execute `PUT` request using specified URL. @@ -312,7 +398,7 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. public func put(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.put(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.put(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `PUT` request using specified URL. @@ -322,8 +408,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func put(url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.PUT, url: url, body: body, deadline: deadline, logger: logger) + public func put( + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger + ) -> EventLoopFuture { + self.execute(.PUT, url: url, body: body, deadline: deadline, logger: logger) } /// Execute `DELETE` request using specified URL. @@ -332,7 +423,7 @@ public class HTTPClient { /// - url: Remote URL. /// - deadline: The time when the request must have been completed by. public func delete(url: String, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.delete(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.delete(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `DELETE` request using specified URL. @@ -342,7 +433,7 @@ public class HTTPClient { /// - deadline: The time when the request must have been completed by. /// - logger: The logger to use for this request. public func delete(url: String, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.DELETE, url: url, deadline: deadline, logger: logger) + self.execute(.DELETE, url: url, deadline: deadline, logger: logger) } /// Execute arbitrary HTTP request using specified URL. @@ -353,7 +444,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(_ method: HTTPMethod = .GET, url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger? = nil) -> EventLoopFuture { + public func execute( + _ method: HTTPMethod = .GET, + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger? = nil + ) -> EventLoopFuture { do { let request = try Request(url: url, method: method, body: body) return self.execute(request: request, deadline: deadline, logger: logger ?? HTTPClient.loggingDisabled) @@ -371,7 +468,14 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(_ method: HTTPMethod = .GET, socketPath: String, urlPath: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger? = nil) -> EventLoopFuture { + public func execute( + _ method: HTTPMethod = .GET, + socketPath: String, + urlPath: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger? = nil + ) -> EventLoopFuture { do { guard let url = URL(httpURLWithSocketPath: socketPath, uri: urlPath) else { throw HTTPClientError.invalidURL @@ -392,7 +496,14 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(_ method: HTTPMethod = .GET, secureSocketPath: String, urlPath: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger? = nil) -> EventLoopFuture { + public func execute( + _ method: HTTPMethod = .GET, + secureSocketPath: String, + urlPath: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger? = nil + ) -> EventLoopFuture { do { guard let url = URL(httpsURLWithSocketPath: secureSocketPath, uri: urlPath) else { throw HTTPClientError.invalidURL @@ -410,7 +521,7 @@ public class HTTPClient { /// - request: HTTP request to execute. /// - deadline: Point in time by which the request must complete. public func execute(request: Request, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.execute(request: request, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.execute(request: request, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute arbitrary HTTP request using specified URL. @@ -430,26 +541,40 @@ public class HTTPClient { /// - request: HTTP request to execute. /// - eventLoop: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. - public func execute(request: Request, eventLoop: EventLoopPreference, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.execute(request: request, - eventLoop: eventLoop, - deadline: deadline, - logger: HTTPClient.loggingDisabled) + public func execute( + request: Request, + eventLoop: EventLoopPreference, + deadline: NIODeadline? = nil + ) -> EventLoopFuture { + self.execute( + request: request, + eventLoop: eventLoop, + deadline: deadline, + logger: HTTPClient.loggingDisabled + ) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. /// /// - parameters: /// - request: HTTP request to execute. - /// - eventLoop: NIO Event Loop preference. + /// - eventLoopPreference: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(request: Request, - eventLoop eventLoopPreference: EventLoopPreference, - deadline: NIODeadline? = nil, - logger: Logger?) -> EventLoopFuture { + public func execute( + request: Request, + eventLoop eventLoopPreference: EventLoopPreference, + deadline: NIODeadline? = nil, + logger: Logger? + ) -> EventLoopFuture { let accumulator = ResponseAccumulator(request: request) - return self.execute(request: request, delegate: accumulator, eventLoop: eventLoopPreference, deadline: deadline, logger: logger).futureResult + return self.execute( + request: request, + delegate: accumulator, + eventLoop: eventLoopPreference, + deadline: deadline, + logger: logger + ).futureResult } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -458,10 +583,12 @@ public class HTTPClient { /// - request: HTTP request to execute. /// - delegate: Delegate to process response parts. /// - deadline: Point in time by which the request must complete. - public func execute(request: Request, - delegate: Delegate, - deadline: NIODeadline? = nil) -> Task { - return self.execute(request: request, delegate: delegate, deadline: deadline, logger: HTTPClient.loggingDisabled) + public func execute( + request: Request, + delegate: Delegate, + deadline: NIODeadline? = nil + ) -> Task { + self.execute(request: request, delegate: delegate, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -471,11 +598,13 @@ public class HTTPClient { /// - delegate: Delegate to process response parts. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(request: Request, - delegate: Delegate, - deadline: NIODeadline? = nil, - logger: Logger) -> Task { - return self.execute(request: request, delegate: delegate, eventLoop: .indifferent, deadline: deadline, logger: logger) + public func execute( + request: Request, + delegate: Delegate, + deadline: NIODeadline? = nil, + logger: Logger + ) -> Task { + self.execute(request: request, delegate: delegate, eventLoop: .indifferent, deadline: deadline, logger: logger) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -483,18 +612,21 @@ public class HTTPClient { /// - parameters: /// - request: HTTP request to execute. /// - delegate: Delegate to process response parts. - /// - eventLoop: NIO Event Loop preference. + /// - eventLoopPreference: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. - /// - logger: The logger to use for this request. - public func execute(request: Request, - delegate: Delegate, - eventLoop eventLoopPreference: EventLoopPreference, - deadline: NIODeadline? = nil) -> Task { - return self.execute(request: request, - delegate: delegate, - eventLoop: eventLoopPreference, - deadline: deadline, - logger: HTTPClient.loggingDisabled) + public func execute( + request: Request, + delegate: Delegate, + eventLoop eventLoopPreference: EventLoopPreference, + deadline: NIODeadline? = nil + ) -> Task { + self.execute( + request: request, + delegate: delegate, + eventLoop: eventLoopPreference, + deadline: deadline, + logger: HTTPClient.loggingDisabled + ) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -502,7 +634,7 @@ public class HTTPClient { /// - parameters: /// - request: HTTP request to execute. /// - delegate: Delegate to process response parts. - /// - eventLoop: NIO Event Loop preference. + /// - eventLoopPreference: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. public func execute( @@ -510,14 +642,14 @@ public class HTTPClient { delegate: Delegate, eventLoop eventLoopPreference: EventLoopPreference, deadline: NIODeadline? = nil, - logger originalLogger: Logger? + logger: Logger? ) -> Task { self._execute( request: request, delegate: delegate, eventLoop: eventLoopPreference, deadline: deadline, - logger: originalLogger, + logger: logger, redirectState: RedirectState( self.configuration.redirectConfiguration.mode, initialURL: request.url.absoluteString @@ -541,34 +673,51 @@ public class HTTPClient { logger originalLogger: Logger?, redirectState: RedirectState? ) -> Task { - let logger = (originalLogger ?? HTTPClient.loggingDisabled).attachingRequestInformation(request, requestID: globalRequestID.add(1)) + let logger = (originalLogger ?? HTTPClient.loggingDisabled).attachingRequestInformation( + request, + requestID: globalRequestID.wrappingIncrementThenLoad(ordering: .relaxed) + ) let taskEL: EventLoop switch eventLoopPreference.preference { case .indifferent: // if possible we want a connection on the current `EventLoop` taskEL = self.eventLoopGroup.any() case .delegate(on: let eventLoop): - precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "Provided EventLoop must be part of clients EventLoopGroup.") + precondition( + self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, + "Provided EventLoop must be part of clients EventLoopGroup." + ) taskEL = eventLoop case .delegateAndChannel(on: let eventLoop): - precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "Provided EventLoop must be part of clients EventLoopGroup.") + precondition( + self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, + "Provided EventLoop must be part of clients EventLoopGroup." + ) taskEL = eventLoop case .testOnly_exact(_, delegateOn: let delegateEL): taskEL = delegateEL } - logger.trace("selected EventLoop for task given the preference", - metadata: ["ahc-eventloop": "\(taskEL)", - "ahc-el-preference": "\(eventLoopPreference)"]) + + logger.trace( + "selected EventLoop for task given the preference", + metadata: [ + "ahc-eventloop": "\(taskEL)", + "ahc-el-preference": "\(eventLoopPreference)", + ] + ) let failedTask: Task? = self.stateLock.withLock { - switch state { + switch self.state { case .upAndRunning: return nil case .shuttingDown, .shutDown: logger.debug("client is shutting down, failing request") - return Task.failedTask(eventLoop: taskEL, - error: HTTPClientError.alreadyShutdown, - logger: logger) + return Task.failedTask( + eventLoop: taskEL, + error: HTTPClientError.alreadyShutdown, + logger: logger, + makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool + ) } } @@ -591,14 +740,18 @@ public class HTTPClient { } }() - let task = Task(eventLoop: taskEL, logger: logger) + let task = Task( + eventLoop: taskEL, + logger: logger, + makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool + ) do { let requestBag = try RequestBag( request: request, eventLoopPreference: eventLoopPreference, task: task, redirectHandler: redirectHandler, - connectionDeadline: .now() + (self.configuration.timeout.connect ?? .seconds(10)), + connectionDeadline: .now() + (self.configuration.timeout.connectionCreationTimeout), requestOptions: .fromClientConfiguration(self.configuration), delegate: delegate ) @@ -606,7 +759,7 @@ public class HTTPClient { var deadlineSchedule: Scheduled? if let deadline = deadline { deadlineSchedule = taskEL.scheduleTask(deadline: deadline) { - requestBag.fail(HTTPClientError.deadlineExceeded) + requestBag.deadlineExceeded() } task.promise.futureResult.whenComplete { _ in @@ -622,11 +775,22 @@ public class HTTPClient { return task } - /// `HTTPClient` configuration. + /// ``HTTPClient`` configuration. public struct Configuration { /// TLS configuration, defaults to `TLSConfiguration.makeClientConfiguration()`. public var tlsConfiguration: Optional - /// Enables following 3xx redirects automatically, defaults to `RedirectConfiguration()`. + + /// Sometimes it can be useful to connect to one host e.g. `x.example.com` but + /// request and validate the certificate chain as if we would connect to `y.example.com`. + /// ``dnsOverride`` allows to do just that by mapping host names which we will request and validate the certificate chain, to a different + /// host name which will be used to actually connect to. + /// + /// **Example:** if ``dnsOverride`` is set to `["example.com": "localhost"]` and we execute a request with a + /// `url` of `https://example.com/`, the ``HTTPClient`` will actually open a connection to `localhost` instead of `example.com`. + /// ``HTTPClient`` will still request certificates from the server for `example.com` and validate them as if we would connect to `example.com`. + public var dnsOverride: [String: String] = [:] + + /// Enables following 3xx redirects automatically. /// /// Following redirects are supported: /// - `301: Moved Permanently` @@ -637,7 +801,8 @@ public class HTTPClient { /// - `307: Temporary Redirect` /// - `308: Permanent Redirect` public var redirectConfiguration: RedirectConfiguration - /// Default client timeout, defaults to no `read` timeout and 10 seconds `connect` timeout. + /// Default client timeout, defaults to no ``Timeout-swift.struct/read`` timeout + /// and 10 seconds ``Timeout-swift.struct/connect`` timeout. public var timeout: Timeout /// Connection pool configuration. public var connectionPool: ConnectionPool @@ -646,15 +811,42 @@ public class HTTPClient { /// Enables automatic body decompression. Supported algorithms are gzip and deflate. public var decompression: Decompression /// Ignore TLS unclean shutdown error, defaults to `false`. - @available(*, deprecated, message: "AsyncHTTPClient now correctly supports handling unexpected SSL connection drops. This property is ignored") + @available( + *, + deprecated, + message: + "AsyncHTTPClient now correctly supports handling unexpected SSL connection drops. This property is ignored" + ) public var ignoreUncleanSSLShutdown: Bool { get { false } set {} } - /// is set to `.automatic` by default which will use HTTP/2 if run over https and the server supports it, otherwise HTTP/1 + /// What HTTP versions to use. + /// + /// Set to ``HTTPVersion-swift.struct/automatic`` by default which will use HTTP/2 if run over https and the server supports it, otherwise HTTP/1 public var httpVersion: HTTPVersion + /// Whether ``HTTPClient`` will let Network.framework sit in the `.waiting` state awaiting new network changes, or fail immediately. Defaults to `true`, + /// which is the recommended setting. Only set this to `false` when attempting to trigger a particular error path. + public var networkFrameworkWaitForConnectivity: Bool + + /// The maximum number of times each connection can be used before it is replaced with a new one. Use `nil` (the default) + /// if no limit should be applied to each connection. + /// + /// - Precondition: The value must be greater than zero. + public var maximumUsesPerConnection: Int? { + willSet { + if let newValue = newValue, newValue <= 0 { + fatalError("maximumUsesPerConnection must be greater than zero or nil") + } + } + } + + /// Whether ``HTTPClient`` will use Multipath TCP or not + /// By default, don't use it + public var enableMultipath: Bool + public init( tlsConfiguration: TLSConfiguration? = nil, redirectConfiguration: RedirectConfiguration? = nil, @@ -671,14 +863,18 @@ public class HTTPClient { self.proxy = proxy self.decompression = decompression self.httpVersion = .automatic + self.networkFrameworkWaitForConnectivity = true + self.enableMultipath = false } - public init(tlsConfiguration: TLSConfiguration? = nil, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled) { + public init( + tlsConfiguration: TLSConfiguration? = nil, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled + ) { self.init( tlsConfiguration: tlsConfiguration, redirectConfiguration: redirectConfiguration, @@ -690,49 +886,59 @@ public class HTTPClient { ) } - public init(certificateVerification: CertificateVerification, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - maximumAllowedIdleTimeInConnectionPool: TimeAmount = .seconds(60), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled) { + public init( + certificateVerification: CertificateVerification, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + maximumAllowedIdleTimeInConnectionPool: TimeAmount = .seconds(60), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled + ) { var tlsConfig = TLSConfiguration.makeClientConfiguration() tlsConfig.certificateVerification = certificateVerification - self.init(tlsConfiguration: tlsConfig, - redirectConfiguration: redirectConfiguration, - timeout: timeout, - connectionPool: ConnectionPool(), - proxy: proxy, - ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, - decompression: decompression) + self.init( + tlsConfiguration: tlsConfig, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: ConnectionPool(idleTimeout: maximumAllowedIdleTimeInConnectionPool), + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) } - public init(certificateVerification: CertificateVerification, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - connectionPool: TimeAmount = .seconds(60), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled, - backgroundActivityLogger: Logger?) { + public init( + certificateVerification: CertificateVerification, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + connectionPool: TimeAmount = .seconds(60), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled, + backgroundActivityLogger: Logger? + ) { var tlsConfig = TLSConfiguration.makeClientConfiguration() tlsConfig.certificateVerification = certificateVerification - self.init(tlsConfiguration: tlsConfig, - redirectConfiguration: redirectConfiguration, - timeout: timeout, - connectionPool: ConnectionPool(), - proxy: proxy, - ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, - decompression: decompression) + self.init( + tlsConfiguration: tlsConfig, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: ConnectionPool(idleTimeout: connectionPool), + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) } - public init(certificateVerification: CertificateVerification, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled) { + public init( + certificateVerification: CertificateVerification, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled + ) { self.init( certificateVerification: certificateVerification, redirectConfiguration: redirectConfiguration, @@ -749,11 +955,16 @@ public class HTTPClient { public enum EventLoopGroupProvider { /// `EventLoopGroup` will be provided by the user. Owner of this group is responsible for its lifecycle. case shared(EventLoopGroup) - /// `EventLoopGroup` will be created by the client. When `syncShutdown` is called, created `EventLoopGroup` will be shut down as well. + /// The original intention of this was that ``HTTPClient`` would create and own its own `EventLoopGroup` to + /// facilitate use in programs that are not already using SwiftNIO. + /// Since https://github.com/apple/swift-nio/pull/2471 however, SwiftNIO does provide a global, shared singleton + /// `EventLoopGroup`s that we can use. ``HTTPClient`` is no longer able to create & own its own + /// `EventLoopGroup` which solves a whole host of issues around shutdown. + @available(*, deprecated, renamed: "singleton", message: "Please use the singleton EventLoopGroup explicitly") case createNew } - /// Specifies how the library will treat event loop passed by the user. + /// Specifies how the library will treat the event loop passed by the user. public struct EventLoopPreference { enum Preference { /// Event Loop will be selected by the library. @@ -781,7 +992,7 @@ public class HTTPClient { /// `EventLoop` but will not establish a new network connection just to satisfy the `EventLoop` preference if /// another existing connection on a different `EventLoop` is readily available from a connection pool. public static func delegate(on eventLoop: EventLoop) -> EventLoopPreference { - return EventLoopPreference(.delegate(on: eventLoop)) + EventLoopPreference(.delegate(on: eventLoop)) } /// The delegate and the `Channel` will be run on the specified EventLoop. @@ -789,34 +1000,70 @@ public class HTTPClient { /// Use this for use-cases where you prefer a new connection to be established over re-using an existing /// connection that might be on a different `EventLoop`. public static func delegateAndChannel(on eventLoop: EventLoop) -> EventLoopPreference { - return EventLoopPreference(.delegateAndChannel(on: eventLoop)) + EventLoopPreference(.delegateAndChannel(on: eventLoop)) } } /// Specifies decompression settings. - public enum Decompression { + public enum Decompression: Sendable { /// Decompression is disabled. case disabled /// Decompression is enabled. case enabled(limit: NIOHTTPDecompression.DecompressionLimit) } + typealias ShutdownCallback = @Sendable (Error?) -> Void + enum State { case upAndRunning - case shuttingDown(requiresCleanClose: Bool, callback: (Error?) -> Void) + case shuttingDown(requiresCleanClose: Bool, callback: ShutdownCallback) case shutDown } } +extension HTTPClient.EventLoopGroupProvider { + /// Shares ``HTTPClient/defaultEventLoopGroup`` which is a singleton `EventLoopGroup` suitable for the platform. + public static var singleton: Self { + .shared(HTTPClient.defaultEventLoopGroup) + } +} + +extension HTTPClient { + /// Returns the default `EventLoopGroup` singleton, automatically selecting the best for the platform. + /// + /// This will select the concrete `EventLoopGroup` depending which platform this is running on. + public static var defaultEventLoopGroup: EventLoopGroup { + #if canImport(Network) + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { + return NIOTSEventLoopGroup.singleton + } else { + return MultiThreadedEventLoopGroup.singleton + } + #else + return MultiThreadedEventLoopGroup.singleton + #endif + } +} + +extension HTTPClient.Configuration: Sendable {} + +extension HTTPClient.EventLoopGroupProvider: Sendable {} +extension HTTPClient.EventLoopPreference: Sendable {} + +// HTTPClient is thread-safe because its shared mutable state is protected through a lock +extension HTTPClient: @unchecked Sendable {} + extension HTTPClient.Configuration { /// Timeout configuration. - public struct Timeout { - /// Specifies connect timeout. If no connect timeout is given, a default 30 seconds timeout will applied. + public struct Timeout: Sendable { + /// Specifies connect timeout. If no connect timeout is given, a default 10 seconds timeout will be applied. public var connect: TimeAmount? /// Specifies read timeout. public var read: TimeAmount? + /// Specifies the maximum amount of time without bytes being written by the client before closing the connection. + public var write: TimeAmount? - /// internal connection creation timeout. Defaults the connect timeout to always contain a value. + /// Internal connection creation timeout. Defaults the connect timeout to always contain a value. var connectionCreationTimeout: TimeAmount { self.connect ?? .seconds(10) } @@ -824,17 +1071,35 @@ extension HTTPClient.Configuration { /// Create timeout. /// /// - parameters: - /// - connect: `connect` timeout. Will default to 10 seconds, if no value is - /// provided. See `var connectionCreationTimeout` + /// - connect: `connect` timeout. Will default to 10 seconds, if no value is provided. /// - read: `read` timeout. - public init(connect: TimeAmount? = nil, read: TimeAmount? = nil) { + public init( + connect: TimeAmount? = nil, + read: TimeAmount? = nil + ) { self.connect = connect self.read = read } + + /// Create timeout. + /// + /// - parameters: + /// - connect: `connect` timeout. Will default to 10 seconds, if no value is provided. + /// - read: `read` timeout. + /// - write: `write` timeout. + public init( + connect: TimeAmount? = nil, + read: TimeAmount? = nil, + write: TimeAmount + ) { + self.connect = connect + self.read = read + self.write = write + } } /// Specifies redirect processing settings. - public struct RedirectConfiguration { + public struct RedirectConfiguration: Sendable { enum Mode { /// Redirects are not followed. case disallow @@ -862,21 +1127,34 @@ extension HTTPClient.Configuration { /// - allowCycles: Whether cycles are allowed. /// /// - warning: Cycle detection will keep all visited URLs in memory which means a malicious server could use this as a denial-of-service vector. - public static func follow(max: Int, allowCycles: Bool) -> RedirectConfiguration { return .init(configuration: .follow(max: max, allowCycles: allowCycles)) } + public static func follow(max: Int, allowCycles: Bool) -> RedirectConfiguration { + .init(configuration: .follow(max: max, allowCycles: allowCycles)) + } } /// Connection pool configuration. - public struct ConnectionPool: Hashable { + public struct ConnectionPool: Hashable, Sendable { /// Specifies amount of time connections are kept idle in the pool. After this time has passed without a new /// request the connections are closed. - public var idleTimeout: TimeAmount + public var idleTimeout: TimeAmount = .seconds(60) /// The maximum number of connections that are kept alive in the connection pool per host. If requests with /// an explicit eventLoopRequirement are sent, this number might be exceeded due to overflow connections. - public var concurrentHTTP1ConnectionsPerHostSoftLimit: Int + public var concurrentHTTP1ConnectionsPerHostSoftLimit: Int = 8 + + /// If true, ``HTTPClient`` will try to create new connections on connection failure with an exponential backoff. + /// Requests will only fail after the ``HTTPClient/Configuration/Timeout-swift.struct/connect`` timeout exceeded. + /// If false, all requests that have no assigned connection will fail immediately after a connection could not be established. + /// Defaults to `true`. + /// - warning: We highly recommend leaving this on. + /// It is very common that connections establishment is flaky at scale. + /// ``HTTPClient`` will automatically mitigate these kind of issues if this flag is turned on. + public var retryConnectionEstablishment: Bool = true + + public init() {} - public init(idleTimeout: TimeAmount = .seconds(60)) { - self.init(idleTimeout: idleTimeout, concurrentHTTP1ConnectionsPerHostSoftLimit: 8) + public init(idleTimeout: TimeAmount) { + self.idleTimeout = idleTimeout } public init(idleTimeout: TimeAmount, concurrentHTTP1ConnectionsPerHostSoftLimit: Int) { @@ -885,19 +1163,19 @@ extension HTTPClient.Configuration { } } - public struct HTTPVersion { - internal enum Configuration { + public struct HTTPVersion: Sendable, Hashable { + enum Configuration { case http1Only case automatic } - /// we only use HTTP/1, even if the server would supports HTTP/2 + /// We will only use HTTP/1, even if the server would supports HTTP/2 public static let http1Only: Self = .init(configuration: .http1Only) /// HTTP/2 is used if we connect to a server with HTTPS and the server supports HTTP/2, otherwise we use HTTP/1 public static let automatic: Self = .init(configuration: .automatic) - internal var configuration: Configuration + var configuration: Configuration } } @@ -911,6 +1189,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case emptyScheme case unsupportedScheme(String) case readTimeout + case writeTimeout case remoteConnectionClosed case cancelled case identityCodingIncorrectlyPresent @@ -924,6 +1203,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case uncleanShutdown case traceRequestWithBody case invalidHeaderFieldNames([String]) + case invalidHeaderFieldValues([String]) case bodyLengthMismatch case writeAfterRequestSent @available(*, deprecated, message: "AsyncHTTPClient now silently corrects invalid headers.") @@ -937,6 +1217,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case getConnectionFromPoolTimeout case deadlineExceeded case httpEndReceivedAfterHeadWith1xx + case shutdownUnsupported } private var code: Code @@ -946,7 +1227,83 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { } public var description: String { - return "HTTPClientError.\(String(describing: self.code))" + "HTTPClientError.\(String(describing: self.code))" + } + + /// Short description of the error that can be used in case a bounded set of error descriptions is expected, e.g. to + /// include in metric labels. For this reason the description must not contain associated values. + public var shortDescription: String { + // When adding new cases here, do *not* include dynamic (associated) values in the description. + switch self.code { + case .invalidURL: + return "Invalid URL" + case .emptyHost: + return "Empty host" + case .missingSocketPath: + return "Missing socket path" + case .alreadyShutdown: + return "Already shutdown" + case .emptyScheme: + return "Empty scheme" + case .unsupportedScheme: + return "Unsupported scheme" + case .readTimeout: + return "Read timeout" + case .writeTimeout: + return "Write timeout" + case .remoteConnectionClosed: + return "Remote connection closed" + case .cancelled: + return "Cancelled" + case .identityCodingIncorrectlyPresent: + return "Identity coding incorrectly present" + case .chunkedSpecifiedMultipleTimes: + return "Chunked specified multiple times" + case .invalidProxyResponse: + return "Invalid proxy response" + case .contentLengthMissing: + return "Content length missing" + case .proxyAuthenticationRequired: + return "Proxy authentication required" + case .redirectLimitReached: + return "Redirect limit reached" + case .redirectCycleDetected: + return "Redirect cycle detected" + case .uncleanShutdown: + return "Unclean shutdown" + case .traceRequestWithBody: + return "Trace request with body" + case .invalidHeaderFieldNames: + return "Invalid header field names" + case .invalidHeaderFieldValues: + return "Invalid header field values" + case .bodyLengthMismatch: + return "Body length mismatch" + case .writeAfterRequestSent: + return "Write after request sent" + case .incompatibleHeaders: + return "Incompatible headers" + case .connectTimeout: + return "Connect timeout" + case .socksHandshakeTimeout: + return "SOCKS handshake timeout" + case .httpProxyHandshakeTimeout: + return "HTTP proxy handshake timeout" + case .tlsHandshakeTimeout: + return "TLS handshake timeout" + case .serverOfferedUnsupportedApplicationProtocol: + return "Server offered unsupported application protocol" + case .requestStreamCancelled: + return "Request stream cancelled" + case .getConnectionFromPoolTimeout: + return "Get connection from pool timeout" + case .deadlineExceeded: + return "Deadline exceeded" + case .httpEndReceivedAfterHeadWith1xx: + return "HTTP end received after head with 1xx" + case .shutdownUnsupported: + return "The global singleton HTTP client cannot be shut down" + } } /// URL provided is invalid. @@ -960,9 +1317,13 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { /// URL does not contain scheme. public static let emptyScheme = HTTPClientError(code: .emptyScheme) /// Provided URL scheme is not supported, supported schemes are: `http` and `https` - public static func unsupportedScheme(_ scheme: String) -> HTTPClientError { return HTTPClientError(code: .unsupportedScheme(scheme)) } - /// Request timed out. + public static func unsupportedScheme(_ scheme: String) -> HTTPClientError { + HTTPClientError(code: .unsupportedScheme(scheme)) + } + /// Request timed out while waiting for response. public static let readTimeout = HTTPClientError(code: .readTimeout) + /// Request timed out. + public static let writeTimeout = HTTPClientError(code: .writeTimeout) /// Remote connection was closed unexpectedly. public static let remoteConnectionClosed = HTTPClientError(code: .remoteConnectionClosed) /// Request was cancelled. @@ -987,7 +1348,13 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { /// A body was sent in a request with method TRACE. public static let traceRequestWithBody = HTTPClientError(code: .traceRequestWithBody) /// Header field names contain invalid characters. - public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { return HTTPClientError(code: .invalidHeaderFieldNames(names)) } + public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { + HTTPClientError(code: .invalidHeaderFieldNames(names)) + } + /// Header field values contain invalid characters. + public static func invalidHeaderFieldValues(_ values: [String]) -> HTTPClientError { + HTTPClientError(code: .invalidHeaderFieldValues(values)) + } /// Body length is not equal to `Content-Length`. public static let bodyLengthMismatch = HTTPClientError(code: .bodyLengthMismatch) /// Body part was written after request was fully sent. @@ -1005,7 +1372,12 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static let tlsHandshakeTimeout = HTTPClientError(code: .tlsHandshakeTimeout) /// The remote server only offered an unsupported application protocol public static func serverOfferedUnsupportedApplicationProtocol(_ proto: String) -> HTTPClientError { - return HTTPClientError(code: .serverOfferedUnsupportedApplicationProtocol(proto)) + HTTPClientError(code: .serverOfferedUnsupportedApplicationProtocol(proto)) + } + + /// The globally shared singleton ``HTTPClient`` cannot be shut down. + public static var shutdownUnsupported: HTTPClientError { + HTTPClientError(code: .shutdownUnsupported) } /// The request deadline was exceeded. The request was cancelled because of this. @@ -1022,6 +1394,11 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { /// - Tasks are not processed fast enough on the existing connections, to process all waiters in time public static let getConnectionFromPoolTimeout = HTTPClientError(code: .getConnectionFromPoolTimeout) - @available(*, deprecated, message: "AsyncHTTPClient now correctly supports informational headers. For this reason `httpEndReceivedAfterHeadWith1xx` will not be thrown anymore.") + @available( + *, + deprecated, + message: + "AsyncHTTPClient now correctly supports informational headers. For this reason `httpEndReceivedAfterHeadWith1xx` will not be thrown anymore." + ) public static let httpEndReceivedAfterHeadWith1xx = HTTPClientError(code: .httpEndReceivedAfterHeadWith1xx) } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index c1ce39632..38b930638 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -12,21 +12,25 @@ // //===----------------------------------------------------------------------===// +import Algorithms import Foundation import Logging import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 +import NIOPosix import NIOSSL extension HTTPClient { - /// Represent request body. + /// A request body. public struct Body { - /// Chunk provider. + /// A streaming uploader. + /// + /// ``StreamWriter`` abstracts public struct StreamWriter { let closure: (IOData) -> EventLoopFuture - /// Create new StreamWriter + /// Create new ``HTTPClient/Body/StreamWriter`` /// /// - parameters: /// - closure: function that will be called to write actual bytes to the channel. @@ -39,19 +43,98 @@ extension HTTPClient { /// - parameters: /// - data: `IOData` to write. public func write(_ data: IOData) -> EventLoopFuture { - return self.closure(data) + self.closure(data) + } + + @inlinable + func writeChunks( + of bytes: Bytes, + maxChunkSize: Int + ) -> EventLoopFuture where Bytes.Element == UInt8 { + // `StreamWriter` is has design issues, for example + // - https://github.com/swift-server/async-http-client/issues/194 + // - https://github.com/swift-server/async-http-client/issues/264 + // - We're not told the EventLoop the task runs on and the user is free to return whatever EL they + // want. + // One important consideration then is that we must lock around the iterator because we could be hopping + // between threads. + typealias Iterator = EnumeratedSequence>.Iterator + typealias Chunk = (offset: Int, element: ChunksOfCountCollection.Element) + + func makeIteratorAndFirstChunk( + bytes: Bytes + ) -> ( + iterator: NIOLockedValueBox, + chunk: Chunk + )? { + var iterator = bytes.chunks(ofCount: maxChunkSize).enumerated().makeIterator() + guard let chunk = iterator.next() else { + return nil + } + + return (NIOLockedValueBox(iterator), chunk) + } + + guard let (iterator, chunk) = makeIteratorAndFirstChunk(bytes: bytes) else { + return self.write(IOData.byteBuffer(.init())) + } + + @Sendable // can't use closure here as we recursively call ourselves which closures can't do + func writeNextChunk(_ chunk: Chunk, allDone: EventLoopPromise) { + if let nextElement = iterator.withLockedValue({ $0.next() }) { + self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).map { + let index = nextElement.offset + if (index + 1) % 4 == 0 { + // Let's not stack-overflow if the futures insta-complete which they at least in HTTP/2 + // mode. + // Also, we must frequently return to the EventLoop because we may get the pause signal + // from another thread. If we fail to do that promptly, we may balloon our body chunks + // into memory. + allDone.futureResult.eventLoop.execute { + writeNextChunk(nextElement, allDone: allDone) + } + } else { + writeNextChunk(nextElement, allDone: allDone) + } + }.cascadeFailure(to: allDone) + } else { + self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).cascade(to: allDone) + } + } + + // HACK (again, we're not told the right EventLoop): Let's write 0 bytes to make the user tell us... + return self.write(.byteBuffer(ByteBuffer())).flatMapWithEventLoop { (_, loop) in + let allDone = loop.makePromise(of: Void.self) + writeNextChunk(chunk, allDone: allDone) + return allDone.futureResult + } } } - /// Body size. if nil,`Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` + /// Body size. If nil,`Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` /// header is set with the given `length`. - public var length: Int? + @available(*, deprecated, renamed: "contentLength") + public var length: Int? { + get { + self.contentLength.flatMap { Int($0) } + } + set { + self.contentLength = newValue.flatMap { Int64($0) } + } + } + + /// Body size. If nil,`Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` + /// header is set with the given `contentLength`. + public var contentLength: Int64? + /// Body chunk provider. - public var stream: (StreamWriter) -> EventLoopFuture + public var stream: @Sendable (StreamWriter) -> EventLoopFuture + + @usableFromInline typealias StreamCallback = @Sendable (StreamWriter) -> EventLoopFuture @inlinable - init(length: Int?, stream: @escaping (StreamWriter) -> EventLoopFuture) { - self.length = length + init(contentLength: Int64?, stream: @escaping StreamCallback) { + self.contentLength = contentLength.flatMap { $0 } self.stream = stream } @@ -60,29 +143,53 @@ extension HTTPClient { /// - parameters: /// - buffer: Body `ByteBuffer` representation. public static func byteBuffer(_ buffer: ByteBuffer) -> Body { - return Body(length: buffer.readableBytes) { writer in + Body(contentLength: Int64(buffer.readableBytes)) { writer in writer.write(.byteBuffer(buffer)) } } - /// Create and stream body using `StreamWriter`. + /// Create and stream body using ``StreamWriter``. /// /// - parameters: /// - length: Body size. If nil, `Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` /// header is set with the given `length`. /// - stream: Body chunk provider. - public static func stream(length: Int? = nil, _ stream: @escaping (StreamWriter) -> EventLoopFuture) -> Body { - return Body(length: length, stream: stream) + @_disfavoredOverload + @preconcurrency + public static func stream( + length: Int? = nil, + _ stream: @Sendable @escaping (StreamWriter) -> EventLoopFuture + ) -> Body { + Body(contentLength: length.flatMap { Int64($0) }, stream: stream) + } + + /// Create and stream body using ``StreamWriter``. + /// + /// - parameters: + /// - contentLength: Body size. If nil, `Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` + /// header is set with the given `contentLength`. + /// - stream: Body chunk provider. + public static func stream( + contentLength: Int64? = nil, + _ stream: @Sendable @escaping (StreamWriter) -> EventLoopFuture + ) -> Body { + Body(contentLength: contentLength, stream: stream) } /// Create and stream body using a collection of bytes. /// /// - parameters: - /// - data: Body binary representation. + /// - bytes: Body binary representation. + @preconcurrency @inlinable - public static func bytes(_ bytes: Bytes) -> Body where Bytes: RandomAccessCollection, Bytes.Element == UInt8 { - return Body(length: bytes.count) { writer in - writer.write(.byteBuffer(ByteBuffer(bytes: bytes))) + public static func bytes(_ bytes: Bytes) -> Body + where Bytes: RandomAccessCollection, Bytes: Sendable, Bytes.Element == UInt8 { + Body(contentLength: Int64(bytes.count)) { writer in + if bytes.count <= bagOfBytesToByteBufferConversionChunkSize { + return writer.write(.byteBuffer(ByteBuffer(bytes: bytes))) + } else { + return writer.writeChunks(of: bytes, maxChunkSize: bagOfBytesToByteBufferConversionChunkSize) + } } } @@ -91,13 +198,17 @@ extension HTTPClient { /// - parameters: /// - string: Body `String` representation. public static func string(_ string: String) -> Body { - return Body(length: string.utf8.count) { writer in - writer.write(.byteBuffer(ByteBuffer(string: string))) + Body(contentLength: Int64(string.utf8.count)) { writer in + if string.utf8.count <= bagOfBytesToByteBufferConversionChunkSize { + return writer.write(.byteBuffer(ByteBuffer(string: string))) + } else { + return writer.writeChunks(of: string.utf8, maxChunkSize: bagOfBytesToByteBufferConversionChunkSize) + } } } } - /// Represent HTTP request. + /// Represents an HTTP request. public struct Request { /// Request HTTP method, defaults to `GET`. public let method: HTTPMethod @@ -123,7 +234,6 @@ extension HTTPClient { /// /// - parameters: /// - url: Remote `URL`. - /// - version: HTTP version. /// - method: HTTP method. /// - headers: Custom HTTP headers. /// - body: Request body. @@ -132,7 +242,12 @@ extension HTTPClient { /// - `emptyScheme` if URL does not contain HTTP scheme. /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. - public init(url: String, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { + public init( + url: String, + method: HTTPMethod = .GET, + headers: HTTPHeaders = HTTPHeaders(), + body: Body? = nil + ) throws { try self.init(url: url, method: method, headers: headers, body: body, tlsConfiguration: nil) } @@ -140,7 +255,6 @@ extension HTTPClient { /// /// - parameters: /// - url: Remote `URL`. - /// - version: HTTP version. /// - method: HTTP method. /// - headers: Custom HTTP headers. /// - body: Request body. @@ -150,7 +264,13 @@ extension HTTPClient { /// - `emptyScheme` if URL does not contain HTTP scheme. /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. - public init(url: String, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil, tlsConfiguration: TLSConfiguration?) throws { + public init( + url: String, + method: HTTPMethod = .GET, + headers: HTTPHeaders = HTTPHeaders(), + body: Body? = nil, + tlsConfiguration: TLSConfiguration? + ) throws { guard let url = URL(string: url) else { throw HTTPClientError.invalidURL } @@ -170,7 +290,8 @@ extension HTTPClient { /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. /// - `missingSocketPath` if URL does not contains a socketPath as an encoded host. - public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { + public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws + { try self.init(url: url, method: method, headers: headers, body: body, tlsConfiguration: nil) } @@ -187,7 +308,13 @@ extension HTTPClient { /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. /// - `missingSocketPath` if URL does not contains a socketPath as an encoded host. - public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil, tlsConfiguration: TLSConfiguration?) throws { + public init( + url: URL, + method: HTTPMethod = .GET, + headers: HTTPHeaders = HTTPHeaders(), + body: Body? = nil, + tlsConfiguration: TLSConfiguration? + ) throws { self.deconstructedURL = try DeconstructedURL(url: url) self.url = url @@ -220,14 +347,26 @@ extension HTTPClient { head.headers.addHostIfNeeded(for: self.deconstructedURL) - let metadata = try head.headers.validateAndSetTransportFraming(method: self.method, bodyLength: .init(self.body)) + let metadata = try head.headers.validateAndSetTransportFraming( + method: self.method, + bodyLength: .init(self.body) + ) return (head, metadata) } + + /// Set basic auth for a request. + /// + /// - parameters: + /// - username: the username to authenticate with + /// - password: authentication password associated with the username + public mutating func setBasicAuth(username: String, password: String) { + self.headers.setBasicAuth(username: username, password: password) + } } - /// Represent HTTP response. - public struct Response { + /// Represents an HTTP response. + public struct Response: Sendable { /// Remote host of the request. public var host: String /// Response HTTP status. @@ -263,7 +402,13 @@ extension HTTPClient { /// - version: Response HTTP version. /// - headers: Reponse HTTP headers. /// - body: Response body. - public init(host: String, status: HTTPResponseStatus, version: HTTPVersion, headers: HTTPHeaders, body: ByteBuffer?) { + public init( + host: String, + status: HTTPResponseStatus, + version: HTTPVersion, + headers: HTTPHeaders, + body: ByteBuffer? + ) { self.host = host self.status = status self.version = version @@ -272,8 +417,8 @@ extension HTTPClient { } } - /// HTTP authentication - public struct Authorization: Hashable { + /// HTTP authentication. + public struct Authorization: Hashable, Sendable { private enum Scheme: Hashable { case Basic(String) case Bearer(String) @@ -285,18 +430,24 @@ extension HTTPClient { self.scheme = scheme } + /// HTTP basic auth. public static func basic(username: String, password: String) -> HTTPClient.Authorization { - return .basic(credentials: Base64.encode(bytes: "\(username):\(password)".utf8)) + .basic(credentials: Base64.encode(bytes: "\(username):\(password)".utf8)) } + /// HTTP basic auth. + /// + /// This version uses the raw string directly. public static func basic(credentials: String) -> HTTPClient.Authorization { - return .init(scheme: .Basic(credentials)) + .init(scheme: .Basic(credentials)) } + /// HTTP bearer auth public static func bearer(tokens: String) -> HTTPClient.Authorization { - return .init(scheme: .Bearer(tokens)) + .init(scheme: .Bearer(tokens)) } + /// The header string for this auth field. public var headerValue: String { switch self.scheme { case .Basic(let credentials): @@ -308,7 +459,11 @@ extension HTTPClient { } } -public class ResponseAccumulator: HTTPClientResponseDelegate { +/// The default ``HTTPClientResponseDelegate``. +/// +/// This ``HTTPClientResponseDelegate`` buffers a complete HTTP response in memory. It does not stream the response body in. +/// The resulting ``Response`` type is ``HTTPClient/Response``. +public final class ResponseAccumulator: HTTPClientResponseDelegate { public typealias Response = HTTPClient.Response enum State { @@ -319,16 +474,66 @@ public class ResponseAccumulator: HTTPClientResponseDelegate { case error(Error) } + public struct ResponseTooBigError: Error, CustomStringConvertible { + public var maxBodySize: Int + public init(maxBodySize: Int) { + self.maxBodySize = maxBodySize + } + + public var description: String { + "ResponseTooBigError: received response body exceeds maximum accepted size of \(self.maxBodySize) bytes" + } + } + var state = State.idle - let request: HTTPClient.Request + let requestMethod: HTTPMethod + let requestHost: String + + static let maxByteBufferSize = Int(UInt32.max) - public init(request: HTTPClient.Request) { - self.request = request + /// Maximum size in bytes of the HTTP response body that ``ResponseAccumulator`` will accept + /// until it will abort the request and throw an ``ResponseTooBigError``. + /// + /// Default is 2^32. + /// - precondition: not allowed to exceed 2^32 because `ByteBuffer` can not store more bytes + public let maxBodySize: Int + + public convenience init(request: HTTPClient.Request) { + self.init(request: request, maxBodySize: Self.maxByteBufferSize) + } + + /// - Parameters: + /// - request: The corresponding request of the response this delegate will be accumulating. + /// - maxBodySize: Maximum size in bytes of the HTTP response body that ``ResponseAccumulator`` will accept + /// until it will abort the request and throw an ``ResponseTooBigError``. + /// Default is 2^32. + /// - precondition: maxBodySize is not allowed to exceed 2^32 because `ByteBuffer` can not store more bytes + /// - warning: You can use ``ResponseAccumulator`` for just one request. + /// If you start another request, you need to initiate another ``ResponseAccumulator``. + public init(request: HTTPClient.Request, maxBodySize: Int) { + precondition(maxBodySize >= 0, "maxBodyLength is not allowed to be negative") + precondition( + maxBodySize <= Self.maxByteBufferSize, + "maxBodyLength is not allowed to exceed 2^32 because ByteBuffer can not store more bytes" + ) + self.requestMethod = request.method + self.requestHost = request.host + self.maxBodySize = maxBodySize } public func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { switch self.state { case .idle: + if self.requestMethod != .HEAD, + let contentLength = head.headers.first(name: "Content-Length"), + let announcedBodySize = Int(contentLength), + announcedBodySize > self.maxBodySize + { + let error = ResponseTooBigError(maxBodySize: maxBodySize) + self.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } + self.state = .head(head) case .head: preconditionFailure("head already set") @@ -347,8 +552,20 @@ public class ResponseAccumulator: HTTPClientResponseDelegate { case .idle: preconditionFailure("no head received before body") case .head(let head): + guard part.readableBytes <= self.maxBodySize else { + let error = ResponseTooBigError(maxBodySize: self.maxBodySize) + self.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } self.state = .body(head, part) case .body(let head, var body): + let newBufferSize = body.writerIndex + part.readableBytes + guard newBufferSize <= self.maxBodySize else { + let error = ResponseTooBigError(maxBodySize: self.maxBodySize) + self.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } + // The compiler can't prove that `self.state` is dead here (and it kinda isn't, there's // a cross-module call in the way) so we need to drop the original reference to `body` in // `self.state` or we'll get a CoW. To fix that we temporarily set the state to `.end` (which @@ -374,9 +591,21 @@ public class ResponseAccumulator: HTTPClientResponseDelegate { case .idle: preconditionFailure("no head received before end") case .head(let head): - return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: nil) + return Response( + host: self.requestHost, + status: head.status, + version: head.version, + headers: head.headers, + body: nil + ) case .body(let head, let body): - return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: body) + return Response( + host: self.requestHost, + status: head.status, + version: head.version, + headers: head.headers, + body: body + ) case .end: preconditionFailure("request already processed") case .error(let error): @@ -385,32 +614,34 @@ public class ResponseAccumulator: HTTPClientResponseDelegate { } } -/// `HTTPClientResponseDelegate` allows an implementation to receive notifications about request processing and to control how response parts are processed. +/// ``HTTPClientResponseDelegate`` allows an implementation to receive notifications about request processing and to control how response parts are processed. +/// /// You can implement this protocol if you need fine-grained control over an HTTP request/response, for example, if you want to inspect the response /// headers before deciding whether to accept a response body, or if you want to stream your request body. Pass an instance of your conforming -/// class to the `HTTPClient.execute()` method and this package will call each delegate method appropriately as the request takes place./ +/// class to the ``HTTPClient/execute(request:delegate:eventLoop:deadline:)`` method and this package will call each delegate method appropriately as the request takes place. /// /// ### Backpressure /// -/// A `HTTPClientResponseDelegate` can be used to exert backpressure on the server response. This is achieved by way of the futures returned from -/// `didReceiveHead` and `didReceiveBodyPart`. The following functions are part of the "backpressure system" in the delegate: +/// A ``HTTPClientResponseDelegate`` can be used to exert backpressure on the server response. This is achieved by way of the futures returned from +/// ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` and ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``. +/// The following functions are part of the "backpressure system" in the delegate: /// -/// - `didReceiveHead` -/// - `didReceiveBodyPart` -/// - `didFinishRequest` -/// - `didReceiveError` +/// - ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` +/// - ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v`` +/// - ``HTTPClientResponseDelegate/didFinishRequest(task:)`` +/// - ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg`` /// -/// The first three methods are strictly _exclusive_, with that exclusivity managed by the futures returned by `didReceiveHead` and -/// `didReceiveBodyPart`. What this means is that until the returned future is completed, none of these three methods will be called -/// again. This allows delegates to rate limit the server to a capacity it can manage. `didFinishRequest` does not return a future, +/// The first three methods are strictly _exclusive_, with that exclusivity managed by the futures returned by ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` and +/// ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``. What this means is that until the returned future is completed, none of these three methods will be called +/// again. This allows delegates to rate limit the server to a capacity it can manage. ``HTTPClientResponseDelegate/didFinishRequest(task:)`` does not return a future, /// as we are expecting no more data from the server at this time. /// -/// `didReceiveError` is somewhat special: it signals the end of this regime. `didRecieveError` is not exclusive: it may be called at -/// any time, even if a returned future is not yet completed. `didReceiveError` is terminal, meaning that once it has been called none -/// of these four methods will be called again. This can be used as a signal to abandon all outstanding work. +/// ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg`` is somewhat special: it signals the end of this regime. ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg`` +/// is not exclusive: it may be called at any time, even if a returned future is not yet completed. ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg`` is terminal, meaning +/// that once it has been called none of these four methods will be called again. This can be used as a signal to abandon all outstanding work. /// /// - note: This delegate is strongly held by the `HTTPTaskHandler` -/// for the duration of the `Request` processing and will be +/// for the duration of the ``HTTPClient/Request`` processing and will be /// released together with the `HTTPTaskHandler` when channel is closed. /// Users of the library are not required to keep a reference to the /// object that implements this protocol, but may do so if needed. @@ -428,7 +659,7 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// /// - parameters: /// - task: Current request context. - /// - part: Request body `Part`. + /// - part: Request body part. func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) /// Called when the request is fully sent. Will be called once. @@ -451,7 +682,7 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// You must return an `EventLoopFuture` that you complete when you have finished processing the body part. /// You can create an already succeeded future by calling `task.eventLoop.makeSucceededFuture(())`. /// - /// This function will not be called until the future returned by `didReceiveHead` has completed. + /// This function will not be called until the future returned by ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` has completed. /// /// This function will not be called for subsequent body parts until the previous future returned by a /// call to this function completes. @@ -464,19 +695,22 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// Called when error was thrown during request execution. Will be called zero or one time only. Request processing will be stopped after that. /// - /// This function may be called at any time: it does not respect the backpressure exerted by `didReceiveHead` and `didReceiveBodyPart`. - /// All outstanding work may be cancelled when this is received. Once called, no further calls will be made to `didReceiveHead`, `didReceiveBodyPart`, - /// or `didFinishRequest`. + /// This function may be called at any time: it does not respect the backpressure exerted by ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` + /// and ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``. + /// All outstanding work may be cancelled when this is received. Once called, no further calls will be made to + /// ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd``, ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``, + /// or ``HTTPClientResponseDelegate/didFinishRequest(task:)``. /// /// - parameters: /// - task: Current request context. /// - error: Error that occured during response processing. func didReceiveError(task: HTTPClient.Task, _ error: Error) - /// Called when the complete HTTP request is finished. You must return an instance of your `Response` associated type. Will be called once, except if an error occurred. + /// Called when the complete HTTP request is finished. You must return an instance of your ``Response`` associated type. Will be called once, except if an error occurred. /// - /// This function will not be called until all futures returned by `didReceiveHead` and `didReceiveBodyPart` have completed. Once called, - /// no further calls will be made to `didReceiveHead`, `didReceiveBodyPart`, or `didReceiveError`. + /// This function will not be called until all futures returned by ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` and ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v`` + /// have completed. Once called, no further calls will be made to ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd``, ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``, + /// or ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg``. /// /// - parameters: /// - task: Current request context. @@ -485,20 +719,38 @@ public protocol HTTPClientResponseDelegate: AnyObject { } extension HTTPClientResponseDelegate { + /// Default implementation of ``HTTPClientResponseDelegate/didSendRequest(task:)-9od5p``. + /// + /// By default, this does nothing. public func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) {} + /// Default implementation of ``HTTPClientResponseDelegate/didSendRequestPart(task:_:)-4qxap``. + /// + /// By default, this does nothing. public func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) {} + /// Default implementation of ``HTTPClientResponseDelegate/didSendRequest(task:)-3vqgm``. + /// + /// By default, this does nothing. public func didSendRequest(task: HTTPClient.Task) {} + /// Default implementation of ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd``. + /// + /// By default, this does nothing. public func didReceiveHead(task: HTTPClient.Task, _: HTTPResponseHead) -> EventLoopFuture { - return task.eventLoop.makeSucceededFuture(()) + task.eventLoop.makeSucceededVoidFuture() } + /// Default implementation of ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``. + /// + /// By default, this does nothing. public func didReceiveBodyPart(task: HTTPClient.Task, _: ByteBuffer) -> EventLoopFuture { - return task.eventLoop.makeSucceededFuture(()) + task.eventLoop.makeSucceededVoidFuture() } + /// Default implementation of ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg``. + /// + /// By default, this does nothing. public func didReceiveError(task: HTTPClient.Task, _: Error) {} } @@ -507,7 +759,7 @@ extension URL { if self.path.isEmpty { return "/" } - return URLComponents(url: self, resolvingAgainstBaseURL: false)?.percentEncodedPath ?? self.path + return URLComponents(url: self, resolvingAgainstBaseURL: true)?.percentEncodedPath ?? self.path } var uri: String { @@ -521,7 +773,7 @@ extension URL { } func hasTheSameOrigin(as other: URL) -> Bool { - return self.host == other.host && self.scheme == other.scheme && self.port == other.port + self.host == other.host && self.scheme == other.scheme && self.port == other.port } /// Initializes a newly created HTTP URL connecting to a unix domain socket path. The socket path is encoded as the URL's host, replacing percent encoding invalid path characters, and will use the "http+unix" scheme. @@ -556,18 +808,21 @@ extension URL { } protocol HTTPClientTaskDelegate { - func cancel() + func fail(_ error: Error) } extension HTTPClient { - /// Response execution context. Will be created by the library and could be used for obtaining + /// Response execution context. + /// + /// Will be created by the library and could be used for obtaining /// `EventLoopFuture` of the execution or cancellation of the execution. public final class Task { /// The `EventLoop` the delegate will be executed on. public let eventLoop: EventLoop + /// The `Logger` used by the `Task` for logging. + public let logger: Logger // We are okay to store the logger here because a Task is for only one request. let promise: EventLoopPromise - let logger: Logger // We are okay to store the logger here because a Task is for only one request. var isCancelled: Bool { self.lock.withLock { self._isCancelled } @@ -584,57 +839,102 @@ extension HTTPClient { private var _isCancelled: Bool = false private var _taskDelegate: HTTPClientTaskDelegate? - private let lock = Lock() + private let lock = NIOLock() + private let makeOrGetFileIOThreadPool: () -> NIOThreadPool + + /// The shared thread pool of a ``HTTPClient`` used for file IO. It is lazily created on first access. + internal var fileIOThreadPool: NIOThreadPool { + self.makeOrGetFileIOThreadPool() + } - init(eventLoop: EventLoop, logger: Logger) { + init(eventLoop: EventLoop, logger: Logger, makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool) { self.eventLoop = eventLoop self.promise = eventLoop.makePromise() self.logger = logger - } - - static func failedTask(eventLoop: EventLoop, error: Error, logger: Logger) -> Task { - let task = self.init(eventLoop: eventLoop, logger: logger) + self.makeOrGetFileIOThreadPool = makeOrGetFileIOThreadPool + } + + static func failedTask( + eventLoop: EventLoop, + error: Error, + logger: Logger, + makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool + ) -> Task { + let task = self.init( + eventLoop: eventLoop, + logger: logger, + makeOrGetFileIOThreadPool: makeOrGetFileIOThreadPool + ) task.promise.fail(error) return task } /// `EventLoopFuture` for the response returned by this request. public var futureResult: EventLoopFuture { - return self.promise.futureResult + self.promise.futureResult } /// Waits for execution of this request to complete. /// - /// - returns: The value of the `EventLoopFuture` when it completes. - /// - throws: The error value of the `EventLoopFuture` if it errors. + /// - returns: The value of ``futureResult`` when it completes. + /// - throws: The error value of ``futureResult`` if it errors. + @available(*, noasync, message: "wait() can block indefinitely, prefer get()", renamed: "get()") public func wait() throws -> Response { - return try self.promise.futureResult.wait() + try self.promise.futureResult.wait() } - /// Cancels the request execution. + /// Provides the result of this request. + /// + /// - warning: This method may violates Structured Concurrency because doesn't respect cancellation. + /// + /// - returns: The value of ``futureResult`` when it completes. + /// - throws: The error value of ``futureResult`` if it errors. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public func get() async throws -> Response { + try await self.promise.futureResult.get() + } + + /// Initiate cancellation of a HTTP request. + /// + /// This method will return immeidately and doesn't wait for the cancellation to complete. public func cancel() { + self.fail(reason: HTTPClientError.cancelled) + } + + /// Initiate cancellation of a HTTP request with an `error`. + /// + /// This method will return immeidately and doesn't wait for the cancellation to complete. + /// + /// - Parameter error: the error that is used to fail the promise + public func fail(reason error: Error) { let taskDelegate = self.lock.withLock { () -> HTTPClientTaskDelegate? in self._isCancelled = true return self._taskDelegate } - taskDelegate?.cancel() + taskDelegate?.fail(error) } - func succeed(promise: EventLoopPromise?, - with value: Response, - delegateType: Delegate.Type, - closing: Bool) { + func succeed( + promise: EventLoopPromise?, + with value: Response, + delegateType: Delegate.Type, + closing: Bool + ) { promise?.succeed(value) } - func fail(with error: Error, - delegateType: Delegate.Type) { + func fail( + with error: Error, + delegateType: Delegate.Type + ) { self.promise.fail(error) } } } +extension HTTPClient.Task: @unchecked Sendable {} + internal struct TaskCancelEvent {} // MARK: - RedirectHandler @@ -693,7 +993,7 @@ extension RequestBodyLength { self = .known(0) return } - guard let length = body.length else { + guard let length = body.contentLength else { self = .unknown return } diff --git a/Sources/AsyncHTTPClient/LRUCache.swift b/Sources/AsyncHTTPClient/LRUCache.swift index 0a01da0d2..f8b58c36a 100644 --- a/Sources/AsyncHTTPClient/LRUCache.swift +++ b/Sources/AsyncHTTPClient/LRUCache.swift @@ -52,9 +52,11 @@ struct LRUCache { @discardableResult mutating func append(key: Key, value: Value) -> Value { - let newElement = Element(generation: self.generation, - key: key, - value: value) + let newElement = Element( + generation: self.generation, + key: key, + value: value + ) if let found = self.bumpGenerationAndFindIndex(key: key) { self.elements[found] = newElement return value diff --git a/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift b/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift index 4334bb9f9..148b4a4c4 100644 --- a/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift +++ b/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift @@ -12,15 +12,17 @@ // //===----------------------------------------------------------------------===// -#if canImport(Network) -import Network -#endif import NIOCore import NIOHTTP1 import NIOTransportServices +#if canImport(Network) +import Network +#endif + extension HTTPClient { #if canImport(Network) + /// A wrapper for `POSIX` errors thrown by `Network.framework`. public struct NWPOSIXError: Error, CustomStringConvertible { /// POSIX error code (enum) public let errorCode: POSIXErrorCode @@ -37,11 +39,12 @@ extension HTTPClient { self.reason = reason } - public var description: String { return self.reason } + public var description: String { self.reason } } + /// A wrapper for TLS errors thrown by `Network.framework`. public struct NWTLSError: Error, CustomStringConvertible { - /// TLS error status. List of TLS errors can be found in + /// TLS error status. List of TLS errors can be found in `` public let status: OSStatus /// actual reason, in human readable form @@ -56,11 +59,11 @@ extension HTTPClient { self.reason = reason } - public var description: String { return self.reason } + public var description: String { self.reason } } #endif - class NWErrorHandler: ChannelInboundHandler { + final class NWErrorHandler: ChannelInboundHandler { typealias InboundIn = HTTPClientResponsePart func errorCaught(context: ChannelHandlerContext, error: Error) { @@ -73,9 +76,9 @@ extension HTTPClient { if let error = error as? NWError { switch error { case .tls(let status): - return NWTLSError(status, reason: error.localizedDescription) + return NWTLSError(status, reason: String(describing: error)) case .posix(let errorCode): - return NWPOSIXError(errorCode, reason: error.localizedDescription) + return NWPOSIXError(errorCode, reason: String(describing: error)) default: return error } diff --git a/Sources/AsyncHTTPClient/NIOTransportServices/NWWaitingHandler.swift b/Sources/AsyncHTTPClient/NIOTransportServices/NWWaitingHandler.swift new file mode 100644 index 000000000..d7c6055ec --- /dev/null +++ b/Sources/AsyncHTTPClient/NIOTransportServices/NWWaitingHandler.swift @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if canImport(Network) +import Network +import NIOCore +import NIOHTTP1 +import NIOTransportServices + +@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) +final class NWWaitingHandler: ChannelInboundHandler { + typealias InboundIn = Any + typealias InboundOut = Any + + private var requester: Requester + private let connectionID: HTTPConnectionPool.Connection.ID + + init(requester: Requester, connectionID: HTTPConnectionPool.Connection.ID) { + self.requester = requester + self.connectionID = connectionID + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + if let waitingEvent = event as? NIOTSNetworkEvents.WaitingForConnectivity { + self.requester.waitingForConnectivity( + self.connectionID, + error: HTTPClient.NWErrorHandler.translateError(waitingEvent.transientError) + ) + } + context.fireUserInboundEventTriggered(event) + } +} +#endif diff --git a/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift b/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift index e20f52634..ef505e3b7 100644 --- a/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift +++ b/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift @@ -57,7 +57,7 @@ extension TLSVersion { } } -@available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) +@available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) extension TLSConfiguration { /// Dispatch queue used by Network framework TLS to control certificate verification static var tlsDispatchQueue = DispatchQueue(label: "TLSDispatch") @@ -66,11 +66,14 @@ extension TLSConfiguration { /// /// - Parameter eventLoop: EventLoop to wait for creation of options on /// - Returns: Future holding NWProtocolTLS Options - func getNWProtocolTLSOptions(on eventLoop: EventLoop) -> EventLoopFuture { + func getNWProtocolTLSOptions( + on eventLoop: EventLoop, + serverNameIndicatorOverride: String? + ) -> EventLoopFuture { let promise = eventLoop.makePromise(of: NWProtocolTLS.Options.self) Self.tlsDispatchQueue.async { do { - let options = try self.getNWProtocolTLSOptions() + let options = try self.getNWProtocolTLSOptions(serverNameIndicatorOverride: serverNameIndicatorOverride) promise.succeed(options) } catch { promise.fail(error) @@ -82,27 +85,42 @@ extension TLSConfiguration { /// create NWProtocolTLS.Options for use with NIOTransportServices from the NIOSSL TLSConfiguration /// /// - Returns: Equivalent NWProtocolTLS Options - func getNWProtocolTLSOptions() throws -> NWProtocolTLS.Options { + func getNWProtocolTLSOptions(serverNameIndicatorOverride: String?) throws -> NWProtocolTLS.Options { let options = NWProtocolTLS.Options() let useMTELGExplainer = """ - You can still use this configuration option on macOS if you initialize HTTPClient \ - with a MultiThreadedEventLoopGroup. Please note that using MultiThreadedEventLoopGroup \ - will make AsyncHTTPClient use NIO on BSD Sockets and not Network.framework (which is the preferred \ - platform networking stack). - """ + You can still use this configuration option on macOS if you initialize HTTPClient \ + with a MultiThreadedEventLoopGroup. Please note that using MultiThreadedEventLoopGroup \ + will make AsyncHTTPClient use NIO on BSD Sockets and not Network.framework (which is the preferred \ + platform networking stack). + """ + + if let serverNameIndicatorOverride = serverNameIndicatorOverride { + serverNameIndicatorOverride.withCString { serverNameIndicatorOverride in + sec_protocol_options_set_tls_server_name(options.securityProtocolOptions, serverNameIndicatorOverride) + } + } // minimum TLS protocol if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) { - sec_protocol_options_set_min_tls_protocol_version(options.securityProtocolOptions, self.minimumTLSVersion.nwTLSProtocolVersion) + sec_protocol_options_set_min_tls_protocol_version( + options.securityProtocolOptions, + self.minimumTLSVersion.nwTLSProtocolVersion + ) } else { - sec_protocol_options_set_tls_min_version(options.securityProtocolOptions, self.minimumTLSVersion.sslProtocol) + sec_protocol_options_set_tls_min_version( + options.securityProtocolOptions, + self.minimumTLSVersion.sslProtocol + ) } // maximum TLS protocol if let maximumTLSVersion = self.maximumTLSVersion { if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) { - sec_protocol_options_set_max_tls_protocol_version(options.securityProtocolOptions, maximumTLSVersion.nwTLSProtocolVersion) + sec_protocol_options_set_max_tls_protocol_version( + options.securityProtocolOptions, + maximumTLSVersion.nwTLSProtocolVersion + ) } else { sec_protocol_options_set_tls_max_version(options.securityProtocolOptions, maximumTLSVersion.sslProtocol) } @@ -115,11 +133,6 @@ extension TLSConfiguration { } } - // the certificate chain - if self.certificateChain.count > 0 { - preconditionFailure("TLSConfiguration.certificateChain is not supported. \(useMTELGExplainer)") - } - // cipher suites if self.cipherSuites.count > 0 { // TODO: Requires NIOSSL to provide list of cipher values before we can continue @@ -160,8 +173,10 @@ extension TLSConfiguration { break } - precondition(self.certificateVerification != .noHostnameVerification, - "TLSConfiguration.certificateVerification = .noHostnameVerification is not supported. \(useMTELGExplainer)") + precondition( + self.certificateVerification != .noHostnameVerification, + "TLSConfiguration.certificateVerification = .noHostnameVerification is not supported. \(useMTELGExplainer)" + ) if certificateVerification != .fullVerification || trustRoots != nil { // add verify block to control certificate verification @@ -195,7 +210,8 @@ extension TLSConfiguration { } } } - }, Self.tlsDispatchQueue + }, + Self.tlsDispatchQueue ) } return options diff --git a/Sources/AsyncHTTPClient/RedirectState.swift b/Sources/AsyncHTTPClient/RedirectState.swift index c4e427ef1..95de2d508 100644 --- a/Sources/AsyncHTTPClient/RedirectState.swift +++ b/Sources/AsyncHTTPClient/RedirectState.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -import struct Foundation.URL import NIOHTTP1 +import struct Foundation.URL + typealias RedirectMode = HTTPClient.Configuration.RedirectConfiguration.Mode struct RedirectState { diff --git a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift index 456317f11..37b2a42f0 100644 --- a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift +++ b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift @@ -12,18 +12,33 @@ // //===----------------------------------------------------------------------===// -import struct Foundation.URL import NIOCore import NIOHTTP1 +import struct Foundation.URL + +extension HTTPClient { + /// The maximum body size allowed, before a redirect response is cancelled. 3KB. + /// + /// Why 3KB? We feel like this is a good compromise between potentially reusing the + /// connection in HTTP/1.1 mode (if we load all data from the redirect response we can + /// reuse the connection) and not being to wasteful in the amount of data that is thrown + /// away being transferred. + fileprivate static let maxBodySizeRedirectResponse = 1024 * 3 +} + extension RequestBag { struct StateMachine { fileprivate enum State { - case initialized - case queued(HTTPRequestScheduler) + case initialized(RedirectHandler?) + case queued(HTTPRequestScheduler, RedirectHandler?) + /// if the deadline was exceeded while in the `.queued(_:)` state, + /// we wait until the request pool fails the request with a potential more descriptive error message, + /// if a connection failure has occured while the request was queued. + case deadlineExceededWhileQueued case executing(HTTPRequestExecutor, RequestStreamState, ResponseStreamState) case finished(error: Error?) - case redirected(HTTPResponseHead, URL) + case redirected(HTTPRequestExecutor, RedirectHandler, Int, HTTPResponseHead, URL) case modifying } @@ -41,23 +56,22 @@ extension RequestBag { case eof } - case initialized + case initialized(RedirectHandler?) case buffering(CircularBuffer, next: Next) case waitingForRemote } - private var state: State = .initialized - private let redirectHandler: RedirectHandler? + private var state: State init(redirectHandler: RedirectHandler?) { - self.redirectHandler = redirectHandler + self.state = .initialized(redirectHandler) } } } extension RequestBag.StateMachine { mutating func requestWasQueued(_ scheduler: HTTPRequestScheduler) { - guard case .initialized = self.state else { + guard case .initialized(let redirectHandler) = self.state else { // There might be a race between `requestWasQueued` and `willExecuteRequest`: // // If the request is created and passed to the HTTPClient on thread A, it will move into @@ -77,16 +91,26 @@ extension RequestBag.StateMachine { return } - self.state = .queued(scheduler) + self.state = .queued(scheduler, redirectHandler) + } + + enum WillExecuteRequestAction { + case cancelExecuter(HTTPRequestExecutor) + case failTaskAndCancelExecutor(Error, HTTPRequestExecutor) + case none } - mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> Bool { + mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> WillExecuteRequestAction { switch self.state { - case .initialized, .queued: - self.state = .executing(executor, .initialized, .initialized) - return true + case .initialized(let redirectHandler), .queued(_, let redirectHandler): + self.state = .executing(executor, .initialized, .initialized(redirectHandler)) + return .none + case .deadlineExceededWhileQueued: + let error: Error = HTTPClientError.deadlineExceeded + self.state = .finished(error: error) + return .failTaskAndCancelExecutor(error, executor) case .finished(error: .some): - return false + return .cancelExecuter(executor) case .executing, .redirected, .finished(error: .none), .modifying: preconditionFailure("Invalid state: \(self.state)") } @@ -100,11 +124,11 @@ extension RequestBag.StateMachine { mutating func resumeRequestBodyStream() -> ResumeProducingAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("A request stream can only be resumed, if the request was started") - case .executing(let executor, .initialized, .initialized): - self.state = .executing(executor, .producing, .initialized) + case .executing(let executor, .initialized, .initialized(let redirectHandler)): + self.state = .executing(executor, .producing, .initialized(redirectHandler)) return .startWriter case .executing(_, .producing, _): @@ -127,7 +151,11 @@ extension RequestBag.StateMachine { return .none case .finished: - preconditionFailure("Invalid state: \(self.state)") + // If this task has been cancelled we may be in an error state. As a matter of + // defensive programming, we also tolerate receiving this notification if we've ended cleanly: + // while it shouldn't happen, nothing will go wrong if we just ignore it. + // All paths through this state machine should cancel our request body stream to get here anyway. + return .none case .modifying: preconditionFailure("Invalid state: \(self.state)") @@ -136,7 +164,7 @@ extension RequestBag.StateMachine { mutating func pauseRequestBodyStream() { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("A request stream can only be paused, if the request was started") case .executing(let executor, let requestState, let responseState): switch requestState { @@ -171,7 +199,7 @@ extension RequestBag.StateMachine { mutating func writeNextRequestPart(_ part: IOData, taskEventLoop: EventLoop) -> WriteAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(let executor, let requestState, let responseState): switch requestState { @@ -217,7 +245,7 @@ extension RequestBag.StateMachine { mutating func finishRequestBodyStream(_ result: Result) -> FinishAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(let executor, let requestState, let responseState): switch requestState { @@ -255,33 +283,51 @@ extension RequestBag.StateMachine { } } + enum ReceiveResponseHeadAction { + case none + case forwardResponseHead(HTTPResponseHead) + case signalBodyDemand(HTTPRequestExecutor) + case redirect(HTTPRequestExecutor, RedirectHandler, HTTPResponseHead, URL) + } + /// The response head has been received. /// /// - Parameter head: The response' head /// - Returns: Whether the response should be forwarded to the delegate. Will be `false` if the request follows a redirect. - mutating func receiveResponseHead(_ head: HTTPResponseHead) -> Bool { + mutating func receiveResponseHead(_ head: HTTPResponseHead) -> ReceiveResponseHeadAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response, if the request hasn't started yet.") case .executing(let executor, let requestState, let responseState): - guard case .initialized = responseState else { + guard case .initialized(let redirectHandler) = responseState else { preconditionFailure("If we receive a response, we must not have received something else before") } - if let redirectURL = self.redirectHandler?.redirectTarget( - status: head.status, - responseHeaders: head.headers - ) { - self.state = .redirected(head, redirectURL) - return false + if let redirectHandler = redirectHandler, + let redirectURL = redirectHandler.redirectTarget( + status: head.status, + responseHeaders: head.headers + ) + { + // If we will redirect, we need to consume the response's body ASAP, to be able to + // reuse the existing connection. We will consume a response body, if the body is + // smaller than 3kb. + switch head.contentLength { + case .some(0...(HTTPClient.maxBodySizeRedirectResponse)), .none: + self.state = .redirected(executor, redirectHandler, 0, head, redirectURL) + return .signalBodyDemand(executor) + case .some: + self.state = .finished(error: HTTPClientError.cancelled) + return .redirect(executor, redirectHandler, head, redirectURL) + } } else { self.state = .executing(executor, requestState, .buffering(.init(), next: .askExecutorForMore)) - return true + return .forwardResponseHead(head) } case .redirected: preconditionFailure("This state can only be reached after we have received a HTTP head") case .finished(error: .some): - return false + return .none case .finished(error: .none): preconditionFailure("How can the request be finished without error, before receiving response head?") case .modifying: @@ -289,16 +335,25 @@ extension RequestBag.StateMachine { } } - mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ByteBuffer? { + enum ReceiveResponseBodyAction { + case none + case forwardResponsePart(ByteBuffer) + case signalBodyDemand(HTTPRequestExecutor) + case redirect(HTTPRequestExecutor, RedirectHandler, HTTPResponseHead, URL) + } + + mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ReceiveResponseBodyAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") case .executing(_, _, .initialized): preconditionFailure("If we receive a response body, we must have received a head before") case .executing(let executor, let requestState, .buffering(var currentBuffer, next: let next)): guard case .askExecutorForMore = next else { - preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") + preconditionFailure( + "If we have received an error or eof before, why did we get another body part? Next: \(next)" + ) } self.state = .modifying @@ -308,17 +363,30 @@ extension RequestBag.StateMachine { currentBuffer.append(contentsOf: buffer) } self.state = .executing(executor, requestState, .buffering(currentBuffer, next: next)) - return nil + return .none case .executing(let executor, let requestState, .waitingForRemote): - var buffer = buffer - let first = buffer.removeFirst() - self.state = .executing(executor, requestState, .buffering(buffer, next: .askExecutorForMore)) - return first - case .redirected: - // ignore body - return nil + if buffer.count > 0 { + var buffer = buffer + let first = buffer.removeFirst() + self.state = .executing(executor, requestState, .buffering(buffer, next: .askExecutorForMore)) + return .forwardResponsePart(first) + } else { + return .none + } + case .redirected(let executor, let redirectHandler, var receivedBytes, let head, let redirectURL): + let partsLength = buffer.reduce(into: 0) { $0 += $1.readableBytes } + receivedBytes += partsLength + + if receivedBytes > HTTPClient.maxBodySizeRedirectResponse { + self.state = .finished(error: HTTPClientError.cancelled) + return .redirect(executor, redirectHandler, head, redirectURL) + } else { + self.state = .redirected(executor, redirectHandler, receivedBytes, head, redirectURL) + return .signalBodyDemand(executor) + } + case .finished(error: .some): - return nil + return .none case .finished(error: .none): preconditionFailure("How can the request be finished without error, before receiving response head?") case .modifying: @@ -335,14 +403,16 @@ extension RequestBag.StateMachine { mutating func succeedRequest(_ newChunks: CircularBuffer?) -> ReceiveResponseEndAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") case .executing(_, _, .initialized): preconditionFailure("If we receive a response body, we must have received a head before") case .executing(let executor, let requestState, .buffering(var buffer, next: let next)): guard case .askExecutorForMore = next else { - preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") + preconditionFailure( + "If we have received an error or eof before, why did we get another body part? Next: \(next)" + ) } if buffer.isEmpty, let newChunks = newChunks, !newChunks.isEmpty { @@ -364,9 +434,9 @@ extension RequestBag.StateMachine { self.state = .executing(executor, requestState, .buffering(newChunks, next: .eof)) return .consume(first) - case .redirected(let head, let redirectURL): + case .redirected(_, let redirectHandler, _, let head, let redirectURL): self.state = .finished(error: nil) - return .redirect(self.redirectHandler!, head, redirectURL) + return .redirect(redirectHandler, head, redirectURL) case .finished(error: .some): return .none @@ -397,10 +467,12 @@ extension RequestBag.StateMachine { private mutating func failWithConsumptionError(_ error: Error) -> ConsumeAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(_, _, .initialized): - preconditionFailure("Invalid state: Must have received response head, before this method is called for the first time") + preconditionFailure( + "Invalid state: Must have received response head, before this method is called for the first time" + ) case .executing(_, _, .buffering(_, next: .error(let connectionError))): // if an error was received from the connection, we fail the task with the one @@ -413,17 +485,23 @@ extension RequestBag.StateMachine { return .failTask(error, executorToCancel: executor) case .executing(_, _, .waitingForRemote): - preconditionFailure("Invalid state... We just returned from a consumption function. We can't already be waiting") + preconditionFailure( + "Invalid state... We just returned from a consumption function. We can't already be waiting" + ) case .redirected: - preconditionFailure("Invalid state... Redirect don't call out to delegate functions. Thus we should never land here.") + preconditionFailure( + "Invalid state... Redirect don't call out to delegate functions. Thus we should never land here." + ) case .finished(error: .some): // don't overwrite existing errors return .doNothing case .finished(error: .none): - preconditionFailure("Invalid state... If no error occured, this must not be called, after the request was finished") + preconditionFailure( + "Invalid state... If no error occured, this must not be called, after the request was finished" + ) case .modifying: preconditionFailure() @@ -432,11 +510,13 @@ extension RequestBag.StateMachine { private mutating func consumeMoreBodyData() -> ConsumeAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(_, _, .initialized): - preconditionFailure("Invalid state: Must have received response head, before this method is called for the first time") + preconditionFailure( + "Invalid state: Must have received response head, before this method is called for the first time" + ) case .executing(let executor, let requestState, .buffering(var buffer, next: .askExecutorForMore)): self.state = .modifying @@ -466,7 +546,9 @@ extension RequestBag.StateMachine { return .failTask(error, executorToCancel: nil) case .executing(_, _, .waitingForRemote): - preconditionFailure("Invalid state... We just returned from a consumption function. We can't already be waiting") + preconditionFailure( + "Invalid state... We just returned from a consumption function. We can't already be waiting" + ) case .redirected: return .doNothing @@ -475,15 +557,42 @@ extension RequestBag.StateMachine { return .doNothing case .finished(error: .none): - preconditionFailure("Invalid state... If no error occurred, this must not be called, after the request was finished") + preconditionFailure( + "Invalid state... If no error occurred, this must not be called, after the request was finished" + ) case .modifying: preconditionFailure() } } + enum DeadlineExceededAction { + case cancelScheduler(HTTPRequestScheduler?) + case fail(FailAction) + } + + mutating func deadlineExceeded() -> DeadlineExceededAction { + switch self.state { + case .queued(let queuer, _): + /// We do not fail the request immediately because we want to give the scheduler a chance of throwing a better error message + /// We therefore depend on the scheduler failing the request after we cancel the request. + self.state = .deadlineExceededWhileQueued + return .cancelScheduler(queuer) + + case .initialized, + .deadlineExceededWhileQueued, + .executing, + .finished, + .redirected, + .modifying: + /// if we are not in the queued state, we can fail early by just calling down to `self.fail(_:)` + /// which does the appropriate state transition for us. + return .fail(self.fail(HTTPClientError.deadlineExceeded)) + } + } + enum FailAction { - case failTask(HTTPRequestScheduler?, HTTPRequestExecutor?) + case failTask(Error, HTTPRequestScheduler?, HTTPRequestExecutor?) case cancelExecutor(HTTPRequestExecutor) case none } @@ -492,31 +601,44 @@ extension RequestBag.StateMachine { switch self.state { case .initialized: self.state = .finished(error: error) - return .failTask(nil, nil) - case .queued(let queuer): + return .failTask(error, nil, nil) + case .queued(let queuer, _): self.state = .finished(error: error) - return .failTask(queuer, nil) + return .failTask(error, queuer, nil) case .executing(let executor, let requestState, .buffering(_, next: .eof)): self.state = .executing(executor, requestState, .buffering(.init(), next: .error(error))) return .cancelExecutor(executor) case .executing(let executor, _, .buffering(_, next: .askExecutorForMore)): self.state = .finished(error: error) - return .failTask(nil, executor) + return .failTask(error, nil, executor) case .executing(let executor, _, .buffering(_, next: .error(_))): // this would override another error, let's keep the first one return .cancelExecutor(executor) case .executing(let executor, _, .initialized): self.state = .finished(error: error) - return .failTask(nil, executor) + return .failTask(error, nil, executor) case .executing(let executor, _, .waitingForRemote): self.state = .finished(error: error) - return .failTask(nil, executor) + return .failTask(error, nil, executor) case .redirected: self.state = .finished(error: error) - return .failTask(nil, nil) + return .failTask(error, nil, nil) case .finished(.none): // An error occurred after the request has finished. Ignore... return .none + case .deadlineExceededWhileQueued: + let realError: Error = { + if (error as? HTTPClientError) == .cancelled { + /// if we just get a `HTTPClientError.cancelled` we can use the original cancellation reason + /// to give a more descriptive error to the user. + return HTTPClientError.deadlineExceeded + } else { + /// otherwise we already had an intermediate connection error which we should present to the user instead + return error + } + }() + self.state = .finished(error: realError) + return .failTask(realError, nil, nil) case .finished(.some(_)): // this might happen, if the stream consumer has failed... let's just drop the data return .none @@ -525,3 +647,12 @@ extension RequestBag.StateMachine { } } } + +extension HTTPResponseHead { + var contentLength: Int? { + guard let header = self.headers.first(name: "content-length") else { + return nil + } + return Int(header) + } +} diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index 9a40e9ff5..f2720d9ef 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -19,17 +19,30 @@ import NIOHTTP1 import NIOSSL final class RequestBag { + /// Defends against the call stack getting too large when consuming body parts. + /// + /// If the response body comes in lots of tiny chunks, we'll deliver those tiny chunks to users + /// one at a time. + private static var maxConsumeBodyPartStackDepth: Int { + 50 + } + + let poolKey: ConnectionPool.Key + let task: HTTPClient.Task var eventLoop: EventLoop { self.task.eventLoop } private let delegate: Delegate - private let request: HTTPClient.Request + private var request: HTTPClient.Request // the request state is synchronized on the task eventLoop private var state: StateMachine + // the consume body part stack depth is synchronized on the task event loop. + private var consumeBodyPartStackDepth: Int + // MARK: HTTPClientTask properties var logger: Logger { @@ -45,16 +58,20 @@ final class RequestBag { let eventLoopPreference: HTTPClient.EventLoopPreference - init(request: HTTPClient.Request, - eventLoopPreference: HTTPClient.EventLoopPreference, - task: HTTPClient.Task, - redirectHandler: RedirectHandler?, - connectionDeadline: NIODeadline, - requestOptions: RequestOptions, - delegate: Delegate) throws { + init( + request: HTTPClient.Request, + eventLoopPreference: HTTPClient.EventLoopPreference, + task: HTTPClient.Task, + redirectHandler: RedirectHandler?, + connectionDeadline: NIODeadline, + requestOptions: RequestOptions, + delegate: Delegate + ) throws { + self.poolKey = .init(request, dnsOverride: requestOptions.dnsOverride) self.eventLoopPreference = eventLoopPreference self.task = task self.state = .init(redirectHandler: redirectHandler) + self.consumeBodyPartStackDepth = 0 self.request = request self.connectionDeadline = connectionDeadline self.requestOptions = requestOptions @@ -81,8 +98,16 @@ final class RequestBag { private func willExecuteRequest0(_ executor: HTTPRequestExecutor) { self.task.eventLoop.assertInEventLoop() - if !self.state.willExecuteRequest(executor) { - return executor.cancelRequest(self) + let action = self.state.willExecuteRequest(executor) + switch action { + case .cancelExecuter(let executor): + executor.cancelRequest(self) + case .failTaskAndCancelExecutor(let error, let executor): + self.delegate.didReceiveError(task: self.task, error) + self.task.fail(with: error, delegateType: Delegate.self) + executor.cancelRequest(self) + case .none: + break } } @@ -106,6 +131,7 @@ final class RequestBag { guard let body = self.request.body else { preconditionFailure("Expected to have a body, if the `HTTPRequestStateMachine` resume a request stream") } + self.request.body = nil let writer = HTTPClient.Body.StreamWriter { self.writeNextRequestPart($0) @@ -154,8 +180,11 @@ final class RequestBag { return self.task.eventLoop.makeFailedFuture(error) case .write(let part, let writer, let future): - writer.writeRequestBodyPart(part, request: self) - self.delegate.didSendRequestPart(task: self.task, part) + let promise = self.task.eventLoop.makePromise(of: Void.self) + promise.futureResult.whenSuccess { + self.delegate.didSendRequestPart(task: self.task, part) + } + writer.writeRequestBodyPart(part, request: self, promise: promise) return future } } @@ -168,11 +197,12 @@ final class RequestBag { switch action { case .none: break - case .forwardStreamFinished(let writer, let promise): - writer.finishRequestBodyStream(self) - promise?.succeed(()) - - self.delegate.didSendRequest(task: self.task) + case .forwardStreamFinished(let writer, let writerPromise): + let promise = writerPromise ?? self.task.eventLoop.makePromise(of: Void.self) + promise.futureResult.whenSuccess { + self.delegate.didSendRequest(task: self.task) + } + writer.finishRequestBodyStream(self, promise: promise) case .forwardStreamFailureAndFailTask(let writer, let error, let promise): writer.cancelRequest(self) @@ -196,33 +226,49 @@ final class RequestBag { self.task.eventLoop.assertInEventLoop() // runs most likely on channel eventLoop - let forwardToDelegate = self.state.receiveResponseHead(head) + switch self.state.receiveResponseHead(head) { + case .none: + break - guard forwardToDelegate else { return } + case .signalBodyDemand(let executor): + executor.demandResponseBodyStream(self) - self.delegate.didReceiveHead(task: self.task, head) - .hop(to: self.task.eventLoop) - .whenComplete { result in - // After the head received, let's start to consume body data - self.consumeMoreBodyData0(resultOfPreviousConsume: result) - } + case .redirect(let executor, let handler, let head, let newURL): + handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + executor.cancelRequest(self) + + case .forwardResponseHead(let head): + self.delegate.didReceiveHead(task: self.task, head) + .hop(to: self.task.eventLoop) + .whenComplete { result in + // After the head received, let's start to consume body data + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } + } } private func receiveResponseBodyParts0(_ buffer: CircularBuffer) { self.task.eventLoop.assertInEventLoop() - let maybeForwardBuffer = self.state.receiveResponseBodyParts(buffer) + switch self.state.receiveResponseBodyParts(buffer) { + case .none: + break - guard let forwardBuffer = maybeForwardBuffer else { - return - } + case .signalBodyDemand(let executor): + executor.demandResponseBodyStream(self) - self.delegate.didReceiveBodyPart(task: self.task, forwardBuffer) - .hop(to: self.task.eventLoop) - .whenComplete { result in - // on task el - self.consumeMoreBodyData0(resultOfPreviousConsume: result) - } + case .redirect(let executor, let handler, let head, let newURL): + handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + executor.cancelRequest(self) + + case .forwardResponsePart(let part): + self.delegate.didReceiveBodyPart(task: self.task, part) + .hop(to: self.task.eventLoop) + .whenComplete { result in + // on task el + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } + } } private func succeedRequest0(_ buffer: CircularBuffer?) { @@ -236,14 +282,7 @@ final class RequestBag { self.delegate.didReceiveBodyPart(task: self.task, buffer) .hop(to: self.task.eventLoop) .whenComplete { - switch $0 { - case .success: - self.consumeMoreBodyData0(resultOfPreviousConsume: $0) - case .failure(let error): - // if in the response stream consumption an error has occurred, we need to - // cancel the running request and fail the task. - self.fail(error) - } + self.consumeMoreBodyData0(resultOfPreviousConsume: $0) } case .succeedRequest: @@ -262,18 +301,36 @@ final class RequestBag { private func consumeMoreBodyData0(resultOfPreviousConsume result: Result) { self.task.eventLoop.assertInEventLoop() + // We get defensive here about the maximum stack depth. It's possible for the `didReceiveBodyPart` + // future to be returned to us completed. If it is, we will recurse back into this method. To + // break that recursion we have a max stack depth which we increment and decrement in this method: + // if it gets too large, instead of recurring we'll insert an `eventLoop.execute`, which will + // manually break the recursion and unwind the stack. + // + // Note that we don't bother starting this at the various other call sites that _begin_ stacks + // that risk ending up in this loop. That's because we don't need an accurate count: our limit is + // a best-effort target anyway, one stack frame here or there does not put us at risk. We're just + // trying to prevent ourselves looping out of control. + self.consumeBodyPartStackDepth += 1 + defer { + self.consumeBodyPartStackDepth -= 1 + assert(self.consumeBodyPartStackDepth >= 0) + } + let consumptionAction = self.state.consumeMoreBodyData(resultOfPreviousConsume: result) switch consumptionAction { case .consume(let byteBuffer): self.delegate.didReceiveBodyPart(task: self.task, byteBuffer) .hop(to: self.task.eventLoop) - .whenComplete { - switch $0 { - case .success: - self.consumeMoreBodyData0(resultOfPreviousConsume: $0) - case .failure(let error): - self.fail(error) + .whenComplete { result in + if self.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth { + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } else { + // We need to unwind the stack, let's take a break. + self.task.eventLoop.execute { + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } } } @@ -300,8 +357,12 @@ final class RequestBag { let action = self.state.fail(error) + self.executeFailAction0(action) + } + + private func executeFailAction0(_ action: RequestBag.StateMachine.FailAction) { switch action { - case .failTask(let scheduler, let executor): + case .failTask(let error, let scheduler, let executor): scheduler?.cancelRequest(self) executor?.cancelRequest(self) self.failTask0(error) @@ -311,13 +372,31 @@ final class RequestBag { break } } -} -extension RequestBag: HTTPSchedulableRequest { - var poolKey: ConnectionPool.Key { - ConnectionPool.Key(self.request) + func deadlineExceeded0() { + self.task.eventLoop.assertInEventLoop() + let action = self.state.deadlineExceeded() + + switch action { + case .cancelScheduler(let scheduler): + scheduler?.cancelRequest(self) + case .fail(let failAction): + self.executeFailAction0(failAction) + } + } + + func deadlineExceeded() { + if self.task.eventLoop.inEventLoop { + self.deadlineExceeded0() + } else { + self.task.eventLoop.execute { + self.deadlineExceeded0() + } + } } +} +extension RequestBag: HTTPSchedulableRequest, HTTPClientTaskDelegate { var tlsConfiguration: TLSConfiguration? { self.request.tlsConfiguration } @@ -358,8 +437,8 @@ extension RequestBag: HTTPExecutableRequest { case .indifferent: return self.task.eventLoop case .delegate(let eventLoop), - .delegateAndChannel(on: let eventLoop), - .testOnly_exact(channelOn: let eventLoop, delegateOn: _): + .delegateAndChannel(on: let eventLoop), + .testOnly_exact(channelOn: let eventLoop, delegateOn: _): return eventLoop } } @@ -434,15 +513,3 @@ extension RequestBag: HTTPExecutableRequest { } } } - -extension RequestBag: HTTPClientTaskDelegate { - func cancel() { - if self.task.eventLoop.inEventLoop { - self.fail0(HTTPClientError.cancelled) - } else { - self.task.eventLoop.execute { - self.fail0(HTTPClientError.cancelled) - } - } - } -} diff --git a/Sources/AsyncHTTPClient/RequestValidation.swift b/Sources/AsyncHTTPClient/RequestValidation.swift index e23c35423..f338e06a9 100644 --- a/Sources/AsyncHTTPClient/RequestValidation.swift +++ b/Sources/AsyncHTTPClient/RequestValidation.swift @@ -21,6 +21,7 @@ extension HTTPHeaders { bodyLength: RequestBodyLength ) throws -> RequestFramingMetadata { try self.validateFieldNames() + try self.validateFieldValues() if case .TRACE = method { switch bodyLength { @@ -49,23 +50,23 @@ extension HTTPHeaders { let satisfy = name.utf8.allSatisfy { char -> Bool in switch char { case UInt8(ascii: "a")...UInt8(ascii: "z"), - UInt8(ascii: "A")...UInt8(ascii: "Z"), - UInt8(ascii: "0")...UInt8(ascii: "9"), - UInt8(ascii: "!"), - UInt8(ascii: "#"), - UInt8(ascii: "$"), - UInt8(ascii: "%"), - UInt8(ascii: "&"), - UInt8(ascii: "'"), - UInt8(ascii: "*"), - UInt8(ascii: "+"), - UInt8(ascii: "-"), - UInt8(ascii: "."), - UInt8(ascii: "^"), - UInt8(ascii: "_"), - UInt8(ascii: "`"), - UInt8(ascii: "|"), - UInt8(ascii: "~"): + UInt8(ascii: "A")...UInt8(ascii: "Z"), + UInt8(ascii: "0")...UInt8(ascii: "9"), + UInt8(ascii: "!"), + UInt8(ascii: "#"), + UInt8(ascii: "$"), + UInt8(ascii: "%"), + UInt8(ascii: "&"), + UInt8(ascii: "'"), + UInt8(ascii: "*"), + UInt8(ascii: "+"), + UInt8(ascii: "-"), + UInt8(ascii: "."), + UInt8(ascii: "^"), + UInt8(ascii: "_"), + UInt8(ascii: "`"), + UInt8(ascii: "|"), + UInt8(ascii: "~"): return true default: return false @@ -80,6 +81,56 @@ extension HTTPHeaders { } } + private func validateFieldValues() throws { + let invalidValues = self.compactMap { _, value -> String? in + let satisfy = value.utf8.allSatisfy { char -> Bool in + /// Validates a byte of a given header field value against the definition in RFC 9110. + /// + /// The spec in [RFC 9110](https://httpwg.org/specs/rfc9110.html#fields.values) defines the valid + /// characters as the following: + /// + /// ``` + /// field-value = *field-content + /// field-content = field-vchar + /// [ 1*( SP / HTAB / field-vchar ) field-vchar ] + /// field-vchar = VCHAR / obs-text + /// obs-text = %x80-FF + /// ``` + /// + /// Additionally, it makes the following note: + /// + /// "Field values containing CR, LF, or NUL characters are invalid and dangerous, due to the + /// varying ways that implementations might parse and interpret those characters; a recipient + /// of CR, LF, or NUL within a field value MUST either reject the message or replace each of + /// those characters with SP before further processing or forwarding of that message. Field + /// values containing other CTL characters are also invalid; however, recipients MAY retain + /// such characters for the sake of robustness when they appear within a safe context (e.g., + /// an application-specific quoted string that will not be processed by any downstream HTTP + /// parser)." + /// + /// As we cannot guarantee the context is safe, this code will reject all ASCII control characters + /// directly _except_ for HTAB, which is explicitly allowed. + switch char { + case UInt8(ascii: "\t"): + // HTAB, explicitly allowed. + return true + case 0...0x1f, 0x7F: + // ASCII control character, forbidden. + return false + default: + // Printable or non-ASCII, allowed. + return true + } + } + + return satisfy ? nil : value + } + + guard invalidValues.count == 0 else { + throw HTTPClientError.invalidHeaderFieldValues(invalidValues) + } + } + private mutating func setTransportFraming( method: HTTPMethod, bodyLength: RequestBodyLength @@ -115,13 +166,14 @@ extension HTTPHeaders { mutating func addHostIfNeeded(for url: DeconstructedURL) { // if no host header was set, let's use the url host guard !self.contains(name: "host"), - var host = url.connectionTarget.host + var host = url.connectionTarget.host else { return } // if the request uses a non-default port, we need to add it after the host if let port = url.connectionTarget.port, - port != url.scheme.defaultPort { + port != url.scheme.defaultPort + { host += ":\(port)" } self.add(name: "host", value: host) diff --git a/Sources/AsyncHTTPClient/SSLContextCache.swift b/Sources/AsyncHTTPClient/SSLContextCache.swift index 31ed106a0..599003e56 100644 --- a/Sources/AsyncHTTPClient/SSLContextCache.swift +++ b/Sources/AsyncHTTPClient/SSLContextCache.swift @@ -18,40 +18,50 @@ import NIOConcurrencyHelpers import NIOCore import NIOSSL -class SSLContextCache { - private let lock = Lock() +final class SSLContextCache { + private let lock = NIOLock() private var sslContextCache = LRUCache() private let offloadQueue = DispatchQueue(label: "io.github.swift-server.AsyncHTTPClient.SSLContextCache") } extension SSLContextCache { - func sslContext(tlsConfiguration: TLSConfiguration, - eventLoop: EventLoop, - logger: Logger) -> EventLoopFuture { + func sslContext( + tlsConfiguration: TLSConfiguration, + eventLoop: EventLoop, + logger: Logger + ) -> EventLoopFuture { let eqTLSConfiguration = BestEffortHashableTLSConfiguration(wrapping: tlsConfiguration) let sslContext = self.lock.withLock { self.sslContextCache.find(key: eqTLSConfiguration) } if let sslContext = sslContext { - logger.trace("found SSL context in cache", - metadata: ["ahc-tls-config": "\(tlsConfiguration)"]) + logger.trace( + "found SSL context in cache", + metadata: ["ahc-tls-config": "\(tlsConfiguration)"] + ) return eventLoop.makeSucceededFuture(sslContext) } - logger.trace("creating new SSL context", - metadata: ["ahc-tls-config": "\(tlsConfiguration)"]) + logger.trace( + "creating new SSL context", + metadata: ["ahc-tls-config": "\(tlsConfiguration)"] + ) let newSSLContext = self.offloadQueue.asyncWithFuture(eventLoop: eventLoop) { try NIOSSLContext(configuration: tlsConfiguration) } newSSLContext.whenSuccess { (newSSLContext: NIOSSLContext) -> Void in self.lock.withLock { () -> Void in - self.sslContextCache.append(key: eqTLSConfiguration, - value: newSSLContext) + self.sslContextCache.append( + key: eqTLSConfiguration, + value: newSSLContext + ) } } return newSSLContext } } + +extension SSLContextCache: @unchecked Sendable {} diff --git a/Sources/AsyncHTTPClient/Singleton.swift b/Sources/AsyncHTTPClient/Singleton.swift new file mode 100644 index 000000000..0ddf1bc40 --- /dev/null +++ b/Sources/AsyncHTTPClient/Singleton.swift @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +extension HTTPClient { + /// A globally shared, singleton ``HTTPClient``. + /// + /// The returned client uses the following settings: + /// - configuration is ``HTTPClient/Configuration/singletonConfiguration`` (matching the platform's default/prevalent browser as well as possible) + /// - `EventLoopGroup` is ``HTTPClient/defaultEventLoopGroup`` (matching the platform default) + /// - logging is disabled + public static var shared: HTTPClient { + globallySharedHTTPClient + } +} + +private let globallySharedHTTPClient: HTTPClient = { + let httpClient = HTTPClient( + eventLoopGroup: HTTPClient.defaultEventLoopGroup, + configuration: .singletonConfiguration, + backgroundActivityLogger: HTTPClient.loggingDisabled, + canBeShutDown: false + ) + return httpClient +}() diff --git a/Sources/AsyncHTTPClient/StringConvertibleInstances.swift b/Sources/AsyncHTTPClient/StringConvertibleInstances.swift index f75fb0d87..61d4b067a 100644 --- a/Sources/AsyncHTTPClient/StringConvertibleInstances.swift +++ b/Sources/AsyncHTTPClient/StringConvertibleInstances.swift @@ -14,6 +14,6 @@ extension HTTPClient.EventLoopPreference: CustomStringConvertible { public var description: String { - return "\(self.preference)" + "\(self.preference)" } } diff --git a/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift b/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift new file mode 100644 index 000000000..25f1225e0 --- /dev/null +++ b/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift @@ -0,0 +1,83 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// swift-format-ignore +// Note: Whitespace changes are used to workaround compiler bug +// https://github.com/swiftlang/swift/issues/79285 + +#if compiler(>=6.0) +@inlinable +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +internal func asyncDo( + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ body: () async throws -> sending R, finally: sending @escaping ((any Error)?) async throws -> Void) async throws -> sending R { + let result: R + do { + result = try await body() + } catch { + // `body` failed, we need to invoke `finally` with the `error`. + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(error) + }.value + throw error + } + + // `body` succeeded, we need to invoke `finally` with `nil` (no error). + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(nil) + }.value + return result +} +#else +@inlinable +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +internal func asyncDo( + _ body: () async throws -> R, + finally: @escaping @Sendable ((any Error)?) async throws -> Void +) async throws -> R { + let result: R + do { + result = try await body() + } catch { + // `body` failed, we need to invoke `finally` with the `error`. + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(error) + }.value + throw error + } + + // `body` succeeded, we need to invoke `finally` with `nil` (no error). + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(nil) + }.value + return result +} +#endif diff --git a/Sources/AsyncHTTPClient/Utils.swift b/Sources/AsyncHTTPClient/Utils.swift index f4154df3d..abdd5bbc2 100644 --- a/Sources/AsyncHTTPClient/Utils.swift +++ b/Sources/AsyncHTTPClient/Utils.swift @@ -14,21 +14,26 @@ import NIOCore +/// An ``HTTPClientResponseDelegate`` that wraps a callback. +/// +/// ``HTTPClientCopyingDelegate`` discards most parts of a HTTP response, but streams the body +/// to the `chunkHandler` provided on ``init(chunkHandler:)``. This is mostly useful for testing. public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate { public typealias Response = Void let chunkHandler: (ByteBuffer) -> EventLoopFuture - public init(chunkHandler: @escaping (ByteBuffer) -> EventLoopFuture) { + @preconcurrency + public init(chunkHandler: @Sendable @escaping (ByteBuffer) -> EventLoopFuture) { self.chunkHandler = chunkHandler } public func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - return self.chunkHandler(buffer) + self.chunkHandler(buffer) } public func didFinishRequest(task: HTTPClient.Task) throws { - return () + () } } @@ -39,7 +44,12 @@ public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate { /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. @inlinable internal func debugOnly(_ body: () -> Void) { - assert({ body(); return true }()) + assert( + { + body() + return true + }() + ) } extension BidirectionalCollection where Element: Equatable { @@ -56,8 +66,8 @@ extension BidirectionalCollection where Element: Equatable { guard self[ourIdx] == suffix[suffixIdx] else { return false } } guard suffixIdx == suffix.startIndex else { - return false // Exhausted self, but 'suffix' has elements remaining. + return false // Exhausted self, but 'suffix' has elements remaining. } - return true // Exhausted 'other' without finding a mismatch. + return true // Exhausted 'other' without finding a mismatch. } } diff --git a/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c b/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c index 2a09d04c9..6342da89f 100644 --- a/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c +++ b/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c @@ -15,7 +15,6 @@ #if __APPLE__ #include #elif __linux__ - #define _GNU_SOURCE #include #endif @@ -32,7 +31,11 @@ bool swiftahc_cshims_strptime(const char * string, const char * format, struct t bool swiftahc_cshims_strptime_l(const char * string, const char * format, struct tm * result, void * locale) { // The pointer cast is fine as long we make sure it really points to a locale_t. +#if defined(__musl__) || defined(__ANDROID__) + const char * firstNonProcessed = strptime(string, format, result); +#else const char * firstNonProcessed = strptime_l(string, format, result, (locale_t)locale); +#endif if (firstNonProcessed) { return *firstNonProcessed == 0; } diff --git a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift deleted file mode 100644 index 6a8d923c7..000000000 --- a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift +++ /dev/null @@ -1,46 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// AsyncAwaitEndToEndTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension AsyncAwaitEndToEndTests { - static var allTests: [(String, (AsyncAwaitEndToEndTests) -> () throws -> Void)] { - return [ - ("testSimpleGet", testSimpleGet), - ("testSimplePost", testSimplePost), - ("testPostWithByteBuffer", testPostWithByteBuffer), - ("testPostWithSequenceOfUInt8", testPostWithSequenceOfUInt8), - ("testPostWithCollectionOfUInt8", testPostWithCollectionOfUInt8), - ("testPostWithRandomAccessCollectionOfUInt8", testPostWithRandomAccessCollectionOfUInt8), - ("testPostWithAsyncSequenceOfByteBuffers", testPostWithAsyncSequenceOfByteBuffers), - ("testPostWithAsyncSequenceOfUInt8", testPostWithAsyncSequenceOfUInt8), - ("testPostWithFragmentedAsyncSequenceOfByteBuffers", testPostWithFragmentedAsyncSequenceOfByteBuffers), - ("testPostWithFragmentedAsyncSequenceOfLargeByteBuffers", testPostWithFragmentedAsyncSequenceOfLargeByteBuffers), - ("testCanceling", testCanceling), - ("testDeadline", testDeadline), - ("testImmediateDeadline", testImmediateDeadline), - ("testInvalidURL", testInvalidURL), - ("testRedirectChangesHostHeader", testRedirectChangesHostHeader), - ("testShutdown", testShutdown), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift index 2cd056225..c580164a0 100644 --- a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift +++ b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift @@ -12,14 +12,18 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore +import NIOFoundationCompat +import NIOHTTP1 import NIOPosix +import NIOSSL import XCTest +@testable import AsyncHTTPClient + private func makeDefaultHTTPClient( - eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .createNew + eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .singleton ) -> HTTPClient { var config = HTTPClient.Configuration() config.tlsConfiguration = .clientDefault @@ -32,10 +36,30 @@ private func makeDefaultHTTPClient( ) } +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) final class AsyncAwaitEndToEndTests: XCTestCase { + var clientGroup: EventLoopGroup! + var serverGroup: EventLoopGroup! + + override func setUp() { + XCTAssertNil(self.clientGroup) + XCTAssertNil(self.serverGroup) + + self.clientGroup = getDefaultEventLoopGroup(numberOfThreads: 1) + self.serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + } + + override func tearDown() { + XCTAssertNotNil(self.clientGroup) + XCTAssertNoThrow(try self.clientGroup.syncShutdownGracefully()) + self.clientGroup = nil + + XCTAssertNotNil(self.serverGroup) + XCTAssertNoThrow(try self.serverGroup.syncShutdownGracefully()) + self.serverGroup = nil + } + func testSimpleGet() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -44,21 +68,20 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.status, .ok) XCTAssertEqual(response.version, .http2) } - #endif } func testSimplePost() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -67,21 +90,20 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.status, .ok) XCTAssertEqual(response.version, .http2) } - #endif } func testPostWithByteBuffer() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -92,21 +114,22 @@ final class AsyncAwaitEndToEndTests: XCTestCase { request.method = .POST request.body = .bytes(ByteBuffer(string: "1234")) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], ["4"]) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif } func testPostWithSequenceOfUInt8() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -115,23 +138,24 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") request.method = .POST - request.body = .bytes(AnySequence("1234".utf8), length: .unknown) + request.body = .bytes(AnySendableSequence("1234".utf8), length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif } func testPostWithCollectionOfUInt8() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -140,23 +164,24 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") request.method = .POST - request.body = .bytes(AnyCollection("1234".utf8), length: .unknown) + request.body = .bytes(AnySendableCollection("1234".utf8), length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif } func testPostWithRandomAccessCollectionOfUInt8() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -167,21 +192,82 @@ final class AsyncAwaitEndToEndTests: XCTestCase { request.method = .POST request.body = .bytes(ByteBuffer(string: "1234").readableBytesView) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], ["4"]) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif + } + + struct AsyncSequenceByteBufferGenerator: AsyncSequence, Sendable, AsyncIteratorProtocol { + typealias Element = ByteBuffer + + let chunkSize: Int + let totalChunks: Int + let buffer: ByteBuffer + var chunksGenerated: Int = 0 + + init(chunkSize: Int, totalChunks: Int) { + self.chunkSize = chunkSize + self.totalChunks = totalChunks + self.buffer = ByteBuffer(repeating: 1, count: self.chunkSize) + } + + mutating func next() async throws -> ByteBuffer? { + guard self.chunksGenerated < self.totalChunks else { return nil } + + self.chunksGenerated += 1 + return self.buffer + } + + func makeAsyncIterator() -> AsyncSequenceByteBufferGenerator { + self + } + } + + func testEchoStreamThatHas3GBInTotal() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let bin = HTTPBin(.http1_1()) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let client: HTTPClient = makeDefaultHTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) + defer { XCTAssertNoThrow(try client.syncShutdown()) } + + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + + var request = HTTPClientRequest(url: "http://localhost:\(bin.port)/") + request.method = .POST + + let sequence = AsyncSequenceByteBufferGenerator( + chunkSize: 4_194_304, // 4MB chunk + totalChunks: 768 // Total = 3GB + ) + request.body = .stream(sequence, length: .unknown) + + let response: HTTPClientResponse = try await client.execute( + request, + deadline: .now() + .seconds(30), + logger: logger + ) + XCTAssertEqual(response.headers["content-length"], []) + + var receivedBytes: Int64 = 0 + for try await part in response.body { + receivedBytes += Int64(part.readableBytes) + } + XCTAssertEqual(receivedBytes, 3_221_225_472) // 3GB } func testPostWithAsyncSequenceOfByteBuffers() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -190,27 +276,31 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") request.method = .POST - request.body = .stream([ - ByteBuffer(string: "1"), - ByteBuffer(string: "2"), - ByteBuffer(string: "34"), - ].asAsyncSequence(), length: .unknown) + request.body = .stream( + [ + ByteBuffer(string: "1"), + ByteBuffer(string: "2"), + ByteBuffer(string: "34"), + ].async, + length: .unknown + ) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif } func testPostWithAsyncSequenceOfUInt8() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -219,23 +309,24 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") request.method = .POST - request.body = .stream("1234".utf8.asAsyncSequence(), length: .unknown) + request.body = .stream("1234".utf8.async, length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif } func testPostWithFragmentedAsyncSequenceOfByteBuffers() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -247,9 +338,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let streamWriter = AsyncSequenceWriter() request.body = .stream(streamWriter, length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) let fragments = [ @@ -260,24 +353,25 @@ final class AsyncAwaitEndToEndTests: XCTestCase { var bodyIterator = response.body.makeAsyncIterator() for expectedFragment in fragments { streamWriter.write(expectedFragment) - guard let actualFragment = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let actualFragment = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(expectedFragment, actualFragment) } streamWriter.end() - guard let lastResult = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let lastResult = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(lastResult, nil) } - #endif } func testPostWithFragmentedAsyncSequenceOfLargeByteBuffers() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -289,9 +383,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let streamWriter = AsyncSequenceWriter() request.body = .stream(streamWriter, length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) let fragments = [ @@ -303,24 +399,25 @@ final class AsyncAwaitEndToEndTests: XCTestCase { var bodyIterator = response.body.makeAsyncIterator() for expectedFragment in fragments { streamWriter.write(expectedFragment) - guard let actualFragment = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let actualFragment = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(expectedFragment, actualFragment) } streamWriter.end() - guard let lastResult = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let lastResult = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(lastResult, nil) } - #endif } func testCanceling() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest(timeout: 5) { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -337,15 +434,40 @@ final class AsyncAwaitEndToEndTests: XCTestCase { } task.cancel() await XCTAssertThrowsError(try await task.value) { error in - XCTAssertEqual(error as? HTTPClientError, .cancelled) + XCTAssertTrue(error is CancellationError, "unexpected error \(error)") + } + } + } + + func testCancelingResponseBody() { + XCTAsyncTest(timeout: 5) { + let bin = HTTPBin(.http2(compress: false)) { _ in + HTTPEchoHandler() } + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/handler") + request.method = .POST + let streamWriter = AsyncSequenceWriter() + request.body = .stream(streamWriter, length: .unknown) + let response = try await client.execute(request, deadline: .now() + .seconds(2), logger: logger) + streamWriter.write(.init(bytes: [1])) + let task = Task { + try await response.body.collect(upTo: 1024 * 1024) + } + task.cancel() + + await XCTAssertThrowsError(try await task.value) { error in + XCTAssertTrue(error is CancellationError, "unexpected error \(error)") + } + + streamWriter.end() } - #endif } func testDeadline() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest(timeout: 5) { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -361,16 +483,18 @@ final class AsyncAwaitEndToEndTests: XCTestCase { guard let error = error as? HTTPClientError else { return XCTFail("unexpected error \(error)") } - // a race between deadline and connect timer can result in either error - XCTAssertTrue([.deadlineExceeded, .connectTimeout].contains(error)) + // a race between deadline and connect timer can result in either error. + // If closing happens really fast we might shutdown the pipeline before we fail the request. + // If the pipeline is closed we may receive a `.remoteConnectionClosed`. + XCTAssertTrue( + [.deadlineExceeded, .connectTimeout, .remoteConnectionClosed].contains(error), + "unexpected error \(error)" + ) } } - #endif } func testImmediateDeadline() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest(timeout: 5) { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -386,32 +510,223 @@ final class AsyncAwaitEndToEndTests: XCTestCase { guard let error = error as? HTTPClientError else { return XCTFail("unexpected error \(error)") } - // a race between deadline and connect timer can result in either error - XCTAssertTrue([.deadlineExceeded, .connectTimeout].contains(error)) + // a race between deadline and connect timer can result in either error. + // If closing happens really fast we might shutdown the pipeline before we fail the request. + // If the pipeline is closed we may receive a `.remoteConnectionClosed`. + XCTAssertTrue( + [.deadlineExceeded, .connectTimeout, .remoteConnectionClosed].contains(error), + "unexpected error \(error)" + ) + } + } + } + + func testConnectTimeout() { + XCTAsyncTest(timeout: 60) { + #if os(Linux) + // 198.51.100.254 is reserved for documentation only and therefore should not accept any TCP connection + let url = "http://198.51.100.254/get" + #else + // on macOS we can use the TCP backlog behaviour when the queue is full to simulate a non reachable server. + // this makes this test a bit more stable if `198.51.100.254` actually responds to connection attempt. + // The backlog behaviour on Linux can not be used to simulate a non-reachable server. + // Linux sends a `SYN/ACK` back even if the `backlog` queue is full as it has two queues. + // The second queue is not limit by `ChannelOptions.backlog` but by `/proc/sys/net/ipv4/tcp_max_syn_backlog`. + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let serverChannel = try await ServerBootstrap(group: self.serverGroup) + .serverChannelOption(ChannelOptions.backlog, value: 1) + .serverChannelOption(ChannelOptions.autoRead, value: false) + .bind(host: "127.0.0.1", port: 0) + .get() + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + let port = serverChannel.localAddress!.port! + let firstClientChannel = try await ClientBootstrap(group: self.serverGroup) + .connect(host: "127.0.0.1", port: port) + .get() + defer { + XCTAssertNoThrow(try firstClientChannel.close().wait()) } + let url = "http://localhost:\(port)/get" + #endif + + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150))) + ) + + defer { + XCTAssertNoThrow(try httpClient.syncShutdown()) + } + + let request = HTTPClientRequest(url: url) + let start = NIODeadline.now() + await XCTAssertThrowsError(try await httpClient.execute(request, deadline: .now() + .seconds(30))) { + XCTAssertEqualTypeAndValue($0, HTTPClientError.connectTimeout) + let end = NIODeadline.now() + let duration = end - start + + // We give ourselves 10x slack in order to be confident that even on slow machines this assertion passes. + // It's 30x smaller than our other timeout though. + XCTAssertLessThan(duration, .seconds(1)) + } + } + } + + func testSelfSignedCertificateIsRejectedWithCorrectErrorIfRequestDeadlineIsExceeded() { + XCTAsyncTest(timeout: 5) { + /// key + cert was created with the follwing command: + /// openssl req -x509 -newkey rsa:4096 -keyout self_signed_key.pem -out self_signed_cert.pem -sha256 -days 99999 -nodes -subj '/CN=localhost' + let certPath = Bundle.module.path(forResource: "self_signed_cert", ofType: "pem")! + let keyPath = Bundle.module.path(forResource: "self_signed_key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) + let configuration = TLSConfiguration.makeServerConfiguration( + certificateChain: try NIOSSLCertificate.fromPEMFile(certPath).map { .certificate($0) }, + privateKey: .privateKey(key) + ) + let sslContext = try NIOSSLContext(configuration: configuration) + let serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try serverGroup.syncShutdownGracefully()) } + let server = ServerBootstrap(group: serverGroup) + .childChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + } + let serverChannel = try await server.bind(host: "localhost", port: 0).get() + defer { XCTAssertNoThrow(try serverChannel.close().wait()) } + let port = serverChannel.localAddress!.port! + + let config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() + + let localClient = HTTPClient(eventLoopGroupProvider: .singleton, configuration: config) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + let request = HTTPClientRequest(url: "https://localhost:\(port)") + await XCTAssertThrowsError(try await localClient.execute(request, deadline: .now() + .seconds(2))) { + error in + #if canImport(Network) + guard let nwTLSError = error as? HTTPClient.NWTLSError else { + XCTFail("could not cast \(error) of type \(type(of: error)) to \(HTTPClient.NWTLSError.self)") + return + } + XCTAssertEqual(nwTLSError.status, errSSLBadCert, "unexpected tls error: \(nwTLSError)") + #else + guard let sslError = error as? NIOSSLError, + case .handshakeFailed(.sslError) = sslError + else { + XCTFail("unexpected error \(error)") + return + } + #endif + } + } + } + + func testDnsOverride() { + XCTAsyncTest(timeout: 5) { + /// key + cert was created with the following code (depends on swift-certificates) + /// ``` + /// import X509 + /// import CryptoKit + /// import Foundation + /// + /// let privateKey = P384.Signing.PrivateKey() + /// let name = try DistinguishedName { + /// OrganizationName("Self Signed") + /// CommonName("localhost") + /// } + /// let certificate = try Certificate( + /// version: .v3, + /// serialNumber: .init(), + /// publicKey: .init(privateKey.publicKey), + /// notValidBefore: Date(), + /// notValidAfter: Date().advanced(by: 365 * 24 * 3600), + /// issuer: name, + /// subject: name, + /// signatureAlgorithm: .ecdsaWithSHA384, + /// extensions: try .init { + /// SubjectAlternativeNames([.dnsName("example.com")]) + /// try ExtendedKeyUsage([.serverAuth]) + /// }, + /// issuerPrivateKey: .init(privateKey) + /// ) + /// ``` + let certPath = Bundle.module.path(forResource: "example.com.cert", ofType: "pem")! + let keyPath = Bundle.module.path(forResource: "example.com.private-key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) + let localhostCert = try NIOSSLCertificate.fromPEMFile(certPath) + let configuration = TLSConfiguration.makeServerConfiguration( + certificateChain: localhostCert.map { .certificate($0) }, + privateKey: .privateKey(key) + ) + let bin = HTTPBin(.http2(tlsConfiguration: configuration)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + var config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() + var tlsConfig = TLSConfiguration.makeClientConfiguration() + + tlsConfig.trustRoots = .certificates(localhostCert) + config.tlsConfiguration = tlsConfig + // this is the actual configuration under test + config.dnsOverride = ["example.com": "localhost"] + + let localClient = HTTPClient(eventLoopGroupProvider: .singleton, configuration: config) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + let request = HTTPClientRequest(url: "https://example.com:\(bin.port)/echohostheader") + let response = await XCTAssertNoThrowWithResult( + try await localClient.execute(request, deadline: .now() + .seconds(2)) + ) + XCTAssertEqual(response?.status, .ok) + XCTAssertEqual(response?.version, .http2) + var body = try await response?.body.collect(upTo: 1024) + let readableBytes = body?.readableBytes ?? 0 + let responseInfo = try body?.readJSONDecodable(RequestInfo.self, length: readableBytes) + XCTAssertEqual(responseInfo?.data, "example.com\(bin.port == 443 ? "" : ":\(bin.port)")") } - #endif } func testInvalidURL() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest(timeout: 5) { let client = makeDefaultHTTPClient() defer { XCTAssertNoThrow(try client.syncShutdown()) } let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) - let request = HTTPClientRequest(url: "") // invalid URL + let request = HTTPClientRequest(url: "") // invalid URL - await XCTAssertThrowsError(try await client.execute(request, deadline: .now() + .seconds(2), logger: logger)) { + await XCTAssertThrowsError( + try await client.execute(request, deadline: .now() + .seconds(2), logger: logger) + ) { XCTAssertEqual($0 as? HTTPClientError, .invalidURL) } } - #endif + } + + func testInsanelyHighConcurrentHTTP1ConnectionLimitDoesNotCrash() async throws { + let bin = HTTPBin(.http1_1(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + var httpClientConfig = HTTPClient.Configuration() + httpClientConfig.connectionPool = .init( + idleTimeout: .hours(1), + concurrentHTTP1ConnectionsPerHostSoftLimit: Int.max + ) + httpClientConfig.timeout = .init(connect: .seconds(10), read: .seconds(100), write: .seconds(100)) + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: httpClientConfig) + defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } + + let request = HTTPClientRequest(url: "http://localhost:\(bin.port)") + _ = try await httpClient.execute(request, deadline: .now() + .seconds(2)) } func testRedirectChangesHostHeader() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -419,14 +734,21 @@ final class AsyncAwaitEndToEndTests: XCTestCase { defer { XCTAssertNoThrow(try client.syncShutdown()) } let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://127.0.0.1:\(bin.port)/redirect/target") - request.headers.replaceOrAdd(name: "X-Target-Redirect-URL", value: "https://localhost:\(bin.port)/echohostheader") + request.headers.replaceOrAdd( + name: "X-Target-Redirect-URL", + value: "https://localhost:\(bin.port)/echohostheader" + ) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { + return + } + guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect(upTo: 1024)) else { return } - guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect()) else { return } var maybeRequestInfo: RequestInfo? XCTAssertNoThrow(maybeRequestInfo = try JSONDecoder().decode(RequestInfo.self, from: body)) guard let requestInfo = maybeRequestInfo else { return } @@ -435,12 +757,9 @@ final class AsyncAwaitEndToEndTests: XCTestCase { XCTAssertEqual(response.version, .http2) XCTAssertEqual(requestInfo.data, "localhost:\(bin.port)") } - #endif } func testShutdown() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let client = makeDefaultHTTPClient() try await client.shutdown() @@ -448,17 +767,289 @@ final class AsyncAwaitEndToEndTests: XCTestCase { XCTAssertEqualTypeAndValue(error, HTTPClientError.alreadyShutdown) } } - #endif } -} -#if compiler(>=5.5.2) && canImport(_Concurrency) -extension AsyncSequence where Element == ByteBuffer { - func collect() async rethrows -> ByteBuffer { - try await self.reduce(into: ByteBuffer()) { accumulatingBuffer, nextBuffer in - var nextBuffer = nextBuffer - accumulatingBuffer.writeBuffer(&nextBuffer) + /// Regression test for https://github.com/swift-server/async-http-client/issues/612 + func testCancelingBodyDoesNotCrash() { + XCTAsyncTest { + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let bin = HTTPBin(.http2(compress: true)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let request = HTTPClientRequest(url: "https://127.0.0.1:\(bin.port)/mega-chunked") + let response = try await client.execute(request, deadline: .now() + .seconds(10)) + + await XCTAssertThrowsError(try await response.body.collect(upTo: 100)) { error in + XCTAssert(error is NIOTooManyBytesError) + } + } + } + + func testAsyncSequenceReuse() { + XCTAsyncTest { + let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") + request.method = .POST + request.body = .stream( + [ + ByteBuffer(string: "1"), + ByteBuffer(string: "2"), + ByteBuffer(string: "34"), + ].async, + length: .unknown + ) + + guard + let response1 = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } + XCTAssertEqual(response1.headers["content-length"], []) + guard + let body = await XCTAssertNoThrowWithResult( + try await response1.body.collect(upTo: 1024) + ) + else { return } + XCTAssertEqual(body, ByteBuffer(string: "1234")) + + guard + let response2 = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } + XCTAssertEqual(response2.headers["content-length"], []) + guard + let body = await XCTAssertNoThrowWithResult( + try await response2.body.collect(upTo: 1024) + ) + else { return } + XCTAssertEqual(body, ByteBuffer(string: "1234")) + } + } + + func testRejectsInvalidCharactersInHeaderFieldNames_http1() { + self._rejectsInvalidCharactersInHeaderFieldNames(mode: .http1_1(ssl: true)) + } + + func testRejectsInvalidCharactersInHeaderFieldNames_http2() { + self._rejectsInvalidCharactersInHeaderFieldNames(mode: .http2(compress: false)) + } + + private func _rejectsInvalidCharactersInHeaderFieldNames(mode: HTTPBin.Mode) { + XCTAsyncTest { + let bin = HTTPBin(mode) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + + // The spec in [RFC 9110](https://httpwg.org/specs/rfc9110.html#fields.values) defines the valid + // characters as the following: + // + // ``` + // field-name = token + // + // token = 1*tchar + // + // tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" + // / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" + // / DIGIT / ALPHA + // ; any VCHAR, except delimiters + let weirdAllowedFieldName = "!#$%&'*+-.^_`|~0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + request.headers.add(name: weirdAllowedFieldName, value: "present") + + // This should work fine. + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { + return + } + + XCTAssertEqual(response.status, .ok) + + // Now, let's confirm all other bytes are rejected. We want to stay within the ASCII space as the HTTPHeaders type will forbid anything else. + for byte in UInt8(0)...UInt8(127) { + // Skip bytes that we already believe are allowed. + if weirdAllowedFieldName.utf8.contains(byte) { + continue + } + let forbiddenFieldName = weirdAllowedFieldName + String(decoding: [byte], as: UTF8.self) + + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + request.headers.add(name: forbiddenFieldName, value: "present") + + await XCTAssertThrowsError( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) { error in + XCTAssertEqual(error as? HTTPClientError, .invalidHeaderFieldNames([forbiddenFieldName])) + } + } + } + } + + func testRejectsInvalidCharactersInHeaderFieldValues_http1() { + self._rejectsInvalidCharactersInHeaderFieldValues(mode: .http1_1(ssl: true)) + } + + func testRejectsInvalidCharactersInHeaderFieldValues_http2() { + self._rejectsInvalidCharactersInHeaderFieldValues(mode: .http2(compress: false)) + } + + private func _rejectsInvalidCharactersInHeaderFieldValues(mode: HTTPBin.Mode) { + XCTAsyncTest { + let bin = HTTPBin(mode) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + + // We reject all ASCII control characters except HTAB and tolerate everything else. + let weirdAllowedFieldValue = + "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" + + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + request.headers.add(name: "Weird-Value", value: weirdAllowedFieldValue) + + // This should work fine. + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { + return + } + + XCTAssertEqual(response.status, .ok) + + // Now, let's confirm all other bytes in the ASCII range ar rejected + for byte in UInt8(0)...UInt8(127) { + // Skip bytes that we already believe are allowed. + if weirdAllowedFieldValue.utf8.contains(byte) { + continue + } + let forbiddenFieldValue = weirdAllowedFieldValue + String(decoding: [byte], as: UTF8.self) + + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + request.headers.add(name: "Weird-Value", value: forbiddenFieldValue) + + await XCTAssertThrowsError( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) { error in + XCTAssertEqual(error as? HTTPClientError, .invalidHeaderFieldValues([forbiddenFieldValue])) + } + } + + // All the bytes outside the ASCII range are fine though. + for byte in UInt8(128)...UInt8(255) { + let evenWeirderAllowedValue = weirdAllowedFieldValue + String(decoding: [byte], as: UTF8.self) + + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + request.headers.add(name: "Weird-Value", value: evenWeirderAllowedValue) + + // This should work fine. + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { + return + } + + XCTAssertEqual(response.status, .ok) + } + } + } + + func testUsingGetMethodInsteadOfWait() { + XCTAsyncTest { + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let request = try HTTPClient.Request(url: "https://localhost:\(bin.port)/get") + + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request: request).get() + ) + else { + return + } + + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.version, .http2) + } + } + + func testSimpleContentLengthErrorNoBody() { + XCTAsyncTest { + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/content-length-without-body") + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } + await XCTAssertThrowsError( + try await response.body.collect(upTo: 3) + ) { + XCTAssertEqualTypeAndValue($0, NIOTooManyBytesError(maxBytes: 3)) + } } } } -#endif + +struct AnySendableSequence: @unchecked Sendable { + private let wrapped: AnySequence + init( + _ sequence: WrappedSequence + ) where WrappedSequence.Element == Element { + self.wrapped = .init(sequence) + } +} + +extension AnySendableSequence: Sequence { + func makeIterator() -> AnySequence.Iterator { + self.wrapped.makeIterator() + } +} + +struct AnySendableCollection: @unchecked Sendable { + private let wrapped: AnyCollection + init( + _ collection: WrappedCollection + ) where WrappedCollection.Element == Element { + self.wrapped = .init(collection) + } +} + +extension AnySendableCollection: Collection { + var startIndex: AnyCollection.Index { + self.wrapped.startIndex + } + + var endIndex: AnyCollection.Index { + self.wrapped.endIndex + } + + func index(after i: AnyIndex) -> AnyIndex { + self.wrapped.index(after: i) + } + + subscript(position: AnyCollection.Index) -> Element { + self.wrapped[position] + } +} diff --git a/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift b/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift index 312008959..cbab922a4 100644 --- a/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift +++ b/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift @@ -12,12 +12,12 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) import NIOConcurrencyHelpers import NIOCore +/// ``AsyncSequenceWriter`` is `Sendable` because its state is protected by a Lock @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -class AsyncSequenceWriter: AsyncSequence { +final class AsyncSequenceWriter: AsyncSequence, @unchecked Sendable { typealias AsyncIterator = Iterator struct Iterator: AsyncIteratorProtocol { @@ -33,7 +33,7 @@ class AsyncSequenceWriter: AsyncSequence { } func makeAsyncIterator() -> Iterator { - return Iterator(self) + Iterator(self) } private enum State { @@ -44,7 +44,7 @@ class AsyncSequenceWriter: AsyncSequence { } private var _state = State.buffering(.init(), nil) - private let lock = Lock() + private let lock = NIOLock() public var hasDemand: Bool { self.lock.withLock { @@ -117,7 +117,9 @@ class AsyncSequenceWriter: AsyncSequence { case .waiting: let state = self._state self.lock.unlock() - preconditionFailure("Expected that there is always only one concurrent call to next. Invalid state: \(state)") + preconditionFailure( + "Expected that there is always only one concurrent call to next. Invalid state: \(state)" + ) } } @@ -191,4 +193,3 @@ class AsyncSequenceWriter: AsyncSequence { } } } -#endif diff --git a/Tests/AsyncHTTPClientTests/ConnectionPoolSizeConfigValueIsRespectedTests.swift b/Tests/AsyncHTTPClientTests/ConnectionPoolSizeConfigValueIsRespectedTests.swift new file mode 100644 index 000000000..962791334 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/ConnectionPoolSizeConfigValueIsRespectedTests.swift @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class ConnectionPoolSizeConfigValueIsRespectedTests: XCTestCaseHTTPClientTestsBaseClass { + func testConnectionPoolSizeConfigValueIsRespected() { + let numberOfRequestsPerThread = 1000 + let numberOfParallelWorkers = 16 + let poolSize = 12 + + let httpBin = HTTPBin() + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 4) + defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } + + let configuration = HTTPClient.Configuration( + connectionPool: .init( + idleTimeout: .seconds(30), + concurrentHTTP1ConnectionsPerHostSoftLimit: poolSize + ) + ) + let client = HTTPClient(eventLoopGroupProvider: .shared(group), configuration: configuration) + defer { XCTAssertNoThrow(try client.syncShutdown()) } + + let g = DispatchGroup() + for workerID in 0.. Void = { _ in }) throws { let part = try self.readOutbound(as: HTTPClientRequestPart.self) @@ -77,7 +78,7 @@ extension EmbeddedChannel { channel: self, connectionID: 1, delegate: connectionDelegate, - configuration: .init(), + decompression: .disabled, logger: logger ) @@ -86,8 +87,8 @@ extension EmbeddedChannel { let decoder = try self.pipeline.syncOperations.handler(type: ByteToMessageHandler.self) let encoder = try self.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self) - let removeDecoderFuture = self.pipeline.removeHandler(decoder) - let removeEncoderFuture = self.pipeline.removeHandler(encoder) + let removeDecoderFuture = self.pipeline.syncOperations.removeHandler(decoder) + let removeEncoderFuture = self.pipeline.syncOperations.removeHandler(encoder) self.embeddedEventLoop.run() @@ -111,6 +112,6 @@ public struct HTTP1EmbeddedChannelError: Error, Hashable, CustomStringConvertibl } public var description: String { - return self.reason + self.reason } } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift deleted file mode 100644 index 8d28c15c4..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift +++ /dev/null @@ -1,36 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP1ClientChannelHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP1ClientChannelHandlerTests { - static var allTests: [(String, (HTTP1ClientChannelHandlerTests) -> () throws -> Void)] { - return [ - ("testResponseBackpressure", testResponseBackpressure), - ("testWriteBackpressure", testWriteBackpressure), - ("testClientHandlerCancelsRequestIfWeWantToShutdown", testClientHandlerCancelsRequestIfWeWantToShutdown), - ("testIdleReadTimeout", testIdleReadTimeout), - ("testIdleReadTimeoutIsCanceledIfRequestIsCanceled", testIdleReadTimeoutIsCanceledIfRequestIsCanceled), - ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift index 62e42f94e..df1a2926a 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class HTTP1ClientChannelHandlerTests: XCTestCase { func testResponseBackpressure() { let embedded = EmbeddedChannel() @@ -32,27 +33,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -113,22 +122,30 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 100) { writer in - testWriter.start(writer: writer) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 100) { writer in + testWriter.start(writer: writer) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } // the handler only writes once the channel is writable @@ -143,12 +160,14 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { testWriter.writabilityChanged(true) embedded.pipeline.fireChannelWritabilityChanged() - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .POST) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - XCTAssertEqual($0.headers.first(name: "content-length"), "100") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .POST) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + XCTAssertEqual($0.headers.first(name: "content-length"), "100") + } + ) // the next body write will be executed once we tick the el. before we make the channel // unwritable @@ -162,9 +181,11 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { embedded.embeddedEventLoop.run() - XCTAssertNoThrow(try embedded.receiveBodyAndVerify { - XCTAssertEqual($0.readableBytes, 2) - }) + XCTAssertNoThrow( + try embedded.receiveBodyAndVerify { + XCTAssertEqual($0.readableBytes, 2) + } + ) XCTAssertEqual(testWriter.written, index + 1) @@ -201,24 +222,28 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) XCTAssertTrue(embedded.isActive) @@ -247,27 +272,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -299,27 +332,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -327,7 +368,7 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) // canceling the request - requestBag.cancel() + requestBag.fail(HTTPClientError.cancelled) XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { XCTAssertEqual($0 as? HTTPClientError, .cancelled) } @@ -337,6 +378,217 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) } + func testIdleWriteTimeout() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + return testWriter.start(writer: writer) + } + ) + ) + + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutRaceToEnd() { + let embedded = EmbeddedChannel() + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream { _ in + // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. + let scheduled = embedded.embeddedEventLoop.flatScheduleTask(in: .milliseconds(2)) { + embedded.embeddedEventLoop.makeSucceededVoidFuture() + } + return scheduled.futureResult + } + ) + ) + + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(5)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + let expectedHeaders: HTTPHeaders = ["host": "localhost", "Transfer-Encoding": "chunked"] + XCTAssertEqual( + try embedded.readOutbound(as: HTTPClientRequestPart.self), + .head(HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: expectedHeaders)) + ) + + // change the writability to false. + embedded.isWritable = false + embedded.pipeline.fireChannelWritabilityChanged() + embedded.embeddedEventLoop.run() + + // let the writer, write an end (while writability is false) + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + + XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) + } + + func testIdleWriteTimeoutWritabilityChanged() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + embedded.isWritable = false + embedded.pipeline.fireChannelWritabilityChanged() + // This should not trigger any errors or timeouts, because the timer isn't running + // as the channel is not writable. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(20)) + + // Now that the channel will become writable, this should trigger a timeout. + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + + return testWriter.start(writer: writer) + } + ) + ) + + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutIsCancelledIfRequestIsCancelled() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 1) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 2) { writer in + testWriter.start(writer: writer, expectedErrors: [HTTPClientError.cancelled]) + } + ) + ) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + + // canceling the request + requestBag.fail(HTTPClientError.cancelled) + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + + // the idle write timeout should be cleared because we canceled the request + // therefore advancing the time should not trigger a crash + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + } + func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand() { let embedded = EmbeddedChannel() var maybeTestUtils: HTTP1TestTools? @@ -349,27 +601,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "50")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "50")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -394,6 +654,247 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { XCTAssertEqual($0 as? HTTPClientError, .remoteConnectionClosed) } } + + func testWriteHTTPHeadFails() { + struct WriteError: Error, Equatable {} + + class FailWriteHandler: ChannelOutboundHandler { + typealias OutboundIn = HTTPClientRequestPart + typealias OutboundOut = HTTPClientRequestPart + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let error = WriteError() + promise?.fail(error) + context.fireErrorCaught(error) + } + } + + let bodies: [HTTPClient.Body?] = [ + .none, + .some(.byteBuffer(ByteBuffer(string: "hello world"))), + ] + + for body in bodies { + let embedded = EmbeddedChannel() + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + XCTAssertNoThrow( + try embedded.pipeline.syncOperations.addHandler( + FailWriteHandler(), + position: .after(testUtils.readEventHandler) + ) + ) + + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: body)) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to be able to create a request bag") + } + + embedded.isWritable = false + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + embedded.write(requestBag, promise: nil) + + // the handler only writes once the channel is writable + XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .none) + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? WriteError, WriteError()) + } + + XCTAssertEqual(embedded.isActive, false) + } + } + + func testHandlerClosesChannelIfLastActionIsSendEndAndItFails() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + testWriter.start(writer: writer) + } + ) + ) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + XCTAssertNoThrow(try embedded.pipeline.addHandler(FailEndHandler(), position: .first).wait()) + + // Execute the request and we'll receive the head. + testWriter.writabilityChanged(true) + testUtils.connection.executeRequest(requestBag) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .POST) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + XCTAssertEqual($0.headers.first(name: "content-length"), "10") + } + ) + // We're going to immediately send the response head and end. + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) + embedded.read() + + // Send the end and confirm the connection is still live. + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.end(nil))) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionClosed, 0) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 0) + + // Ok, now we can process some reads. We expect 5 reads, but we do _not_ expect an .end, because + // the `FailEndHandler` is going to fail it. + embedded.embeddedEventLoop.run() + XCTAssertEqual(testWriter.written, 5) + for _ in 0..<5 { + XCTAssertNoThrow( + try embedded.receiveBodyAndVerify { + XCTAssertEqual($0.readableBytes, 2) + } + ) + } + + embedded.embeddedEventLoop.run() + XCTAssertNil(try embedded.readOutbound(as: HTTPClientRequestPart.self)) + + // We should have seen the connection close, and the request is complete. + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionClosed, 1) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 0) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { error in + XCTAssertTrue(error is FailEndHandler.Error) + } + } + + func testChannelBecomesNonWritableDuringHeaderWrite() throws { + final class ChangeWritabilityOnFlush: ChannelOutboundHandler { + typealias OutboundIn = Any + func flush(context: ChannelHandlerContext) { + context.flush() + (context.channel as! EmbeddedChannel).isWritable = false + context.fireChannelWritabilityChanged() + } + } + let eventLoopGroup = EmbeddedEventLoopGroup(loops: 1) + let eventLoop = eventLoopGroup.next() as! EmbeddedEventLoop + let handler = HTTP1ClientChannelHandler( + eventLoop: eventLoop, + backgroundLogger: Logger(label: "no-op", factory: SwiftLogNoOpLogHandler.init), + connectionIdLoggerMetadata: "test connection" + ) + let channel = EmbeddedChannel( + handlers: [ + ChangeWritabilityOnFlush(), + handler, + ], + loop: eventLoop + ) + try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait() + + let request = MockHTTPExecutableRequest() + // non empty body is important to trigger this bug as we otherwise finish the request in a single flush + request.requestFramingMetadata.body = .fixedSize(1) + request.raiseErrorIfUnimplementedMethodIsCalled = false + channel.writeAndFlush(request, promise: nil) + XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent]) + } + + func testIdleWriteTimeoutOutsideOfRunningState() { + let embedded = EmbeddedChannel() + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + print("pipeline", embedded.pipeline) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/")) + guard var request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + // start a request stream we'll never write to + let streamPromise = embedded.eventLoop.makePromise(of: Void.self) + let streamCallback = { @Sendable (streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture in + streamPromise.futureResult + } + request.body = .init(contentLength: nil, stream: streamCallback) + + let accumulator = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests( + idleReadTimeout: .milliseconds(10), + idleWriteTimeout: .milliseconds(2) + ), + delegate: accumulator + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + testUtils.connection.executeRequest(requestBag) + + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) + + // close the pipeline to simulate a server-side close + // note this happens before we write so the idle write timeout is still running + try! embedded.pipeline.close().wait() + + // advance time to trigger the idle write timeout + // and ensure that the state machine can tolerate this + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + } } class TestBackpressureWriter { @@ -414,7 +915,7 @@ class TestBackpressureWriter { self.finishPromise = eventLoop.makePromise(of: Void.self) } - func start(writer: HTTPClient.Body.StreamWriter) -> EventLoopFuture { + func start(writer: HTTPClient.Body.StreamWriter, expectedErrors: [HTTPClientError] = []) -> EventLoopFuture { func recursive() { XCTAssert(self.eventLoop.inEventLoop) XCTAssert(self.channelIsWritable) @@ -429,7 +930,15 @@ class TestBackpressureWriter { case .success: recursive() case .failure(let error): - XCTFail("Unexpected error: \(error)") + let isExpectedError = expectedErrors.contains { httpError in + if let castError = error as? HTTPClientError { + return castError == httpError + } + return false + } + if !isExpectedError { + XCTFail("Unexpected error: \(error)") + } } } } @@ -476,7 +985,10 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { return newPromise.futureResult case .waitingForRemote(var promiseBuffer): - assert(!promiseBuffer.isEmpty, "assert expected to be waiting if we have at least one promise in the buffer") + assert( + !promiseBuffer.isEmpty, + "assert expected to be waiting if we have at least one promise in the buffer" + ) let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) promiseBuffer.append(promise) self.state = .waitingForRemote(promiseBuffer) @@ -515,7 +1027,10 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { switch self.state { case .waitingForRemote(var promiseBuffer): - assert(!promiseBuffer.isEmpty, "assert expected to be waiting if we have at least one promise in the buffer") + assert( + !promiseBuffer.isEmpty, + "assert expected to be waiting if we have at least one promise in the buffer" + ) let promise = promiseBuffer.removeFirst() if promiseBuffer.isEmpty { let newBackpressurePromise = self.eventLoop.makePromise(of: Void.self) @@ -534,7 +1049,9 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { return promise.futureResult case .buffering(.some): - preconditionFailure("Did receive response part should not be called, before the previous promise was succeeded.") + preconditionFailure( + "Did receive response part should not be called, before the previous promise was succeeded." + ) case .done, .consuming: preconditionFailure("Invalid state: \(self.state)") @@ -544,8 +1061,8 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { func didFinishRequest(task: HTTPClient.Task) throws { switch self.state { case .waitingForRemote(let promiseBuffer): - promiseBuffer.forEach { - $0.succeed(.none) + for promise in promiseBuffer { + promise.succeed(.none) } self.state = .done @@ -556,7 +1073,9 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { preconditionFailure("Invalid state: \(self.state)") case .buffering(.some): - preconditionFailure("Did receive response part should not be called, before the previous promise was succeeded.") + preconditionFailure( + "Did receive response part should not be called, before the previous promise was succeeded." + ) } } } @@ -573,3 +1092,19 @@ class ReadEventHitHandler: ChannelOutboundHandler { context.read() } } + +final class FailEndHandler: ChannelOutboundHandler { + typealias OutboundIn = HTTPClientRequestPart + typealias OutboundOut = HTTPClientRequestPart + + struct Error: Swift.Error {} + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + if case .end = self.unwrapOutboundIn(data) { + // We fail this. + promise?.fail(Self.Error()) + } else { + context.write(data, promise: promise) + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift deleted file mode 100644 index 76a37936b..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift +++ /dev/null @@ -1,48 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP1ConnectionStateMachineTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP1ConnectionStateMachineTests { - static var allTests: [(String, (HTTP1ConnectionStateMachineTests) -> () throws -> Void)] { - return [ - ("testPOSTRequestWithWriteAndReadBackpressure", testPOSTRequestWithWriteAndReadBackpressure), - ("testResponseReadingWithBackpressure", testResponseReadingWithBackpressure), - ("testAConnectionCloseHeaderInTheRequestLeadsToConnectionCloseAfterRequest", testAConnectionCloseHeaderInTheRequestLeadsToConnectionCloseAfterRequest), - ("testAHTTP1_0ResponseWithoutKeepAliveHeaderLeadsToConnectionCloseAfterRequest", testAHTTP1_0ResponseWithoutKeepAliveHeaderLeadsToConnectionCloseAfterRequest), - ("testAHTTP1_0ResponseWithKeepAliveHeaderLeadsToConnectionBeingKeptAlive", testAHTTP1_0ResponseWithKeepAliveHeaderLeadsToConnectionBeingKeptAlive), - ("testAConnectionCloseHeaderInTheResponseLeadsToConnectionCloseAfterRequest", testAConnectionCloseHeaderInTheResponseLeadsToConnectionCloseAfterRequest), - ("testNIOTriggersChannelActiveTwice", testNIOTriggersChannelActiveTwice), - ("testIdleConnectionBecomesInactive", testIdleConnectionBecomesInactive), - ("testConnectionGoesAwayWhileInRequest", testConnectionGoesAwayWhileInRequest), - ("testRequestWasCancelledWhileUploadingData", testRequestWasCancelledWhileUploadingData), - ("testCancelRequestIsIgnoredWhenConnectionIsIdle", testCancelRequestIsIgnoredWhenConnectionIsIdle), - ("testReadsAreForwardedIfConnectionIsClosing", testReadsAreForwardedIfConnectionIsClosing), - ("testChannelReadsAreIgnoredIfConnectionIsClosing", testChannelReadsAreIgnoredIfConnectionIsClosing), - ("testRequestIsCancelledWhileWaitingForWritable", testRequestIsCancelledWhileWaitingForWritable), - ("testConnectionIsClosedIfErrorHappensWhileInRequest", testConnectionIsClosedIfErrorHappensWhileInRequest), - ("testConnectionIsClosedAfterSwitchingProtocols", testConnectionIsClosedAfterSwitchingProtocols), - ("testWeDontCrashAfterEarlyHintsAndConnectionClose", testWeDontCrashAfterEarlyHintsAndConnectionClose), - ("testWeDontCrashInRaceBetweenSchedulingNewRequestAndConnectionClose", testWeDontCrashInRaceBetweenSchedulingNewRequestAndConnectionClose), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift index c8ad3d510..1c6e9659f 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOHTTP1 import NIOHTTPCompression import XCTest +@testable import AsyncHTTPClient + class HTTP1ConnectionStateMachineTests: XCTestCase { func testPOSTRequestWithWriteAndReadBackpressure() { var state = HTTP1ConnectionStateMachine() @@ -26,31 +27,38 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) - XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) let part2 = IOData.byteBuffer(ByteBuffer(bytes: [2])) let part3 = IOData.byteBuffer(ByteBuffer(bytes: [3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) // oh the channel reports... we should slow down producing... XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) // but we issued a .produceMoreRequestBodyData before... Thus, we must accept more produced // data - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) // however when we have put the data on the channel, we should not issue further // .produceMoreRequestBodyData events // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) - XCTAssertEqual(state.requestStreamPartReceived(part3), .sendBodyPart(part3)) - XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd) + XCTAssertEqual(state.requestStreamPartReceived(part3, promise: nil), .sendBodyPart(part3, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.informConnectionIsIdle, .init([responseBody]))) @@ -64,10 +72,17 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "12"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -86,16 +101,43 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.read(), .read) } + func testWriteTimeoutAfterErrorDoesntCrash() { + var state = HTTP1ConnectionStateMachine() + XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) + + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) + let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) + + struct MyError: Error, Equatable {} + XCTAssertEqual(state.errorHappened(MyError()), .failRequest(MyError(), .close(nil))) + + // Primarily we care that we don't crash here + XCTAssertEqual(state.idleWriteTimeoutTriggered(), .wait) + } + func testAConnectionCloseHeaderInTheRequestLeadsToConnectionCloseAfterRequest() { var state = HTTP1ConnectionStateMachine() XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: ["connection": "close"]) let metadata = RequestFramingMetadata(connectionClose: true, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) @@ -108,10 +150,17 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .ok, headers: ["content-length": "4"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) @@ -124,10 +173,21 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) - - let responseHead = HTTPResponseHead(version: .http1_0, status: .ok, headers: ["content-length": "4", "connection": "keep-alive"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_0, + status: .ok, + headers: ["content-length": "4", "connection": "keep-alive"] + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.informConnectionIsIdle, .init([responseBody]))) @@ -141,10 +201,17 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["connection": "close"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) @@ -170,9 +237,11 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) + + XCTAssertEqual(state.headSent(), .wait) } func testRequestWasCancelledWhileUploadingData() { @@ -182,13 +251,33 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) - XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) - XCTAssertEqual(state.requestCancelled(closeConnection: false), .failRequest(HTTPClientError.cancelled, .close)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) + XCTAssertEqual( + state.requestCancelled(closeConnection: false), + .failRequest(HTTPClientError.cancelled, .close(nil)) + ) + } + + func testNewRequestAfterErrorHappened() { + var state = HTTP1ConnectionStateMachine() + XCTAssertEqual(state.channelActive(isWritable: false), .fireChannelActive) + struct MyError: Error, Equatable {} + XCTAssertEqual(state.errorHappened(MyError()), .fireChannelError(MyError(), closeConnection: true)) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) + let action = state.runNewRequest(head: requestHead, metadata: metadata) + guard case .failRequest = action else { + return XCTFail("unexpected action \(action)") + } } func testCancelRequestIsIgnoredWhenConnectionIsIdle() { @@ -196,9 +285,17 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) XCTAssertEqual(state.requestCancelled(closeConnection: false), .wait, "Should be ignored.") XCTAssertEqual(state.requestCancelled(closeConnection: true), .close, "Should lead to connection closure.") - XCTAssertEqual(state.requestCancelled(closeConnection: true), .wait, "Should be ignored. Connection is already closing") + XCTAssertEqual( + state.requestCancelled(closeConnection: true), + .wait, + "Should be ignored. Connection is already closing" + ) XCTAssertEqual(state.channelInactive(), .fireChannelInactive) - XCTAssertEqual(state.requestCancelled(closeConnection: true), .wait, "Should be ignored. Connection is already closed") + XCTAssertEqual( + state.requestCancelled(closeConnection: true), + .wait, + "Should be ignored. Connection is already closed" + ) } func testReadsAreForwardedIfConnectionIsClosing() { @@ -226,7 +323,10 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) - XCTAssertEqual(state.requestCancelled(closeConnection: false), .failRequest(HTTPClientError.cancelled, .informConnectionIsIdle)) + XCTAssertEqual( + state.requestCancelled(closeConnection: false), + .failRequest(HTTPClientError.cancelled, .informConnectionIsIdle) + ) } func testConnectionIsClosedIfErrorHappensWhileInRequest() { @@ -235,13 +335,20 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: "Hello world!\n"))), .wait) XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: "Foo Bar!\n"))), .wait) let decompressionError = NIOHTTPDecompression.DecompressionError.limit - XCTAssertEqual(state.errorHappened(decompressionError), .failRequest(decompressionError, .close)) + XCTAssertEqual(state.errorHappened(decompressionError), .failRequest(decompressionError, .close(nil))) } func testConnectionIsClosedAfterSwitchingProtocols() { @@ -250,9 +357,16 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, [])) } @@ -262,8 +376,15 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) - let responseHead = HTTPResponseHead(version: .http1_1, status: .init(statusCode: 103, reasonPhrase: "Early Hints")) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .init(statusCode: 103, reasonPhrase: "Early Hints") + ) XCTAssertEqual(state.channelRead(.head(responseHead)), .wait) XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) } @@ -291,12 +412,20 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { case (.fireChannelInactive, .fireChannelInactive): return true + case (.fireChannelError(_, let lhsCloseConnection), .fireChannelError(_, let rhsCloseConnection)): + return lhsCloseConnection == rhsCloseConnection case (.sendRequestHead(let lhsHead, let lhsStartBody), .sendRequestHead(let rhsHead, let rhsStartBody)): return lhsHead == rhsHead && lhsStartBody == rhsStartBody - case (.sendBodyPart(let lhsData), .sendBodyPart(let rhsData)): - return lhsData == rhsData + case ( + .notifyRequestHeadSendSuccessfully(let lhsResumeRequestBodyStream, let lhsStartIdleTimer), + .notifyRequestHeadSendSuccessfully(let rhsResumeRequestBodyStream, let rhsStartIdleTimer) + ): + return lhsResumeRequestBodyStream == rhsResumeRequestBodyStream && lhsStartIdleTimer == rhsStartIdleTimer + + case (.sendBodyPart(let lhsData, let lhsPromise), .sendBodyPart(let rhsData, let rhsPromise)): + return lhsData == rhsData && lhsPromise?.futureResult == rhsPromise?.futureResult case (.sendRequestEnd, .sendRequestEnd): return true @@ -306,13 +435,19 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { case (.resumeRequestBodyStream, .resumeRequestBodyStream): return true - case (.forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream)): + case ( + .forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), + .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream) + ): return lhsHead == rhsHead && lhsPauseRequestBodyStream == rhsPauseRequestBodyStream case (.forwardResponseBodyParts(let lhsData), .forwardResponseBodyParts(let rhsData)): return lhsData == rhsData - case (.succeedRequest(let lhsFinalAction, let lhsFinalBuffer), .succeedRequest(let rhsFinalAction, let rhsFinalBuffer)): + case ( + .succeedRequest(let lhsFinalAction, let lhsFinalBuffer), + .succeedRequest(let rhsFinalAction, let rhsFinalBuffer) + ): return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer case (.failRequest(_, let lhsFinalAction), .failRequest(_, let rhsFinalAction)): @@ -332,3 +467,42 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { } } } + +extension HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction: Equatable { + public static func == ( + lhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction, + rhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction + ) -> Bool { + switch (lhs, rhs) { + case (.close, .close): + return true + case (sendRequestEnd(let lhsPromise, let lhsShouldClose), sendRequestEnd(let rhsPromise, let rhsShouldClose)): + return lhsPromise?.futureResult == rhsPromise?.futureResult && lhsShouldClose == rhsShouldClose + case (informConnectionIsIdle, informConnectionIsIdle): + return true + default: + return false + } + } +} + +extension HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction: Equatable { + public static func == ( + lhs: HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction, + rhs: HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction + ) -> Bool { + switch (lhs, rhs) { + case (.close(let lhsPromise), .close(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + case (.informConnectionIsIdle, .informConnectionIsIdle): + return true + case (.failWritePromise(let lhsPromise), .failWritePromise(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + case (.none, .none): + return true + + default: + return false + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift deleted file mode 100644 index 95b3e5dac..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift +++ /dev/null @@ -1,42 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP1ConnectionTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP1ConnectionTests { - static var allTests: [(String, (HTTP1ConnectionTests) -> () throws -> Void)] { - return [ - ("testCreateNewConnectionWithDecompression", testCreateNewConnectionWithDecompression), - ("testCreateNewConnectionWithoutDecompression", testCreateNewConnectionWithoutDecompression), - ("testCreateNewConnectionFailureClosedIO", testCreateNewConnectionFailureClosedIO), - ("testGETRequest", testGETRequest), - ("testConnectionClosesOnCloseHeader", testConnectionClosesOnCloseHeader), - ("testConnectionClosesOnRandomlyAppearingCloseHeader", testConnectionClosesOnRandomlyAppearingCloseHeader), - ("testConnectionClosesAfterTheRequestWithoutHavingSentAnCloseHeader", testConnectionClosesAfterTheRequestWithoutHavingSentAnCloseHeader), - ("testConnectionIsClosedAfterSwitchingProtocols", testConnectionIsClosedAfterSwitchingProtocols), - ("testConnectionDropAfterEarlyHints", testConnectionDropAfterEarlyHints), - ("testConnectionIsClosedIfResponseIsReceivedBeforeRequest", testConnectionIsClosedIfResponseIsReceivedBeforeRequest), - ("testDoubleHTTPResponseLine", testDoubleHTTPResponseLine), - ("testDownloadStreamingBackpressure", testDownloadStreamingBackpressure), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift index 3575a6080..5f980bccb 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOConcurrencyHelpers import NIOCore @@ -23,6 +22,8 @@ import NIOPosix import NIOTestUtils import XCTest +@testable import AsyncHTTPClient + class HTTP1ConnectionTests: XCTestCase { func testCreateNewConnectionWithDecompression() { let embedded = EmbeddedChannel() @@ -31,16 +32,20 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) var connection: HTTP1Connection? - XCTAssertNoThrow(connection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: MockHTTP1ConnectionDelegate(), - configuration: .init(decompression: .enabled(limit: .ratio(4))), - logger: logger - )) + XCTAssertNoThrow( + connection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self)) - XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self)) + XCTAssertNotNil( + try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self) + ) XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) XCTAssertNoThrow(try connection?.close().wait()) @@ -54,17 +59,22 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) - XCTAssertNoThrow(try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: MockHTTP1ConnectionDelegate(), - configuration: .init(decompression: .disabled), - logger: logger - )) + XCTAssertNoThrow( + try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + decompression: .disabled, + logger: logger + ) + ) XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self)) - XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self)) - XCTAssertThrowsError(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) { error in + XCTAssertNotNil( + try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self) + ) + XCTAssertThrowsError(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) { + error in XCTAssertEqual(error as? ChannelPipelineError, .notFound) } } @@ -78,13 +88,15 @@ class HTTP1ConnectionTests: XCTestCase { embedded.embeddedEventLoop.run() let logger = Logger(label: "test.http1.connection") - XCTAssertThrowsError(try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: MockHTTP1ConnectionDelegate(), - configuration: .init(), - logger: logger - )) + XCTAssertThrowsError( + try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + decompression: .disabled, + logger: logger + ) + ) } func testGETRequest() { @@ -106,37 +118,39 @@ class HTTP1ConnectionTests: XCTestCase { channel: $0, connectionID: 0, delegate: delegate, - configuration: .init(decompression: .disabled), + decompression: .disabled, logger: logger ) } .wait() var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "http://localhost/hello/swift", - method: .POST, - body: .stream(length: 4) { writer -> EventLoopFuture in - func recursive(count: UInt8, promise: EventLoopPromise) { - guard count < 4 else { - return promise.succeed(()) - } + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/hello/swift", + method: .POST, + body: .stream(contentLength: 4) { writer -> EventLoopFuture in + @Sendable func recursive(count: UInt8, promise: EventLoopPromise) { + guard count < 4 else { + return promise.succeed(()) + } - writer.write(.byteBuffer(ByteBuffer(bytes: [count]))).whenComplete { result in - switch result { - case .failure(let error): - XCTFail("Unexpected error: \(error)") - case .success: - recursive(count: count + 1, promise: promise) + writer.write(.byteBuffer(ByteBuffer(bytes: [count]))).whenComplete { result in + switch result { + case .failure(let error): + XCTFail("Unexpected error: \(error)") + case .success: + recursive(count: count + 1, promise: promise) + } } } - } - let promise = clientEL.makePromise(of: Void.self) - recursive(count: 0, promise: promise) - return promise.futureResult - } - )) + let promise = clientEL.makePromise(of: Void.self) + recursive(count: 0, promise: promise) + return promise.futureResult + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a connection and a request") @@ -145,33 +159,39 @@ class HTTP1ConnectionTests: XCTestCase { let task = HTTPClient.Task(eventLoop: clientEL, logger: logger) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: clientEL), - task: task, - redirectHandler: nil, - connectionDeadline: .now() + .seconds(60), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: request) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: clientEL), + task: task, + redirectHandler: nil, + connectionDeadline: .now() + .seconds(60), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: request) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } connection.executeRequest(requestBag) - XCTAssertNoThrow(try server.receiveHeadAndVerify { head in - XCTAssertEqual(head.method, .POST) - XCTAssertEqual(head.uri, "/hello/swift") - XCTAssertEqual(head.headers["content-length"].first, "4") - }) + XCTAssertNoThrow( + try server.receiveHeadAndVerify { head in + XCTAssertEqual(head.method, .POST) + XCTAssertEqual(head.uri, "/hello/swift") + XCTAssertEqual(head.headers["content-length"].first, "4") + } + ) var received: UInt8 = 0 while received < 4 { - XCTAssertNoThrow(try server.receiveBodyAndVerify { body in - var body = body - while let read = body.readInteger(as: UInt8.self) { - XCTAssertEqual(received, read) - received += 1 + XCTAssertNoThrow( + try server.receiveBodyAndVerify { body in + var body = body + while let read = body.readInteger(as: UInt8.self) { + XCTAssertEqual(received, read) + received += 1 + } } - }) + ) } XCTAssertEqual(received, 4) XCTAssertNoThrow(try server.receiveEnd()) @@ -198,17 +218,23 @@ class HTTP1ConnectionTests: XCTestCase { var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait() + ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try eventLoop.submit { try HTTP1Connection.start( - channel: XCTUnwrap(maybeChannel), - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(), - logger: logger - ) }.wait()) + XCTAssertNoThrow( + maybeConnection = try eventLoop.submit { + try HTTP1Connection.start( + channel: XCTUnwrap(maybeChannel), + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ) + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } var maybeRequest: HTTPClient.Request? @@ -217,15 +243,17 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: eventLoopGroup.next()), - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: eventLoopGroup.next()), + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } connection.executeRequest(requestBag) @@ -248,21 +276,29 @@ class HTTP1ConnectionTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let closeOnRequest = (30...100).randomElement()! - let httpBin = HTTPBin(handlerFactory: { _ in SuddenlySendsCloseHeaderChannelHandler(closeOnRequest: closeOnRequest) }) + let httpBin = HTTPBin(handlerFactory: { _ in + SuddenlySendsCloseHeaderChannelHandler(closeOnRequest: closeOnRequest) + }) var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait() + ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try eventLoop.submit { try HTTP1Connection.start( - channel: XCTUnwrap(maybeChannel), - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(), - logger: logger - ) }.wait()) + XCTAssertNoThrow( + maybeConnection = try eventLoop.submit { + try HTTP1Connection.start( + channel: XCTUnwrap(maybeChannel), + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ) + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } var counter = 0 @@ -275,16 +311,20 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: eventLoopGroup.next()), - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) - guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: eventLoopGroup.next()), + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to be able to create a request bag") + } connection.executeRequest(requestBag) @@ -293,7 +333,7 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertEqual(response?.status, .ok) if response?.headers.first(name: "connection") == "close" { - break // the loop + break // the loop } else { XCTAssertEqual(httpBin.activeConnections, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, counter) @@ -306,8 +346,11 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertEqual(counter, closeOnRequest) XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) - XCTAssertEqual(connectionDelegate.hitConnectionReleased, counter - 1, - "If a close header is received connection release is not triggered.") + XCTAssertEqual( + connectionDelegate.hitConnectionReleased, + counter - 1, + "If a close header is received connection release is not triggered." + ) // we need to wait a small amount of time to see the connection close on the server try! eventLoop.scheduleTask(in: .milliseconds(200)) {}.futureResult.wait() @@ -324,17 +367,23 @@ class HTTP1ConnectionTests: XCTestCase { var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait() + ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try eventLoop.submit { try HTTP1Connection.start( - channel: XCTUnwrap(maybeChannel), - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(), - logger: logger - ) }.wait()) + XCTAssertNoThrow( + maybeConnection = try eventLoop.submit { + try HTTP1Connection.start( + channel: XCTUnwrap(maybeChannel), + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ) + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } var maybeRequest: HTTPClient.Request? @@ -343,15 +392,17 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: eventLoopGroup.next()), - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: eventLoopGroup.next()), + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } connection.executeRequest(requestBag) @@ -373,13 +424,15 @@ class HTTP1ConnectionTests: XCTestCase { var maybeConnection: HTTP1Connection? let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(maybeConnection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(decompression: .enabled(limit: .ratio(4))), - logger: logger - )) + XCTAssertNoThrow( + maybeConnection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point.") } var maybeRequest: HTTPClient.Request? @@ -388,38 +441,40 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end let responseString = """ - HTTP/1.1 101 Switching Protocols\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: xAMUK7/Il9bLRFJrikq6mm8CNZI=\r\n\ - Connection: upgrade\r\n\ - date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ - \r\n\ - \r\nfoo bar baz - """ + HTTP/1.1 101 Switching Protocols\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: xAMUK7/Il9bLRFJrikq6mm8CNZI=\r\n\ + Connection: upgrade\r\n\ + date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ + \r\n\ + \r\nfoo bar baz + """ XCTAssertTrue(embedded.isActive) XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) XCTAssertNoThrow(try embedded.writeInbound(ByteBuffer(string: responseString))) XCTAssertFalse(embedded.isActive) - (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. + (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) @@ -438,13 +493,15 @@ class HTTP1ConnectionTests: XCTestCase { var maybeConnection: HTTP1Connection? let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(maybeConnection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(decompression: .enabled(limit: .ratio(4))), - logger: logger - )) + XCTAssertNoThrow( + maybeConnection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point.") } var maybeRequest: HTTPClient.Request? @@ -453,28 +510,30 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end let responseString = """ - HTTP/1.1 103 Early Hints\r\n\ - date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ - \r\n\ - \r\n - """ + HTTP/1.1 103 Early Hints\r\n\ + date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ + \r\n\ + \r\n + """ XCTAssertTrue(embedded.isActive) XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) @@ -484,7 +543,7 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertTrue(embedded.isActive, "The connection remains active after the informational response head") XCTAssertNoThrow(try embedded.close().wait(), "the connection was closed") - embedded.embeddedEventLoop.run() // tick once to run futures. + embedded.embeddedEventLoop.run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) @@ -500,20 +559,22 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)).wait()) let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(decompression: .enabled(limit: .ratio(4))), - logger: logger - )) + XCTAssertNoThrow( + try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) let responseString = """ - HTTP/1.1 200 OK\r\n\ - date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ - \r\n\ - \r\n - """ + HTTP/1.1 200 OK\r\n\ + date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ + \r\n\ + \r\n + """ XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) @@ -522,7 +583,7 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertEqual($0 as? NIOHTTPDecoderError, .unsolicitedResponse) } XCTAssertFalse(embedded.isActive) - (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. + (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) } @@ -535,13 +596,15 @@ class HTTP1ConnectionTests: XCTestCase { var maybeConnection: HTTP1Connection? let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(maybeConnection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(decompression: .enabled(limit: .ratio(4))), - logger: logger - )) + XCTAssertNoThrow( + maybeConnection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point.") } var maybeRequest: HTTPClient.Request? @@ -550,32 +613,34 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } connection.executeRequest(requestBag) let responseString = """ - HTTP/1.0 200 OK\r\n\ - HTTP/1.0 200 OK\r\n\r\n - """ + HTTP/1.0 200 OK\r\n\ + HTTP/1.0 200 OK\r\n\r\n + """ - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) XCTAssertNoThrow(try embedded.writeInbound(ByteBuffer(string: responseString))) XCTAssertFalse(embedded.isActive) - (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. + (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) } @@ -595,35 +660,35 @@ class HTTP1ConnectionTests: XCTestCase { var _reads = 0 var _channel: Channel? - let lock: Lock + let lock: NIOLock let backpressurePromise: EventLoopPromise let messageReceived: EventLoopPromise init(eventLoop: EventLoop) { - self.lock = Lock() + self.lock = NIOLock() self.backpressurePromise = eventLoop.makePromise() self.messageReceived = eventLoop.makePromise() } var reads: Int { - return self.lock.withLock { + self.lock.withLock { self._reads } } func willExecuteOnChannel(_ channel: Channel) { - self.lock.withLockVoid { + self.lock.withLock { self._channel = channel } } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - return task.futureResult.eventLoop.makeSucceededVoidFuture() + task.futureResult.eventLoop.makeSucceededVoidFuture() } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { // We count a number of reads received. - self.lock.withLockVoid { + self.lock.withLock { self._reads += 1 } // We need to notify the test when first byte of the message is arrived. @@ -679,34 +744,42 @@ class HTTP1ConnectionTests: XCTestCase { defer { XCTAssertNoThrow(try httpBin.shutdown()) } var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoopGroup) - .channelOption(ChannelOptions.maxMessagesPerRead, value: 1) - .channelOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1)) - .connect(host: "localhost", port: httpBin.port) - .wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoopGroup) + .channelOption(ChannelOptions.maxMessagesPerRead, value: 1) + .channelOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1)) + .connect(host: "localhost", port: httpBin.port) + .wait() + ) guard let channel = maybeChannel else { return XCTFail("Expected to have a channel at this point") } let connectionDelegate = MockConnectionDelegate() var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try channel.eventLoop.submit { try HTTP1Connection.start( - channel: channel, - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(), - logger: logger - ) }.wait()) + XCTAssertNoThrow( + maybeConnection = try channel.eventLoop.submit { + try HTTP1Connection.start( + channel: channel, + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ) + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point") } var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: HTTPClient.Request(url: "http://localhost:\(httpBin.port)/custom"), - eventLoopPreference: .delegate(on: requestEventLoop), - task: .init(eventLoop: requestEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: backpressureDelegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: HTTPClient.Request(url: "http://localhost:\(httpBin.port)/custom"), + eventLoopPreference: .delegate(on: requestEventLoop), + task: .init(eventLoop: requestEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: backpressureDelegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } backpressureDelegate.willExecuteOnChannel(connection.channel) @@ -764,7 +837,12 @@ class SuddenlySendsCloseHeaderChannelHandler: ChannelInboundHandler { break case .end: if self.closeOnRequest == self.counter { - context.write(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: ["connection": "close"]))), promise: nil) + context.write( + self.wrapOutboundOut( + .head(.init(version: .http1_1, status: .ok, headers: ["connection": "close"])) + ), + promise: nil + ) context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() self.counter += 1 @@ -805,7 +883,7 @@ class AfterRequestCloseConnectionChannelHandler: ChannelInboundHandler { } class MockConnectionDelegate: HTTP1ConnectionDelegate { - private var lock = Lock() + private var lock = NIOLock() private var _hitConnectionReleased = 0 private var _hitConnectionClosed = 0 @@ -821,13 +899,13 @@ class MockConnectionDelegate: HTTP1ConnectionDelegate { init() {} func http1ConnectionReleased(_: HTTP1Connection) { - self.lock.withLockVoid { + self.lock.withLock { self._hitConnectionReleased += 1 } } func http1ConnectionClosed(_: HTTP1Connection) { - self.lock.withLockVoid { + self.lock.withLock { self._hitConnectionClosed += 1 } } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests+XCTest.swift deleted file mode 100644 index 15c432037..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests+XCTest.swift +++ /dev/null @@ -1,35 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP1ProxyConnectHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP1ProxyConnectHandlerTests { - static var allTests: [(String, (HTTP1ProxyConnectHandlerTests) -> () throws -> Void)] { - return [ - ("testProxyConnectWithoutAuthorizationSuccess", testProxyConnectWithoutAuthorizationSuccess), - ("testProxyConnectWithAuthorization", testProxyConnectWithAuthorization), - ("testProxyConnectWithoutAuthorizationFailure500", testProxyConnectWithoutAuthorizationFailure500), - ("testProxyConnectWithoutAuthorizationButAuthorizationNeeded", testProxyConnectWithoutAuthorizationButAuthorizationNeeded), - ("testProxyConnectReceivesBody", testProxyConnectReceivesBody), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift index bbe6fab1f..d75865da2 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class HTTP1ProxyConnectHandlerTests: XCTestCase { func testProxyConnectWithoutAuthorizationSuccess() { let embedded = EmbeddedChannel() @@ -43,6 +44,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { XCTAssertEqual(head.method, .CONNECT) XCTAssertEqual(head.uri, "swift.org:443") + XCTAssertEqual(head.headers["host"].first, "swift.org") XCTAssertNil(head.headers["proxy-authorization"].first) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) @@ -76,6 +78,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { XCTAssertEqual(head.method, .CONNECT) XCTAssertEqual(head.uri, "swift.org:443") + XCTAssertEqual(head.headers["host"].first, "swift.org") XCTAssertEqual(head.headers["proxy-authorization"].first, "Basic abc123") XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) @@ -109,6 +112,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { XCTAssertEqual(head.method, .CONNECT) XCTAssertEqual(head.uri, "swift.org:443") + XCTAssertEqual(head.headers["host"].first, "swift.org") XCTAssertNil(head.headers["proxy-authorization"].first) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) @@ -148,6 +152,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { XCTAssertEqual(head.method, .CONNECT) XCTAssertEqual(head.uri, "swift.org:443") + XCTAssertEqual(head.headers["host"].first, "swift.org") XCTAssertNil(head.headers["proxy-authorization"].first) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) @@ -187,6 +192,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { XCTAssertEqual(head.method, .CONNECT) XCTAssertEqual(head.uri, "swift.org:443") + XCTAssertEqual(head.headers["host"].first, "swift.org") XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift deleted file mode 100644 index 8fa219838..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift +++ /dev/null @@ -1,35 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP2ClientRequestHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP2ClientRequestHandlerTests { - static var allTests: [(String, (HTTP2ClientRequestHandlerTests) -> () throws -> Void)] { - return [ - ("testResponseBackpressure", testResponseBackpressure), - ("testWriteBackpressure", testWriteBackpressure), - ("testIdleReadTimeout", testIdleReadTimeout), - ("testIdleReadTimeoutIsCanceledIfRequestIsCanceled", testIdleReadTimeoutIsCanceledIfRequestIsCanceled), - ("testWriteHTTPHeadFails", testWriteHTTPHeadFails), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift index e67529ad8..1f5f1b4c0 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class HTTP2ClientRequestHandlerTests: XCTestCase { func testResponseBackpressure() { let embedded = EmbeddedChannel() @@ -34,28 +35,36 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.write(requestBag, promise: nil) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(readEventHandler.readHitCounter, 0) embedded.read() @@ -115,22 +124,30 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 50) var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 100) { writer in - testWriter.start(writer: writer) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 100) { writer in + testWriter.start(writer: writer) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.isWritable = false @@ -143,12 +160,14 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { testWriter.writabilityChanged(true) embedded.pipeline.fireChannelWritabilityChanged() - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .POST) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - XCTAssertEqual($0.headers.first(name: "content-length"), "100") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .POST) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + XCTAssertEqual($0.headers.first(name: "content-length"), "100") + } + ) // the next body write will be executed once we tick the el. before we make the channel // unwritable @@ -162,9 +181,11 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { embedded.embeddedEventLoop.run() - XCTAssertNoThrow(try embedded.receiveBodyAndVerify { - XCTAssertEqual($0.readableBytes, 2) - }) + XCTAssertNoThrow( + try embedded.receiveBodyAndVerify { + XCTAssertEqual($0.readableBytes, 2) + } + ) XCTAssertEqual(testWriter.written, index + 1) @@ -198,27 +219,35 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.write(requestBag, promise: nil) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(readEventHandler.readHitCounter, 0) embedded.read() @@ -248,27 +277,35 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.write(requestBag, promise: nil) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(readEventHandler.readHitCounter, 0) embedded.read() @@ -276,7 +313,164 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) // canceling the request - requestBag.cancel() + requestBag.fail(HTTPClientError.cancelled) + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + + // the idle read timeout should be cleared because we canceled the request + // therefore advancing the time should not trigger a crash + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + } + + func testIdleWriteTimeout() { + let embedded = EmbeddedChannel() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandlers([requestHandler])) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + let logger = Logger(label: "test") + + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + return testWriter.start(writer: writer) + } + ) + ) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + embedded.write(requestBag, promise: nil) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutWritabilityChanged() { + let embedded = EmbeddedChannel() + let readEventHandler = ReadEventHitHandler() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandlers([readEventHandler, requestHandler])) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + let logger = Logger(label: "test") + + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + embedded.isWritable = false + embedded.pipeline.fireChannelWritabilityChanged() + // This should not trigger any errors or timeouts, because the timer isn't running + // as the channel is not writable. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(20)) + + // Now that the channel will become writable, this should trigger a timeout. + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + + return testWriter.start(writer: writer) + } + ) + ) + + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + embedded.write(requestBag, promise: nil) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutIsCanceledIfRequestIsCanceled() { + let embedded = EmbeddedChannel() + let readEventHandler = ReadEventHitHandler() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandlers([readEventHandler, requestHandler])) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + let logger = Logger(label: "test") + + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 2) { writer in + testWriter.start(writer: writer, expectedErrors: [HTTPClientError.cancelled]) + } + ) + ) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + embedded.write(requestBag, promise: nil) + + // canceling the request + requestBag.fail(HTTPClientError.cancelled) XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { XCTAssertEqual($0 as? HTTPClientError, .cancelled) } @@ -318,16 +512,20 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) - guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to be able to create a request bag") + } embedded.isWritable = false XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) @@ -335,6 +533,7 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { // the handler only writes once the channel is writable XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .none) + XCTAssertTrue(embedded.isActive) embedded.isWritable = true embedded.pipeline.fireChannelWritabilityChanged() @@ -342,7 +541,38 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { XCTAssertEqual($0 as? WriteError, WriteError()) } - XCTAssertEqual(embedded.isActive, false) + XCTAssertFalse(embedded.isActive) + } + } + + func testChannelBecomesNonWritableDuringHeaderWrite() throws { + final class ChangeWritabilityOnFlush: ChannelOutboundHandler { + typealias OutboundIn = Any + func flush(context: ChannelHandlerContext) { + context.flush() + (context.channel as! EmbeddedChannel).isWritable = false + context.fireChannelWritabilityChanged() + } } + let eventLoopGroup = EmbeddedEventLoopGroup(loops: 1) + let eventLoop = eventLoopGroup.next() as! EmbeddedEventLoop + let handler = HTTP2ClientRequestHandler( + eventLoop: eventLoop + ) + let channel = EmbeddedChannel( + handlers: [ + ChangeWritabilityOnFlush(), + handler, + ], + loop: eventLoop + ) + try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait() + + let request = MockHTTPExecutableRequest() + // non empty body is important to trigger this bug as we otherwise finish the request in a single flush + request.requestFramingMetadata.body = .fixedSize(1) + request.raiseErrorIfUnimplementedMethodIsCalled = false + channel.writeAndFlush(request, promise: nil) + XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent]) } } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift deleted file mode 100644 index e7f399658..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift +++ /dev/null @@ -1,42 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP2ClientTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP2ClientTests { - static var allTests: [(String, (HTTP2ClientTests) -> () throws -> Void)] { - return [ - ("testSimpleGet", testSimpleGet), - ("testStreamRequestBodyWithoutKnowledgeAboutLength", testStreamRequestBodyWithoutKnowledgeAboutLength), - ("testStreamRequestBodyWithFalseKnowledgeAboutLength", testStreamRequestBodyWithFalseKnowledgeAboutLength), - ("testConcurrentRequests", testConcurrentRequests), - ("testConcurrentRequestsFromDifferentThreads", testConcurrentRequestsFromDifferentThreads), - ("testConcurrentRequestsWorkWithRequiredEventLoop", testConcurrentRequestsWorkWithRequiredEventLoop), - ("testUncleanShutdownCancelsExecutingAndQueuedTasks", testUncleanShutdownCancelsExecutingAndQueuedTasks), - ("testCancelingRunningRequest", testCancelingRunningRequest), - ("testReadTimeout", testReadTimeout), - ("testH2CanHandleRequestsThatHaveAlreadyHitTheDeadline", testH2CanHandleRequestsThatHaveAlreadyHitTheDeadline), - ("testStressCancelingRunningRequestFromDifferentThreads", testStressCancelingRunningRequestFromDifferentThreads), - ("testPlatformConnectErrorIsForwardedOnTimeout", testPlatformConnectErrorIsForwardedOnTimeout), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift index eb1ac2ddc..d6bc2de14 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift @@ -12,20 +12,23 @@ // //===----------------------------------------------------------------------===// -/* NOT @testable */ import AsyncHTTPClient // Tests that really need @testable go into HTTP2ClientInternalTests.swift -#if canImport(Network) -import Network -#endif +import AsyncHTTPClient // NOT @testable - tests that really need @testable go into HTTP2ClientInternalTests.swift import Logging import NIOCore +import NIOFoundationCompat import NIOHTTP1 +import NIOHTTP2 import NIOPosix import NIOSSL import XCTest +#if canImport(Network) +import Network +#endif + class HTTP2ClientTests: XCTestCase { func makeDefaultHTTPClient( - eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .createNew + eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .singleton ) -> HTTPClient { var config = HTTPClient.Configuration() config.tlsConfiguration = .clientDefault @@ -40,7 +43,7 @@ class HTTP2ClientTests: XCTestCase { func makeClientWithActiveHTTP2Connection( to bin: HTTPBin, - eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .createNew + eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .singleton ) -> HTTPClient { let client = self.makeDefaultHTTPClient(eventLoopGroupProvider: eventLoopGroupProvider) var response: HTTPClient.Response? @@ -68,7 +71,7 @@ class HTTP2ClientTests: XCTestCase { let client = self.makeDefaultHTTPClient() defer { XCTAssertNoThrow(try client.syncShutdown()) } var response: HTTPClient.Response? - let body = HTTPClient.Body.stream(length: nil) { writer in + let body = HTTPClient.Body.stream(contentLength: nil) { writer in writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))).flatMap { writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))) } @@ -84,7 +87,7 @@ class HTTP2ClientTests: XCTestCase { defer { XCTAssertNoThrow(try bin.shutdown()) } let client = self.makeDefaultHTTPClient() defer { XCTAssertNoThrow(try client.syncShutdown()) } - let body = HTTPClient.Body.stream(length: 12) { writer in + let body = HTTPClient.Body.stream(contentLength: 12) { writer in writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))).flatMap { writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))) } @@ -132,8 +135,8 @@ class HTTP2ClientTests: XCTestCase { let q = DispatchQueue(label: "worker \(w)") q.async(group: allDone) { func go() { - allWorkersReady.signal() // tell the driver we're ready - allWorkersGo.wait() // wait for the driver to let us go + allWorkersReady.signal() // tell the driver we're ready + allWorkersGo.wait() // wait for the driver to let us go for _ in 0..] = [] - XCTAssertNoThrow(results = try EventLoopFuture - .whenAllComplete(responses, on: clientGroup.next()) - .timeout(after: .seconds(2)) - .wait()) + XCTAssertNoThrow( + results = + try EventLoopFuture + .whenAllComplete(responses, on: clientGroup.next()) + .timeout(after: .seconds(2)) + .wait() + ) for result in results { switch result { @@ -287,7 +294,7 @@ class HTTP2ClientTests: XCTestCase { ) XCTAssertThrowsError(try task.futureResult.timeout(after: .seconds(2)).wait()) { - XCTAssertEqual($0 as? HTTPClientError, .cancelled) + XCTAssertEqualTypeAndValue($0, HTTPClientError.cancelled) } } @@ -301,7 +308,7 @@ class HTTP2ClientTests: XCTestCase { config.httpVersion = .automatic config.timeout.read = .milliseconds(100) let client = HTTPClient( - eventLoopGroupProvider: .createNew, + eventLoopGroupProvider: .singleton, configuration: config, backgroundActivityLogger: Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) ) @@ -322,7 +329,8 @@ class HTTP2ClientTests: XCTestCase { config.tlsConfiguration = tlsConfig config.httpVersion = .automatic let client = HTTPClient( - eventLoopGroupProvider: .createNew, + // TODO: Test fails if the provided ELG is a multi-threaded NIOTSEventLoopGroup (probably racy) + eventLoopGroupProvider: .shared(bin.group), configuration: config, backgroundActivityLogger: Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) ) @@ -375,7 +383,7 @@ class HTTP2ClientTests: XCTestCase { } func testPlatformConnectErrorIsForwardedOnTimeout() { - let bin = HTTPBin(.http2(compress: false)) + let bin = HTTPBin(.http2(compress: false), reusePort: true) let clientGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) let el1 = clientGroup.next() let el2 = clientGroup.next() @@ -396,27 +404,35 @@ class HTTP2ClientTests: XCTestCase { XCTAssertNoThrow(maybeRequest1 = try HTTPClient.Request(url: "https://localhost:\(bin.port)/get")) guard let request1 = maybeRequest1 else { return } - let task1 = client.execute(request: request1, delegate: ResponseAccumulator(request: request1), eventLoop: .delegateAndChannel(on: el1)) + let task1 = client.execute( + request: request1, + delegate: ResponseAccumulator(request: request1), + eventLoop: .delegateAndChannel(on: el1) + ) var response1: ResponseAccumulator.Response? XCTAssertNoThrow(response1 = try task1.wait()) XCTAssertEqual(.ok, response1?.status) XCTAssertEqual(response1?.version, .http2) let serverPort = bin.port - XCTAssertNoThrow(try bin.shutdown()) - // client is now in HTTP/2 state and the HTTPBin is closed - // start a new server on the old port which closes all connections immediately + let serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try serverGroup.syncShutdownGracefully()) } var maybeServer: Channel? - XCTAssertNoThrow(maybeServer = try ServerBootstrap(group: serverGroup) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .childChannelInitializer { channel in - channel.close() - } - .childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .bind(host: "127.0.0.1", port: serverPort) - .wait()) + XCTAssertNoThrow( + maybeServer = try ServerBootstrap(group: serverGroup) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: 1) + .childChannelInitializer { channel in + channel.close() + } + .childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .bind(host: "127.0.0.1", port: serverPort) + .wait() + ) + // shutting down the old server closes all connections immediately + XCTAssertNoThrow(try bin.shutdown()) + // client is now in HTTP/2 state and the HTTPBin is closed guard let server = maybeServer else { return } defer { XCTAssertNoThrow(try server.close().wait()) } @@ -424,7 +440,11 @@ class HTTP2ClientTests: XCTestCase { XCTAssertNoThrow(maybeRequest2 = try HTTPClient.Request(url: "https://localhost:\(serverPort)/")) guard let request2 = maybeRequest2 else { return } - let task2 = client.execute(request: request2, delegate: ResponseAccumulator(request: request2), eventLoop: .delegateAndChannel(on: el2)) + let task2 = client.execute( + request: request2, + delegate: ResponseAccumulator(request: request2), + eventLoop: .delegateAndChannel(on: el2) + ) XCTAssertThrowsError(try task2.wait()) { error in XCTAssertNil( error as? HTTPClientError, @@ -432,6 +452,97 @@ class HTTP2ClientTests: XCTestCase { ) } } + + func testMassiveDownload() { + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = self.makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + var response: HTTPClient.Response? + XCTAssertNoThrow(response = try client.get(url: "https://localhost:\(bin.port)/mega-chunked").wait()) + + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(response?.version, .http2) + XCTAssertEqual(response?.body?.readableBytes, 10_000) + } + + func testSimplePost() { + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = self.makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + var response: HTTPClient.Response? + XCTAssertNoThrow( + response = try client.post( + url: "https://localhost:\(bin.port)/post", + body: .byteBuffer(ByteBuffer(repeating: 0, count: 12345)) + ).wait() + ) + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(response?.version, .http2) + XCTAssertEqual( + String(buffer: ByteBuffer(repeating: 0, count: 12345)), + try response?.body.map { body in + try JSONDecoder().decode(RequestInfo.self, from: body) + }?.data + ) + } + + func testHugePost() { + // Regression test for https://github.com/swift-server/async-http-client/issues/784 + let group = MultiThreadedEventLoopGroup(numberOfThreads: 2) // This needs to be more than 1! + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + var serverH2Settings: HTTP2Settings = HTTP2Settings() + serverH2Settings.append(HTTP2Setting(parameter: .maxFrameSize, value: 16 * 1024 * 1024 - 1)) + serverH2Settings.append(HTTP2Setting(parameter: .initialWindowSize, value: Int(Int32.max))) + let bin = HTTPBin( + .http2(compress: false, settings: serverH2Settings) + ) + defer { XCTAssertNoThrow(try bin.shutdown()) } + var clientConfig = HTTPClient.Configuration() + clientConfig.tlsConfiguration = .clientDefault + clientConfig.tlsConfiguration?.certificateVerification = .none + clientConfig.httpVersion = .automatic + let client = HTTPClient( + eventLoopGroupProvider: .shared(group), + configuration: clientConfig, + backgroundActivityLogger: Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + ) + defer { XCTAssertNoThrow(try client.syncShutdown()) } + + let loop1 = group.next() + let loop2 = group.next() + precondition(loop1 !== loop2, "bug in test setup, need two distinct loops") + + XCTAssertNoThrow( + try client.execute( + request: .init(url: "https://localhost:\(bin.port)/get"), + eventLoop: .delegateAndChannel(on: loop1) // This will force the channel to live on `loop1`. + ).wait() + ) + var response: HTTPClient.Response? + let byteCount = 1024 * 1024 * 1024 // 1 GiB (unfortunately it has to be that big to trigger the bug) + XCTAssertNoThrow( + response = try client.execute( + request: HTTPClient.Request( + url: "https://localhost:\(bin.port)/post-respond-with-byte-count", + method: .POST, + body: .data(Data(repeating: 0, count: byteCount)) + ), + eventLoop: .delegate(on: loop2) + ).wait() + ) + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(response?.version, .http2) + XCTAssertEqual( + "\(byteCount)", + try response?.body.map { body in + try JSONDecoder().decode(RequestInfo.self, from: body) + }?.data + ) + } } private final class HeadReceivedCallback: HTTPClientResponseDelegate { @@ -458,11 +569,17 @@ private final class SendHeaderAndWaitChannelHandler: ChannelInboundHandler { let requestPart = self.unwrapInboundIn(data) switch requestPart { case .head: - context.writeAndFlush(self.wrapOutboundOut(.head(HTTPResponseHead( - version: HTTPVersion(major: 1, minor: 1), - status: .ok - )) - ), promise: nil) + context.writeAndFlush( + self.wrapOutboundOut( + .head( + HTTPResponseHead( + version: HTTPVersion(major: 1, minor: 1), + status: .ok + ) + ) + ), + promise: nil + ) case .body, .end: return } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests+XCTest.swift deleted file mode 100644 index 9f9582d9f..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests+XCTest.swift +++ /dev/null @@ -1,34 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP2ConnectionTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP2ConnectionTests { - static var allTests: [(String, (HTTP2ConnectionTests) -> () throws -> Void)] { - return [ - ("testCreateNewConnectionFailureClosedIO", testCreateNewConnectionFailureClosedIO), - ("testSimpleGetRequest", testSimpleGetRequest), - ("testEveryDoneRequestLeadsToAStreamAvailableCall", testEveryDoneRequestLeadsToAStreamAvailableCall), - ("testCancelAllRunningRequests", testCancelAllRunningRequests), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift index fab866867..a50f1ab54 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift @@ -12,17 +12,20 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +import NIOHPACK import NIOHTTP1 +import NIOHTTP2 import NIOPosix import NIOSSL import NIOTestUtils import XCTest +@testable import AsyncHTTPClient + class HTTP2ConnectionTests: XCTestCase { func testCreateNewConnectionFailureClosedIO() { let embedded = EmbeddedChannel() @@ -33,13 +36,41 @@ class HTTP2ConnectionTests: XCTestCase { embedded.embeddedEventLoop.run() let logger = Logger(label: "test.http2.connection") - XCTAssertThrowsError(try HTTP2Connection.start( + XCTAssertThrowsError( + try HTTP2Connection.start( + channel: embedded, + connectionID: 0, + delegate: TestHTTP2ConnectionDelegate(), + decompression: .disabled, + maximumConnectionUses: nil, + logger: logger + ).wait() + ) + } + + func testConnectionToleratesShutdownEventsAfterAlreadyClosed() { + let embedded = EmbeddedChannel() + XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) + + let logger = Logger(label: "test.http2.connection") + let connection = HTTP2Connection( channel: embedded, connectionID: 0, + decompression: .disabled, + maximumConnectionUses: nil, delegate: TestHTTP2ConnectionDelegate(), - configuration: .init(), logger: logger - ).wait()) + ) + let startFuture = connection._start0() + + XCTAssertNoThrow(try embedded.close().wait()) + // to really destroy the channel we need to tick once + embedded.embeddedEventLoop.run() + + XCTAssertThrowsError(try startFuture.wait()) + + // should not crash + connection.shutdown() } func testSimpleGetRequest() { @@ -53,11 +84,12 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - ) + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") @@ -66,15 +98,17 @@ class HTTP2ConnectionTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to have a request bag at this point") } @@ -109,11 +143,13 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - )) + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") } @@ -127,15 +163,17 @@ class HTTP2ConnectionTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to have a request bag at this point") } @@ -173,11 +211,12 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - ) + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") @@ -189,15 +228,17 @@ class HTTP2ConnectionTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to have a request bag at this point") } @@ -219,6 +260,129 @@ class HTTP2ConnectionTests: XCTestCase { XCTAssertNoThrow(try http2Connection.closeFuture.wait()) } + + func testChildStreamsAreRemovedFromTheOpenChannelListOnceTheRequestIsDone() { + class SucceedPromiseOnRequestHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + let dataArrivedPromise: EventLoopPromise + let triggerResponseFuture: EventLoopFuture + + init(dataArrivedPromise: EventLoopPromise, triggerResponseFuture: EventLoopFuture) { + self.dataArrivedPromise = dataArrivedPromise + self.triggerResponseFuture = triggerResponseFuture + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.dataArrivedPromise.succeed(()) + + self.triggerResponseFuture.hop(to: context.eventLoop).whenSuccess { + switch self.unwrapInboundIn(data) { + case .head: + context.write(self.wrapOutboundOut(.head(.init(version: .http2, status: .ok))), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .body, .end: + break + } + } + } + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoop = eventLoopGroup.next() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let serverReceivedRequestPromise = eventLoop.makePromise(of: Void.self) + let triggerResponsePromise = eventLoop.makePromise(of: Void.self) + let httpBin = HTTPBin(.http2(compress: false)) { _ in + SucceedPromiseOnRequestHandler( + dataArrivedPromise: serverReceivedRequestPromise, + triggerResponseFuture: triggerResponsePromise.futureResult + ) + } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let connectionCreator = TestConnectionCreator() + let delegate = TestHTTP2ConnectionDelegate() + var maybeHTTP2Connection: HTTP2Connection? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) + guard let http2Connection = maybeHTTP2Connection else { + return XCTFail("Expected to have an HTTP2 connection here.") + } + + var maybeRequest: HTTPClient.Request? + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to have a request bag at this point") + } + + http2Connection.executeRequest(requestBag) + + XCTAssertNoThrow(try serverReceivedRequestPromise.futureResult.wait()) + var channelCount: Int? + XCTAssertNoThrow( + channelCount = try eventLoop.submit { http2Connection.__forTesting_getStreamChannels().count }.wait() + ) + XCTAssertEqual(channelCount, 1) + triggerResponsePromise.succeed(()) + + XCTAssertNoThrow(try requestBag.task.futureResult.wait()) + + // this is racy. for this reason we allow a couple of tries + var retryCount = 0 + let maxRetries = 1000 + while retryCount < maxRetries { + XCTAssertNoThrow( + channelCount = try eventLoop.submit { http2Connection.__forTesting_getStreamChannels().count }.wait() + ) + if channelCount == 0 { + break + } + retryCount += 1 + } + XCTAssertLessThan(retryCount, maxRetries) + } + + func testServerPushIsDisabled() { + let embedded = EmbeddedChannel() + let logger = Logger(label: "test.http2.connection") + let connection = HTTP2Connection( + channel: embedded, + connectionID: 0, + decompression: .disabled, + maximumConnectionUses: nil, + delegate: TestHTTP2ConnectionDelegate(), + logger: logger + ) + _ = connection._start0() + + let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([]))) + XCTAssertNoThrow(try connection.channel.writeAndFlush(settingsFrame).wait()) + + let pushPromiseFrame = HTTP2Frame(streamID: 0, payload: .pushPromise(.init(pushedStreamID: 1, headers: [:]))) + XCTAssertThrowsError(try connection.channel.writeAndFlush(pushPromiseFrame).wait()) { error in + XCTAssertNotNil(error as? NIOHTTP2Errors.PushInViolationOfSetting) + } + } } class TestConnectionCreator { @@ -235,7 +399,7 @@ class TestConnectionCreator { } private var state: State = .idle - private let lock = Lock() + private let lock = NIOLock() init() {} @@ -405,6 +569,10 @@ extension TestConnectionCreator: HTTPConnectionRequester { } wrapper.fail(error) } + + func waitingForConnectivity(_: HTTPConnectionPool.Connection.ID, error: Swift.Error) { + preconditionFailure("TODO") + } } class TestHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { @@ -424,7 +592,7 @@ class TestHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { self.lock.withLock { self._maxStreamSetting } } - private let lock = Lock() + private let lock = NIOLock() private var _hitStreamClosed: Int = 0 private var _hitGoAwayReceived: Int = 0 private var _hitConnectionClosed: Int = 0 @@ -435,19 +603,19 @@ class TestHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) {} func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) { - self.lock.withLockVoid { + self.lock.withLock { self._hitStreamClosed += 1 } } func http2ConnectionGoAwayReceived(_: HTTP2Connection) { - self.lock.withLockVoid { + self.lock.withLock { self._hitGoAwayReceived += 1 } } func http2ConnectionClosed(_: HTTP2Connection) { - self.lock.withLockVoid { + self.lock.withLock { self._hitConnectionClosed += 1 } } diff --git a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests+XCTest.swift deleted file mode 100644 index 5c7021e23..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests+XCTest.swift +++ /dev/null @@ -1,40 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP2IdleHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP2IdleHandlerTests { - static var allTests: [(String, (HTTP2IdleHandlerTests) -> () throws -> Void)] { - return [ - ("testReceiveSettingsWithMaxConcurrentStreamSetting", testReceiveSettingsWithMaxConcurrentStreamSetting), - ("testReceiveSettingsWithoutMaxConcurrentStreamSetting", testReceiveSettingsWithoutMaxConcurrentStreamSetting), - ("testEmptySettingsDontOverwriteMaxConcurrentStreamSetting", testEmptySettingsDontOverwriteMaxConcurrentStreamSetting), - ("testOverwriteMaxConcurrentStreamSetting", testOverwriteMaxConcurrentStreamSetting), - ("testGoAwayReceivedBeforeSettings", testGoAwayReceivedBeforeSettings), - ("testGoAwayReceivedAfterSettings", testGoAwayReceivedAfterSettings), - ("testCloseEventBeforeFirstSettings", testCloseEventBeforeFirstSettings), - ("testCloseEventWhileNoOpenStreams", testCloseEventWhileNoOpenStreams), - ("testCloseEventWhileThereAreOpenStreams", testCloseEventWhileThereAreOpenStreams), - ("testGoAwayWhileThereAreOpenStreams", testGoAwayWhileThereAreOpenStreams), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift index 57560b659..f2b56daa0 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOEmbedded import NIOHTTP2 import XCTest +@testable import AsyncHTTPClient + class HTTP2IdleHandlerTests: XCTestCase { func testReceiveSettingsWithMaxConcurrentStreamSetting() { let delegate = MockHTTP2IdleHandlerDelegate() @@ -26,7 +27,10 @@ class HTTP2IdleHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [idleHandler]) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + let settingsFrame = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)])) + ) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) XCTAssertEqual(delegate.maxStreams, 10) @@ -41,7 +45,11 @@ class HTTP2IdleHandlerTests: XCTestCase { let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([]))) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) - XCTAssertEqual(delegate.maxStreams, 100, "Expected to assume 100 maxConcurrentConnection, if no setting was present") + XCTAssertEqual( + delegate.maxStreams, + 100, + "Expected to assume 100 maxConcurrentConnection, if no setting was present" + ) } func testEmptySettingsDontOverwriteMaxConcurrentStreamSetting() { @@ -50,7 +58,10 @@ class HTTP2IdleHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [idleHandler]) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + let settingsFrame = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)])) + ) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) XCTAssertEqual(delegate.maxStreams, 10) @@ -66,12 +77,18 @@ class HTTP2IdleHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [idleHandler]) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + let settingsFrame = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)])) + ) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) XCTAssertEqual(delegate.maxStreams, 10) - let emptySettings = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 20)]))) + let emptySettings = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 20)])) + ) XCTAssertNoThrow(try embedded.writeInbound(emptySettings)) XCTAssertEqual(delegate.maxStreams, 20) } @@ -83,7 +100,10 @@ class HTTP2IdleHandlerTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) let randomStreamID = HTTP2StreamID((0.. () throws -> Void)] { - return [ - ("testProxySOCKS", testProxySOCKS), - ("testProxySOCKSBogusAddress", testProxySOCKSBogusAddress), - ("testProxySOCKSFailureNoServer", testProxySOCKSFailureNoServer), - ("testProxySOCKSFailureInvalidServer", testProxySOCKSFailureInvalidServer), - ("testProxySOCKSMisbehavingServer", testProxySOCKSMisbehavingServer), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift index 5fdc5ac61..af32284b0 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -/* NOT @testable */ import AsyncHTTPClient // Tests that need @testable go into HTTPClientInternalTests.swift +import AsyncHTTPClient // NOT @testable - tests that need @testable go into HTTPClientInternalTests.swift import Logging import NIOCore +import NIOHTTP1 import NIOPosix import NIOSOCKS import XCTest @@ -29,7 +30,7 @@ class HTTPClientSOCKSTests: XCTestCase { var backgroundLogStore: CollectEverythingLogHandler.LogStore! var defaultHTTPBinURLPrefix: String { - return "http://localhost:\(self.defaultHTTPBin.port)/" + "http://localhost:\(self.defaultHTTPBin.port)/" } override func setUp() { @@ -43,12 +44,17 @@ class HTTPClientSOCKSTests: XCTestCase { self.serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) self.defaultHTTPBin = HTTPBin() self.backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: self.backgroundLogStore!) - }) + var backgroundLogger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: self.backgroundLogStore!) + } + ) backgroundLogger.logLevel = .trace - self.defaultClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - backgroundActivityLogger: backgroundLogger) + self.defaultClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + backgroundActivityLogger: backgroundLogger + ) } override func tearDown() { @@ -75,8 +81,12 @@ class HTTPClientSOCKSTests: XCTestCase { func testProxySOCKS() throws { let socksBin = try MockSOCKSServer(expectedURL: "/socks/test", expectedResponse: "it works!") - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .socksServer(host: "localhost", port: socksBin.port))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init( + proxy: .socksServer(host: "localhost", port: socksBin.port) + ).enableFastFailureModeForTesting() + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -90,8 +100,12 @@ class HTTPClientSOCKSTests: XCTestCase { } func testProxySOCKSBogusAddress() throws { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .socksServer(host: "127.0.."))) + let config = HTTPClient.Configuration(proxy: .socksServer(host: "127.0..")) + .enableFastFailureModeForTesting() + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -102,8 +116,13 @@ class HTTPClientSOCKSTests: XCTestCase { // there is no socks server, so we should fail func testProxySOCKSFailureNoServer() throws { let localHTTPBin = HTTPBin() - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .socksServer(host: "localhost", port: localHTTPBin.port))) + let config = HTTPClient.Configuration(proxy: .socksServer(host: "localhost", port: localHTTPBin.port)) + .enableFastFailureModeForTesting() + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -113,8 +132,13 @@ class HTTPClientSOCKSTests: XCTestCase { // speak to a server that doesn't speak SOCKS func testProxySOCKSFailureInvalidServer() throws { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .socksServer(host: "localhost"))) + let config = HTTPClient.Configuration(proxy: .socksServer(host: "localhost")) + .enableFastFailureModeForTesting() + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } @@ -124,8 +148,13 @@ class HTTPClientSOCKSTests: XCTestCase { // test a handshake failure with a misbehaving server func testProxySOCKSMisbehavingServer() throws { let socksBin = try MockSOCKSServer(expectedURL: "/socks/test", expectedResponse: "it works!", misbehave: true) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .socksServer(host: "localhost", port: socksBin.port))) + let config = HTTPClient.Configuration(proxy: .socksServer(host: "localhost", port: socksBin.port)) + .enableFastFailureModeForTesting() + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) diff --git a/Tests/AsyncHTTPClientTests/HTTPClient+StructuredConcurrencyTests.swift b/Tests/AsyncHTTPClientTests/HTTPClient+StructuredConcurrencyTests.swift new file mode 100644 index 000000000..a7cc1f454 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClient+StructuredConcurrencyTests.swift @@ -0,0 +1,101 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import NIO +import NIOFoundationCompat +import NIOHTTP1 +import XCTest + +final class HTTPClientStructuredConcurrencyTests: XCTestCase { + func testDoNothingWorks() async throws { + let actual = try await HTTPClient.withHTTPClient { httpClient in + "OK" + } + XCTAssertEqual("OK", actual) + } + + func testShuttingDownTheClientInBodyLeadsToError() async { + do { + let actual = try await HTTPClient.withHTTPClient { httpClient in + try await httpClient.shutdown() + return "OK" + } + XCTFail("Expected error, got \(actual)") + } catch let error as HTTPClientError where error == .alreadyShutdown { + // OK + } catch { + XCTFail("unexpected error: \(error)") + } + } + + func testBasicRequest() async throws { + let httpBin = HTTPBin() + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let actualBytes = try await HTTPClient.withHTTPClient { httpClient in + let response = try await httpClient.get(url: httpBin.baseURL).get() + XCTAssertEqual(response.status, .ok) + return response.body ?? ByteBuffer(string: "n/a") + } + let actual = try JSONDecoder().decode(RequestInfo.self, from: actualBytes) + + XCTAssertGreaterThanOrEqual(actual.requestNumber, 0) + XCTAssertGreaterThanOrEqual(actual.connectionNumber, 0) + } + + func testClientIsShutDownAfterReturn() async throws { + let leakedClient = try await HTTPClient.withHTTPClient { httpClient in + httpClient + } + do { + try await leakedClient.shutdown() + XCTFail("unexpected, shutdown should have failed") + } catch let error as HTTPClientError where error == .alreadyShutdown { + // OK + } catch { + XCTFail("unexpected error: \(error)") + } + } + + func testClientIsShutDownOnThrowAlso() async throws { + struct TestError: Error { + var httpClient: HTTPClient + } + + let leakedClient: HTTPClient + do { + try await HTTPClient.withHTTPClient { httpClient in + throw TestError(httpClient: httpClient) + } + XCTFail("unexpected, shutdown should have failed") + return + } catch let error as TestError { + // OK + leakedClient = error.httpClient + } catch { + XCTFail("unexpected error: \(error)") + return + } + + do { + try await leakedClient.shutdown() + XCTFail("unexpected, shutdown should have failed") + } catch let error as HTTPClientError where error == .alreadyShutdown { + // OK + } catch { + XCTFail("unexpected error: \(error)") + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientBase.swift b/Tests/AsyncHTTPClientTests/HTTPClientBase.swift new file mode 100644 index 000000000..aaf072b2f --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClientBase.swift @@ -0,0 +1,92 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +class XCTestCaseHTTPClientTestsBaseClass: XCTestCase { + typealias Request = HTTPClient.Request + + var clientGroup: EventLoopGroup! + var serverGroup: EventLoopGroup! + var defaultHTTPBin: HTTPBin! + var defaultClient: HTTPClient! + var backgroundLogStore: CollectEverythingLogHandler.LogStore! + + var defaultHTTPBinURLPrefix: String { + self.defaultHTTPBin.baseURL + } + + override func setUp() { + XCTAssertNil(self.clientGroup) + XCTAssertNil(self.serverGroup) + XCTAssertNil(self.defaultHTTPBin) + XCTAssertNil(self.defaultClient) + XCTAssertNil(self.backgroundLogStore) + + self.clientGroup = getDefaultEventLoopGroup(numberOfThreads: 1) + self.serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + self.defaultHTTPBin = HTTPBin() + self.backgroundLogStore = CollectEverythingLogHandler.LogStore() + var backgroundLogger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: self.backgroundLogStore!) + } + ) + backgroundLogger.logLevel = .trace + self.defaultClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration().enableFastFailureModeForTesting(), + backgroundActivityLogger: backgroundLogger + ) + } + + override func tearDown() { + if let defaultClient = self.defaultClient { + XCTAssertNoThrow(try defaultClient.syncShutdown()) + self.defaultClient = nil + } + + XCTAssertNotNil(self.defaultHTTPBin) + XCTAssertNoThrow(try self.defaultHTTPBin.shutdown()) + self.defaultHTTPBin = nil + + XCTAssertNotNil(self.clientGroup) + XCTAssertNoThrow(try self.clientGroup.syncShutdownGracefully()) + self.clientGroup = nil + + XCTAssertNotNil(self.serverGroup) + XCTAssertNoThrow(try self.serverGroup.syncShutdownGracefully()) + self.serverGroup = nil + + XCTAssertNotNil(self.backgroundLogStore) + self.backgroundLogStore = nil + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientCookieTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientCookieTests+XCTest.swift deleted file mode 100644 index 7ecf54d4d..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientCookieTests+XCTest.swift +++ /dev/null @@ -1,44 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientCookieTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientCookieTests { - static var allTests: [(String, (HTTPClientCookieTests) -> () throws -> Void)] { - return [ - ("testCookie", testCookie), - ("testEmptyValueCookie", testEmptyValueCookie), - ("testCookieDefaults", testCookieDefaults), - ("testCookieInit", testCookieInit), - ("testMalformedCookies", testMalformedCookies), - ("testExpires", testExpires), - ("testMaxAge", testMaxAge), - ("testDomain", testDomain), - ("testPath", testPath), - ("testSecure", testSecure), - ("testHttpOnly", testHttpOnly), - ("testCookieExpiresDateParsing", testCookieExpiresDateParsing), - ("testQuotedCookies", testQuotedCookies), - ("testCookieExpiresDateParsingWithNonEnglishLocale", testCookieExpiresDateParsingWithNonEnglishLocale), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientCookieTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientCookieTests.swift index 8b4c9adf6..fa9abb9d8 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientCookieTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientCookieTests.swift @@ -19,7 +19,8 @@ import XCTest class HTTPClientCookieTests: XCTestCase { func testCookie() { - let v = "key=value; PaTh=/path; DoMaIn=EXampLE.CoM; eXpIRes=Wed, 21 Oct 2015 07:28:00 GMT; max-AGE=42; seCURE; HTTPOnly" + let v = + "key=value; PaTh=/path; DoMaIn=EXampLE.CoM; eXpIRes=Wed, 21 Oct 2015 07:28:00 GMT; max-AGE=42; seCURE; HTTPOnly" guard let c = HTTPClient.Cookie(header: v, defaultDomain: "exAMPle.cOm") else { XCTFail("Failed to parse cookie") return @@ -67,7 +68,16 @@ class HTTPClientCookieTests: XCTestCase { } func testCookieInit() { - let c = HTTPClient.Cookie(name: "key", value: "value", path: "/path", domain: "example.com", expires: Date(timeIntervalSince1970: 1_445_412_480), maxAge: 42, httpOnly: true, secure: true) + let c = HTTPClient.Cookie( + name: "key", + value: "value", + path: "/path", + domain: "example.com", + expires: Date(timeIntervalSince1970: 1_445_412_480), + maxAge: 42, + httpOnly: true, + secure: true + ) XCTAssertEqual("key", c.name) XCTAssertEqual("value", c.value) XCTAssertEqual("/path", c.path) @@ -118,17 +128,26 @@ class HTTPClientCookieTests: XCTestCase { XCTAssertNil(c?.expires) // Later values override earlier values, except if they are ignored. - c = HTTPClient.Cookie(header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires=04/01/2022", defaultDomain: "example.com") + c = HTTPClient.Cookie( + header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires=04/01/2022", + defaultDomain: "example.com" + ) XCTAssertEqual("key", c?.name) XCTAssertEqual("value", c?.value) XCTAssertEqual(Date(timeIntervalSince1970: 784_111_777), c?.expires) - c = HTTPClient.Cookie(header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires=", defaultDomain: "example.com") + c = HTTPClient.Cookie( + header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires=", + defaultDomain: "example.com" + ) XCTAssertEqual("key", c?.name) XCTAssertEqual("value", c?.value) XCTAssertEqual(Date(timeIntervalSince1970: 784_111_777), c?.expires) - c = HTTPClient.Cookie(header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires", defaultDomain: "example.com") + c = HTTPClient.Cookie( + header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires", + defaultDomain: "example.com" + ) XCTAssertEqual("key", c?.name) XCTAssertEqual("value", c?.value) XCTAssertEqual(Date(timeIntervalSince1970: 784_111_777), c?.expires) @@ -467,11 +486,17 @@ class HTTPClientCookieTests: XCTestCase { try XCTSkipIf(localeCheck.tm_mon != 1, "Unable to set locale") // Cookie parsing should be independent of C locale. - var c = HTTPClient.Cookie(header: "key=value; eXpIRes=Sunday, 06-Nov-94 08:49:37 GMT;", defaultDomain: "example.org") + var c = HTTPClient.Cookie( + header: "key=value; eXpIRes=Sunday, 06-Nov-94 08:49:37 GMT;", + defaultDomain: "example.org" + ) XCTAssertEqual(Date(timeIntervalSince1970: 784_111_777), c?.expires) c = HTTPClient.Cookie(header: "key=value; eXpIRes=Sun Nov 6 08:49:37 1994;", defaultDomain: "example.org")! XCTAssertEqual(Date(timeIntervalSince1970: 784_111_777), c?.expires) - c = HTTPClient.Cookie(header: "key=value; eXpIRes=Sonntag, 06-Nov-94 08:49:37 GMT;", defaultDomain: "example.org")! + c = HTTPClient.Cookie( + header: "key=value; eXpIRes=Sonntag, 06-Nov-94 08:49:37 GMT;", + defaultDomain: "example.org" + )! XCTAssertNil(c?.expires) } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests+XCTest.swift deleted file mode 100644 index 63d7f85e2..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests+XCTest.swift +++ /dev/null @@ -1,32 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientInformationalResponsesTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientReproTests { - static var allTests: [(String, (HTTPClientReproTests) -> () throws -> Void)] { - return [ - ("testServerSends100ContinueFirst", testServerSends100ContinueFirst), - ("testServerSendsSwitchingProtocols", testServerSendsSwitchingProtocols), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests.swift index f57d5fd10..5c41a6adb 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests.swift @@ -27,7 +27,10 @@ final class HTTPClientReproTests: XCTestCase { func channelRead(context: ChannelHandlerContext, data: NIOAny) { switch self.unwrapInboundIn(data) { case .head: - context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .continue))), promise: nil) + context.writeAndFlush( + self.wrapOutboundOut(.head(.init(version: .http1_1, status: .continue))), + promise: nil + ) case .body: break case .end: @@ -37,7 +40,7 @@ final class HTTPClientReproTests: XCTestCase { } } - let client = HTTPClient(eventLoopGroupProvider: .createNew) + let client = HTTPClient(eventLoopGroupProvider: .singleton) defer { XCTAssertNoThrow(try client.syncShutdown()) } let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in @@ -47,14 +50,16 @@ final class HTTPClientReproTests: XCTestCase { let body = #"{"foo": "bar"}"# var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "http://localhost:\(httpBin.port)/", - method: .POST, - headers: [ - "Content-Type": "application/json", - ], - body: .string(body) - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost:\(httpBin.port)/", + method: .POST, + headers: [ + "Content-Type": "application/json" + ], + body: .string(body) + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request here") } var logger = Logger(label: "test") @@ -73,10 +78,14 @@ final class HTTPClientReproTests: XCTestCase { func channelRead(context: ChannelHandlerContext, data: NIOAny) { switch self.unwrapInboundIn(data) { case .head: - let head = HTTPResponseHead(version: .http1_1, status: .switchingProtocols, headers: [ - "Connection": "Upgrade", - "Upgrade": "Websocket", - ]) + let head = HTTPResponseHead( + version: .http1_1, + status: .switchingProtocols, + headers: [ + "Connection": "Upgrade", + "Upgrade": "Websocket", + ] + ) let body = context.channel.allocator.buffer(string: "foo bar") context.write(self.wrapOutboundOut(.head(head)), promise: nil) @@ -91,7 +100,7 @@ final class HTTPClientReproTests: XCTestCase { } } - let client = HTTPClient(eventLoopGroupProvider: .createNew) + let client = HTTPClient(eventLoopGroupProvider: .singleton) defer { XCTAssertNoThrow(try client.syncShutdown()) } let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in @@ -101,14 +110,16 @@ final class HTTPClientReproTests: XCTestCase { let body = #"{"foo": "bar"}"# var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "http://localhost:\(httpBin.port)/", - method: .POST, - headers: [ - "Content-Type": "application/json", - ], - body: .string(body) - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost:\(httpBin.port)/", + method: .POST, + headers: [ + "Content-Type": "application/json" + ], + body: .string(body) + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request here") } var logger = Logger(label: "test") diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift deleted file mode 100644 index 3be2c79a6..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift +++ /dev/null @@ -1,41 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientInternalTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientInternalTests { - static var allTests: [(String, (HTTPClientInternalTests) -> () throws -> Void)] { - return [ - ("testProxyStreaming", testProxyStreaming), - ("testProxyStreamingFailure", testProxyStreamingFailure), - ("testRequestURITrailingSlash", testRequestURITrailingSlash), - ("testChannelAndDelegateOnDifferentEventLoops", testChannelAndDelegateOnDifferentEventLoops), - ("testResponseFutureIsOnCorrectEL", testResponseFutureIsOnCorrectEL), - ("testUncleanCloseThrows", testUncleanCloseThrows), - ("testUploadStreamingIsCalledOnTaskEL", testUploadStreamingIsCalledOnTaskEL), - ("testTaskPromiseBoundToEL", testTaskPromiseBoundToEL), - ("testConnectErrorCalloutOnCorrectEL", testConnectErrorCalloutOnCorrectEL), - ("testInternalRequestURI", testInternalRequestURI), - ("testHasSuffix", testHasSuffix), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index eb8d523bb..5b70699a0 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -12,15 +12,17 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +import NIOFoundationCompat import NIOHTTP1 import NIOPosix import NIOTestUtils import XCTest +@testable import AsyncHTTPClient + class HTTPClientInternalTests: XCTestCase { typealias Request = HTTPClient.Request typealias Task = HTTPClient.Task @@ -52,7 +54,7 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertNoThrow(try httpBin.shutdown()) } - let body: HTTPClient.Body = .stream(length: 50) { writer in + let body: HTTPClient.Body = .stream(contentLength: 50) { writer in do { var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/1") request.headers.add(name: "Accept", value: "text/event-stream") @@ -81,13 +83,13 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertNoThrow(try httpBin.shutdown()) } - var body: HTTPClient.Body = .stream(length: 50) { _ in + var body: HTTPClient.Body = .stream(contentLength: 50) { _ in httpClient.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidProxyResponse) } XCTAssertThrowsError(try httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait()) - body = .stream(length: 50) { _ in + body = .stream(contentLength: 50) { _ in do { var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/1") request.headers.add(name: "Accept", value: "text/event-stream") @@ -142,6 +144,25 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertEqual(request12.url.uri, "/some%2Fpathsegment1/pathsegment2") } + func testURIOfRelativeURLRequest() throws { + let requestNoLeadingSlash = try Request( + url: URL( + string: "percent%2Fencoded/hello", + relativeTo: URL(string: "http://127.0.0.1")! + )! + ) + + let requestWithLeadingSlash = try Request( + url: URL( + string: "/percent%2Fencoded/hello", + relativeTo: URL(string: "http://127.0.0.1")! + )! + ) + + XCTAssertEqual(requestNoLeadingSlash.url.uri, "/percent%2Fencoded/hello") + XCTAssertEqual(requestWithLeadingSlash.url.uri, "/percent%2Fencoded/hello") + } + func testChannelAndDelegateOnDifferentEventLoops() throws { class Delegate: HTTPClientResponseDelegate { typealias Response = ([Message], [Message]) @@ -185,15 +206,19 @@ class HTTPClientInternalTests: XCTestCase { self.receivedMessages.append(.error(error)) } - public func didReceiveHead(task: HTTPClient.Task, - _ head: HTTPResponseHead) -> EventLoopFuture { + public func didReceiveHead( + task: HTTPClient.Task, + _ head: HTTPResponseHead + ) -> EventLoopFuture { self.eventLoop.assertInEventLoop() self.receivedMessages.append(.head(head)) return self.randoEL.makeSucceededFuture(()) } - func didReceiveBodyPart(task: HTTPClient.Task, - _ buffer: ByteBuffer) -> EventLoopFuture { + func didReceiveBodyPart( + task: HTTPClient.Task, + _ buffer: ByteBuffer + ) -> EventLoopFuture { self.eventLoop.assertInEventLoop() self.receivedMessages.append(.bodyPart(buffer)) return self.randoEL.makeSucceededFuture(()) @@ -223,7 +248,7 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) } - let body: HTTPClient.Body = .stream(length: 8) { writer in + let body: HTTPClient.Body = .stream(contentLength: 8) { writer in let buffer = ByteBuffer(string: "1234") return writer.write(.byteBuffer(buffer)).flatMap { let buffer = ByteBuffer(string: "4321") @@ -231,22 +256,38 @@ class HTTPClientInternalTests: XCTestCase { } } - let request = try Request(url: "http://127.0.0.1:\(server.serverPort)/custom", - body: body) + let request = try Request( + url: "http://127.0.0.1:\(server.serverPort)/custom", + body: body + ) let delegate = Delegate(expectedEventLoop: delegateEL, randomOtherEventLoop: randoEL) - let future = httpClient.execute(request: request, - delegate: delegate, - eventLoop: .init(.testOnly_exact(channelOn: channelEL, - delegateOn: delegateEL))).futureResult - - XCTAssertNoThrow(try server.readInbound()) // .head - XCTAssertNoThrow(try server.readInbound()) // .body - XCTAssertNoThrow(try server.readInbound()) // .end + let future = httpClient.execute( + request: request, + delegate: delegate, + eventLoop: .init( + .testOnly_exact( + channelOn: channelEL, + delegateOn: delegateEL + ) + ) + ).futureResult + + XCTAssertNoThrow(try server.readInbound()) // .head + XCTAssertNoThrow(try server.readInbound()) // .body + XCTAssertNoThrow(try server.readInbound()) // .end // Send 3 parts, but only one should be received until the future is complete - XCTAssertNoThrow(try server.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), - status: .ok, - headers: HTTPHeaders([("Transfer-Encoding", "chunked")]))))) + XCTAssertNoThrow( + try server.writeOutbound( + .head( + .init( + version: .init(major: 1, minor: 1), + status: .ok, + headers: HTTPHeaders([("Transfer-Encoding", "chunked")]) + ) + ) + ) + ) let buffer = ByteBuffer(string: "1234") XCTAssertNoThrow(try server.writeOutbound(.body(.byteBuffer(buffer)))) XCTAssertNoThrow(try server.writeOutbound(.end(nil))) @@ -278,7 +319,7 @@ class HTTPClientInternalTests: XCTestCase { switch sentMessages.dropFirst(3).first { case .some(.sentRequest): - () // OK + () // OK default: XCTFail("wrong message") } @@ -316,7 +357,10 @@ class HTTPClientInternalTests: XCTestCase { let el = group.next() let req1 = client.execute(request: request, eventLoop: .delegate(on: el)) let req2 = client.execute(request: request, eventLoop: .delegateAndChannel(on: el)) - let req3 = client.execute(request: request, eventLoop: .init(.testOnly_exact(channelOn: el, delegateOn: el))) + let req3 = client.execute( + request: request, + eventLoop: .init(.testOnly_exact(channelOn: el, delegateOn: el)) + ) XCTAssert(req1.eventLoop === el) XCTAssert(req2.eventLoop === el) XCTAssert(req3.eventLoop === el) @@ -335,8 +379,8 @@ class HTTPClientInternalTests: XCTestCase { _ = httpClient.get(url: "http://localhost:\(server.serverPort)/wait") - XCTAssertNoThrow(try server.readInbound()) // .head - XCTAssertNoThrow(try server.readInbound()) // .end + XCTAssertNoThrow(try server.readInbound()) // .head + XCTAssertNoThrow(try server.readInbound()) // .end do { try httpClient.syncShutdown(requiresCleanClose: true) @@ -366,7 +410,7 @@ class HTTPClientInternalTests: XCTestCase { let el2 = group.next() XCTAssert(el1 !== el2) - let body: HTTPClient.Body = .stream(length: 8) { writer in + let body: HTTPClient.Body = .stream(contentLength: 8) { writer in XCTAssert(el1.inEventLoop) let buffer = ByteBuffer(string: "1234") return writer.write(.byteBuffer(buffer)).flatMap { @@ -376,10 +420,16 @@ class HTTPClientInternalTests: XCTestCase { } } let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/post", method: .POST, body: body) - let response = httpClient.execute(request: request, - delegate: ResponseAccumulator(request: request), - eventLoop: HTTPClient.EventLoopPreference(.testOnly_exact(channelOn: el2, - delegateOn: el1))) + let response = httpClient.execute( + request: request, + delegate: ResponseAccumulator(request: request), + eventLoop: HTTPClient.EventLoopPreference( + .testOnly_exact( + channelOn: el2, + delegateOn: el1 + ) + ) + ) XCTAssert(el1 === response.eventLoop) XCTAssertNoThrow(try response.wait()) } @@ -400,7 +450,11 @@ class HTTPClientInternalTests: XCTestCase { let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)//get") let delegate = ResponseAccumulator(request: request) - let task = client.execute(request: request, delegate: delegate, eventLoop: .init(.testOnly_exact(channelOn: el1, delegateOn: el2))) + let task = client.execute( + request: request, + delegate: delegate, + eventLoop: .init(.testOnly_exact(channelOn: el1, delegateOn: el2)) + ) XCTAssertTrue(task.futureResult.eventLoop === el2) XCTAssertNoThrow(try task.wait()) } @@ -429,7 +483,9 @@ class HTTPClientInternalTests: XCTestCase { let el2 = elg.next() let httpBin = HTTPBin(.refuse) - let client = HTTPClient(eventLoopGroupProvider: .shared(elg)) + let config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() + let client = HTTPClient(eventLoopGroupProvider: .shared(elg), configuration: config) defer { XCTAssertNoThrow(try client.syncShutdown()) @@ -439,7 +495,11 @@ class HTTPClientInternalTests: XCTestCase { let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get") let delegate = TestDelegate(expectedEL: el1) XCTAssertNoThrow(try httpBin.shutdown()) - let task = client.execute(request: request, delegate: delegate, eventLoop: .init(.testOnly_exact(channelOn: el2, delegateOn: el1))) + let task = client.execute( + request: request, + delegate: delegate, + eventLoop: .init(.testOnly_exact(channelOn: el2, delegateOn: el1)) + ) XCTAssertThrowsError(try task.wait()) XCTAssertTrue(delegate.receivedError) } @@ -472,10 +532,13 @@ class HTTPClientInternalTests: XCTestCase { let request6 = try Request(url: "https://127.0.0.1") XCTAssertEqual(request6.deconstructedURL.scheme, .https) - XCTAssertEqual(request6.deconstructedURL.connectionTarget, .ipAddress( - serialization: "127.0.0.1", - address: try! SocketAddress(ipAddress: "127.0.0.1", port: 443) - )) + XCTAssertEqual( + request6.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "127.0.0.1", + address: try! SocketAddress(ipAddress: "127.0.0.1", port: 443) + ) + ) XCTAssertEqual(request6.deconstructedURL.uri, "/") let request7 = try Request(url: "https://0x7F.1:9999") @@ -485,18 +548,24 @@ class HTTPClientInternalTests: XCTestCase { let request8 = try Request(url: "http://[::1]") XCTAssertEqual(request8.deconstructedURL.scheme, .http) - XCTAssertEqual(request8.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[::1]", - address: try! SocketAddress(ipAddress: "::1", port: 80) - )) + XCTAssertEqual( + request8.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[::1]", + address: try! SocketAddress(ipAddress: "::1", port: 80) + ) + ) XCTAssertEqual(request8.deconstructedURL.uri, "/") let request9 = try Request(url: "http://[763e:61d9::6ACA:3100:6274]:4242/foo/bar?baz") XCTAssertEqual(request9.deconstructedURL.scheme, .http) - XCTAssertEqual(request9.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[763e:61d9::6ACA:3100:6274]", - address: try! SocketAddress(ipAddress: "763e:61d9::6aca:3100:6274", port: 4242) - )) + XCTAssertEqual( + request9.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[763e:61d9::6ACA:3100:6274]", + address: try! SocketAddress(ipAddress: "763e:61d9::6aca:3100:6274", port: 4242) + ) + ) XCTAssertEqual(request9.deconstructedURL.uri, "/foo/bar?baz") // Some systems have quirks in their implementations of 'ntop' which cause them to write @@ -505,18 +574,24 @@ class HTTPClientInternalTests: XCTestCase { // so the serialization must be kept verbatim as it was given in the request. let request10 = try Request(url: "http://[::c0a8:1]:4242/foo/bar?baz") XCTAssertEqual(request10.deconstructedURL.scheme, .http) - XCTAssertEqual(request10.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[::c0a8:1]", - address: try! SocketAddress(ipAddress: "::c0a8:1", port: 4242) - )) + XCTAssertEqual( + request10.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[::c0a8:1]", + address: try! SocketAddress(ipAddress: "::c0a8:1", port: 4242) + ) + ) XCTAssertEqual(request10.deconstructedURL.uri, "/foo/bar?baz") let request11 = try Request(url: "http://[::192.168.0.1]:4242/foo/bar?baz") XCTAssertEqual(request11.deconstructedURL.scheme, .http) - XCTAssertEqual(request11.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[::192.168.0.1]", - address: try! SocketAddress(ipAddress: "::192.168.0.1", port: 4242) - )) + XCTAssertEqual( + request11.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[::192.168.0.1]", + address: try! SocketAddress(ipAddress: "::192.168.0.1", port: 4242) + ) + ) XCTAssertEqual(request11.deconstructedURL.uri, "/foo/bar?baz") } @@ -545,11 +620,55 @@ class HTTPClientInternalTests: XCTestCase { } // Empty collection. do { - let elements: Array = [] + let elements: [Int] = [] XCTAssertTrue(elements.hasSuffix([])) XCTAssertFalse(elements.hasSuffix([0])) XCTAssertFalse(elements.hasSuffix([42])) XCTAssertFalse(elements.hasSuffix([0, 0, 0])) } } + + /// test to verify that we actually share the same thread pool across all ``FileDownloadDelegate``s for a given ``HTTPClient`` + func testSharedThreadPoolIsIdenticalForAllDelegates() throws { + let httpBin = HTTPBin() + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/content-length") + request.headers.add(name: "Accept", value: "text/event-stream") + + let filePaths = (0..<10).map { _ in + TemporaryFileHelpers.makeTemporaryFilePath() + } + defer { + for filePath in filePaths { + TemporaryFileHelpers.removeTemporaryFile(at: filePath) + } + } + let delegates = try filePaths.map { + try FileDownloadDelegate(path: $0) + } + + let resultFutures = delegates.map { delegate in + httpClient.execute( + request: request, + delegate: delegate + ).futureResult + } + _ = try EventLoopFuture.whenAllSucceed(resultFutures, on: self.clientGroup.next()).wait() + let threadPools = delegates.map { $0.fileIOThreadPool } + let firstThreadPool = threadPools.first ?? nil + XCTAssert(threadPools.dropFirst().allSatisfy { $0 === firstThreadPool }) + } +} + +extension HTTPClient.Configuration { + func enableFastFailureModeForTesting() -> Self { + var copy = self + copy.networkFrameworkWaitForConnectivity = false + copy.connectionPool.retryConnectionEstablishment = false + return copy + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests+XCTest.swift deleted file mode 100644 index cc33f6aee..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests+XCTest.swift +++ /dev/null @@ -1,35 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientNIOTSTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientNIOTSTests { - static var allTests: [(String, (HTTPClientNIOTSTests) -> () throws -> Void)] { - return [ - ("testCorrectEventLoopGroup", testCorrectEventLoopGroup), - ("testTLSFailError", testTLSFailError), - ("testConnectionFailError", testConnectionFailError), - ("testTLSVersionError", testTLSVersionError), - ("testTrustRootCertificateLoadFail", testTrustRootCertificateLoadFail), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift index 172ee89ba..4c2d24dc4 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift @@ -12,16 +12,19 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient -#if canImport(Network) -import Network -#endif +import NIOConcurrencyHelpers import NIOCore import NIOPosix import NIOSSL import NIOTransportServices import XCTest +@testable import AsyncHTTPClient + +#if canImport(Network) +import Network +#endif + class HTTPClientNIOTSTests: XCTestCase { var clientGroup: EventLoopGroup! @@ -37,7 +40,7 @@ class HTTPClientNIOTSTests: XCTestCase { } func testCorrectEventLoopGroup() { - let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } @@ -54,7 +57,12 @@ class HTTPClientNIOTSTests: XCTestCase { guard isTestingNIOTS() else { return } let httpBin = HTTPBin(.http1_1(ssl: true)) - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) + let config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) @@ -65,8 +73,10 @@ class HTTPClientNIOTSTests: XCTestCase { _ = try httpClient.get(url: "https://localhost:\(httpBin.port)/get").wait() XCTFail("This should have failed") } catch let error as HTTPClient.NWTLSError { - XCTAssert(error.status == errSSLHandshakeFail || error.status == errSSLBadCert, - "unexpected NWTLSError with status \(error.status)") + XCTAssert( + error.status == errSSLHandshakeFail || error.status == errSSLBadCert, + "unexpected NWTLSError with status \(error.status)" + ) } catch { XCTFail("Error should have been NWTLSError not \(type(of: error))") } @@ -75,12 +85,44 @@ class HTTPClientNIOTSTests: XCTestCase { #endif } + func testConnectionFailsFastError() { + guard isTestingNIOTS() else { return } + #if canImport(Network) + let httpBin = HTTPBin(.http1_1(ssl: false)) + let config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() + + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) + + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + } + + let port = httpBin.port + XCTAssertNoThrow(try httpBin.shutdown()) + + XCTAssertThrowsError(try httpClient.get(url: "http://localhost:\(port)/get").wait()) { + XCTAssertTrue($0 is NWError) + } + #endif + } + func testConnectionFailError() { guard isTestingNIOTS() else { return } - let httpBin = HTTPBin(.http1_1(ssl: true)) - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(timeout: .init(connect: .milliseconds(100), - read: .milliseconds(100)))) + #if canImport(Network) + let httpBin = HTTPBin(.http1_1(ssl: false)) + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init( + timeout: .init( + connect: .milliseconds(100), + read: .milliseconds(100) + ) + ) + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) @@ -89,9 +131,16 @@ class HTTPClientNIOTSTests: XCTestCase { let port = httpBin.port XCTAssertNoThrow(try httpBin.shutdown()) - XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(port)/get").wait()) { - XCTAssertEqual($0 as? HTTPClientError, .connectTimeout) + XCTAssertThrowsError(try httpClient.get(url: "http://localhost:\(port)/get").wait()) { + if let httpClientError = $0 as? HTTPClientError { + XCTAssertEqual(httpClientError, .connectTimeout) + } else if let posixError = $0 as? HTTPClient.NWPOSIXError { + XCTAssertEqual(posixError.errorCode, .ECONNREFUSED) + } else { + XCTFail("unexpected error \($0)") + } } + #endif } func testTLSVersionError() { @@ -102,9 +151,12 @@ class HTTPClientNIOTSTests: XCTestCase { tlsConfig.certificateVerification = .none tlsConfig.minimumTLSVersion = .tlsv11 tlsConfig.maximumTLSVersion = .tlsv1 + + let clientConfig = HTTPClient.Configuration(tlsConfiguration: tlsConfig) + .enableFastFailureModeForTesting() let httpClient = HTTPClient( eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(tlsConfiguration: tlsConfig) + configuration: clientConfig ) defer { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) @@ -124,7 +176,7 @@ class HTTPClientNIOTSTests: XCTestCase { var tlsConfig = TLSConfiguration.makeClientConfiguration() tlsConfig.trustRoots = .file("not/a/certificate") - XCTAssertThrowsError(try tlsConfig.getNWProtocolTLSOptions()) { error in + XCTAssertThrowsError(try tlsConfig.getNWProtocolTLSOptions(serverNameIndicatorOverride: nil)) { error in switch error { case let error as NIOSSL.NIOSSLError where error == .failedToLoadCertificate: break diff --git a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientRequestTests+XCTest.swift deleted file mode 100644 index 30d93f7de..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests+XCTest.swift +++ /dev/null @@ -1,43 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientRequestTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientRequestTests { - static var allTests: [(String, (HTTPClientRequestTests) -> () throws -> Void)] { - return [ - ("testCustomHeadersAreRespected", testCustomHeadersAreRespected), - ("testUnixScheme", testUnixScheme), - ("testHTTPUnixScheme", testHTTPUnixScheme), - ("testHTTPSUnixScheme", testHTTPSUnixScheme), - ("testGetWithoutBody", testGetWithoutBody), - ("testPostWithoutBody", testPostWithoutBody), - ("testPostWithEmptyByteBuffer", testPostWithEmptyByteBuffer), - ("testPostWithByteBuffer", testPostWithByteBuffer), - ("testPostWithSequenceOfUnknownLength", testPostWithSequenceOfUnknownLength), - ("testPostWithSequenceWithFixedLength", testPostWithSequenceWithFixedLength), - ("testPostWithRandomAccessCollection", testPostWithRandomAccessCollection), - ("testPostWithAsyncSequenceOfUnknownLength", testPostWithAsyncSequenceOfUnknownLength), - ("testPostWithAsyncSequenceWithKnownLength", testPostWithAsyncSequenceWithKnownLength), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift index 1ebe7e939..a2cc3b108 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift @@ -12,58 +12,74 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient +import Algorithms import NIOCore +import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) class HTTPClientRequestTests: XCTestCase { - #if compiler(>=5.5.2) && canImport(_Concurrency) - @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) private typealias Request = HTTPClientRequest - @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) private typealias PreparedRequest = HTTPClientRequest.Prepared - #endif func testCustomHeadersAreRespected() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "https://example.com/get") request.headers = [ - "custom-header": "custom-header-value", + "custom-header": "custom-header-value" ] var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .https, - connectionTarget: .domain(name: "example.com", port: 443), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/get", - headers: [ - "host": "example.com", - "custom-header": "custom-header-value", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .https, + connectionTarget: .domain(name: "example.com", port: 443), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/get", + headers: [ + "host": "example.com", + "custom-header": "custom-header-value", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif + } + + func testBasicAuth() { + XCTAsyncTest { + var request = Request(url: "https://example.com/get") + request.setBasicAuth(username: "foo", password: "bar") + var preparedRequest: PreparedRequest? + XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) + guard let preparedRequest = preparedRequest else { return } + XCTAssertEqual(preparedRequest.head.headers.first(name: "Authorization")!, "Basic Zm9vOmJhcg==") + } } func testUnixScheme() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "unix://%2Fexample%2Ffolder.sock/some_path") request.headers = ["custom-header": "custom-value"] @@ -71,30 +87,37 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .unix, - connectionTarget: .unixSocket(path: "/some_path"), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/", - headers: ["custom-header": "custom-value"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .unix, + connectionTarget: .unixSocket(path: "/some_path"), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/", + headers: ["custom-header": "custom-value"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testHTTPUnixScheme() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http+unix://%2Fexample%2Ffolder.sock/some_path") request.headers = ["custom-header": "custom-value"] @@ -102,30 +125,37 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .httpUnix, - connectionTarget: .unixSocket(path: "/example/folder.sock"), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/some_path", - headers: ["custom-header": "custom-value"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .httpUnix, + connectionTarget: .unixSocket(path: "/example/folder.sock"), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/some_path", + headers: ["custom-header": "custom-value"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testHTTPSUnixScheme() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "https+unix://%2Fexample%2Ffolder.sock/some_path") request.headers = ["custom-header": "custom-value"] @@ -133,60 +163,74 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .httpsUnix, - connectionTarget: .unixSocket(path: "/example/folder.sock"), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/some_path", - headers: ["custom-header": "custom-value"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .httpsUnix, + connectionTarget: .unixSocket(path: "/example/folder.sock"), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/some_path", + headers: ["custom-header": "custom-value"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testGetWithoutBody() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let request = Request(url: "https://example.com/get") var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .https, - connectionTarget: .domain(name: "example.com", port: 443), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/get", - headers: ["host": "example.com"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .https, + connectionTarget: .domain(name: "example.com", port: 443), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/get", + headers: ["host": "example.com"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testPostWithoutBody() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST @@ -194,34 +238,41 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "0", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "0", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testPostWithEmptyByteBuffer() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST @@ -230,34 +281,41 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "0", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "0", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testPostWithByteBuffer() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST @@ -266,106 +324,127 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif } func testPostWithSequenceOfUnknownLength() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST - let sequence = AnySequence(ByteBuffer(string: "post body").readableBytesView) + let sequence = AnySendableSequence(ByteBuffer(string: "post body").readableBytesView) request.body = .bytes(sequence, length: .unknown) var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "transfer-encoding": "chunked", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .stream - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "transfer-encoding": "chunked", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .stream + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif } func testPostWithSequenceWithFixedLength() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST - let sequence = AnySequence(ByteBuffer(string: "post body").readableBytesView) - request.body = .bytes(sequence, length: .known(9)) + let sequence = AnySendableSequence(ByteBuffer(string: "post body").readableBytesView) + request.body = .bytes(sequence, length: .known(Int64(9))) var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif } func testPostWithRandomAccessCollection() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST @@ -375,40 +454,47 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif } func testPostWithAsyncSequenceOfUnknownLength() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST let asyncSequence = ByteBuffer(string: "post body") .readableBytesView - .chunked(maxChunkSize: 2) - .asAsyncSequence() + .chunks(ofCount: 2) + .async .map { ByteBuffer($0) } request.body = .stream(asyncSequence, length: .unknown) @@ -416,84 +502,258 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "transfer-encoding": "chunked", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .stream - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "transfer-encoding": "chunked", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .stream + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif } func testPostWithAsyncSequenceWithKnownLength() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST let asyncSequence = ByteBuffer(string: "post body") .readableBytesView - .chunked(maxChunkSize: 2) - .asAsyncSequence() + .chunks(ofCount: 2) + .async .map { ByteBuffer($0) } - request.body = .stream(asyncSequence, length: .known(9)) + request.body = .stream(asyncSequence, length: .known(Int64(9))) var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif + } + + func testChunkingRandomAccessCollection() async throws { + let body = try await HTTPClientRequest.Body.bytes( + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + ).collect() + + let expectedChunks = [ + ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize), + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testChunkingCollection() async throws { + let body = try await HTTPClientRequest.Body.bytes( + (String(repeating: "0", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "1", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "2", count: bagOfBytesToByteBufferConversionChunkSize)).utf8, + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)) + ).collect() + + let expectedChunks = [ + ByteBuffer(repeating: UInt8(ascii: "0"), count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: UInt8(ascii: "1"), count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: UInt8(ascii: "2"), count: bagOfBytesToByteBufferConversionChunkSize), + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testChunkingSequenceThatDoesNotImplementWithContiguousStorageIfAvailable() async throws { + let bagOfBytesToByteBufferConversionChunkSize = 8 + let body = try await HTTPClientRequest.Body._bytes( + AnySequence( + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + ), + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)), + bagOfBytesToByteBufferConversionChunkSize: bagOfBytesToByteBufferConversionChunkSize, + byteBufferMaxSize: byteBufferMaxSize + ).collect() + + let expectedChunks = [ + ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize), + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testChunkingSequenceFastPath() async throws { + func makeBytes() -> some Sequence & Sendable { + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + } + let body = try await HTTPClientRequest.Body.bytes( + makeBytes(), + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)) + ).collect() + + var firstChunk = ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + firstChunk.writeImmutableBuffer(ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize)) + firstChunk.writeImmutableBuffer(ByteBuffer(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize)) + let expectedChunks = [ + firstChunk + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testChunkingSequenceFastPathExceedingByteBufferMaxSize() async throws { + let bagOfBytesToByteBufferConversionChunkSize = 8 + let byteBufferMaxSize = 16 + func makeBytes() -> some Sequence & Sendable { + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + } + let body = try await HTTPClientRequest.Body._bytes( + makeBytes(), + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)), + bagOfBytesToByteBufferConversionChunkSize: bagOfBytesToByteBufferConversionChunkSize, + byteBufferMaxSize: byteBufferMaxSize + ).collect() + + var firstChunk = ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + firstChunk.writeImmutableBuffer(ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize)) + let secondChunk = ByteBuffer(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + let expectedChunks = [ + firstChunk, + secondChunk, + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testBodyStringChunking() throws { + let body = try HTTPClient.Body.string( + String(repeating: "0", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "1", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "2", count: bagOfBytesToByteBufferConversionChunkSize) + ).collect().wait() + + let expectedChunks = [ + ByteBuffer(), // We're currently emitting an empty chunk first. + ByteBuffer(repeating: UInt8(ascii: "0"), count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: UInt8(ascii: "1"), count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: UInt8(ascii: "2"), count: bagOfBytesToByteBufferConversionChunkSize), + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testBodyChunkingRandomAccessCollection() throws { + let body = try HTTPClient.Body.bytes( + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + ).collect().wait() + + let expectedChunks = [ + ByteBuffer(), // We're currently emitting an empty chunk first. + ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize), + ] + + XCTAssertEqual(body, expectedChunks) + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension AsyncSequence { + func collect() async throws -> [Element] { + try await self.reduce(into: []) { $0 += CollectionOfOne($1) } + } +} + +extension HTTPClient.Body { + func collect() -> EventLoopFuture<[ByteBuffer]> { + let eelg = EmbeddedEventLoopGroup(loops: 1) + let el = eelg.next() + var body = [ByteBuffer]() + let writer = StreamWriter { + switch $0 { + case .byteBuffer(let byteBuffer): + body.append(byteBuffer) + case .fileRegion: + fatalError("file region not supported") + } + return el.makeSucceededVoidFuture() + } + return self.stream(writer).map { _ in body } } } -#if compiler(>=5.5.2) && canImport(_Concurrency) private struct LengthMismatch: Error { - var announcedLength: Int - var actualLength: Int + var announcedLength: Int64 + var actualLength: Int64 } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension Optional where Wrapped == HTTPClientRequest.Body { +extension Optional where Wrapped == HTTPClientRequest.Prepared.Body { /// Accumulates all data from `self` into a single `ByteBuffer` and checks that the user specified length matches /// the length of the accumulated data. fileprivate func read() async throws -> ByteBuffer { - switch self?.mode { + switch self { case .none: return ByteBuffer() case .byteBuffer(let buffer): @@ -501,8 +761,9 @@ extension Optional where Wrapped == HTTPClientRequest.Body { case .sequence(let announcedLength, _, let generate): let buffer = generate(ByteBufferAllocator()) if case .known(let announcedLength) = announcedLength, - announcedLength != buffer.readableBytes { - throw LengthMismatch(announcedLength: announcedLength, actualLength: buffer.readableBytes) + announcedLength != Int64(buffer.readableBytes) + { + throw LengthMismatch(announcedLength: announcedLength, actualLength: Int64(buffer.readableBytes)) } return buffer case .asyncSequence(length: let announcedLength, let generate): @@ -511,67 +772,14 @@ extension Optional where Wrapped == HTTPClientRequest.Body { accumulatedBuffer.writeBuffer(&buffer) } if case .known(let announcedLength) = announcedLength, - announcedLength != accumulatedBuffer.readableBytes { - throw LengthMismatch(announcedLength: announcedLength, actualLength: accumulatedBuffer.readableBytes) + announcedLength != Int64(accumulatedBuffer.readableBytes) + { + throw LengthMismatch( + announcedLength: announcedLength, + actualLength: Int64(accumulatedBuffer.readableBytes) + ) } return accumulatedBuffer } } } - -struct ChunkedSequence: Sequence { - struct Iterator: IteratorProtocol { - fileprivate var remainingElements: Wrapped.SubSequence - fileprivate let maxChunkSize: Int - mutating func next() -> Wrapped.SubSequence? { - guard !self.remainingElements.isEmpty else { - return nil - } - let chunk = self.remainingElements.prefix(self.maxChunkSize) - self.remainingElements = self.remainingElements.dropFirst(self.maxChunkSize) - return chunk - } - } - - fileprivate let wrapped: Wrapped - fileprivate let maxChunkSize: Int - - func makeIterator() -> Iterator { - .init(remainingElements: self.wrapped[...], maxChunkSize: self.maxChunkSize) - } -} - -extension Collection { - /// Lazily splits `self` into `SubSequence`s with `maxChunkSize` elements. - /// - Parameter maxChunkSize: size of each chunk except the last one which can be smaller if not enough elements are remaining. - func chunked(maxChunkSize: Int) -> ChunkedSequence { - .init(wrapped: self, maxChunkSize: maxChunkSize) - } -} - -@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -struct AsyncSequenceFromSyncSequence: AsyncSequence { - typealias Element = Wrapped.Element - struct AsyncIterator: AsyncIteratorProtocol { - fileprivate var iterator: Wrapped.Iterator - mutating func next() async throws -> Wrapped.Element? { - self.iterator.next() - } - } - - fileprivate let wrapped: Wrapped - - func makeAsyncIterator() -> AsyncIterator { - .init(iterator: self.wrapped.makeIterator()) - } -} - -@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension Sequence { - /// Turns `self` into an `AsyncSequence` by wending each element of `self` asynchronously. - func asAsyncSequence() -> AsyncSequenceFromSyncSequence { - .init(wrapped: self) - } -} - -#endif diff --git a/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift new file mode 100644 index 000000000..7dcc4efe6 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOHTTP1 +import XCTest + +@testable import AsyncHTTPClient + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +final class HTTPClientResponseTests: XCTestCase { + func testSimpleResponse() { + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .GET, + headers: ["content-length": "1025"], + status: .ok + ) + XCTAssertEqual(response, 1025) + } + + func testSimpleResponseNotModified() { + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .GET, + headers: ["content-length": "1025"], + status: .notModified + ) + XCTAssertEqual(response, 0) + } + + func testSimpleResponseHeadRequestMethod() { + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .HEAD, + headers: ["content-length": "1025"], + status: .ok + ) + XCTAssertEqual(response, 0) + } + + func testResponseNoContentLengthHeader() { + let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: [:], status: .ok) + XCTAssertEqual(response, nil) + } + + func testResponseInvalidInteger() { + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .GET, + headers: ["content-length": "none"], + status: .ok + ) + XCTAssertEqual(response, nil) + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 230c91a2b..4f08bc4f5 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -import AsyncHTTPClient +import Atomics import Foundation import Logging import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +import NIOFoundationCompat import NIOHPACK import NIOHTTP1 import NIOHTTP2 @@ -27,8 +28,19 @@ import NIOSSL import NIOTLS import NIOTransportServices import XCTest -#if canImport(Darwin) + +@testable import AsyncHTTPClient + +#if canImport(xlocale) +import xlocale +#elseif canImport(locale_h) +import locale_h +#elseif canImport(Darwin) import Darwin +#elseif canImport(Musl) +import Musl +#elseif canImport(Android) +import Android #elseif canImport(Glibc) import Glibc #endif @@ -45,7 +57,8 @@ func isTestingNIOTS() -> Bool { func getDefaultEventLoopGroup(numberOfThreads: Int) -> EventLoopGroup { #if canImport(Network) if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), - isTestingNIOTS() { + isTestingNIOTS() + { return NIOTSEventLoopGroup(loopCount: numberOfThreads, defaultQoS: .default) } #endif @@ -137,7 +150,7 @@ class CountingDelegate: HTTPClientResponseDelegate { } func didFinishRequest(task: HTTPClient.Task) throws -> Int { - return self.count + self.count } } @@ -212,8 +225,8 @@ enum TemporaryFileHelpers { } else { return "/tmp" } - #endif // os - #endif // targetEnvironment + #endif // os + #endif // targetEnvironment } private static func openTemporaryFile() -> (CInt, String) { @@ -233,8 +246,10 @@ enum TemporaryFileHelpers { /// /// If the temporary directory is too long to store a UNIX domain socket path, it will `chdir` into the temporary /// directory and return a short-enough path. The iOS simulator is known to have too long paths. - internal static func withTemporaryUnixDomainSocketPathName(directory: String = temporaryDirectory, - _ body: (String) throws -> T) throws -> T { + internal static func withTemporaryUnixDomainSocketPathName( + directory: String = temporaryDirectory, + _ body: (String) throws -> T + ) throws -> T { // this is racy but we're trying to create the shortest possible path so we can't add a directory... let (fd, path) = self.openTemporaryFile() close(fd) @@ -249,10 +264,14 @@ enum TemporaryFileHelpers { shortEnoughPath = path restoreSavedCWD = false } catch SocketAddressError.unixDomainSocketPathTooLong { - FileManager.default.changeCurrentDirectoryPath(URL(fileURLWithPath: path).deletingLastPathComponent().absoluteString) + FileManager.default.changeCurrentDirectoryPath( + URL(fileURLWithPath: path).deletingLastPathComponent().absoluteString + ) shortEnoughPath = URL(fileURLWithPath: path).lastPathComponent restoreSavedCWD = true - print("WARNING: Path '\(path)' could not be used as UNIX domain socket path, using chdir & '\(shortEnoughPath)'") + print( + "WARNING: Path '\(path)' could not be used as UNIX domain socket path, using chdir & '\(shortEnoughPath)'" + ) } defer { if FileManager.default.fileExists(atPath: path) { @@ -282,23 +301,46 @@ enum TemporaryFileHelpers { return try body(path) } + internal static func makeTemporaryFilePath( + directory: String = temporaryDirectory + ) -> String { + let (fd, path) = self.openTemporaryFile() + close(fd) + try! FileManager.default.removeItem(atPath: path) + return path + } + + internal static func removeTemporaryFile( + at path: String + ) { + if FileManager.default.fileExists(atPath: path) { + try? FileManager.default.removeItem(atPath: path) + } + } + internal static func fileSize(path: String) throws -> Int? { - return try FileManager.default.attributesOfItem(atPath: path)[.size] as? Int + try FileManager.default.attributesOfItem(atPath: path)[.size] as? Int } internal static func fileExists(path: String) -> Bool { - return FileManager.default.fileExists(atPath: path) + FileManager.default.fileExists(atPath: path) } } enum TestTLS { static let certificate = try! NIOSSLCertificate(bytes: Array(cert.utf8), format: .pem) static let privateKey = try! NIOSSLPrivateKey(bytes: Array(key.utf8), format: .pem) + static let serverConfiguration: TLSConfiguration = .makeServerConfiguration( + certificateChain: [.certificate(TestTLS.certificate)], + privateKey: .privateKey(TestTLS.privateKey) + ) } -internal final class HTTPBin where +internal final class HTTPBin +where RequestHandler.InboundIn == HTTPServerRequestPart, - RequestHandler.OutboundOut == HTTPServerResponsePart { + RequestHandler.OutboundOut == HTTPServerResponsePart +{ enum BindTarget { case unixDomainSocket(String) case localhostIPv4RandomPort @@ -309,19 +351,51 @@ internal final class HTTPBin where // refuses all connections case refuse // supports http1.1 connections only, which can be either plain text or encrypted - case http1_1(ssl: Bool = false, compress: Bool = false) + case http1_1( + tlsConfiguration: TLSConfiguration? = nil, + compress: Bool = false + ) // supports http1.1 and http2 connections which must be always encrypted - case http2(compress: Bool) + case http2( + tlsConfiguration: TLSConfiguration = TestTLS.serverConfiguration, + compress: Bool = false, + settings: HTTP2Settings? = nil + ) + + static func http1_1(ssl: Bool, compress: Bool = false) -> Self { + .http1_1(tlsConfiguration: ssl ? TestTLS.serverConfiguration : nil, compress: compress) + } // supports request decompression and http response compression var compress: Bool { switch self { case .refuse: return false - case .http1_1(ssl: _, compress: let compress), .http2(compress: let compress): + case .http1_1(_, let compress), .http2(_, let compress, _): return compress } } + + var httpSettings: HTTP2Settings { + switch self { + case .http1_1, .http2(_, _, nil), .refuse: + return HTTP2Connection.defaultSettings + case .http2(_, _, .some(let customSettings)): + return customSettings + } + } + + var tlsConfiguration: TLSConfiguration? { + switch self { + case .refuse: + return nil + case .http1_1(let tlsConfiguration, _): + return tlsConfiguration + case .http2(var tlsConfiguration, _, _): + tlsConfiguration.applicationProtocols = NIOHTTP2SupportedALPNProtocols + return tlsConfiguration + } + } } enum Proxy { @@ -333,31 +407,56 @@ internal final class HTTPBin where private let activeConnCounterHandler: ConnectionsCountHandler var activeConnections: Int { - return self.activeConnCounterHandler.currentlyActiveConnections + self.activeConnCounterHandler.currentlyActiveConnections } var createdConnections: Int { - return self.activeConnCounterHandler.createdConnections + self.activeConnCounterHandler.createdConnections } var port: Int { - return Int(self.serverChannel.localAddress!.port!) + Int(self.serverChannel.localAddress!.port!) } var socketAddress: SocketAddress { - return self.serverChannel.localAddress! + self.serverChannel.localAddress! + } + + var baseURL: String { + let scheme: String = { + switch mode { + case .http1_1, .refuse: + return "http" + case .http2: + return "https" + } + }() + let host: String = { + switch self.socketAddress { + case .v4: + return self.socketAddress.ipAddress! + case .v6: + return "[\(self.socketAddress.ipAddress!)]" + case .unixDomainSocket: + return self.socketAddress.pathname! + } + }() + + return "\(scheme)://\(host):\(self.port)/" } private let mode: Mode private let sslContext: NIOSSLContext? private var serverChannel: Channel! - private let isShutdown: NIOAtomic = .makeAtomic(value: false) + private let isShutdown = ManagedAtomic(false) private let handlerFactory: (Int) -> (RequestHandler) init( _ mode: Mode = .http1_1(ssl: false, compress: false), proxy: Proxy = .none, bindTarget: BindTarget = .localhostIPv4RandomPort, + reusePort: Bool = false, + trafficShapingTargetBytesPerSecond: Int? = nil, handlerFactory: @escaping (Int) -> (RequestHandler) ) { self.mode = mode @@ -376,15 +475,26 @@ internal final class HTTPBin where self.activeConnCounterHandler = ConnectionsCountHandler() - let connectionIDAtomic = NIOAtomic.makeAtomic(value: 0) + let connectionIDAtomic = ManagedAtomic(0) self.serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .serverChannelOption( + ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), + value: reusePort ? 1 : 0 + ) .serverChannelInitializer { channel in channel.pipeline.addHandler(self.activeConnCounterHandler) }.childChannelInitializer { channel in + if let trafficShapingTargetBytesPerSecond = trafficShapingTargetBytesPerSecond { + try! channel.pipeline.syncOperations.addHandler( + BasicInboundTrafficShapingHandler( + targetBytesPerSecond: trafficShapingTargetBytesPerSecond + ) + ) + } do { - let connectionID = connectionIDAtomic.add(1) + let connectionID = connectionIDAtomic.loadThenWrappingIncrement(ordering: .relaxed) if case .refuse = mode { throw HTTPBinError.refusedConnection @@ -430,18 +540,18 @@ internal final class HTTPBin where let responseEncoder = HTTPResponseEncoder() let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) - let proxySimulator = HTTPProxySimulator(promise: promise, expectedAuhorization: expectedAuthorization) + let proxySimulator = HTTPProxySimulator(promise: promise, expectedAuthorization: expectedAuthorization) try sync.addHandler(responseEncoder) try sync.addHandler(requestDecoder) try sync.addHandler(proxySimulator) - promise.futureResult.flatMap { _ in - channel.pipeline.removeHandler(proxySimulator) + promise.futureResult.assumeIsolated().flatMap { _ in + channel.pipeline.syncOperations.removeHandler(proxySimulator) }.flatMap { _ in - channel.pipeline.removeHandler(responseEncoder) + channel.pipeline.syncOperations.removeHandler(responseEncoder) }.flatMap { _ in - channel.pipeline.removeHandler(requestDecoder) + channel.pipeline.syncOperations.removeHandler(requestDecoder) }.whenComplete { result in switch result { case .failure: @@ -484,30 +594,8 @@ internal final class HTTPBin where } } - private static func tlsConfiguration(for mode: Mode) -> TLSConfiguration? { - var configuration: TLSConfiguration? - - switch mode { - case .refuse, .http1_1(ssl: false, compress: _): - break - case .http2: - configuration = .makeServerConfiguration( - certificateChain: [.certificate(TestTLS.certificate)], - privateKey: .privateKey(TestTLS.privateKey) - ) - configuration!.applicationProtocols = NIOHTTP2SupportedALPNProtocols - case .http1_1(ssl: true, compress: _): - configuration = .makeServerConfiguration( - certificateChain: [.certificate(TestTLS.certificate)], - privateKey: .privateKey(TestTLS.privateKey) - ) - } - - return configuration - } - private static func sslContext(for mode: Mode) -> NIOSSLContext? { - if let tlsConfiguration = self.tlsConfiguration(for: mode) { + if let tlsConfiguration = mode.tlsConfiguration { return try! NIOSSLContext(configuration: tlsConfiguration) } return nil @@ -524,20 +612,20 @@ internal final class HTTPBin where // Successful upgrade to HTTP/2. Let the user configure the pipeline. let http2Handler = NIOHTTP2Handler( mode: .server, - initialSettings: [ - // TODO: make max concurrent streams configurable - HTTP2Setting(parameter: .maxConcurrentStreams, value: 10), - HTTP2Setting(parameter: .maxHeaderListSize, value: HPACKDecoder.defaultMaxHeaderListSize), - ] + initialSettings: self.mode.httpSettings ) let multiplexer = HTTP2StreamMultiplexer( mode: .server, channel: channel, + targetWindowSize: 16 * 1024 * 1024, // 16 MiB inboundStreamInitializer: { channel in do { let sync = channel.pipeline.syncOperations try sync.addHandler(HTTP2FramePayloadToHTTP1ServerCodec()) + if self.mode.compress { + try sync.addHandler(HTTPResponseCompressor()) + } try sync.addHandler(self.handlerFactory(connectionID)) return channel.eventLoop.makeSucceededVoidFuture() @@ -567,17 +655,17 @@ internal final class HTTPBin where } } + try channel.pipeline.syncOperations.addHandler(sslHandler) try channel.pipeline.syncOperations.addHandler(alpnHandler) - try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(alpnHandler)) } func shutdown() throws { - self.isShutdown.store(true) + self.isShutdown.store(true, ordering: .relaxed) try self.group.syncShutdownGracefully() } deinit { - assert(self.isShutdown.load(), "HTTPBin not shutdown before deinit") + assert(self.isShutdown.load(ordering: .relaxed), "HTTPBin not shutdown before deinit") } } @@ -585,9 +673,17 @@ extension HTTPBin where RequestHandler == HTTPBinHandler { convenience init( _ mode: Mode = .http1_1(ssl: false, compress: false), proxy: Proxy = .none, - bindTarget: BindTarget = .localhostIPv4RandomPort + bindTarget: BindTarget = .localhostIPv4RandomPort, + reusePort: Bool = false, + trafficShapingTargetBytesPerSecond: Int? = nil ) { - self.init(mode, proxy: proxy, bindTarget: bindTarget) { HTTPBinHandler(connectionID: $0) } + self.init( + mode, + proxy: proxy, + bindTarget: bindTarget, + reusePort: reusePort, + trafficShapingTargetBytesPerSecond: trafficShapingTargetBytesPerSecond + ) { HTTPBinHandler(connectionID: $0) } } } @@ -603,14 +699,18 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { // the promise to succeed, once the proxy connection is setup let promise: EventLoopPromise - let expectedAuhorization: String? + let expectedAuthorization: String? var head: HTTPResponseHead - init(promise: EventLoopPromise, expectedAuhorization: String?) { + init(promise: EventLoopPromise, expectedAuthorization: String?) { self.promise = promise - self.expectedAuhorization = expectedAuhorization - self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0")])) + self.expectedAuthorization = expectedAuthorization + self.head = HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: .ok, + headers: .init([("Content-Length", "0")]) + ) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -622,9 +722,10 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { return } - if let expectedAuhorization = self.expectedAuhorization { + if let expectedAuthorization = self.expectedAuthorization { guard let authorization = head.headers["proxy-authorization"].first, - expectedAuhorization == authorization else { + expectedAuthorization == authorization + else { self.head.status = .proxyAuthenticationRequired return } @@ -648,12 +749,31 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { internal struct HTTPResponseBuilder { var head: HTTPResponseHead var body: ByteBuffer? + var requestBodyByteCount: Int + let responseBodyIsRequestBodyByteCount: Bool - init(_ version: HTTPVersion = HTTPVersion(major: 1, minor: 1), status: HTTPResponseStatus, headers: HTTPHeaders = HTTPHeaders()) { + init( + _ version: HTTPVersion = HTTPVersion(major: 1, minor: 1), + status: HTTPResponseStatus, + headers: HTTPHeaders = HTTPHeaders(), + responseBodyIsRequestBodyByteCount: Bool = false + ) { self.head = HTTPResponseHead(version: version, status: status, headers: headers) + self.requestBodyByteCount = 0 + self.responseBodyIsRequestBodyByteCount = responseBodyIsRequestBodyByteCount } mutating func add(_ part: ByteBuffer) { + self.requestBodyByteCount += part.readableBytes + guard !self.responseBodyIsRequestBodyByteCount else { + if self.body == nil { + self.body = ByteBuffer() + self.body!.reserveCapacity(100) + } + self.body!.clear() + self.body!.writeString("\(self.requestBodyByteCount)") + return + } if var body = body { var part = part body.writeBuffer(&part) @@ -701,8 +821,10 @@ internal final class HTTPBinHandler: ChannelInboundHandler { for header in head.headers { let needle = "x-send-back-header-" if header.name.lowercased().starts(with: needle) { - self.responseHeaders.add(name: String(header.name.dropFirst(needle.count)), - value: header.value) + self.responseHeaders.add( + name: String(header.name.dropFirst(needle.count)), + value: header.value + ) } } } @@ -715,7 +837,12 @@ internal final class HTTPBinHandler: ChannelInboundHandler { headers = HTTPHeaders() } - context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil) + context.write( + wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) for i in 0..<10 { let msg = "id: \(i)" var buf = context.channel.allocator.buffer(capacity: msg.count) @@ -730,7 +857,12 @@ internal final class HTTPBinHandler: ChannelInboundHandler { // This tests receiving chunks very fast: please do not insert delays here! let headers = HTTPHeaders([("Transfer-Encoding", "chunked")]) - context.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil) + context.write( + self.wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) for i in 0..<10 { let msg = "id: \(i)" var buf = context.channel.allocator.buffer(capacity: msg.count) @@ -741,6 +873,27 @@ internal final class HTTPBinHandler: ChannelInboundHandler { context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } + func writeManyChunks(context: ChannelHandlerContext) { + // This tests receiving a lot of tiny chunks: they must all be sent in a single flush or the test doesn't work. + let headers = HTTPHeaders([("Transfer-Encoding", "chunked")]) + + context.write( + self.wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) + let message = ByteBuffer(integer: UInt8(ascii: "a")) + + // This number (10k) is load-bearing and a bit magic: it has been experimentally verified as being sufficient to blow the stack + // in the old implementation on all testing platforms. Please don't change it without good reason. + for _ in 0..<10_000 { + context.write(wrapOutboundOut(.body(.byteBuffer(message))), promise: nil) + } + + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.isServingRequest = true switch self.unwrapInboundIn(data) { @@ -789,6 +942,13 @@ internal final class HTTPBinHandler: ChannelInboundHandler { } self.resps.append(HTTPResponseBuilder(status: .ok)) return + case "/post-respond-with-byte-count": + if req.method != .POST { + self.resps.append(HTTPResponseBuilder(status: .methodNotAllowed)) + return + } + self.resps.append(HTTPResponseBuilder(status: .ok, responseBodyIsRequestBodyByteCount: true)) + return case "/redirect/302": var headers = self.responseHeaders headers.add(name: "location", value: "/ok") @@ -849,9 +1009,12 @@ internal final class HTTPBinHandler: ChannelInboundHandler { context.close(promise: nil) return case "/custom": - context.writeAndFlush(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), promise: nil) + context.writeAndFlush( + wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), + promise: nil + ) return - case "/events/10/1": // TODO: parse path + case "/events/10/1": // TODO: parse path self.writeEvents(context: context) return case "/events/10/content-length": @@ -859,6 +1022,9 @@ internal final class HTTPBinHandler: ChannelInboundHandler { case "/chunked": self.writeChunked(context: context) return + case "/mega-chunked": + self.writeManyChunks(context: context) + return case "/close-on-response": var headers = self.responseHeaders headers.replaceOrAdd(name: "connection", value: "close") @@ -869,8 +1035,23 @@ internal final class HTTPBinHandler: ChannelInboundHandler { // We're forcing this closed now. self.shouldClose = true self.resps.append(builder) + case "/content-length-without-body": + var headers = self.responseHeaders + headers.replaceOrAdd(name: "content-length", value: "1234") + context.writeAndFlush( + wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) + return default: - context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .notFound))), promise: nil) + context.write( + wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .notFound)) + ), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) return } @@ -889,18 +1070,26 @@ internal final class HTTPBinHandler: ChannelInboundHandler { response.head.headers.add(contentsOf: self.responseHeaders) context.write(wrapOutboundOut(.head(response.head)), promise: nil) if let body = response.body { - let requestInfo = RequestInfo(data: String(buffer: body), - requestNumber: self.requestId, - connectionNumber: self.connectionID) - let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo, - allocator: context.channel.allocator) + let requestInfo = RequestInfo( + data: String(buffer: body), + requestNumber: self.requestId, + connectionNumber: self.connectionID + ) + let responseBody = try! JSONEncoder().encodeAsByteBuffer( + requestInfo, + allocator: context.channel.allocator + ) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) } else { - let requestInfo = RequestInfo(data: "", - requestNumber: self.requestId, - connectionNumber: self.connectionID) - let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo, - allocator: context.channel.allocator) + let requestInfo = RequestInfo( + data: "", + requestNumber: self.requestId, + connectionNumber: self.connectionID + ) + let responseBody = try! JSONEncoder().encodeAsByteBuffer( + requestInfo, + allocator: context.channel.allocator + ) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) } context.eventLoop.scheduleTask(in: self.delay) { @@ -913,8 +1102,9 @@ internal final class HTTPBinHandler: ChannelInboundHandler { self.isServingRequest = false switch result { case .success: - if self.responseHeaders[canonicalForm: "X-Close-Connection"].contains("true") || - self.shouldClose { + if self.responseHeaders[canonicalForm: "X-Close-Connection"].contains("true") + || self.shouldClose + { context.close(promise: nil) } case .failure(let error): @@ -946,24 +1136,24 @@ internal final class HTTPBinHandler: ChannelInboundHandler { final class ConnectionsCountHandler: ChannelInboundHandler { typealias InboundIn = Channel - private let activeConns = NIOAtomic.makeAtomic(value: 0) - private let createdConns = NIOAtomic.makeAtomic(value: 0) + private let activeConns = ManagedAtomic(0) + private let createdConns = ManagedAtomic(0) var createdConnections: Int { - self.createdConns.load() + self.createdConns.load(ordering: .relaxed) } var currentlyActiveConnections: Int { - self.activeConns.load() + self.activeConns.load(ordering: .relaxed) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { let channel = self.unwrapInboundIn(data) - _ = self.activeConns.add(1) - _ = self.createdConns.add(1) + _ = self.activeConns.loadThenWrappingIncrement(ordering: .relaxed) + _ = self.createdConns.loadThenWrappingIncrement(ordering: .relaxed) channel.closeFuture.whenComplete { _ in - _ = self.activeConns.sub(1) + _ = self.activeConns.loadThenWrappingDecrement(ordering: .relaxed) } context.fireChannelRead(data) @@ -1017,6 +1207,32 @@ internal final class CloseWithoutClosingServerHandler: ChannelInboundHandler { } } +final class ExpectClosureServerHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + private let onClosePromise: EventLoopPromise + + init(onClosePromise: EventLoopPromise) { + self.onClosePromise = onClosePromise + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.unwrapInboundIn(data) { + case .head: + let head = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "0"]) + context.write(self.wrapOutboundOut(.head(head)), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .body, .end: + () + } + } + + func channelInactive(context: ChannelHandlerContext) { + self.onClosePromise.succeed(()) + } +} + struct EventLoopFutureTimeoutError: Error {} extension EventLoopFuture { @@ -1052,12 +1268,12 @@ struct CollectEverythingLogHandler: LogHandler { var metadata: [String: String] } - var lock = Lock() + var lock = NIOLock() var logs: [Entry] = [] var allEntries: [Entry] { get { - return self.lock.withLock { self.logs } + self.lock.withLock { self.logs } } set { self.lock.withLock { self.logs = newValue } @@ -1066,9 +1282,13 @@ struct CollectEverythingLogHandler: LogHandler { func append(level: Logger.Level, message: Logger.Message, metadata: Logger.Metadata?) { self.lock.withLock { - self.logs.append(Entry(level: level, - message: message.description, - metadata: metadata?.mapValues { $0.description } ?? [:])) + self.logs.append( + Entry( + level: level, + message: message.description, + metadata: metadata?.mapValues { $0.description } ?? [:] + ) + ) } } } @@ -1077,16 +1297,20 @@ struct CollectEverythingLogHandler: LogHandler { self.logStore = logStore } - func log(level: Logger.Level, - message: Logger.Message, - metadata: Logger.Metadata?, - file: String, function: String, line: UInt) { + func log( + level: Logger.Level, + message: Logger.Message, + metadata: Logger.Metadata?, + file: String, + function: String, + line: UInt + ) { self.logStore.append(level: level, message: message, metadata: self.metadata.merging(metadata ?? [:]) { $1 }) } subscript(metadataKey key: String) -> Logger.Metadata.Value? { get { - return self.metadata[key] + self.metadata[key] } set { self.metadata[key] = newValue @@ -1242,7 +1466,10 @@ class HTTPEchoHandler: ChannelInboundHandler { let request = self.unwrapInboundIn(data) switch request { case .head(let requestHead): - context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: requestHead.headers))), promise: nil) + context.writeAndFlush( + self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: requestHead.headers))), + promise: nil + ) case .body(let bytes): context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(bytes))), promise: nil) case .end: @@ -1253,52 +1480,197 @@ class HTTPEchoHandler: ChannelInboundHandler { } } +final class HTTPEchoHeaders: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let request = self.unwrapInboundIn(data) + switch request { + case .head(let requestHead): + context.writeAndFlush( + self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: requestHead.headers))), + promise: nil + ) + case .body: + break + case .end: + context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenSuccess { + context.close(promise: nil) + } + } + } +} + +final class HTTP200DelayedHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + var pendingBodyParts: Int? + + init(bodyPartsBeforeResponse: Int) { + self.pendingBodyParts = bodyPartsBeforeResponse + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let request = self.unwrapInboundIn(data) + switch request { + case .head: + // Once we have received one response, all further requests are responded to immediately. + if self.pendingBodyParts == nil { + context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok))), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + case .body: + if let pendingBodyParts = self.pendingBodyParts { + if pendingBodyParts > 0 { + self.pendingBodyParts = pendingBodyParts - 1 + } else { + self.pendingBodyParts = nil + context.writeAndFlush( + self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok))), + promise: nil + ) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + } + case .end: + break + } + } +} + private let cert = """ ------BEGIN CERTIFICATE----- -MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1 -czAgFw0xODEwMzExNTU1MjJaGA8yMTE4MTAwNzE1NTUyMlowDTELMAkGA1UEBhMC -dXMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDiC+TGmbSP/nWWN1tj -yNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMisUdb -d3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZHud9 -+JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKzV3S8 -kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAVKcNR -9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO538ljg -dslnAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAFYhA7sw8odOsRO8/DUklBOjPnmn -a078oSumgPXXw6AgcoAJv/Qthjo6CCEtrjYfcA9jaBw9/Tii7mDmqDRS5c9ZPL8+ -NEPdHjFCFBOEvlL6uHOgw0Z9Wz+5yCXnJ8oNUEgc3H2NbbzJF6sMBXSPtFS2NOK8 -OsAI9OodMrDd6+lwljrmFoCCkJHDEfE637IcsbgFKkzhO/oNCRK6OrudG4teDahz -Au4LoEYwT730QKC/VQxxEVZobjn9/sTrq9CZlbPYHxX4fz6e00sX7H9i49vk9zQ5 -5qCm9ljhrQPSa42Q62PPE2BEEGSP2KBm0J+H3vlvCD6+SNc/nMZjrRmgjrI= ------END CERTIFICATE----- -""" + -----BEGIN CERTIFICATE----- + MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1 + czAgFw0xODEwMzExNTU1MjJaGA8yMTE4MTAwNzE1NTUyMlowDTELMAkGA1UEBhMC + dXMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDiC+TGmbSP/nWWN1tj + yNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMisUdb + d3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZHud9 + +JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKzV3S8 + kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAVKcNR + 9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO538ljg + dslnAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAFYhA7sw8odOsRO8/DUklBOjPnmn + a078oSumgPXXw6AgcoAJv/Qthjo6CCEtrjYfcA9jaBw9/Tii7mDmqDRS5c9ZPL8+ + NEPdHjFCFBOEvlL6uHOgw0Z9Wz+5yCXnJ8oNUEgc3H2NbbzJF6sMBXSPtFS2NOK8 + OsAI9OodMrDd6+lwljrmFoCCkJHDEfE637IcsbgFKkzhO/oNCRK6OrudG4teDahz + Au4LoEYwT730QKC/VQxxEVZobjn9/sTrq9CZlbPYHxX4fz6e00sX7H9i49vk9zQ5 + 5qCm9ljhrQPSa42Q62PPE2BEEGSP2KBm0J+H3vlvCD6+SNc/nMZjrRmgjrI= + -----END CERTIFICATE----- + """ private let key = """ ------BEGIN PRIVATE KEY----- -MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDiC+TGmbSP/nWW -N1tjyNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMi -sUdbd3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZ -Hud9+JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKz -V3S8kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAV -KcNR9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO53 -8ljgdslnAgMBAAECggEBANZNWFNAnYJ2R5xmVuo/GxFk68Ujd4i4TZpPYbhkk+QG -g8I0w5htlEQQkVHfZx2CpTvq8feuAH/YhlA5qeD5WaPwq26q5qsmyV6tQGDgb9lO -w85l6ySZDbwdVOJe2il/MSB6MclSKvTGNm59chJnfHYsmvY3HHq4qsc2F+tRKYMW -pY75LgEbaTUV69J3cbC1wAeVjv0q/krND+YkhYpTxNZhbazK/FHOCvY+zFu9fg0L -zpwbn5fb6wIvqG7tXp7koa3QMn64AXmO/fb5mBd8G2vBGYnxwb7Egwdg/3Dw+BXu -ynQLP7ixWsE2KNfR9Ce1i3YvEo6QDTv2340I3dntxkECgYEA9vdaL4PGyvEbpim4 -kqz1vuug8Iq0nTVDo6jmgH1o+XdcIbW3imXtgi5zUJpj4oDD7/4aufiJZjG64i/v -phe11xeUvh5QNNOzeMymVDoJut97F97KKKTv7bG8Rpon/WzH2I0SoAkECCwmdWAJ -H3nvOCnXEkpbCqmIUvHVURPRDn8CgYEA6lCk3EzFQlbXs3Sj5op61R3Mscx7/35A -eGv5axzbENHt1so+s3Zvyyi1bo4VBcwnKVCvQjmTuLiqrc9VfX8XdbiTUNnEr2u3 -992Ja6DEJTZ9gy5WiviwYnwU2HpjwOVNBb17T0NLoRHkDZ6iXj7NZgwizOki5p3j -/hS0pObSIRkCgYEAiEdOGNIarHoHy9VR6H5QzR2xHYssx2NRA8p8B4MsnhxjVqaz -tUcxnJiNQXkwjRiJBrGthdnD2ASxH4dcMsb6rMpyZcbMc5ouewZS8j9khx4zCqUB -4RPC4eMmBb+jOZEBZlnSYUUYWHokbrij0B61BsTvzUQCoQuUElEoaSkKP3kCgYEA -mwdqXHvK076jjo9w1drvtEu4IDc8H2oH++TsrEr2QiWzaDZ9z71f8BnqGNCW5jQS -AQrqOjXgIArGmqMgXB0Xh4LsrUS4Fpx9ptiD0JsYy8pGtuGUzvQFt9OC80ve7kSI -dnDMwj+zLUmqCrzXjuWcfpUu/UaPGeiDbZuDfcteYhkCgYBLyL5JY7Qd4gVQIhFX -7Sv3sNJN3KZCQHEzut7IwojaxgpuxiFvgsoXXuYolVCQp32oWbYcE2Yke+hOKsTE -sCMAWZiSGN2Nrfea730IYAXkUm8bpEd3VxDXEEv13nxVeQof+JGMdlkldFGaBRDU -oYQsPj00S3/GA9WDapwe81Wl2A== ------END PRIVATE KEY----- -""" + -----BEGIN PRIVATE KEY----- + MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDiC+TGmbSP/nWW + N1tjyNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMi + sUdbd3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZ + Hud9+JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKz + V3S8kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAV + KcNR9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO53 + 8ljgdslnAgMBAAECggEBANZNWFNAnYJ2R5xmVuo/GxFk68Ujd4i4TZpPYbhkk+QG + g8I0w5htlEQQkVHfZx2CpTvq8feuAH/YhlA5qeD5WaPwq26q5qsmyV6tQGDgb9lO + w85l6ySZDbwdVOJe2il/MSB6MclSKvTGNm59chJnfHYsmvY3HHq4qsc2F+tRKYMW + pY75LgEbaTUV69J3cbC1wAeVjv0q/krND+YkhYpTxNZhbazK/FHOCvY+zFu9fg0L + zpwbn5fb6wIvqG7tXp7koa3QMn64AXmO/fb5mBd8G2vBGYnxwb7Egwdg/3Dw+BXu + ynQLP7ixWsE2KNfR9Ce1i3YvEo6QDTv2340I3dntxkECgYEA9vdaL4PGyvEbpim4 + kqz1vuug8Iq0nTVDo6jmgH1o+XdcIbW3imXtgi5zUJpj4oDD7/4aufiJZjG64i/v + phe11xeUvh5QNNOzeMymVDoJut97F97KKKTv7bG8Rpon/WzH2I0SoAkECCwmdWAJ + H3nvOCnXEkpbCqmIUvHVURPRDn8CgYEA6lCk3EzFQlbXs3Sj5op61R3Mscx7/35A + eGv5axzbENHt1so+s3Zvyyi1bo4VBcwnKVCvQjmTuLiqrc9VfX8XdbiTUNnEr2u3 + 992Ja6DEJTZ9gy5WiviwYnwU2HpjwOVNBb17T0NLoRHkDZ6iXj7NZgwizOki5p3j + /hS0pObSIRkCgYEAiEdOGNIarHoHy9VR6H5QzR2xHYssx2NRA8p8B4MsnhxjVqaz + tUcxnJiNQXkwjRiJBrGthdnD2ASxH4dcMsb6rMpyZcbMc5ouewZS8j9khx4zCqUB + 4RPC4eMmBb+jOZEBZlnSYUUYWHokbrij0B61BsTvzUQCoQuUElEoaSkKP3kCgYEA + mwdqXHvK076jjo9w1drvtEu4IDc8H2oH++TsrEr2QiWzaDZ9z71f8BnqGNCW5jQS + AQrqOjXgIArGmqMgXB0Xh4LsrUS4Fpx9ptiD0JsYy8pGtuGUzvQFt9OC80ve7kSI + dnDMwj+zLUmqCrzXjuWcfpUu/UaPGeiDbZuDfcteYhkCgYBLyL5JY7Qd4gVQIhFX + 7Sv3sNJN3KZCQHEzut7IwojaxgpuxiFvgsoXXuYolVCQp32oWbYcE2Yke+hOKsTE + sCMAWZiSGN2Nrfea730IYAXkUm8bpEd3VxDXEEv13nxVeQof+JGMdlkldFGaBRDU + oYQsPj00S3/GA9WDapwe81Wl2A== + -----END PRIVATE KEY----- + """ + +final class BasicInboundTrafficShapingHandler: ChannelDuplexHandler { + typealias OutboundIn = ByteBuffer + typealias InboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + enum ReadState { + case flowingFreely + case pausing + case paused + + mutating func pause() { + switch self { + case .flowingFreely: + self = .pausing + case .pausing, .paused: + () // nothing to do + } + } + + mutating func unpause() -> Bool { + switch self { + case .flowingFreely: + return false // no extra `read` needed + case .pausing: + self = .flowingFreely + return false // no extra `read` needed + case .paused: + self = .flowingFreely + return true // yes, we need an extra read + } + } + + mutating func shouldRead() -> Bool { + switch self { + case .flowingFreely: + return true + case .pausing: + self = .paused + return false + case .paused: + return false + } + } + } + + private let targetBytesPerSecond: Int + private var currentSecondBytesSeen: Int = 0 + private var readState: ReadState = .flowingFreely + + init(targetBytesPerSecond: Int) { + self.targetBytesPerSecond = targetBytesPerSecond + } + + func evaluatePause(context: ChannelHandlerContext) { + if self.currentSecondBytesSeen >= self.targetBytesPerSecond { + self.readState.pause() + } else if self.currentSecondBytesSeen < self.targetBytesPerSecond { + if self.readState.unpause() { + context.read() + } + } + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let loopBoundContext = NIOLoopBound(context, eventLoop: context.eventLoop) + defer { + context.fireChannelRead(data) + } + let buffer = Self.unwrapInboundIn(data) + let byteCount = buffer.readableBytes + self.currentSecondBytesSeen += byteCount + context.eventLoop.scheduleTask(in: .seconds(1)) { + self.currentSecondBytesSeen -= byteCount + self.evaluatePause(context: loopBoundContext.value) + } + self.evaluatePause(context: context) + } + + func read(context: ChannelHandlerContext) { + if self.readState.shouldRead() { + context.read() + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift deleted file mode 100644 index 7eb532cf9..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ /dev/null @@ -1,141 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientTests { - static var allTests: [(String, (HTTPClientTests) -> () throws -> Void)] { - return [ - ("testRequestURI", testRequestURI), - ("testBadRequestURI", testBadRequestURI), - ("testSchemaCasing", testSchemaCasing), - ("testURLSocketPathInitializers", testURLSocketPathInitializers), - ("testBadUnixWithBaseURL", testBadUnixWithBaseURL), - ("testConvenienceExecuteMethods", testConvenienceExecuteMethods), - ("testConvenienceExecuteMethodsOverSocket", testConvenienceExecuteMethodsOverSocket), - ("testConvenienceExecuteMethodsOverSecureSocket", testConvenienceExecuteMethodsOverSecureSocket), - ("testGet", testGet), - ("testGetWithDifferentEventLoopBackpressure", testGetWithDifferentEventLoopBackpressure), - ("testPost", testPost), - ("testPostWithGenericBody", testPostWithGenericBody), - ("testPostWithFoundationDataBody", testPostWithFoundationDataBody), - ("testGetHttps", testGetHttps), - ("testGetHttpsWithIP", testGetHttpsWithIP), - ("testGetHTTPSWorksOnMTELGWithIP", testGetHTTPSWorksOnMTELGWithIP), - ("testGetHttpsWithIPv6", testGetHttpsWithIPv6), - ("testGetHTTPSWorksOnMTELGWithIPv6", testGetHTTPSWorksOnMTELGWithIPv6), - ("testPostHttps", testPostHttps), - ("testHttpRedirect", testHttpRedirect), - ("testHttpHostRedirect", testHttpHostRedirect), - ("testPercentEncoded", testPercentEncoded), - ("testPercentEncodedBackslash", testPercentEncodedBackslash), - ("testMultipleContentLengthHeaders", testMultipleContentLengthHeaders), - ("testStreaming", testStreaming), - ("testFileDownload", testFileDownload), - ("testFileDownloadError", testFileDownloadError), - ("testRemoteClose", testRemoteClose), - ("testReadTimeout", testReadTimeout), - ("testConnectTimeout", testConnectTimeout), - ("testDeadline", testDeadline), - ("testCancel", testCancel), - ("testStressCancel", testStressCancel), - ("testHTTPClientAuthorization", testHTTPClientAuthorization), - ("testProxyPlaintext", testProxyPlaintext), - ("testProxyTLS", testProxyTLS), - ("testProxyPlaintextWithCorrectlyAuthorization", testProxyPlaintextWithCorrectlyAuthorization), - ("testProxyPlaintextWithIncorrectlyAuthorization", testProxyPlaintextWithIncorrectlyAuthorization), - ("testUploadStreaming", testUploadStreaming), - ("testEventLoopArgument", testEventLoopArgument), - ("testDecompression", testDecompression), - ("testDecompressionLimit", testDecompressionLimit), - ("testLoopDetectionRedirectLimit", testLoopDetectionRedirectLimit), - ("testCountRedirectLimit", testCountRedirectLimit), - ("testRedirectToTheInitialURLDoesThrowOnFirstRedirect", testRedirectToTheInitialURLDoesThrowOnFirstRedirect), - ("testMultipleConcurrentRequests", testMultipleConcurrentRequests), - ("testWorksWith500Error", testWorksWith500Error), - ("testWorksWithHTTP10Response", testWorksWithHTTP10Response), - ("testWorksWhenServerClosesConnectionAfterReceivingRequest", testWorksWhenServerClosesConnectionAfterReceivingRequest), - ("testSubsequentRequestsWorkWithServerSendingConnectionClose", testSubsequentRequestsWorkWithServerSendingConnectionClose), - ("testSubsequentRequestsWorkWithServerAlternatingBetweenKeepAliveAndClose", testSubsequentRequestsWorkWithServerAlternatingBetweenKeepAliveAndClose), - ("testStressGetHttps", testStressGetHttps), - ("testStressGetHttpsSSLError", testStressGetHttpsSSLError), - ("testFailingConnectionIsReleased", testFailingConnectionIsReleased), - ("testResponseDelayGet", testResponseDelayGet), - ("testIdleTimeoutNoReuse", testIdleTimeoutNoReuse), - ("testStressGetClose", testStressGetClose), - ("testManyConcurrentRequestsWork", testManyConcurrentRequestsWork), - ("testRepeatedRequestsWorkWhenServerAlwaysCloses", testRepeatedRequestsWorkWhenServerAlwaysCloses), - ("testShutdownBeforeTasksCompletion", testShutdownBeforeTasksCompletion), - ("testUncleanShutdownActuallyShutsDown", testUncleanShutdownActuallyShutsDown), - ("testUncleanShutdownCancelsTasks", testUncleanShutdownCancelsTasks), - ("testDoubleShutdown", testDoubleShutdown), - ("testTaskFailsWhenClientIsShutdown", testTaskFailsWhenClientIsShutdown), - ("testRaceNewRequestsVsShutdown", testRaceNewRequestsVsShutdown), - ("testVaryingLoopPreference", testVaryingLoopPreference), - ("testMakeSecondRequestDuringCancelledCallout", testMakeSecondRequestDuringCancelledCallout), - ("testMakeSecondRequestDuringSuccessCallout", testMakeSecondRequestDuringSuccessCallout), - ("testMakeSecondRequestWhilstFirstIsOngoing", testMakeSecondRequestWhilstFirstIsOngoing), - ("testUDSBasic", testUDSBasic), - ("testUDSSocketAndPath", testUDSSocketAndPath), - ("testHTTPPlusUNIX", testHTTPPlusUNIX), - ("testHTTPSPlusUNIX", testHTTPSPlusUNIX), - ("testUseExistingConnectionOnDifferentEL", testUseExistingConnectionOnDifferentEL), - ("testWeRecoverFromServerThatClosesTheConnectionOnUs", testWeRecoverFromServerThatClosesTheConnectionOnUs), - ("testPoolClosesIdleConnections", testPoolClosesIdleConnections), - ("testRacePoolIdleConnectionsAndGet", testRacePoolIdleConnectionsAndGet), - ("testAvoidLeakingTLSHandshakeCompletionPromise", testAvoidLeakingTLSHandshakeCompletionPromise), - ("testAsyncShutdown", testAsyncShutdown), - ("testAsyncShutdownDefaultQueue", testAsyncShutdownDefaultQueue), - ("testValidationErrorsAreSurfaced", testValidationErrorsAreSurfaced), - ("testUploadsReallyStream", testUploadsReallyStream), - ("testUploadStreamingCallinToleratedFromOtsideEL", testUploadStreamingCallinToleratedFromOtsideEL), - ("testWeHandleUsSendingACloseHeaderCorrectly", testWeHandleUsSendingACloseHeaderCorrectly), - ("testWeHandleUsReceivingACloseHeaderCorrectly", testWeHandleUsReceivingACloseHeaderCorrectly), - ("testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly), - ("testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly), - ("testLoggingCorrectlyAttachesRequestInformationEvenAfterDuringRedirect", testLoggingCorrectlyAttachesRequestInformationEvenAfterDuringRedirect), - ("testLoggingCorrectlyAttachesRequestInformation", testLoggingCorrectlyAttachesRequestInformation), - ("testNothingIsLoggedAtInfoOrHigher", testNothingIsLoggedAtInfoOrHigher), - ("testAllMethodsLog", testAllMethodsLog), - ("testClosingIdleConnectionsInPoolLogsInTheBackground", testClosingIdleConnectionsInPoolLogsInTheBackground), - ("testUploadStreamingNoLength", testUploadStreamingNoLength), - ("testConnectErrorPropagatedToDelegate", testConnectErrorPropagatedToDelegate), - ("testDelegateCallinsTolerateRandomEL", testDelegateCallinsTolerateRandomEL), - ("testContentLengthTooLongFails", testContentLengthTooLongFails), - ("testContentLengthTooShortFails", testContentLengthTooShortFails), - ("testBodyUploadAfterEndFails", testBodyUploadAfterEndFails), - ("testNoBytesSentOverBodyLimit", testNoBytesSentOverBodyLimit), - ("testDoubleError", testDoubleError), - ("testSSLHandshakeErrorPropagation", testSSLHandshakeErrorPropagation), - ("testSSLHandshakeErrorPropagationDelayedClose", testSSLHandshakeErrorPropagationDelayedClose), - ("testWeCloseConnectionsWhenConnectionCloseSetByServer", testWeCloseConnectionsWhenConnectionCloseSetByServer), - ("testBiDirectionalStreaming", testBiDirectionalStreaming), - ("testSynchronousHandshakeErrorReporting", testSynchronousHandshakeErrorReporting), - ("testFileDownloadChunked", testFileDownloadChunked), - ("testCloseWhileBackpressureIsExertedIsFine", testCloseWhileBackpressureIsExertedIsFine), - ("testErrorAfterCloseWhileBackpressureExerted", testErrorAfterCloseWhileBackpressureExerted), - ("testRequestSpecificTLS", testRequestSpecificTLS), - ("testConnectionPoolSizeConfigValueIsRespected", testConnectionPoolSizeConfigValueIsRespected), - ("testRequestWithHeaderTransferEncodingIdentityDoesNotFail", testRequestWithHeaderTransferEncodingIdentityDoesNotFail), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 6bb4dd9b4..546d1c3f4 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -12,15 +12,15 @@ // //===----------------------------------------------------------------------===// -/* NOT @testable */ import AsyncHTTPClient // Tests that need @testable go into HTTPClientInternalTests.swift -#if canImport(Network) -import Network -#endif +import AsyncHTTPClient // NOT @testable - tests that need @testable go into HTTPClientInternalTests.swift +import Atomics import Logging import NIOConcurrencyHelpers import NIOCore +import NIOEmbedded import NIOFoundationCompat import NIOHTTP1 +import NIOHTTP2 import NIOHTTPCompression import NIOPosix import NIOSSL @@ -28,60 +28,11 @@ import NIOTestUtils import NIOTransportServices import XCTest -class HTTPClientTests: XCTestCase { - typealias Request = HTTPClient.Request - - var clientGroup: EventLoopGroup! - var serverGroup: EventLoopGroup! - var defaultHTTPBin: HTTPBin! - var defaultClient: HTTPClient! - var backgroundLogStore: CollectEverythingLogHandler.LogStore! - - var defaultHTTPBinURLPrefix: String { - return "http://localhost:\(self.defaultHTTPBin.port)/" - } - - override func setUp() { - XCTAssertNil(self.clientGroup) - XCTAssertNil(self.serverGroup) - XCTAssertNil(self.defaultHTTPBin) - XCTAssertNil(self.defaultClient) - XCTAssertNil(self.backgroundLogStore) - - self.clientGroup = getDefaultEventLoopGroup(numberOfThreads: 1) - self.serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - self.defaultHTTPBin = HTTPBin() - self.backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: self.backgroundLogStore!) - }) - backgroundLogger.logLevel = .trace - self.defaultClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - backgroundActivityLogger: backgroundLogger) - } - - override func tearDown() { - if let defaultClient = self.defaultClient { - XCTAssertNoThrow(try defaultClient.syncShutdown()) - self.defaultClient = nil - } - - XCTAssertNotNil(self.defaultHTTPBin) - XCTAssertNoThrow(try self.defaultHTTPBin.shutdown()) - self.defaultHTTPBin = nil - - XCTAssertNotNil(self.clientGroup) - XCTAssertNoThrow(try self.clientGroup.syncShutdownGracefully()) - self.clientGroup = nil - - XCTAssertNotNil(self.serverGroup) - XCTAssertNoThrow(try self.serverGroup.syncShutdownGracefully()) - self.serverGroup = nil - - XCTAssertNotNil(self.backgroundLogStore) - self.backgroundLogStore = nil - } +#if canImport(Network) +import Network +#endif +final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testRequestURI() throws { let request1 = try Request(url: "https://someserver.com:8888/some/path?foo=bar") XCTAssertEqual(request1.url.host, "someserver.com") @@ -94,8 +45,12 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(request2.url.path, "") let request3 = try Request(url: "unix:///tmp/file") - XCTAssertNil(request3.url.host) XCTAssertEqual(request3.host, "") + #if os(Linux) && compiler(>=6.0) + XCTAssertEqual(request3.url.host, "") + #else + XCTAssertNil(request3.url.host) + #endif XCTAssertEqual(request3.url.path, "/tmp/file") XCTAssertEqual(request3.port, 80) XCTAssertFalse(request3.useTLS) @@ -169,7 +124,10 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(url.scheme, "http+unix") XCTAssertEqual(url.host, "/tmp/file with spacesと漢字") XCTAssertEqual(url.path, "/file/path") - XCTAssertEqual(url.absoluteString, "http+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path") + XCTAssertEqual( + url.absoluteString, + "http+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path" + ) } let url5 = URL(httpsURLWithSocketPath: "/tmp/file") @@ -205,14 +163,11 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(url.scheme, "https+unix") XCTAssertEqual(url.host, "/tmp/file with spacesと漢字") XCTAssertEqual(url.path, "/file/path") - XCTAssertEqual(url.absoluteString, "https+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path") + XCTAssertEqual( + url.absoluteString, + "https+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path" + ) } - - let url9 = URL(httpURLWithSocketPath: "/tmp/file", uri: " ") - XCTAssertNil(url9) - - let url10 = URL(httpsURLWithSocketPath: "/tmp/file", uri: " ") - XCTAssertNil(url10) } func testBadUnixWithBaseURL() { @@ -224,55 +179,116 @@ class HTTPClientTests: XCTestCase { } func testConvenienceExecuteMethods() throws { - XCTAssertEqual(["GET"[...]], - try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["POST"[...]], - try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["PATCH"[...]], - try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["PUT"[...]], - try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["DELETE"[...]], - try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["GET"[...]], - try self.defaultClient.execute(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["CHECKOUT"[...]], - try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["POST"[...]], + try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["PATCH"[...]], + try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["PUT"[...]], + try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["DELETE"[...]], + try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.execute(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["CHECKOUT"[...]], + try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) } func testConvenienceExecuteMethodsOverSocket() throws { - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(["GET"[...]], - try self.defaultClient.execute(socketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["GET"[...]], - try self.defaultClient.execute(.GET, socketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["POST"[...]], - try self.defaultClient.execute(.POST, socketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - }) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.execute(socketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.execute(.GET, socketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["POST"[...]], + try self.defaultClient.execute(.POST, socketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + } + ) } func testConvenienceExecuteMethodsOverSecureSocket() throws { - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true, compress: false), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localSocketPathHTTPBin = HTTPBin( + .http1_1(ssl: true, compress: false), + bindTarget: .unixDomainSocket(path) + ) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(["GET"[...]], - try localClient.execute(secureSocketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["GET"[...]], - try localClient.execute(.GET, secureSocketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["POST"[...]], - try localClient.execute(.POST, secureSocketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - }) + XCTAssertEqual( + ["GET"[...]], + try localClient.execute(secureSocketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["GET"[...]], + try localClient.execute(.GET, secureSocketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["POST"[...]], + try localClient.execute(.POST, secureSocketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + } + ) } func testGet() throws { @@ -288,7 +304,8 @@ class HTTPClientTests: XCTestCase { } func testPost() throws { - let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .string("1234")).wait() + let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .string("1234")) + .wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } let data = try JSONDecoder().decode(RequestInfo.self, from: bytes!) @@ -300,7 +317,8 @@ class HTTPClientTests: XCTestCase { let bodyData = Array("hello, world!").lazy.map { $0.uppercased().first!.asciiValue! } let erasedData = AnyRandomAccessCollection(bodyData) - let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .bytes(erasedData)).wait() + let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .bytes(erasedData)) + .wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } let data = try JSONDecoder().decode(RequestInfo.self, from: bytes!) @@ -311,7 +329,8 @@ class HTTPClientTests: XCTestCase { func testPostWithFoundationDataBody() throws { let bodyData = Data("hello, world!".utf8) - let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .data(bodyData)).wait() + let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .data(bodyData)) + .wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } let data = try JSONDecoder().decode(RequestInfo.self, from: bytes!) @@ -321,8 +340,10 @@ class HTTPClientTests: XCTestCase { func testGetHttps() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -334,8 +355,10 @@ class HTTPClientTests: XCTestCase { func testGetHttpsWithIP() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -353,8 +376,10 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try group.syncShutdownGracefully()) } let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(group), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(group), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -367,8 +392,10 @@ class HTTPClientTests: XCTestCase { func testGetHttpsWithIPv6() throws { try XCTSkipUnless(canBindIPv6Loopback, "Requires IPv6") let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .localhostIPv6RandomPort) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -387,8 +414,10 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try group.syncShutdownGracefully()) } let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .localhostIPv6RandomPort) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(group), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(group), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -400,14 +429,20 @@ class HTTPClientTests: XCTestCase { func testPostHttps() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) } - let request = try Request(url: "https://localhost:\(localHTTPBin.port)/post", method: .POST, body: .string("1234")) + let request = try Request( + url: "https://localhost:\(localHTTPBin.port)/post", + method: .POST, + body: .string("1234") + ) let response = try localClient.execute(request: request).wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } @@ -419,8 +454,13 @@ class HTTPClientTests: XCTestCase { func testHttpRedirect() throws { let httpsBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 10, allowCycles: true) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -430,102 +470,189 @@ class HTTPClientTests: XCTestCase { var response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/302").wait() XCTAssertEqual(response.status, .ok) - response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/https?port=\(httpsBin.port)").wait() + response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/https?port=\(httpsBin.port)") + .wait() XCTAssertEqual(response.status, .ok) - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpSocketPath in - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpsSocketPath in - let socketHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(httpSocketPath)) - let socketHTTPSBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(httpsSocketPath)) - defer { - XCTAssertNoThrow(try socketHTTPBin.shutdown()) - XCTAssertNoThrow(try socketHTTPSBin.shutdown()) - } - - // From HTTP or HTTPS to HTTP+UNIX should fail to redirect - var targetURL = "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - var request = try Request(url: self.defaultHTTPBinURLPrefix + "redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - var response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - request = try Request(url: "https://localhost:\(httpsBin.port)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - // From HTTP or HTTPS to HTTPS+UNIX should also fail to redirect - targetURL = "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: self.defaultHTTPBinURLPrefix + "redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - request = try Request(url: "https://localhost:\(httpsBin.port)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - // ... while HTTP+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed - targetURL = self.defaultHTTPBinURLPrefix + "ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https://localhost:\(httpsBin.port)/ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - // ... and HTTPS+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed - targetURL = self.defaultHTTPBinURLPrefix + "ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https://localhost:\(httpsBin.port)/ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpSocketPath in + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpsSocketPath in + let socketHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(httpSocketPath)) + let socketHTTPSBin = HTTPBin( + .http1_1(ssl: true), + bindTarget: .unixDomainSocket(httpsSocketPath) + ) + defer { + XCTAssertNoThrow(try socketHTTPBin.shutdown()) + XCTAssertNoThrow(try socketHTTPSBin.shutdown()) + } - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - }) - }) + // From HTTP or HTTPS to HTTP+UNIX should fail to redirect + var targetURL = + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + var request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + var response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + + request = try Request( + url: "https://localhost:\(httpsBin.port)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + + // From HTTP or HTTPS to HTTPS+UNIX should also fail to redirect + targetURL = + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + + request = try Request( + url: "https://localhost:\(httpsBin.port)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + + // ... while HTTP+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed + targetURL = self.defaultHTTPBinURLPrefix + "ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + + targetURL = "https://localhost:\(httpsBin.port)/ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + + targetURL = + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + + targetURL = + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + + // ... and HTTPS+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed + targetURL = self.defaultHTTPBinURLPrefix + "ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + + targetURL = "https://localhost:\(httpsBin.port)/ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + + targetURL = + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + + targetURL = + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + } + ) + } + ) } func testHttpHostRedirect() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 10, allowCycles: true) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -552,12 +679,37 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(.ok, response.status) } + func testLeadingSlashRelativeURL() throws { + let noLeadingSlashURL = URL( + string: "percent%2Fencoded/hello", + relativeTo: URL(string: self.defaultHTTPBinURLPrefix)! + )! + let withLeadingSlashURL = URL( + string: "/percent%2Fencoded/hello", + relativeTo: URL(string: self.defaultHTTPBinURLPrefix)! + )! + + let noLeadingSlashURLRequest = try HTTPClient.Request(url: noLeadingSlashURL, method: .GET) + let withLeadingSlashURLRequest = try HTTPClient.Request(url: withLeadingSlashURL, method: .GET) + + let noLeadingSlashURLResponse = try self.defaultClient.execute(request: noLeadingSlashURLRequest).wait() + let withLeadingSlashURLResponse = try self.defaultClient.execute(request: withLeadingSlashURLRequest).wait() + + XCTAssertEqual(noLeadingSlashURLResponse.status, .ok) + XCTAssertEqual(withLeadingSlashURLResponse.status, .ok) + } + func testMultipleContentLengthHeaders() throws { let body = ByteBuffer(string: "hello world!") var headers = HTTPHeaders() headers.add(name: "Content-Length", value: "12") - let request = try Request(url: self.defaultHTTPBinURLPrefix + "post", method: .POST, headers: headers, body: .byteBuffer(body)) + let request = try Request( + url: self.defaultHTTPBinURLPrefix + "post", + method: .POST, + headers: headers, + body: .byteBuffer(body) + ) let response = try self.defaultClient.execute(request: request).wait() // if the library adds another content length header we'll get a bad request error. XCTAssertEqual(.ok, response.status) @@ -602,9 +754,12 @@ class HTTPClientTests: XCTestCase { let progress = try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Progress in - let delegate = try FileDownloadDelegate(path: path, reportHead: { - XCTAssertEqual($0.status, .notFound) - }) + let delegate = try FileDownloadDelegate( + path: path, + reportHead: { + XCTAssertEqual($0.status, .notFound) + } + ) let progress = try self.defaultClient.execute( request: request, @@ -621,6 +776,35 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(0, progress.receivedBytes) } + func testFileDownloadCustomError() throws { + let request = try Request(url: self.defaultHTTPBinURLPrefix + "get") + struct CustomError: Equatable, Error {} + + try TemporaryFileHelpers.withTemporaryFilePath { path in + let delegate = try FileDownloadDelegate( + path: path, + reportHead: { task, head in + XCTAssertEqual(head.status, .ok) + task.fail(reason: CustomError()) + }, + reportProgress: { _, _ in + XCTFail("should never be called") + } + ) + XCTAssertThrowsError( + try self.defaultClient.execute( + request: request, + delegate: delegate + ) + .wait() + ) { error in + XCTAssertEqualTypeAndValue(error, CustomError()) + } + + XCTAssertFalse(TemporaryFileHelpers.fileExists(path: path)) + } + } + func testRemoteClose() { XCTAssertThrowsError(try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "close").wait()) { XCTAssertEqual($0 as? HTTPClientError, .remoteConnectionClosed) @@ -628,8 +812,10 @@ class HTTPClientTests: XCTestCase { } func testReadTimeout() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(timeout: HTTPClient.Configuration.Timeout(read: .milliseconds(150)))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(timeout: HTTPClient.Configuration.Timeout(read: .milliseconds(150))) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -640,22 +826,89 @@ class HTTPClientTests: XCTestCase { } } - func testConnectTimeout() { - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150)))) + func testWriteTimeout() throws { + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(timeout: HTTPClient.Configuration.Timeout(write: .nanoseconds(10))) + ) + + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + } + + // Create a request that writes a chunk, then waits longer than the configured write timeout, + // and then writes again. This should trigger a write timeout error. + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "post", + method: .POST, + headers: ["transfer-encoding": "chunked"], + body: .stream { streamWriter in + _ = streamWriter.write(.byteBuffer(.init())) + + let promise = self.clientGroup.next().makePromise(of: Void.self) + self.clientGroup.next().scheduleTask(in: .milliseconds(3)) { + streamWriter.write(.byteBuffer(.init())).cascade(to: promise) + } + + return promise.futureResult + } + ) + + XCTAssertThrowsError(try localClient.execute(request: request).wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testConnectTimeout() throws { + #if os(Linux) + // 198.51.100.254 is reserved for documentation only and therefore should not accept any TCP connection + let url = "http://198.51.100.254/get" + #else + // on macOS we can use the TCP backlog behaviour when the queue is full to simulate a non reachable server. + // this makes this test a bit more stable if `198.51.100.254` actually responds to connection attempt. + // The backlog behaviour on Linux can not be used to simulate a non-reachable server. + // Linux sends a `SYN/ACK` back even if the `backlog` queue is full as it has two queues. + // The second queue is not limit by `ChannelOptions.backlog` but by `/proc/sys/net/ipv4/tcp_max_syn_backlog`. + + let serverChannel = try ServerBootstrap(group: self.serverGroup) + .serverChannelOption(ChannelOptions.backlog, value: 1) + .serverChannelOption(ChannelOptions.autoRead, value: false) + .bind(host: "127.0.0.1", port: 0) + .wait() + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + let port = serverChannel.localAddress!.port! + let firstClientChannel = try ClientBootstrap(group: self.serverGroup) + .connect(host: "127.0.0.1", port: port) + .wait() + defer { + XCTAssertNoThrow(try firstClientChannel.close().wait()) + } + let url = "http://localhost:\(port)/get" + #endif + + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150))) + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } - // This must throw as 198.51.100.254 is reserved for documentation only - XCTAssertThrowsError(try httpClient.get(url: "http://198.51.100.254/get").wait()) { - XCTAssertEqual($0 as? HTTPClientError, .connectTimeout) + XCTAssertThrowsError(try httpClient.get(url: url).wait()) { + XCTAssertEqualTypeAndValue($0, HTTPClientError.connectTimeout) } } func testDeadline() { - XCTAssertThrowsError(try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "wait", deadline: .now() + .milliseconds(150)).wait()) { + XCTAssertThrowsError( + try self.defaultClient.get( + url: self.defaultHTTPBinURLPrefix + "wait", + deadline: .now() + .milliseconds(150) + ).wait() + ) { XCTAssertEqual($0 as? HTTPClientError, .deadlineExceeded) } } @@ -741,7 +994,13 @@ class HTTPClientTests: XCTestCase { let localHTTPBin = HTTPBin(proxy: .simulate(authorization: "Basic YWxhZGRpbjpvcGVuc2VzYW1l")) let localClient = HTTPClient( eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .server(host: "localhost", port: localHTTPBin.port, authorization: .basic(username: "aladdin", password: "opensesame"))) + configuration: .init( + proxy: .server( + host: "localhost", + port: localHTTPBin.port, + authorization: .basic(username: "aladdin", password: "opensesame") + ) + ) ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -753,11 +1012,19 @@ class HTTPClientTests: XCTestCase { func testProxyPlaintextWithIncorrectlyAuthorization() throws { let localHTTPBin = HTTPBin(proxy: .simulate(authorization: "Basic YWxhZGRpbjpvcGVuc2VzYW1l")) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .server(host: "localhost", - port: localHTTPBin.port, - authorization: .basic(username: "aladdin", - password: "opensesamefoo")))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init( + proxy: .server( + host: "localhost", + port: localHTTPBin.port, + authorization: .basic( + username: "aladdin", + password: "opensesamefoo" + ) + ) + ).enableFastFailureModeForTesting() + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -770,7 +1037,7 @@ class HTTPClientTests: XCTestCase { } func testUploadStreaming() throws { - let body: HTTPClient.Body = .stream(length: 8) { writer in + let body: HTTPClient.Body = .stream(contentLength: 8) { writer in let buffer = ByteBuffer(string: "1234") return writer.write(.byteBuffer(buffer)).flatMap { let buffer = ByteBuffer(string: "4321") @@ -787,8 +1054,10 @@ class HTTPClientTests: XCTestCase { } func testEventLoopArgument() throws { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(redirectConfiguration: .follow(max: 10, allowCycles: true)) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } @@ -809,26 +1078,33 @@ class HTTPClientTests: XCTestCase { } func didFinishRequest(task: HTTPClient.Task) throws -> Bool { - return self.result + self.result } } let eventLoop = self.clientGroup.next() let delegate = EventLoopValidatingDelegate(eventLoop: eventLoop) var request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get") - var response = try localClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)).wait() + var response = try localClient.execute( + request: request, + delegate: delegate, + eventLoop: .delegate(on: eventLoop) + ).wait() XCTAssertEqual(true, response) // redirect request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "redirect/302") - response = try localClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)).wait() + response = try localClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)) + .wait() XCTAssertEqual(true, response) } func testDecompression() throws { let localHTTPBin = HTTPBin(.http1_1(compress: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(decompression: .enabled(limit: .none))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(decompression: .enabled(limit: .none)) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -837,7 +1113,8 @@ class HTTPClientTests: XCTestCase { var body = "" for _ in 1...1000 { - body += "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + body += + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." } for algorithm in [nil, "gzip", "deflate"] { @@ -862,9 +1139,56 @@ class HTTPClientTests: XCTestCase { } } + func testDecompressionHTTP2() throws { + let localHTTPBin = HTTPBin(.http2(compress: true)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init( + certificateVerification: .none, + decompression: .enabled(limit: .none) + ) + ) + + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + + var body = "" + for _ in 1...1000 { + body += + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + } + + for algorithm: String? in [nil] { + var request = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/post", method: .POST) + request.body = .string(body) + if let algorithm = algorithm { + request.headers.add(name: "Accept-Encoding", value: algorithm) + } + + let response = try localClient.execute(request: request).wait() + var responseBody = try XCTUnwrap(response.body) + let data = try responseBody.readJSONDecodable(RequestInfo.self, length: responseBody.readableBytes) + + XCTAssertEqual(.ok, response.status) + let contentLength = try XCTUnwrap(response.headers["Content-Length"].first.flatMap { Int($0) }) + XCTAssertGreaterThan(body.count, contentLength) + if let algorithm = algorithm { + XCTAssertEqual(algorithm, response.headers["Content-Encoding"].first) + } else { + XCTAssertEqual("deflate", response.headers["Content-Encoding"].first) + } + XCTAssertEqual(body, data?.data) + } + } + func testDecompressionLimit() throws { let localHTTPBin = HTTPBin(.http1_1(compress: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(decompression: .enabled(limit: .ratio(1)))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(decompression: .enabled(limit: .ratio(1))) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -882,30 +1206,47 @@ class HTTPClientTests: XCTestCase { func testLoopDetectionRedirectLimit() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 5, allowCycles: false))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 5, allowCycles: false) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) } - XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").wait(), "Should fail with redirect limit") { error in + XCTAssertThrowsError( + try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").wait(), + "Should fail with redirect limit" + ) { error in XCTAssertEqual(error as? HTTPClientError, HTTPClientError.redirectCycleDetected) } } func testCountRedirectLimit() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 10, allowCycles: true) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) } - XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").timeout(after: .seconds(10)).wait()) { error in + XCTAssertThrowsError( + try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").timeout( + after: .seconds(10) + ).wait() + ) { error in XCTAssertEqual(error as? HTTPClientError, HTTPClientError.redirectLimitReached) } } @@ -923,13 +1264,15 @@ class HTTPClientTests: XCTestCase { defer { XCTAssertNoThrow(try localClient.syncShutdown()) } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "https://localhost:\(localHTTPBin.port)/redirect/target", - method: .GET, - headers: [ - "X-Target-Redirect-URL": "/redirect/target", - ] - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/redirect/target", + method: .GET, + headers: [ + "X-Target-Redirect-URL": "/redirect/target" + ] + ) + ) guard let request = maybeRequest else { return } XCTAssertThrowsError( @@ -949,8 +1292,12 @@ class HTTPClientTests: XCTestCase { func channelRead(context: ChannelHandlerContext, data: NIOAny) { if case .end = self.unwrapInboundIn(data) { - let responseHead = HTTPServerResponsePart.head(.init(version: .init(major: 1, minor: 1), - status: .ok)) + let responseHead = HTTPServerResponsePart.head( + .init( + version: .init(major: 1, minor: 1), + status: .ok + ) + ) context.write(self.wrapOutboundOut(responseHead), promise: nil) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } @@ -963,27 +1310,31 @@ class HTTPClientTests: XCTestCase { } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) - .serverChannelOption(ChannelOptions.backlog, value: .init(numberOfParallelWorkers)) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, - withServerUpgrade: nil, - withErrorHandling: false).flatMap { - channel.pipeline.addHandler(HTTPServer()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) + .serverChannelOption(ChannelOptions.backlog, value: .init(numberOfParallelWorkers)) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline( + withPipeliningAssistance: false, + withServerUpgrade: nil, + withErrorHandling: false + ).flatMap { + channel.pipeline.addHandler(HTTPServer()) + } } - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } + let url = "http://127.0.0.1:\(server?.localAddress?.port ?? -1)/hello" let g = DispatchGroup() for workerID in 0..]() - for _ in 1...requestCount { - let req = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, headers: ["X-internal-delay": "100"]) - futureResults.append(localClient.execute(request: req)) } - XCTAssertNoThrow(try EventLoopFuture.andAllSucceed(futureResults, on: eventLoop).wait()) - } - func testStressGetHttpsSSLError() throws { let request = try Request(url: "https://localhost:\(self.defaultHTTPBin.port)/wait", method: .GET) let tasks = (1...100).map { _ -> HTTPClient.Task in - self.defaultClient.execute(request: request, delegate: TestHTTPDelegate()) + localClient.execute(request: request, delegate: TestHTTPDelegate()) } - let results = try EventLoopFuture.whenAllComplete(tasks.map { $0.futureResult }, on: self.defaultClient.eventLoopGroup.next()).wait() + let results = try EventLoopFuture.whenAllComplete( + tasks.map { $0.futureResult }, + on: localClient.eventLoopGroup.next() + ).wait() for result in results { switch result { @@ -1179,9 +1571,10 @@ class HTTPClientTests: XCTestCase { // We're speaking TLS to a plain text server. This will cause the handshake to fail but given // that the bytes "HTTP/1.1" aren't the start of a valid TLS packet, we can also get // errSSLPeerProtocolVersion because the first bytes contain the version. - XCTAssert(clientError.status == errSSLHandshakeFail || - clientError.status == errSSLPeerProtocolVersion, - "unexpected NWTLSError with status \(clientError.status)") + XCTAssert( + clientError.status == errSSLHandshakeFail || clientError.status == errSSLPeerProtocolVersion, + "unexpected NWTLSError with status \(clientError.status)" + ) #endif } else { guard let clientError = error as? NIOSSLError, case NIOSSLError.handshakeFailed = clientError else { @@ -1193,6 +1586,97 @@ class HTTPClientTests: XCTestCase { } } + func testSelfSignedCertificateIsRejectedWithCorrectError() throws { + /// key + cert was created with the follwing command: + /// openssl req -x509 -newkey rsa:4096 -keyout self_signed_key.pem -out self_signed_cert.pem -sha256 -days 99999 -nodes -subj '/CN=localhost' + let certPath = Bundle.module.path(forResource: "self_signed_cert", ofType: "pem")! + let keyPath = Bundle.module.path(forResource: "self_signed_key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) + let configuration = try TLSConfiguration.makeServerConfiguration( + certificateChain: NIOSSLCertificate.fromPEMFile(certPath).map { .certificate($0) }, + privateKey: .privateKey(key) + ) + let sslContext = try NIOSSLContext(configuration: configuration) + + let server = ServerBootstrap(group: serverGroup) + .childChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + } + let serverChannel = try server.bind(host: "localhost", port: 0).wait() + defer { XCTAssertNoThrow(try serverChannel.close().wait()) } + let port = serverChannel.localAddress!.port! + + let config = HTTPClient.Configuration().enableFastFailureModeForTesting() + + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: config) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(port)").wait()) { error in + #if canImport(Network) + guard let nwTLSError = error as? HTTPClient.NWTLSError else { + XCTFail("could not cast \(error) of type \(type(of: error)) to \(HTTPClient.NWTLSError.self)") + return + } + XCTAssertEqual(nwTLSError.status, errSSLBadCert, "unexpected tls error: \(nwTLSError)") + #else + guard let sslError = error as? NIOSSLError, + case .handshakeFailed(.sslError) = sslError + else { + XCTFail("unexpected error \(error)") + return + } + #endif + } + } + + func testSelfSignedCertificateIsRejectedWithCorrectErrorIfRequestDeadlineIsExceeded() throws { + /// key + cert was created with the follwing command: + /// openssl req -x509 -newkey rsa:4096 -keyout self_signed_key.pem -out self_signed_cert.pem -sha256 -days 99999 -nodes -subj '/CN=localhost' + let certPath = Bundle.module.path(forResource: "self_signed_cert", ofType: "pem")! + let keyPath = Bundle.module.path(forResource: "self_signed_key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) + let configuration = try TLSConfiguration.makeServerConfiguration( + certificateChain: NIOSSLCertificate.fromPEMFile(certPath).map { .certificate($0) }, + privateKey: .privateKey(key) + ) + let sslContext = try NIOSSLContext(configuration: configuration) + + let server = ServerBootstrap(group: serverGroup) + .childChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + } + let serverChannel = try server.bind(host: "localhost", port: 0).wait() + defer { XCTAssertNoThrow(try serverChannel.close().wait()) } + let port = serverChannel.localAddress!.port! + + let config = HTTPClient.Configuration().enableFastFailureModeForTesting() + + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: config) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + + XCTAssertThrowsError( + try localClient.get(url: "https://localhost:\(port)", deadline: .now() + .seconds(2)).wait() + ) { error in + #if canImport(Network) + guard let nwTLSError = error as? HTTPClient.NWTLSError else { + XCTFail("could not cast \(error) of type \(type(of: error)) to \(HTTPClient.NWTLSError.self)") + return + } + XCTAssertEqual(nwTLSError.status, errSSLBadCert, "unexpected tls error: \(nwTLSError)") + #else + guard let sslError = error as? NIOSSLError, + case .handshakeFailed(.sslError) = sslError + else { + XCTFail("unexpected error \(error)") + return + } + #endif + } + } + func testFailingConnectionIsReleased() { let localHTTPBin = HTTPBin(.refuse) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) @@ -1211,37 +1695,22 @@ class HTTPClientTests: XCTestCase { } } - func testResponseDelayGet() throws { - let req = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", - method: .GET, - headers: ["X-internal-delay": "2000"], - body: nil) - let start = Date() - let response = try! self.defaultClient.execute(request: req).wait() - XCTAssertGreaterThan(Date().timeIntervalSince(start), 2) - XCTAssertEqual(response.status, .ok) - } - - func testIdleTimeoutNoReuse() throws { - var req = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .GET) - XCTAssertNoThrow(try self.defaultClient.execute(request: req, deadline: .now() + .seconds(2)).wait()) - req.headers.add(name: "X-internal-delay", value: "2500") - try self.defaultClient.eventLoopGroup.next().scheduleTask(in: .milliseconds(250)) {}.futureResult.wait() - XCTAssertNoThrow(try self.defaultClient.execute(request: req).timeout(after: .seconds(10)).wait()) - } - func testStressGetClose() throws { let eventLoop = self.defaultClient.eventLoopGroup.next() let requestCount = 200 var futureResults = [EventLoopFuture]() for _ in 1...requestCount { - let req = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", - method: .GET, - headers: ["X-internal-delay": "5", "Connection": "close"]) + let req = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .GET, + headers: ["X-internal-delay": "5", "Connection": "close"] + ) futureResults.append(self.defaultClient.execute(request: req)) } - XCTAssertNoThrow(try EventLoopFuture.andAllComplete(futureResults, on: eventLoop) - .timeout(after: .seconds(10)).wait()) + XCTAssertNoThrow( + try EventLoopFuture.andAllComplete(futureResults, on: eventLoop) + .timeout(after: .seconds(10)).wait() + ) } func testManyConcurrentRequestsWork() { @@ -1258,8 +1727,8 @@ class HTTPClientTests: XCTestCase { let q = DispatchQueue(label: "worker \(w)") q.async(group: allDone) { func go() { - allWorkersReady.signal() // tell the driver we're ready - allWorkersGo.wait() // wait for the driver to let us go + allWorkersReady.signal() // tell the driver we're ready + allWorkersGo.wait() // wait for the driver to let us go for _ in 0..]() for i in 1...100 { - let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .GET, headers: ["X-internal-delay": "10"]) + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .GET, + headers: ["X-internal-delay": "10"] + ) let preference: HTTPClient.EventLoopPreference if i <= 50 { preference = .delegateAndChannel(on: first) @@ -1496,15 +1991,18 @@ class HTTPClientTests: XCTestCase { let seenError = DispatchGroup() seenError.enter() var maybeSecondRequest: EventLoopFuture? - XCTAssertNoThrow(maybeSecondRequest = try el.submit { - let neverSucceedingRequest = localClient.get(url: url) - let secondRequest = neverSucceedingRequest.flatMapError { error in - XCTAssertEqual(.cancelled, error as? HTTPClientError) - seenError.leave() - return localClient.get(url: url) // <== this is the main part, during the error callout, we call back in - } - return secondRequest - }.wait()) + XCTAssertNoThrow( + maybeSecondRequest = try el.submit { + let neverSucceedingRequest = localClient.get(url: url) + let secondRequest = neverSucceedingRequest.flatMapError { error in + XCTAssertEqual(.cancelled, error as? HTTPClientError) + seenError.leave() + // v this is the main part, during the error callout, we call back in + return localClient.get(url: url) + } + return secondRequest + }.wait() + ) guard let secondRequest = maybeSecondRequest else { XCTFail("couldn't get request future") @@ -1530,13 +2028,15 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try localClient.syncShutdown()) } - XCTAssertEqual(.ok, - try el.flatSubmit { () -> EventLoopFuture in - localClient.get(url: url).flatMap { firstResponse in - XCTAssertEqual(.ok, firstResponse.status) - return localClient.get(url: url) // <== interesting bit here - } - }.wait().status) + XCTAssertEqual( + .ok, + try el.flatSubmit { () -> EventLoopFuture in + localClient.get(url: url).flatMap { firstResponse in + XCTAssertEqual(.ok, firstResponse.status) + return localClient.get(url: url) // <== interesting bit here + } + }.wait().status + ) } func testMakeSecondRequestWhilstFirstIsOngoing() { @@ -1553,11 +2053,11 @@ class HTTPClientTests: XCTestCase { let url = "http://127.0.0.1:\(web.serverPort)" let firstRequest = client.get(url: url) - XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head + XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head // Now, the first request is ongoing but not complete, let's start a second one let secondRequest = client.get(url: url) - XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end + XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end XCTAssertNoThrow(try web.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .ok)))) XCTAssertNoThrow(try web.writeOutbound(.end(nil))) @@ -1565,8 +2065,8 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(.ok, try firstRequest.wait().status) // Okay, first request done successfully, let's do the second one too. - XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head - XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end + XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head + XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end XCTAssertNoThrow(try web.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .created)))) XCTAssertNoThrow(try web.writeOutbound(.end(nil))) @@ -1577,15 +2077,19 @@ class HTTPClientTests: XCTestCase { // This tests just connecting to a URL where the whole URL is the UNIX domain socket path like // unix:///this/is/my/socket.sock // We don't really have a path component, so we'll have to use "/" - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localHTTPBin.shutdown()) + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + let target = "unix://\(path)" + XCTAssertEqual( + ["Yes"[...]], + try self.defaultClient.get(url: target).wait().headers[canonicalForm: "X-Is-This-Slash"] + ) } - let target = "unix://\(path)" - XCTAssertEqual(["Yes"[...]], - try self.defaultClient.get(url: target).wait().headers[canonicalForm: "X-Is-This-Slash"]) - }) + ) } func testUDSSocketAndPath() { @@ -1593,56 +2097,73 @@ class HTTPClientTests: XCTestCase { // // 1. a "base path" which is the path to the UNIX domain socket // 2. an actual path which is the normal path in a regular URL like https://example.com/this/is/the/path - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localHTTPBin.shutdown()) - } - guard let target = URL(string: "/echo-uri", relativeTo: URL(string: "unix://\(path)")), - let request = try? Request(url: target) else { - XCTFail("couldn't build URL for request") - return + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + guard let target = URL(string: "/echo-uri", relativeTo: URL(string: "unix://\(path)")), + let request = try? Request(url: target) + else { + XCTFail("couldn't build URL for request") + return + } + XCTAssertEqual( + ["/echo-uri"[...]], + try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"] + ) } - XCTAssertEqual(["/echo-uri"[...]], - try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"]) - }) + ) } func testHTTPPlusUNIX() { // Here, we're testing a URL where the UNIX domain socket is encoded as the host name - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localHTTPBin.shutdown()) - } - guard let target = URL(httpURLWithSocketPath: path, uri: "/echo-uri"), - let request = try? Request(url: target) else { - XCTFail("couldn't build URL for request") - return + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + guard let target = URL(httpURLWithSocketPath: path, uri: "/echo-uri"), + let request = try? Request(url: target) + else { + XCTFail("couldn't build URL for request") + return + } + XCTAssertEqual( + ["/echo-uri"[...]], + try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"] + ) } - XCTAssertEqual(["/echo-uri"[...]], - try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"]) - }) + ) } func testHTTPSPlusUNIX() { // Here, we're testing a URL where the UNIX domain socket is encoded as the host name - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localHTTPBin.shutdown()) - } - guard let target = URL(httpsURLWithSocketPath: path, uri: "/echo-uri"), - let request = try? Request(url: target) else { - XCTFail("couldn't build URL for request") - return + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + guard let target = URL(httpsURLWithSocketPath: path, uri: "/echo-uri"), + let request = try? Request(url: target) + else { + XCTFail("couldn't build URL for request") + return + } + XCTAssertEqual( + ["/echo-uri"[...]], + try localClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"] + ) } - XCTAssertEqual(["/echo-uri"[...]], - try localClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"]) - }) + ) } func testUseExistingConnectionOnDifferentEL() throws { @@ -1656,13 +2177,20 @@ class HTTPClientTests: XCTestCase { let eventLoops = (1...threadCount).map { _ in elg.next() } let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get") - let closingRequest = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", headers: ["Connection": "close"]) + let closingRequest = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + headers: ["Connection": "close"] + ) for (index, el) in eventLoops.enumerated() { if index.isMultiple(of: 2) { - XCTAssertNoThrow(try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait()) + XCTAssertNoThrow( + try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait() + ) } else { - XCTAssertNoThrow(try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait()) + XCTAssertNoThrow( + try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait() + ) XCTAssertNoThrow(try localClient.execute(request: closingRequest, eventLoop: .indifferent).wait()) } } @@ -1673,16 +2201,16 @@ class HTTPClientTests: XCTestCase { typealias InboundIn = HTTPServerRequestPart typealias OutboundOut = HTTPServerResponsePart - let requestNumber: NIOAtomic - let connectionNumber: NIOAtomic + let requestNumber: ManagedAtomic + let connectionNumber: ManagedAtomic - init(requestNumber: NIOAtomic, connectionNumber: NIOAtomic) { + init(requestNumber: ManagedAtomic, connectionNumber: ManagedAtomic) { self.requestNumber = requestNumber self.connectionNumber = connectionNumber } func channelActive(context: ChannelHandlerContext) { - _ = self.connectionNumber.add(1) + _ = self.connectionNumber.loadThenWrappingIncrement(ordering: .relaxed) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -1692,11 +2220,13 @@ class HTTPClientTests: XCTestCase { case .head, .body: () case .end: - let last = self.requestNumber.add(1) + let last = self.requestNumber.loadThenWrappingIncrement(ordering: .relaxed) switch last { case 0, 2: - context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), - promise: nil) + context.write( + self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) case 1: context.close(promise: nil) @@ -1707,22 +2237,26 @@ class HTTPClientTests: XCTestCase { } } - let requestNumber = NIOAtomic.makeAtomic(value: 0) - let connectionNumber = NIOAtomic.makeAtomic(value: 0) - let sharedStateServerHandler = ServerThatAcceptsThenRejects(requestNumber: requestNumber, - connectionNumber: connectionNumber) + let requestNumber = ManagedAtomic(0) + let connectionNumber = ManagedAtomic(0) + let sharedStateServerHandler = ServerThatAcceptsThenRejects( + requestNumber: requestNumber, + connectionNumber: connectionNumber + ) var maybeServer: Channel? - XCTAssertNoThrow(maybeServer = try ServerBootstrap(group: self.serverGroup) - .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline().flatMap { - // We're deliberately adding a handler which is shared between multiple channels. This is normally - // very verboten but this handler is specially crafted to tolerate this. - channel.pipeline.addHandler(sharedStateServerHandler) + XCTAssertNoThrow( + maybeServer = try ServerBootstrap(group: self.serverGroup) + .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline().flatMap { + // We're deliberately adding a handler which is shared between multiple channels. This is normally + // very verboten but this handler is specially crafted to tolerate this. + channel.pipeline.addHandler(sharedStateServerHandler) + } } - } - .bind(host: "127.0.0.1", port: 0) - .wait()) + .bind(host: "127.0.0.1", port: 0) + .wait() + ) guard let server = maybeServer else { XCTFail("couldn't create server") return @@ -1737,46 +2271,57 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try client.syncShutdown()) } - XCTAssertEqual(0, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(0, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(0, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(0, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) XCTAssertEqual(.ok, try client.get(url: url).wait().status) - XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(1, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(1, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) XCTAssertThrowsError(try client.get(url: url).wait().status) { error in XCTAssertEqual(.remoteConnectionClosed, error as? HTTPClientError) } - XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(2, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(2, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) XCTAssertEqual(.ok, try client.get(url: url).wait().status) - XCTAssertEqual(2, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(3, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(2, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(3, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) } func testPoolClosesIdleConnections() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(connectionPool: .init(idleTimeout: .milliseconds(100)))) + let configuration = HTTPClient.Configuration( + certificateVerification: .none, + maximumAllowedIdleTimeInConnectionPool: .milliseconds(100) + ) + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: configuration + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + + // Make sure that the idle timeout of the connection pool is properly propagated + // to the connection pool itself, when using both inits. + XCTAssertEqual(configuration.connectionPool.idleTimeout, .milliseconds(100)) + XCTAssertEqual( + configuration.connectionPool.idleTimeout, + HTTPClient.Configuration( + certificateVerification: .none, + connectionPool: .milliseconds(100), + backgroundActivityLogger: nil + ).connectionPool.idleTimeout + ) + XCTAssertNoThrow(try localClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()) Thread.sleep(forTimeInterval: 0.2) XCTAssertEqual(self.defaultHTTPBin.activeConnections, 0) } - func testRacePoolIdleConnectionsAndGet() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(connectionPool: .init(idleTimeout: .milliseconds(10)))) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - } - for _ in 1...500 { - XCTAssertNoThrow(try localClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()) - Thread.sleep(forTimeInterval: 0.01 + .random(in: -0.05...0.05)) - } - } - func testAvoidLeakingTLSHandshakeCompletionPromise() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(timeout: .init(connect: .milliseconds(100)))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(100))) + ) let localHTTPBin = HTTPBin() let port = localHTTPBin.port XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -1786,7 +2331,12 @@ class HTTPClientTests: XCTestCase { XCTAssertThrowsError(try localClient.get(url: "http://localhost:\(port)").wait()) { error in if isTestingNIOTS() { - XCTAssertEqual(error as? HTTPClientError, .connectTimeout) + #if canImport(Network) + // We can't be more specific than this. + XCTAssertTrue(error is HTTPClient.NWTLSError || error is HTTPClient.NWPOSIXError) + #else + XCTFail("Impossible condition") + #endif } else { XCTAssert(error is NIOConnectionError, "Unexpected error: \(error)") } @@ -1818,9 +2368,13 @@ class HTTPClientTests: XCTestCase { } func testValidationErrorsAreSurfaced() throws { - let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .TRACE, body: .stream { _ in - self.defaultClient.eventLoopGroup.next().makeSucceededFuture(()) - }) + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .TRACE, + body: .stream { _ in + self.defaultClient.eventLoopGroup.next().makeSucceededFuture(()) + } + ) let runningRequest = self.defaultClient.execute(request: request) XCTAssertThrowsError(try runningRequest.wait()) { error in XCTAssertEqual(HTTPClientError.traceRequestWithBody, error as? HTTPClientError) @@ -1838,9 +2392,11 @@ class HTTPClientTests: XCTestCase { private var bodyPartsSeenSoFar = 0 private var atEnd = false - init(headPromise: EventLoopPromise, - bodyPromises: [EventLoopPromise], - endPromise: EventLoopPromise) { + init( + headPromise: EventLoopPromise, + bodyPromises: [EventLoopPromise], + endPromise: EventLoopPromise + ) { self.headPromise = headPromise self.bodyPromises = bodyPromises self.endPromise = endPromise @@ -1856,8 +2412,10 @@ class HTTPClientTests: XCTestCase { self.bodyPartsSeenSoFar += 1 self.bodyPromises.dropFirst(myNumber).first?.succeed(bytes) ?? XCTFail("ouch, too many chunks") case .end: - context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), - promise: nil) + context.write( + self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: self.endPromise) self.atEnd = true } @@ -1870,8 +2428,8 @@ class HTTPClientTests: XCTestCase { struct NotFulfilledError: Error {} self.headPromise.fail(NotFulfilledError()) - self.bodyPromises.forEach { - $0.fail(NotFulfilledError()) + for promise in self.bodyPromises { + promise.fail(NotFulfilledError()) } self.endPromise.fail(NotFulfilledError()) } @@ -1892,12 +2450,16 @@ class HTTPClientTests: XCTestCase { let streamWriterPromise = group.next().makePromise(of: HTTPClient.Body.StreamWriter.self) func makeServer() -> Channel? { - return try? ServerBootstrap(group: group) + try? ServerBootstrap(group: group) .childChannelInitializer { channel in channel.pipeline.configureHTTPServerPipeline().flatMap { - channel.pipeline.addHandler(HTTPServer(headPromise: headPromise, - bodyPromises: bodyPromises, - endPromise: endPromise)) + channel.pipeline.addHandler( + HTTPServer( + headPromise: headPromise, + bodyPromises: bodyPromises, + endPromise: endPromise + ) + ) } } .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) @@ -1910,13 +2472,15 @@ class HTTPClientTests: XCTestCase { return nil } - return try? HTTPClient.Request(url: "http://\(localAddress.ipAddress!):\(localAddress.port!)", - method: .POST, - headers: ["transfer-encoding": "chunked"], - body: .stream { streamWriter in - streamWriterPromise.succeed(streamWriter) - return sentOffAllBodyPartsPromise.futureResult - }) + return try? HTTPClient.Request( + url: "http://\(localAddress.ipAddress!):\(localAddress.port!)", + method: .POST, + headers: ["transfer-encoding": "chunked"], + body: .stream { streamWriter in + streamWriterPromise.succeed(streamWriter) + return sentOffAllBodyPartsPromise.futureResult + } + ) } guard let server = makeServer(), let request = makeRequest(server: server) else { @@ -1948,35 +2512,45 @@ class HTTPClientTests: XCTestCase { } func testUploadStreamingCallinToleratedFromOtsideEL() throws { - let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .POST, body: .stream(length: 4) { writer in - let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) - // We have to toleare callins from any thread - DispatchQueue(label: "upload-streaming").async { - writer.write(.byteBuffer(ByteBuffer(string: "1234"))).whenComplete { _ in - promise.succeed(()) + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .POST, + body: .stream(contentLength: 4) { writer in + let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) + // We have to toleare callins from any thread + DispatchQueue(label: "upload-streaming").async { + writer.write(.byteBuffer(ByteBuffer(string: "1234"))).whenComplete { _ in + promise.succeed(()) + } } + return promise.futureResult } - return promise.futureResult - }) + ) XCTAssertNoThrow(try self.defaultClient.execute(request: request).wait()) } func testWeHandleUsSendingACloseHeaderCorrectly() { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["connection": "close"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: ["connection": "close"] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -1992,21 +2566,27 @@ class HTTPClientTests: XCTestCase { } func testWeHandleUsReceivingACloseHeaderCorrectly() { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["X-Send-Back-Header-Connection": "close"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: ["X-Send-Back-Header-Connection": "close"] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -2023,22 +2603,32 @@ class HTTPClientTests: XCTestCase { func testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly() { for closeHeader in [("connection", "close"), ("CoNneCTION", "ClOSe")] { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["X-Send-Back-Header-\(closeHeader.0)": - "foo,\(closeHeader.1),bar"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: [ + "X-Send-Back-Header-\(closeHeader.0)": + "foo,\(closeHeader.1),bar" + ] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } - guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + guard + let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } - guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + guard + let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -2055,22 +2645,32 @@ class HTTPClientTests: XCTestCase { func testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly() { for closeHeader in [("connection", "close"), ("CoNneCTION", "ClOSe")] { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["X-Send-Back-Header-\(closeHeader.0)": - "foo,\(closeHeader.1),bar"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: [ + "X-Send-Back-Header-\(closeHeader.0)": + "foo,\(closeHeader.1),bar" + ] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } - guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + guard + let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } - guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + guard + let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -2088,28 +2688,35 @@ class HTTPClientTests: XCTestCase { func testLoggingCorrectlyAttachesRequestInformationEvenAfterDuringRedirect() { let logStore = CollectEverythingLogHandler.LogStore() - var logger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) + var logger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: logStore) + } + ) logger.logLevel = .trace logger[metadataKey: "custom-request-id"] = "abcd" var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "http://localhost:\(self.defaultHTTPBin.port)/redirect/target", - method: .GET, - headers: [ - "X-Target-Redirect-URL": "/get", - ] - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost:\(self.defaultHTTPBin.port)/redirect/target", + method: .GET, + headers: [ + "X-Target-Redirect-URL": "/get" + ] + ) + ) guard let request = maybeRequest else { return } - XCTAssertNoThrow(try self.defaultClient.execute( - request: request, - eventLoop: .indifferent, - deadline: nil, - logger: logger - ).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request, + eventLoop: .indifferent, + deadline: nil, + logger: logger + ).wait() + ) let logs = logStore.allEntries XCTAssertTrue(logs.allSatisfy { $0.metadata["custom-request-id"] == "abcd" }) @@ -2128,51 +2735,70 @@ class HTTPClientTests: XCTestCase { XCTAssertGreaterThan(secondRequestLogs.count, 0) XCTAssertTrue(secondRequestLogs.allSatisfy { $0.metadata["ahc-request-id"] == lastRequestID }) - logs.forEach { print($0) } + for log in logs { print(log) } } func testLoggingCorrectlyAttachesRequestInformation() { let logStore = CollectEverythingLogHandler.LogStore() - var loggerYolo001 = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) + var loggerYolo001 = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: logStore) + } + ) loggerYolo001.logLevel = .trace loggerYolo001[metadataKey: "yolo-request-id"] = "yolo-001" - var loggerACME002 = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) + var loggerACME002 = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: logStore) + } + ) loggerACME002.logLevel = .trace loggerACME002[metadataKey: "acme-request-id"] = "acme-002" guard let request1 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get"), - let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats"), - let request3 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "ok") else { + let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats"), + let request3 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "ok") + else { XCTFail("bad stuff, can't even make request structures") return } // === Request 1 (Yolo001) - XCTAssertNoThrow(try self.defaultClient.execute(request: request1, - eventLoop: .indifferent, - deadline: nil, - logger: loggerYolo001).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request1, + eventLoop: .indifferent, + deadline: nil, + logger: loggerYolo001 + ).wait() + ) let logsAfterReq1 = logStore.allEntries logStore.allEntries = [] // === Request 2 (Yolo001) - XCTAssertNoThrow(try self.defaultClient.execute(request: request2, - eventLoop: .indifferent, - deadline: nil, - logger: loggerYolo001).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request2, + eventLoop: .indifferent, + deadline: nil, + logger: loggerYolo001 + ).wait() + ) let logsAfterReq2 = logStore.allEntries logStore.allEntries = [] // === Request 3 (ACME002) - XCTAssertNoThrow(try self.defaultClient.execute(request: request3, - eventLoop: .indifferent, - deadline: nil, - logger: loggerACME002).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request3, + eventLoop: .indifferent, + deadline: nil, + logger: loggerACME002 + ).wait() + ) let logsAfterReq3 = logStore.allEntries logStore.allEntries = [] @@ -2181,176 +2807,238 @@ class HTTPClientTests: XCTestCase { XCTAssertGreaterThan(logsAfterReq2.count, 0) XCTAssertGreaterThan(logsAfterReq3.count, 0) - XCTAssert(logsAfterReq1.allSatisfy { entry in - if let httpRequestMetadata = entry.metadata["ahc-request-id"], - let yoloRequestID = entry.metadata["yolo-request-id"] { - XCTAssertNil(entry.metadata["acme-request-id"]) - XCTAssertEqual("yolo-001", yoloRequestID) - XCTAssertNotNil(Int(httpRequestMetadata)) - return true - } else { - XCTFail("log message doesn't contain the right IDs: \(entry)") - return false - } - }) - XCTAssert(logsAfterReq1.contains { entry in - // Since a new connection must be created first we expect that the request is queued - // and log message describing this is emitted. - entry.message == "Request was queued (waiting for a connection to become available)" - && entry.level == .debug - }) - XCTAssert(logsAfterReq1.contains { entry in - // After the new connection was created we expect a log message that describes that the - // request was scheduled on a connection. The connection id must be set from here on. - entry.message == "Request was scheduled on connection" - && entry.level == .debug - && entry.metadata["ahc-connection-id"] != nil - }) - - XCTAssert(logsAfterReq2.allSatisfy { entry in - if let httpRequestMetadata = entry.metadata["ahc-request-id"], - let yoloRequestID = entry.metadata["yolo-request-id"] { - XCTAssertNil(entry.metadata["acme-request-id"]) - XCTAssertEqual("yolo-001", yoloRequestID) - XCTAssertNotNil(Int(httpRequestMetadata)) - return true - } else { - XCTFail("log message doesn't contain the right IDs: \(entry)") - return false - } - }) - XCTAssertFalse(logsAfterReq2.contains { entry in - entry.message == "Request was queued (waiting for a connection to become available)" - }) - XCTAssert(logsAfterReq2.contains { entry in - entry.message == "Request was scheduled on connection" - && entry.level == .debug - && entry.metadata["ahc-connection-id"] != nil - }) - - XCTAssert(logsAfterReq3.allSatisfy { entry in - if let httpRequestMetadata = entry.metadata["ahc-request-id"], - let acmeRequestID = entry.metadata["acme-request-id"] { - XCTAssertNil(entry.metadata["yolo-request-id"]) - XCTAssertEqual("acme-002", acmeRequestID) - XCTAssertNotNil(Int(httpRequestMetadata)) - return true - } else { - XCTFail("log message doesn't contain the right IDs: \(entry)") - return false + XCTAssert( + logsAfterReq1.allSatisfy { entry in + if let httpRequestMetadata = entry.metadata["ahc-request-id"], + let yoloRequestID = entry.metadata["yolo-request-id"] + { + XCTAssertNil(entry.metadata["acme-request-id"]) + XCTAssertEqual("yolo-001", yoloRequestID) + XCTAssertNotNil(Int(httpRequestMetadata)) + return true + } else { + XCTFail("log message doesn't contain the right IDs: \(entry)") + return false + } } - }) - XCTAssertFalse(logsAfterReq3.contains { entry in - entry.message == "Request was queued (waiting for a connection to become available)" - }) - XCTAssert(logsAfterReq3.contains { entry in - entry.message == "Request was scheduled on connection" - && entry.level == .debug - && entry.metadata["ahc-connection-id"] != nil - }) - } - - func testNothingIsLoggedAtInfoOrHigher() { - let logStore = CollectEverythingLogHandler.LogStore() + ) + XCTAssert( + logsAfterReq1.contains { entry in + // Since a new connection must be created first we expect that the request is queued + // and log message describing this is emitted. + entry.message == "Request was queued (waiting for a connection to become available)" + && entry.level == .debug + } + ) + XCTAssert( + logsAfterReq1.contains { entry in + // After the new connection was created we expect a log message that describes that the + // request was scheduled on a connection. The connection id must be set from here on. + entry.message == "Request was scheduled on connection" + && entry.level == .debug + && entry.metadata["ahc-connection-id"] != nil + } + ) - var logger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) - logger.logLevel = .info + XCTAssert( + logsAfterReq2.allSatisfy { entry in + if let httpRequestMetadata = entry.metadata["ahc-request-id"], + let yoloRequestID = entry.metadata["yolo-request-id"] + { + XCTAssertNil(entry.metadata["acme-request-id"]) + XCTAssertEqual("yolo-001", yoloRequestID) + XCTAssertNotNil(Int(httpRequestMetadata)) + return true + } else { + XCTFail("log message doesn't contain the right IDs: \(entry)") + return false + } + } + ) + XCTAssertFalse( + logsAfterReq2.contains { entry in + entry.message == "Request was queued (waiting for a connection to become available)" + } + ) + XCTAssert( + logsAfterReq2.contains { entry in + entry.message == "Request was scheduled on connection" + && entry.level == .debug + && entry.metadata["ahc-connection-id"] != nil + } + ) - guard let request1 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get"), - let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats") else { - XCTFail("bad stuff, can't even make request structures") - return + XCTAssert( + logsAfterReq3.allSatisfy { entry in + if let httpRequestMetadata = entry.metadata["ahc-request-id"], + let acmeRequestID = entry.metadata["acme-request-id"] + { + XCTAssertNil(entry.metadata["yolo-request-id"]) + XCTAssertEqual("acme-002", acmeRequestID) + XCTAssertNotNil(Int(httpRequestMetadata)) + return true + } else { + XCTFail("log message doesn't contain the right IDs: \(entry)") + return false + } + } + ) + XCTAssertFalse( + logsAfterReq3.contains { entry in + entry.message == "Request was queued (waiting for a connection to become available)" + } + ) + XCTAssert( + logsAfterReq3.contains { entry in + entry.message == "Request was scheduled on connection" + && entry.level == .debug + && entry.metadata["ahc-connection-id"] != nil + } + ) + } + + func testNothingIsLoggedAtInfoOrHigher() { + let logStore = CollectEverythingLogHandler.LogStore() + + var logger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: logStore) + } + ) + logger.logLevel = .info + + guard let request1 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get"), + let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats") + else { + XCTFail("bad stuff, can't even make request structures") + return } // === Request 1 - XCTAssertNoThrow(try self.defaultClient.execute(request: request1, - eventLoop: .indifferent, - deadline: nil, - logger: logger).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request1, + eventLoop: .indifferent, + deadline: nil, + logger: logger + ).wait() + ) XCTAssertEqual(0, logStore.allEntries.count) // === Request 2 - XCTAssertNoThrow(try self.defaultClient.execute(request: request2, - eventLoop: .indifferent, - deadline: nil, - logger: logger).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request2, + eventLoop: .indifferent, + deadline: nil, + logger: logger + ).wait() + ) XCTAssertEqual(0, logStore.allEntries.count) // === Synthesized Request - XCTAssertNoThrow(try self.defaultClient.execute(.GET, - url: self.defaultHTTPBinURLPrefix + "get", - body: nil, - deadline: nil, - logger: logger).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + .GET, + url: self.defaultHTTPBinURLPrefix + "get", + body: nil, + deadline: nil, + logger: logger + ).wait() + ) XCTAssertEqual(0, logStore.allEntries.count) XCTAssertEqual(0, self.backgroundLogStore.allEntries.filter { $0.level >= .info }.count) // === Synthesized Socket Path Request - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace - - let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let backgroundLogStore = CollectEverythingLogHandler.LogStore() + var backgroundLogger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: backgroundLogStore) + } + ) + backgroundLogger.logLevel = .trace - XCTAssertNoThrow(try localClient.execute(.GET, - socketPath: path, - urlPath: "get", - body: nil, - deadline: nil, - logger: logger).wait()) - XCTAssertEqual(0, logStore.allEntries.count) + let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .info }.count) - }) + XCTAssertNoThrow( + try localClient.execute( + .GET, + socketPath: path, + urlPath: "get", + body: nil, + deadline: nil, + logger: logger + ).wait() + ) + XCTAssertEqual(0, logStore.allEntries.count) - // === Synthesized Secure Socket Path Request - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace - - let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .info }.count) } + ) - XCTAssertNoThrow(try localClient.execute(.GET, - secureSocketPath: path, - urlPath: "get", - body: nil, - deadline: nil, - logger: logger).wait()) - XCTAssertEqual(0, logStore.allEntries.count) + // === Synthesized Secure Socket Path Request + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let backgroundLogStore = CollectEverythingLogHandler.LogStore() + var backgroundLogger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: backgroundLogStore) + } + ) + backgroundLogger.logLevel = .trace + + let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .info }.count) - }) + XCTAssertNoThrow( + try localClient.execute( + .GET, + secureSocketPath: path, + urlPath: "get", + body: nil, + deadline: nil, + logger: logger + ).wait() + ) + XCTAssertEqual(0, logStore.allEntries.count) + + XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .info }.count) + } + ) } func testAllMethodsLog() { func checkExpectationsWithLogger(type: String, _ body: (Logger, String) throws -> T) throws -> T { let logStore = CollectEverythingLogHandler.LogStore() - var logger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) + var logger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: logStore) + } + ) logger.logLevel = .trace logger[metadataKey: "req"] = "yo-\(type)" @@ -2358,86 +3046,125 @@ class HTTPClientTests: XCTestCase { let result = try body(logger, url) XCTAssertGreaterThan(logStore.allEntries.count, 0) - logStore.allEntries.forEach { entry in + for entry in logStore.allEntries { XCTAssertEqual("yo-\(type)", entry.metadata["req"] ?? "n/a") XCTAssertNotNil(Int(entry.metadata["ahc-request-id"] ?? "n/a")) } return result } - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "GET") { logger, url in - try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "GET") { logger, url in + try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "PUT") { logger, url in - try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "PUT") { logger, url in + try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "POST") { logger, url in - try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "POST") { logger, url in + try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "DELETE") { logger, url in - try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "DELETE") { logger, url in + try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "PATCH") { logger, url in - try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "PATCH") { logger, url in + try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "CHECKOUT") { logger, url in - try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "CHECKOUT") { logger, url in + try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + url, logger: logger) + .wait() + }.status + ) // No background activity expected here. XCTAssertEqual(0, self.backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let backgroundLogStore = CollectEverythingLogHandler.LogStore() + var backgroundLogger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: backgroundLogStore) + } + ) + backgroundLogger.logLevel = .trace - let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "GET") { logger, url in - try localClient.execute(socketPath: path, urlPath: url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "GET") { logger, url in + try localClient.execute(socketPath: path, urlPath: url, logger: logger).wait() + }.status + ) - // No background activity expected here. - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) - }) + // No background activity expected here. + XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) + } + ) - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let backgroundLogStore = CollectEverythingLogHandler.LogStore() + var backgroundLogger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: backgroundLogStore) + } + ) + backgroundLogger.logLevel = .trace - let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "GET") { logger, url in - try localClient.execute(secureSocketPath: path, urlPath: url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "GET") { logger, url in + try localClient.execute(secureSocketPath: path, urlPath: url, logger: logger).wait() + }.status + ) - // No background activity expected here. - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) - }) + // No background activity expected here. + XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) + } + ) } func testClosingIdleConnectionsInPoolLogsInTheBackground() { @@ -2446,16 +3173,19 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try self.defaultClient.syncShutdown()) XCTAssertGreaterThanOrEqual(self.backgroundLogStore.allEntries.count, 0) - XCTAssert(self.backgroundLogStore.allEntries.contains { entry in - entry.message == "Shutting down connection pool" - }) - XCTAssert(self.backgroundLogStore.allEntries.allSatisfy { entry in - entry.metadata["ahc-request-id"] == nil && - entry.metadata["ahc-request"] == nil && - entry.metadata["ahc-pool-key"] != nil - }) + XCTAssert( + self.backgroundLogStore.allEntries.contains { entry in + entry.message == "Shutting down connection pool" + } + ) + XCTAssert( + self.backgroundLogStore.allEntries.allSatisfy { entry in + entry.metadata["ahc-request-id"] == nil && entry.metadata["ahc-request"] == nil + && entry.metadata["ahc-pool-key"] != nil + } + ) - self.defaultClient = nil // so it doesn't get shut down again. + self.defaultClient = nil // so it doesn't get shut down again. } func testUploadStreamingNoLength() throws { @@ -2480,8 +3210,8 @@ class HTTPClientTests: XCTestCase { XCTFail("Unexpected part") } - XCTAssertNoThrow(try server.readInbound()) // .body - XCTAssertNoThrow(try server.readInbound()) // .end + XCTAssertNoThrow(try server.readInbound()) // .body + XCTAssertNoThrow(try server.readInbound()) // .end XCTAssertNoThrow(try server.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .ok)))) XCTAssertNoThrow(try server.writeOutbound(.end(nil))) @@ -2499,8 +3229,10 @@ class HTTPClientTests: XCTestCase { } } - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(timeout: .init(connect: .milliseconds(10)))) + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(10))) + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown()) @@ -2511,8 +3243,8 @@ class HTTPClientTests: XCTestCase { let delegate = TestDelegate() XCTAssertThrowsError(try httpClient.execute(request: request, delegate: delegate).wait()) { - XCTAssertEqual(.connectTimeout, $0 as? HTTPClientError) - XCTAssertEqual(.connectTimeout, delegate.error as? HTTPClientError) + XCTAssertEqualTypeAndValue($0, HTTPClientError.connectTimeout) + XCTAssertEqualTypeAndValue(delegate.error, HTTPClientError.connectTimeout) } } @@ -2526,11 +3258,11 @@ class HTTPClientTests: XCTestCase { } func didReceiveHead(task: HTTPClient.Task, _: HTTPResponseHead) -> EventLoopFuture { - return self.eventLoop.makeSucceededFuture(()) + self.eventLoop.makeSucceededFuture(()) } func didReceiveBodyPart(task: HTTPClient.Task, _: ByteBuffer) -> EventLoopFuture { - return self.eventLoop.makeSucceededFuture(()) + self.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws {} @@ -2553,8 +3285,8 @@ class HTTPClientTests: XCTestCase { let request = try HTTPClient.Request(url: "http://localhost:\(httpServer.serverPort)/") let future = httpClient.execute(request: request, delegate: delegate) - XCTAssertNoThrow(try httpServer.readInbound()) // .head - XCTAssertNoThrow(try httpServer.readInbound()) // .end + XCTAssertNoThrow(try httpServer.readInbound()) // .head + XCTAssertNoThrow(try httpServer.readInbound()) // .end XCTAssertNoThrow(try httpServer.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .ok)))) XCTAssertNoThrow(try httpServer.writeOutbound(.body(.byteBuffer(ByteBuffer(string: "1234"))))) @@ -2566,15 +3298,20 @@ class HTTPClientTests: XCTestCase { func testContentLengthTooLongFails() throws { let url = self.defaultHTTPBinURLPrefix + "post" XCTAssertThrowsError( - try self.defaultClient.execute(request: - Request(url: url, - body: .stream(length: 10) { streamWriter in + try self.defaultClient.execute( + request: + Request( + url: url, + body: .stream(contentLength: 10) { streamWriter in let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) DispatchQueue(label: "content-length-test").async { streamWriter.write(.byteBuffer(ByteBuffer(string: "1"))).cascade(to: promise) } return promise.futureResult - })).wait()) { error in + } + ) + ).wait() + ) { error in XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch) } // Quickly try another request and check that it works. @@ -2596,11 +3333,16 @@ class HTTPClientTests: XCTestCase { let url = self.defaultHTTPBinURLPrefix + "post" let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n" XCTAssertThrowsError( - try self.defaultClient.execute(request: - Request(url: url, - body: .stream(length: 1) { streamWriter in + try self.defaultClient.execute( + request: + Request( + url: url, + body: .stream(contentLength: 1) { streamWriter in streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong))) - })).wait()) { error in + } + ) + ).wait() + ) { error in XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch) } // Quickly try another request and check that it works. If we by accident wrote some extra bytes into the @@ -2621,9 +3363,9 @@ class HTTPClientTests: XCTestCase { func testBodyUploadAfterEndFails() { let url = self.defaultHTTPBinURLPrefix + "post" - func uploader(_ streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture { + let uploader = { @Sendable (_ streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture in let done = streamWriter.write(.byteBuffer(ByteBuffer(string: "X"))) - done.recover { error -> Void in + done.recover { error in XCTFail("unexpected error \(error)") }.whenSuccess { // This is executed when we have already sent the end of the request. @@ -2642,7 +3384,7 @@ class HTTPClientTests: XCTestCase { } var request: HTTPClient.Request? - XCTAssertNoThrow(request = try Request(url: url, body: .stream(length: 1, uploader))) + XCTAssertNoThrow(request = try Request(url: url, body: .stream(contentLength: 1, uploader))) XCTAssertThrowsError(try self.defaultClient.execute(request: XCTUnwrap(request)).wait()) { XCTAssertEqual($0 as? HTTPClientError, .writeAfterRequestSent) } @@ -2652,54 +3394,6 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()) } - func testNoBytesSentOverBodyLimit() throws { - let server = NIOHTTP1TestServer(group: self.serverGroup) - defer { - XCTAssertNoThrow(try server.stop()) - } - - let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n" - - let request = try Request( - url: "http://localhost:\(server.serverPort)", - body: .stream(length: 1) { streamWriter in - streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong))) - } - ) - - let future = self.defaultClient.execute(request: request) - - // Okay, what happens here needs an explanation: - // - // In the request state machine, we should start the request, which will lead to an - // invocation of `context.write(HTTPRequestHead)`. Since we will receive a streamed request - // body a `context.flush()` will be issued. Further the request stream will be started. - // Since the request stream immediately produces to much data, the request will be failed - // and the connection will be closed. - // - // Even though a flush was issued after the request head, there is no guarantee that the - // request head was written to the network. For this reason we must accept not receiving a - // request and receiving a request head. - - do { - _ = try server.receiveHead() - - // A request head was sent. We expect the request now to fail with a parsing error, - // since the client ended the connection to early (from the server's point of view.) - XCTAssertThrowsError(try server.readInbound()) { - XCTAssertEqual($0 as? HTTPParserError, HTTPParserError.invalidEOFState) - } - } catch { - // TBD: We sadly can't verify the error type, since it is private in `NIOTestUtils`: - // NIOTestUtils.BlockingQueue.TimeoutError - } - - // request must always be failed with this error - XCTAssertThrowsError(try future.wait()) { - XCTAssertEqual($0 as? HTTPClientError, .bodyLengthMismatch) - } - } - func testDoubleError() throws { // This is needed to that connection pool will not get into closed state when we release // second connection. @@ -2722,7 +3416,9 @@ class HTTPClientTests: XCTestCase { // We specify a deadline of 2 ms co that request will be timed out before all chunks are writtent, // we need to verify that second error on write after timeout does not lead to double-release. - XCTAssertThrowsError(try self.defaultClient.execute(request: request, deadline: .now() + .milliseconds(2)).wait()) + XCTAssertThrowsError( + try self.defaultClient.execute(request: request, deadline: .now() + .milliseconds(2)).wait() + ) } func testSSLHandshakeErrorPropagation() throws { @@ -2749,6 +3445,8 @@ class HTTPClientTests: XCTestCase { if isTestingNIOTS() { // If we are using Network.framework, we set the connect timeout down very low here // because on NIOTS a failing TLS handshake manifests as a connect timeout. + // Note that we do this here to prove that we correctly manifest the underlying error: + // DO NOT CHANGE THIS TO DISABLE WAITING FOR CONNECTIVITY. timeout.connect = .milliseconds(100) } @@ -2763,7 +3461,12 @@ class HTTPClientTests: XCTestCase { XCTAssertThrowsError(try task.wait()) { error in if isTestingNIOTS() { - XCTAssertEqual(error as? HTTPClientError, .connectTimeout) + #if canImport(Network) + // We can't be more specific than this. + XCTAssertTrue(error is HTTPClient.NWTLSError) + #else + XCTFail("Impossible condition") + #endif } else { switch error as? NIOSSLError { case .some(.handshakeFailed(.sslError(_))): break @@ -2815,7 +3518,12 @@ class HTTPClientTests: XCTestCase { XCTAssertThrowsError(try task.wait()) { error in if isTestingNIOTS() { - XCTAssertEqual(error as? HTTPClientError, .connectTimeout) + #if canImport(Network) + // We can't be more specific than this. + XCTAssertTrue(error is HTTPClient.NWTLSError) + #else + XCTFail("Impossible condition") + #endif } else { switch error as? NIOSSLError { case .some(.handshakeFailed(.sslError(_))): break @@ -2872,7 +3580,7 @@ class HTTPClientTests: XCTestCase { let body: HTTPClient.Body = .stream { writer in let finalPromise = writeEL.makePromise(of: Void.self) - func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + @Sendable func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { // always invoke from the wrong el to test thread safety writeEL.preconditionInEventLoop() @@ -2914,6 +3622,262 @@ class HTTPClientTests: XCTestCase { XCTAssertNil(try delegate.next().wait()) } + func testResponseAccumulatorMaxBodySizeLimitExceedingWithContentLength() throws { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let body = ByteBuffer(bytes: 0..<11) + + var request = try Request(url: httpBin.baseURL) + request.body = .byteBuffer(body) + XCTAssertThrowsError( + try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() + ) { error in + XCTAssertTrue(error is ResponseAccumulator.ResponseTooBigError, "unexpected error \(error)") + } + } + + func testResponseAccumulatorMaxBodySizeLimitNotExceedingWithContentLength() throws { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let body = ByteBuffer(bytes: 0..<10) + + var request = try Request(url: httpBin.baseURL) + request.body = .byteBuffer(body) + let response = try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() + + XCTAssertEqual(response.body, body) + } + + func testResponseAccumulatorMaxBodySizeLimitExceedingWithContentLengthButMethodIsHead() throws { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTPEchoHeaders() } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let body = ByteBuffer(bytes: 0..<11) + + var request = try Request(url: httpBin.baseURL, method: .HEAD) + request.body = .byteBuffer(body) + let response = try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() + + XCTAssertEqual(response.body ?? ByteBuffer(), ByteBuffer()) + } + + func testResponseAccumulatorMaxBodySizeLimitExceedingWithTransferEncodingChuncked() throws { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let body = ByteBuffer(bytes: 0..<11) + + var request = try Request(url: httpBin.baseURL) + request.body = .stream { writer in + writer.write(.byteBuffer(body)) + } + XCTAssertThrowsError( + try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() + ) { error in + XCTAssertTrue(error is ResponseAccumulator.ResponseTooBigError, "unexpected error \(error)") + } + } + + func testResponseAccumulatorMaxBodySizeLimitNotExceedingWithTransferEncodingChuncked() throws { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let body = ByteBuffer(bytes: 0..<10) + + var request = try Request(url: httpBin.baseURL) + request.body = .stream { writer in + writer.write(.byteBuffer(body)) + } + let response = try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() + + XCTAssertEqual(response.body, body) + } + + // In this test, we test that a request can continue to stream its body after the response head and end + // was received where the end is a 200. + func testBiDirectionalStreamingEarly200() { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in + HTTP200DelayedHandler(bodyPartsBeforeResponse: 1) + } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let writeEL = eventLoopGroup.next() + let delegateEL = eventLoopGroup.next() + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) + defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } + + let delegate = ResponseStreamDelegate(eventLoop: delegateEL) + + let body: HTTPClient.Body = .stream { writer in + let finalPromise = writeEL.makePromise(of: Void.self) + + @Sendable func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + // always invoke from the wrong el to test thread safety + writeEL.preconditionInEventLoop() + + if index >= 30 { + return finalPromise.succeed(()) + } + + let sent = ByteBuffer(integer: index) + writer.write(.byteBuffer(sent)).whenComplete { result in + switch result { + case .success: + writeEL.execute { + writeLoop(writer, index: index + 1) + } + + case .failure(let error): + finalPromise.fail(error) + } + } + } + + writeEL.execute { + writeLoop(writer, index: 0) + } + + return finalPromise.futureResult + } + + let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)", body: body) + let future = httpClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: delegateEL)) + XCTAssertNoThrow(try future.wait()) + XCTAssertNil(try delegate.next().wait()) + } + + // This test is identical to the one above, except that we send another request immediately after. This is a regression + // test for https://github.com/swift-server/async-http-client/issues/595. + func testBiDirectionalStreamingEarly200DoesntPreventUsFromSendingMoreRequests() { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in + HTTP200DelayedHandler(bodyPartsBeforeResponse: 1) + } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let writeEL = eventLoopGroup.next() + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) + defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } + + let body: HTTPClient.Body = .stream { writer in + let finalPromise = writeEL.makePromise(of: Void.self) + + @Sendable func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + // always invoke from the wrong el to test thread safety + writeEL.preconditionInEventLoop() + + if index >= 30 { + return finalPromise.succeed(()) + } + + let sent = ByteBuffer(integer: index) + writer.write(.byteBuffer(sent)).whenComplete { result in + switch result { + case .success: + writeEL.execute { + writeLoop(writer, index: index + 1) + } + + case .failure(let error): + finalPromise.fail(error) + } + } + } + + writeEL.execute { + writeLoop(writer, index: 0) + } + + return finalPromise.futureResult + } + + let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)", body: body) + let future = httpClient.execute(request: request) + XCTAssertNoThrow(try future.wait()) + + // Try another request + let future2 = httpClient.execute(request: request) + XCTAssertNoThrow(try future2.wait()) + } + + // This test validates that we correctly close the connection after our body completes when we've streamed a + // body and received the 2XX response _before_ we finished our stream. + func testCloseConnectionAfterEarly2XXWhenStreaming() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let onClosePromise = eventLoopGroup.next().makePromise(of: Void.self) + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in + ExpectClosureServerHandler(onClosePromise: onClosePromise) + } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let writeEL = eventLoopGroup.next() + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) + defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } + + let body: HTTPClient.Body = .stream { writer in + let finalPromise = writeEL.makePromise(of: Void.self) + + @Sendable func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + // always invoke from the wrong el to test thread safety + writeEL.preconditionInEventLoop() + + if index >= 30 { + return finalPromise.succeed(()) + } + + let sent = ByteBuffer(integer: index) + writer.write(.byteBuffer(sent)).whenComplete { result in + switch result { + case .success: + writeEL.execute { + writeLoop(writer, index: index + 1) + } + + case .failure(let error): + finalPromise.fail(error) + } + } + } + + writeEL.execute { + writeLoop(writer, index: 0) + } + + return finalPromise.futureResult + } + + let headers = HTTPHeaders([("Connection", "close")]) + let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)", headers: headers, body: body) + let future = httpClient.execute(request: request) + XCTAssertNoThrow(try future.wait()) + XCTAssertNoThrow(try onClosePromise.futureResult.wait()) + } + func testSynchronousHandshakeErrorReporting() throws { // This only affects cases where we use NIOSSL. guard !isTestingNIOTS() else { return } @@ -2925,8 +3889,10 @@ class HTTPClientTests: XCTestCase { tlsConfig.maximumTLSVersion = .tlsv12 tlsConfig.certificateVerification = .none let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(tlsConfiguration: tlsConfig)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(tlsConfiguration: tlsConfig) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -3001,12 +3967,16 @@ class HTTPClientTests: XCTestCase { } func testRequestSpecificTLS() throws { - let configuration = HTTPClient.Configuration(tlsConfiguration: nil, - timeout: .init(), - decompression: .disabled) + let configuration = HTTPClient.Configuration( + tlsConfiguration: nil, + timeout: .init(), + decompression: .disabled + ) let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: configuration) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: configuration + ) let decoder = JSONDecoder() defer { @@ -3017,7 +3987,11 @@ class HTTPClientTests: XCTestCase { // First two requests use identical TLS configurations. var tlsConfig = TLSConfiguration.makeClientConfiguration() tlsConfig.certificateVerification = .none - let firstRequest = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, tlsConfiguration: tlsConfig) + let firstRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + tlsConfiguration: tlsConfig + ) let firstResponse = try localClient.execute(request: firstRequest).wait() guard let firstBody = firstResponse.body else { XCTFail("No request body found") @@ -3025,7 +3999,11 @@ class HTTPClientTests: XCTestCase { } let firstConnectionNumber = try decoder.decode(RequestInfo.self, from: firstBody).connectionNumber - let secondRequest = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, tlsConfiguration: tlsConfig) + let secondRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + tlsConfiguration: tlsConfig + ) let secondResponse = try localClient.execute(request: secondRequest).wait() guard let secondBody = secondResponse.body else { XCTFail("No request body found") @@ -3037,7 +4015,11 @@ class HTTPClientTests: XCTestCase { var tlsConfig2 = TLSConfiguration.makeClientConfiguration() tlsConfig2.certificateVerification = .none tlsConfig2.maximumTLSVersion = .tlsv1 - let thirdRequest = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, tlsConfiguration: tlsConfig2) + let thirdRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + tlsConfiguration: tlsConfig2 + ) let thirdResponse = try localClient.execute(request: thirdRequest).wait() guard let thirdBody = thirdResponse.body else { XCTFail("No request body found") @@ -3048,51 +4030,16 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(firstResponse.status, .ok) XCTAssertEqual(secondResponse.status, .ok) XCTAssertEqual(thirdResponse.status, .ok) - XCTAssertEqual(firstConnectionNumber, secondConnectionNumber, "Identical TLS configurations did not use the same connection") - XCTAssertNotEqual(thirdConnectionNumber, firstConnectionNumber, "Different TLS configurations did not use different connections.") - } - - func testConnectionPoolSizeConfigValueIsRespected() { - let numberOfRequestsPerThread = 1000 - let numberOfParallelWorkers = 16 - let poolSize = 12 - - let httpBin = HTTPBin() - defer { XCTAssertNoThrow(try httpBin.shutdown()) } - - let group = MultiThreadedEventLoopGroup(numberOfThreads: 4) - defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - - let configuration = HTTPClient.Configuration( - connectionPool: .init( - idleTimeout: .seconds(30), - concurrentHTTP1ConnectionsPerHostSoftLimit: poolSize - ) + XCTAssertEqual( + firstConnectionNumber, + secondConnectionNumber, + "Identical TLS configurations did not use the same connection" + ) + XCTAssertNotEqual( + thirdConnectionNumber, + firstConnectionNumber, + "Different TLS configurations did not use different connections." ) - let client = HTTPClient(eventLoopGroupProvider: .shared(group), configuration: configuration) - defer { XCTAssertNoThrow(try client.syncShutdown()) } - - let g = DispatchGroup() - for workerID in 0..) throws {} + func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { + task.cancel() + } + } + XCTAssertThrowsError(try defaultClient.execute(request: request, delegate: CancelAfterHeadSend()).wait()) + } + + func testCancelingHTTP2RequestAfterHeaderSend() throws { + let bin = HTTPBin(.http2()) + defer { XCTAssertNoThrow(try bin.shutdown()) } + var request = try HTTPClient.Request(url: bin.baseURL + "/wait", method: .POST) + // non-empty body is important + request.body = .byteBuffer(ByteBuffer([1])) + + class CancelAfterHeadSend: HTTPClientResponseDelegate { + init() {} + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} + func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { + task.cancel() + } + } + XCTAssertThrowsError(try defaultClient.execute(request: request, delegate: CancelAfterHeadSend()).wait()) + } + + private func testMaxConnectionReuses(mode: HTTPBin.Mode, maximumUses: Int, requests: Int) throws { + let bin = HTTPBin(mode) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + var configuration = HTTPClient.Configuration(certificateVerification: .none) + // Limit each connection to two uses before discarding them. The test will verify that the + // connection number indicated by the server increments every two requests. + configuration.maximumUsesPerConnection = maximumUses + + let client = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: configuration) + defer { XCTAssertNoThrow(try client.syncShutdown()) } + + let request = try HTTPClient.Request(url: bin.baseURL + "stats") + let decoder = JSONDecoder() + + // Do two requests per batch. Both should report the same connection number. + for requestNumber in stride(from: 0, to: requests, by: maximumUses) { + var responses = [RequestInfo]() + + for _ in 0.. () throws -> Void)] { - return [ - ("testEOFFramedSuccess", testEOFFramedSuccess), - ("testContentLength", testContentLength), - ("testContentLengthButTruncated", testContentLengthButTruncated), - ("testTransferEncoding", testTransferEncoding), - ("testTransferEncodingButTruncated", testTransferEncodingButTruncated), - ("testConnectionDrop", testConnectionDrop), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift index 854d9092c..b63eb7cba 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift @@ -155,7 +155,8 @@ final class HTTPClientUncleanSSLConnectionShutdownTests: XCTestCase { ) defer { XCTAssertNoThrow(try client.syncShutdown()) } - XCTAssertThrowsError(try client.get(url: "https://localhost:\(httpBin.port)/transferencodingtruncated").wait()) { + XCTAssertThrowsError(try client.get(url: "https://localhost:\(httpBin.port)/transferencodingtruncated").wait()) + { XCTAssertEqual($0 as? HTTPParserError, .invalidEOFState) } } @@ -184,7 +185,7 @@ final class HTTPBinForSSLUncleanShutdown { let serverChannel: Channel var port: Int { - return Int(self.serverChannel.localAddress!.port!) + Int(self.serverChannel.localAddress!.port!) } init() { @@ -231,61 +232,61 @@ private final class HTTPBinForSSLUncleanShutdownHandler: ChannelInboundHandler { switch req.uri { case "/nocontentlength": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - \r\n\ - foo - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + \r\n\ + foo + """ case "/nocontent": response = """ - HTTP/1.1 204 OK\r\n\ - Connection: close\r\n\ - \r\n - """ + HTTP/1.1 204 OK\r\n\ + Connection: close\r\n\ + \r\n + """ case "/noresponse": response = nil case "/wrongcontentlength": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - Content-Length: 6\r\n\ - \r\n\ - foo - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + Content-Length: 6\r\n\ + \r\n\ + foo + """ case "/transferencoding": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - 3\r\n\ - foo\r\n\ - 0\r\n\ - \r\n - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 3\r\n\ + foo\r\n\ + 0\r\n\ + \r\n + """ case "/transferencodingtruncated": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - 12\r\n\ - foo - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 12\r\n\ + foo + """ default: response = """ - HTTP/1.1 404 OK\r\n\ - Connection: close\r\n\ - Content-Length: 9\r\n\ - \r\n\ - Not Found - """ + HTTP/1.1 404 OK\r\n\ + Connection: close\r\n\ + Content-Length: 9\r\n\ + \r\n\ + Not Found + """ } if let response = response { diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests+XCTest.swift deleted file mode 100644 index 898b2b867..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests+XCTest.swift +++ /dev/null @@ -1,34 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+FactoryTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_FactoryTests { - static var allTests: [(String, (HTTPConnectionPool_FactoryTests) -> () throws -> Void)] { - return [ - ("testConnectionCreationTimesoutIfDeadlineIsInThePast", testConnectionCreationTimesoutIfDeadlineIsInThePast), - ("testSOCKSConnectionCreationTimesoutIfRemoteIsUnresponsive", testSOCKSConnectionCreationTimesoutIfRemoteIsUnresponsive), - ("testHTTPProxyConnectionCreationTimesoutIfRemoteIsUnresponsive", testHTTPProxyConnectionCreationTimesoutIfRemoteIsUnresponsive), - ("testTLSConnectionCreationTimesoutIfRemoteIsUnresponsive", testTLSConnectionCreationTimesoutIfRemoteIsUnresponsive), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift index b13ff3d18..d9dbd4cb1 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOPosix @@ -20,18 +19,22 @@ import NIOSOCKS import NIOSSL import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_FactoryTests: XCTestCase { func testConnectionCreationTimesoutIfDeadlineIsInThePast() { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -45,12 +48,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - connectionID: 1, - deadline: .now() - .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() - .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { XCTAssertEqual($0 as? HTTPClientError, .connectTimeout) } @@ -61,12 +66,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -76,16 +83,19 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: nil, - clientConfiguration: .init(proxy: .socksServer(host: "127.0.0.1", port: server!.localAddress!.port!)), + clientConfiguration: .init(proxy: .socksServer(host: "127.0.0.1", port: server!.localAddress!.port!)) + .enableFastFailureModeForTesting(), sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - connectionID: 1, - deadline: .now() + .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() + .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { XCTAssertEqual($0 as? HTTPClientError, .socksHandshakeTimeout) } @@ -96,12 +106,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -111,16 +123,19 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: nil, - clientConfiguration: .init(proxy: .server(host: "127.0.0.1", port: server!.localAddress!.port!)), + clientConfiguration: .init(proxy: .server(host: "127.0.0.1", port: server!.localAddress!.port!)) + .enableFastFailureModeForTesting(), sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - connectionID: 1, - deadline: .now() + .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() + .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { XCTAssertEqual($0 as? HTTPClientError, .httpProxyHandshakeTimeout) } @@ -131,12 +146,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -148,16 +165,19 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: nil, - clientConfiguration: .init(tlsConfiguration: tlsConfig), + clientConfiguration: .init(tlsConfiguration: tlsConfig) + .enableFastFailureModeForTesting(), sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - connectionID: 1, - deadline: .now() + .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() + .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { XCTAssertEqual($0 as? HTTPClientError, .tlsHandshakeTimeout) } @@ -171,3 +191,22 @@ class NeverrespondServerHandler: ChannelInboundHandler { // do nothing } } + +/// A `HTTPConnectionRequester` that will fail a test if any of its methods are ever called. +final class ExplodingRequester: HTTPConnectionRequester { + func http1ConnectionCreated(_: HTTP1Connection) { + XCTFail("http1ConnectionCreated called unexpectedly") + } + + func http2ConnectionCreated(_: HTTP2Connection, maximumStreams: Int) { + XCTFail("http2ConnectionCreated called unexpectedly") + } + + func failedToCreateHTTPConnection(_: HTTPConnectionPool.Connection.ID, error: Error) { + XCTFail("failedToCreateHTTPConnection called unexpectedly") + } + + func waitingForConnectivity(_: HTTPConnectionPool.Connection.ID, error: Error) { + XCTFail("waitingForConnectivity called unexpectedly") + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest+XCTest.swift deleted file mode 100644 index 21eb3029e..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest+XCTest.swift +++ /dev/null @@ -1,47 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+HTTP1ConnectionsTest+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_HTTP1ConnectionsTests { - static var allTests: [(String, (HTTPConnectionPool_HTTP1ConnectionsTests) -> () throws -> Void)] { - return [ - ("testCreatingConnections", testCreatingConnections), - ("testCreatingConnectionAndFailing", testCreatingConnectionAndFailing), - ("testLeaseConnectionOnPreferredAndAvailableEL", testLeaseConnectionOnPreferredAndAvailableEL), - ("testLeaseConnectionOnPreferredButUnavailableEL", testLeaseConnectionOnPreferredButUnavailableEL), - ("testLeaseConnectionOnRequiredButUnavailableEL", testLeaseConnectionOnRequiredButUnavailableEL), - ("testLeaseConnectionOnRequiredAndAvailableEL", testLeaseConnectionOnRequiredAndAvailableEL), - ("testCloseConnectionIfIdle", testCloseConnectionIfIdle), - ("testCloseConnectionIfIdleButLeasedRaceCondition", testCloseConnectionIfIdleButLeasedRaceCondition), - ("testCloseConnectionIfIdleButClosedRaceCondition", testCloseConnectionIfIdleButClosedRaceCondition), - ("testShutdown", testShutdown), - ("testMigrationFromHTTP2", testMigrationFromHTTP2), - ("testMigrationFromHTTP2WithPendingRequestsWithRequiredEventLoop", testMigrationFromHTTP2WithPendingRequestsWithRequiredEventLoop), - ("testMigrationFromHTTP2WithPendingRequestsWithPreferredEventLoop", testMigrationFromHTTP2WithPendingRequestsWithPreferredEventLoop), - ("testMigrationFromHTTP2WithAlreadyLeasedHTTP1Connection", testMigrationFromHTTP2WithAlreadyLeasedHTTP1Connection), - ("testMigrationFromHTTP2WithMoreStartingConnectionsThanMaximumAllowedConccurentConnections", testMigrationFromHTTP2WithMoreStartingConnectionsThanMaximumAllowedConccurentConnections), - ("testMigrationFromHTTP2StartsEnoghOverflowConnectionsForRequiredEventLoopRequests", testMigrationFromHTTP2StartsEnoghOverflowConnectionsForRequiredEventLoopRequests), - ("testMigrationFromHTTP1ToHTTP2AndBackToHTTP1", testMigrationFromHTTP1ToHTTP2AndBackToHTTP1), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift index 5afe755a1..914990048 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift @@ -12,15 +12,20 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCreatingConnections() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -52,7 +57,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCreatingConnectionAndFailing() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -103,7 +112,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el3 = elg.next() let el4 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) for el in [el1, el2, el3, el4] { XCTAssertEqual(connections.startingGeneralPurposeConnections, 0) @@ -130,7 +143,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el4 = elg.next() let el5 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) for el in [el1, el2, el3, el4] { XCTAssertEqual(connections.startingGeneralPurposeConnections, 0) @@ -157,7 +174,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el4 = elg.next() let el5 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) for el in [el1, el2, el3, el4] { XCTAssertEqual(connections.startingGeneralPurposeConnections, 0) @@ -181,7 +202,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el1 = elg.next() let el2 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) for el in [el1, el1, el1, el1, el2] { let connID = connections.createNewConnection(on: el) @@ -228,7 +253,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCloseConnectionIfIdle() { let elg = EmbeddedEventLoopGroup(loops: 1) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) let el1 = elg.next() @@ -248,7 +277,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCloseConnectionIfIdleButLeasedRaceCondition() { let elg = EmbeddedEventLoopGroup(loops: 1) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) let el1 = elg.next() @@ -267,7 +300,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCloseConnectionIfIdleButClosedRaceCondition() { let elg = EmbeddedEventLoopGroup(loops: 1) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) let el1 = elg.next() @@ -288,7 +325,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el3 = elg.next() let el4 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) for el in [el1, el2, el3, el4] { let connID = connections.createNewConnection(on: el) @@ -343,7 +384,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -372,7 +417,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2WithPendingRequestsWithRequiredEventLoop() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -408,10 +457,46 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { XCTAssertTrue(context.eventLoop === el3) } + func testMigrationFromHTTP2WithPendingRequestsWithRequiredEventLoopSameAsStartingConnections() { + let elg = EmbeddedEventLoopGroup(loops: 4) + let generator = HTTPConnectionPool.Connection.ID.Generator() + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) + + let el1 = elg.next() + let el2 = elg.next() + + let conn1ID = generator.next() + let conn2ID = generator.next() + + connections.migrateFromHTTP2( + starting: [(conn1ID, el1)], + backingOff: [(conn2ID, el2)] + ) + + let stats = connections.stats + XCTAssertEqual(stats.idle, 0) + XCTAssertEqual(stats.leased, 0) + XCTAssertEqual(stats.connecting, 1) + XCTAssertEqual(stats.backingOff, 1) + + let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) + let (_, context) = connections.newHTTP1ConnectionEstablished(conn1) + XCTAssertEqual(context.use, .generalPurpose) + XCTAssertTrue(context.eventLoop === el1) + } + func testMigrationFromHTTP2WithPendingRequestsWithPreferredEventLoop() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -450,7 +535,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2WithAlreadyLeasedHTTP1Connection() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() let el3 = elg.next() @@ -494,7 +583,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2WithMoreStartingConnectionsThanMaximumAllowedConccurentConnections() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 2, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 2, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -529,7 +622,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2StartsEnoghOverflowConnectionsForRequiredEventLoopRequests() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 1, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 1, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -571,16 +668,23 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el2 = elg.next() let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) let connID1 = connections.createNewConnection(on: el1) let context = connections.migrateToHTTP2() - XCTAssertEqual(context, .init( - backingOff: [], - starting: [(connID1, el1)], - close: [] - )) + XCTAssertEqual( + context, + .init( + backingOff: [], + starting: [(connID1, el1)], + close: [] + ) + ) let connID2 = generator.next() @@ -598,8 +702,7 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { extension HTTPConnectionPool.HTTP1Connections.HTTP1ToHTTP2MigrationContext: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { - return lhs.close == rhs.close && - lhs.starting.elementsEqual(rhs.starting, by: { $0.0 == $1.0 && $0.1 === $1.1 }) && - lhs.backingOff.elementsEqual(rhs.backingOff, by: { $0.0 == $1.0 && $0.1 === $1.1 }) + lhs.close == rhs.close && lhs.starting.elementsEqual(rhs.starting, by: { $0.0 == $1.0 && $0.1 === $1.1 }) + && lhs.backingOff.elementsEqual(rhs.backingOff, by: { $0.0 == $1.0 && $0.1 === $1.1 }) } } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests+XCTest.swift deleted file mode 100644 index 16377d07f..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests+XCTest.swift +++ /dev/null @@ -1,46 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+HTTP1StateTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_HTTP1StateMachineTests { - static var allTests: [(String, (HTTPConnectionPool_HTTP1StateMachineTests) -> () throws -> Void)] { - return [ - ("testCreatingAndFailingConnections", testCreatingAndFailingConnections), - ("testConnectionFailureBackoff", testConnectionFailureBackoff), - ("testCancelRequestWorks", testCancelRequestWorks), - ("testExecuteOnShuttingDownPool", testExecuteOnShuttingDownPool), - ("testRequestsAreQueuedIfAllConnectionsAreInUseAndRequestsAreDequeuedInOrder", testRequestsAreQueuedIfAllConnectionsAreInUseAndRequestsAreDequeuedInOrder), - ("testBestConnectionIsPicked", testBestConnectionIsPicked), - ("testConnectionAbortIsIgnoredIfThereAreNoQueuedRequests", testConnectionAbortIsIgnoredIfThereAreNoQueuedRequests), - ("testConnectionCloseLeadsToTumbleWeedIfThereNoQueuedRequests", testConnectionCloseLeadsToTumbleWeedIfThereNoQueuedRequests), - ("testConnectionAbortLeadsToNewConnectionsIfThereAreQueuedRequests", testConnectionAbortLeadsToNewConnectionsIfThereAreQueuedRequests), - ("testParkedConnectionTimesOut", testParkedConnectionTimesOut), - ("testConnectionPoolFullOfParkedConnectionsIsShutdownImmediately", testConnectionPoolFullOfParkedConnectionsIsShutdownImmediately), - ("testParkedConnectionTimesOutButIsAlsoClosedByRemote", testParkedConnectionTimesOutButIsAlsoClosedByRemote), - ("testConnectionBackoffVsShutdownRace", testConnectionBackoffVsShutdownRace), - ("testRequestThatTimesOutIsFailedWithLastConnectionCreationError", testRequestThatTimesOutIsFailedWithLastConnectionCreationError), - ("testRequestThatTimesOutBeforeAConnectionIsEstablishedIsFailedWithConnectTimeoutError", testRequestThatTimesOutBeforeAConnectionIsEstablishedIsFailedWithConnectTimeoutError), - ("testRequestThatTimesOutAfterAConnectionWasEstablishedSuccessfullyTimesOutWithGenericError", testRequestThatTimesOutAfterAConnectionWasEstablishedSuccessfullyTimesOutWithGenericError), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift index 49a6fb574..2be6cfa26 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift @@ -12,21 +12,26 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { func testCreatingAndFailingConnections() { + struct SomeError: Error, Equatable {} let elg = EmbeddedEventLoopGroup(loops: 4) defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 8 + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil ) var connections = MockConnectionPool() @@ -35,7 +40,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // for the first eight requests, the pool should try to create new connections. for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .createConnection(let connectionID, let connectionEL) = action.connection else { @@ -51,7 +56,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // the next eight requests should only be queued. for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .none = action.connection else { @@ -65,8 +70,6 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // fail all connection attempts while let randomConnectionID = connections.randomStartingConnection() { - struct SomeError: Error, Equatable {} - XCTAssertNoThrow(try connections.failConnectionCreation(randomConnectionID)) let action = state.failedToCreateNewConnection(SomeError(), connectionID: randomConnectionID) @@ -86,9 +89,9 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // cancel all queued requests while let request = queuer.timeoutRandomRequest() { - let cancelAction = state.cancelRequest(request) + let cancelAction = state.cancelRequest(request.0) XCTAssertEqual(cancelAction.connection, .none) - XCTAssertEqual(cancelAction.request, .cancelRequestTimeout(request)) + XCTAssertEqual(cancelAction.request, .failRequest(.init(request.1), SomeError(), cancelTimeout: true)) } // connection backoff done @@ -103,16 +106,89 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssert(connections.isEmpty) } + func testCreatingAndFailingConnectionsWithoutRetry() { + struct SomeError: Error, Equatable {} + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: false, + preferHTTP1: true, + maximumConnectionUses: nil + ) + + var connections = MockConnectionPool() + var queuer = MockRequestQueuer() + + // for the first eight requests, the pool should try to create new connections. + + for _ in 0..<8 { + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + let action = state.executeRequest(request) + guard case .createConnection(let connectionID, let connectionEL) = action.connection else { + return XCTFail("Unexpected connection action") + } + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + XCTAssert(connectionEL === mockRequest.eventLoop) + + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + XCTAssertNoThrow(try queuer.queue(mockRequest, id: request.id)) + } + + // the next eight requests should only be queued. + + for _ in 0..<8 { + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + let action = state.executeRequest(request) + guard case .none = action.connection else { + return XCTFail("Unexpected connection action") + } + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + XCTAssertNoThrow(try queuer.queue(mockRequest, id: request.id)) + } + + // the first failure should cancel all requests because we have disabled connection establishtment retry + let randomConnectionID = connections.randomStartingConnection()! + XCTAssertNoThrow(try connections.failConnectionCreation(randomConnectionID)) + let action = state.failedToCreateNewConnection(SomeError(), connectionID: randomConnectionID) + XCTAssertEqual(action.connection, .none) + guard case .failRequestsAndCancelTimeouts(let requestsToFail, let requestError) = action.request else { + return XCTFail("Unexpected request action: \(action.request)") + } + XCTAssertEqualTypeAndValue(requestError, SomeError()) + for requestToFail in requestsToFail { + XCTAssertNoThrow(try queuer.fail(requestToFail.id, request: requestToFail.__testOnly_wrapped_request())) + } + + // all requests have been canceled and therefore nothing should happen if a connection fails + while let randomConnectionID = connections.randomStartingConnection() { + XCTAssertNoThrow(try connections.failConnectionCreation(randomConnectionID)) + let action = state.failedToCreateNewConnection(SomeError(), connectionID: randomConnectionID) + + XCTAssertEqual(action, .none) + } + + XCTAssert(queuer.isEmpty) + XCTAssert(connections.isEmpty) + } + func testConnectionFailureBackoff() { let elg = EmbeddedEventLoopGroup(loops: 4) defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 2 + maximumConcurrentHTTP1Connections: 2, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -122,9 +198,12 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { return XCTFail("Unexpected connection action: \(action.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux - let failedConnect1 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: connectionID) + let failedConnect1 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: connectionID + ) XCTAssertEqual(failedConnect1.request, .none) guard case .scheduleBackoffTimer(connectionID, let backoffTimeAmount1, _) = failedConnect1.connection else { return XCTFail("Unexpected connection action: \(failedConnect1.connection)") @@ -137,9 +216,12 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { return XCTFail("Unexpected connection action: \(backoffDoneAction.connection)") } XCTAssertGreaterThan(newConnectionID, connectionID) - XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux - let failedConnect2 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: newConnectionID) + let failedConnect2 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: newConnectionID + ) XCTAssertEqual(failedConnect2.request, .none) guard case .scheduleBackoffTimer(newConnectionID, let backoffTimeAmount2, _) = failedConnect2.connection else { return XCTFail("Unexpected connection action: \(failedConnect2.connection)") @@ -152,7 +234,9 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .failRequest(let requestToFail, let requestError, cancelTimeout: false) = failRequest.request else { return XCTFail("Unexpected request action: \(action.request)") } - XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + + // XCTAssertIdentical not available on Linux + XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) XCTAssertEqual(requestError as? HTTPClientError, .connectTimeout) XCTAssertEqual(failRequest.connection, .none) @@ -166,10 +250,13 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 2 + maximumConcurrentHTTP1Connections: 2, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -179,19 +266,21 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. cancel request let cancelAction = state.cancelRequest(request.id) - XCTAssertEqual(cancelAction.request, .cancelRequestTimeout(request.id)) + XCTAssertEqual(cancelAction.request, .failRequest(request, HTTPClientError.cancelled, cancelTimeout: true)) XCTAssertEqual(cancelAction.connection, .none) // 3. request timeout triggers to late XCTAssertEqual(state.timeoutRequest(request.id), .none, "To late timeout is ignored") // 4. succeed connection attempt - let connectedAction = state.newHTTP1ConnectionCreated(.__testOnly_connection(id: connectionID, eventLoop: connectionEL)) + let connectedAction = state.newHTTP1ConnectionCreated( + .__testOnly_connection(id: connectionID, eventLoop: connectionEL) + ) XCTAssertEqual(connectedAction.request, .none, "Request must not be executed") XCTAssertEqual(connectedAction.connection, .scheduleTimeoutTimer(connectionID, on: connectionEL)) } @@ -202,10 +291,13 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 2 + maximumConcurrentHTTP1Connections: 2, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -215,15 +307,18 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. connection succeeds - let connection: HTTPConnectionPool.Connection = .__testOnly_connection(id: connectionID, eventLoop: connectionEL) + let connection: HTTPConnectionPool.Connection = .__testOnly_connection( + id: connectionID, + eventLoop: connectionEL + ) let connectedAction = state.newHTTP1ConnectionCreated(connection) guard case .executeRequest(request, connection, cancelTimeout: true) = connectedAction.request else { return XCTFail("Unexpected request action: \(connectedAction.request)") } - XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux XCTAssertEqual(connectedAction.connection, .none) // 3. shutdown @@ -239,11 +334,14 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(cleanupContext.connectBackoff, []) // 4. execute another request - let finalMockRequest = MockHTTPRequest(eventLoop: elg.next()) + let finalMockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let finalRequest = HTTPConnectionPool.Request(finalMockRequest) let failAction = state.executeRequest(finalRequest) XCTAssertEqual(failAction.connection, .none) - XCTAssertEqual(failAction.request, .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false)) + XCTAssertEqual( + failAction.request, + .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false) + ) // 5. close open connection let closeAction = state.http1ConnectionClosed(connectionID) @@ -264,16 +362,20 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // Add eight requests to fill all connections for _ in 0..<8 { let eventLoop = elg.next() - guard let expectedConnection = connections.newestParkedConnection(for: eventLoop) ?? connections.newestParkedConnection else { + guard + let expectedConnection = connections.newestParkedConnection(for: eventLoop) + ?? connections.newestParkedConnection + else { return XCTFail("Expected to still have connections available") } - let mockRequest = MockHTTPRequest(eventLoop: eventLoop) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) XCTAssertEqual(action.connection, .cancelTimeoutTimer(expectedConnection.id)) - guard case .executeRequest(let returnedRequest, expectedConnection, cancelTimeout: false) = action.request else { + guard case .executeRequest(let returnedRequest, expectedConnection, cancelTimeout: false) = action.request + else { return XCTFail("Expected to execute a request next, but got: \(action.request)") } @@ -288,7 +390,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var queuer = MockRequestQueuer() for _ in 0..<100 { let eventLoop = elg.next() - let mockRequest = MockHTTPRequest(eventLoop: eventLoop, requiresEventLoopForChannel: false) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop, requiresEventLoopForChannel: false) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -347,7 +449,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // 10% of the cases enforce the eventLoop let elRequired = (0..<10).randomElement().flatMap { $0 == 0 ? true : false }! - let mockRequest = MockHTTPRequest(eventLoop: reqEventLoop, requiresEventLoopForChannel: elRequired) + let mockRequest = MockHTTPScheduableRequest( + eventLoop: reqEventLoop, + requiresEventLoopForChannel: elRequired + ) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -359,7 +464,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssert(connEventLoop === reqEventLoop) XCTAssertEqual(action.request, .scheduleRequestTimeout(for: request, on: reqEventLoop)) - let connection: HTTPConnectionPool.Connection = .__testOnly_connection(id: connectionID, eventLoop: connEventLoop) + let connection: HTTPConnectionPool.Connection = .__testOnly_connection( + id: connectionID, + eventLoop: connEventLoop + ) let createdAction = state.newHTTP1ConnectionCreated(connection) XCTAssertEqual(createdAction.request, .executeRequest(request, connection, cancelTimeout: true)) XCTAssertEqual(createdAction.connection, .none) @@ -370,7 +478,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(state.http1ConnectionClosed(connectionID), .none) case .cancelTimeoutTimer(let connectionID): - guard let expectedConnection = connections.newestParkedConnection(for: reqEventLoop) ?? connections.newestParkedConnection else { + guard + let expectedConnection = connections.newestParkedConnection(for: reqEventLoop) + ?? connections.newestParkedConnection + else { return XCTFail("Expected to have connections available") } @@ -378,7 +489,11 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssert(expectedConnection.eventLoop === reqEventLoop) } - XCTAssertEqual(connectionID, expectedConnection.id, "Request is scheduled on the connection we expected") + XCTAssertEqual( + connectionID, + expectedConnection.id, + "Request is scheduled on the connection we expected" + ) XCTAssertNoThrow(try connections.activateConnection(connectionID)) guard case .executeRequest(let request, let connection, cancelTimeout: false) = action.request else { @@ -388,8 +503,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.execute(request.__testOnly_wrapped_request(), on: connection)) XCTAssertNoThrow(try connections.finishExecution(connection.id)) - XCTAssertEqual(state.http1ConnectionReleased(connection.id), - .init(request: .none, connection: .scheduleTimeoutTimer(connection.id, on: connection.eventLoop))) + XCTAssertEqual( + state.http1ConnectionReleased(connection.id), + .init(request: .none, connection: .scheduleTimeoutTimer(connection.id, on: connection.eventLoop)) + ) XCTAssertNoThrow(try connections.parkConnection(connectionID)) default: @@ -411,7 +528,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(connections.parked, 8) // close a leased connection == abort - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) guard let connectionToAbort = connections.newestParkedConnection else { return XCTFail("Expected to have a parked connection") @@ -461,11 +578,14 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // Add eight requests to fill all connections for _ in 0..<8 { let eventLoop = elg.next() - guard let expectedConnection = connections.newestParkedConnection(for: eventLoop) ?? connections.newestParkedConnection else { + guard + let expectedConnection = connections.newestParkedConnection(for: eventLoop) + ?? connections.newestParkedConnection + else { return XCTFail("Expected to still have connections available") } - let mockRequest = MockHTTPRequest(eventLoop: eventLoop) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -482,7 +602,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { for _ in 0..<100 { let eventLoop = elg.next() - let mockRequest = MockHTTPRequest(eventLoop: eventLoop, requiresEventLoopForChannel: false) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop, requiresEventLoopForChannel: false) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -508,12 +628,20 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard let newConnection = maybeNewConnection else { return XCTFail("Expected to get a new connection") } let afterRecreationAction = state.newHTTP1ConnectionCreated(newConnection) XCTAssertEqual(afterRecreationAction.connection, .none) - guard case .executeRequest(let request, newConnection, cancelTimeout: true) = afterRecreationAction.request else { + guard + case .executeRequest(let request, newConnection, cancelTimeout: true) = afterRecreationAction + .request + else { return XCTFail("Unexpected request action: \(action.request)") } XCTAssertEqual(request.id, queuedRequestsOrder.popFirst()) - XCTAssertNoThrow(try connections.execute(queuer.get(request.id, request: request.__testOnly_wrapped_request()), on: newConnection)) + XCTAssertNoThrow( + try connections.execute( + queuer.get(request.id, request: request.__testOnly_wrapped_request()), + on: newConnection + ) + ) case .none: XCTAssert(queuer.isEmpty) @@ -592,10 +720,13 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 6 + maximumConcurrentHTTP1Connections: 6, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -630,10 +761,13 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 6 + maximumConcurrentHTTP1Connections: 6, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -643,7 +777,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(executeAction.request, .scheduleRequestTimeout(for: request, on: mockRequest.eventLoop)) - let failAction = state.failedToCreateNewConnection(HTTPClientError.httpProxyHandshakeTimeout, connectionID: connectionID) + let failAction = state.failedToCreateNewConnection( + HTTPClientError.httpProxyHandshakeTimeout, + connectionID: connectionID + ) guard case .scheduleBackoffTimer(connectionID, backoff: _, on: let timerEL) = failAction.connection else { return XCTFail("Expected to create a backoff timer") } @@ -651,7 +788,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(failAction.request, .none) let timeoutAction = state.timeoutRequest(request.id) - XCTAssertEqual(timeoutAction.request, .failRequest(request, HTTPClientError.httpProxyHandshakeTimeout, cancelTimeout: false)) + XCTAssertEqual( + timeoutAction.request, + .failRequest(request, HTTPClientError.httpProxyHandshakeTimeout, cancelTimeout: false) + ) XCTAssertEqual(timeoutAction.connection, .none) } @@ -661,10 +801,13 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 6 + maximumConcurrentHTTP1Connections: 6, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: eventLoop.next(), requiresEventLoopForChannel: false) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop.next(), requiresEventLoopForChannel: false) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -674,7 +817,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(executeAction.request, .scheduleRequestTimeout(for: request, on: mockRequest.eventLoop)) let timeoutAction = state.timeoutRequest(request.id) - XCTAssertEqual(timeoutAction.request, .failRequest(request, HTTPClientError.connectTimeout, cancelTimeout: false)) + XCTAssertEqual( + timeoutAction.request, + .failRequest(request, HTTPClientError.connectTimeout, cancelTimeout: false) + ) XCTAssertEqual(timeoutAction.connection, .none) } @@ -684,10 +830,13 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 6 + maximumConcurrentHTTP1Connections: 6, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil ) - let mockRequest1 = MockHTTPRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) let request1 = HTTPConnectionPool.Request(mockRequest1) let executeAction1 = state.executeRequest(request1) @@ -698,7 +847,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(executeAction1.request, .scheduleRequestTimeout(for: request1, on: mockRequest1.eventLoop)) - let mockRequest2 = MockHTTPRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) + let mockRequest2 = MockHTTPScheduableRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) let request2 = HTTPConnectionPool.Request(mockRequest2) let executeAction2 = state.executeRequest(request2) @@ -709,7 +858,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(executeAction2.request, .scheduleRequestTimeout(for: request2, on: connEL1)) - let failAction = state.failedToCreateNewConnection(HTTPClientError.httpProxyHandshakeTimeout, connectionID: connectionID1) + let failAction = state.failedToCreateNewConnection( + HTTPClientError.httpProxyHandshakeTimeout, + connectionID: connectionID1 + ) guard case .scheduleBackoffTimer(connectionID1, backoff: _, on: let timerEL) = failAction.connection else { return XCTFail("Expected to create a backoff timer") } @@ -723,7 +875,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(createdAction.connection, .none) let timeoutAction = state.timeoutRequest(request2.id) - XCTAssertEqual(timeoutAction.request, .failRequest(request2, HTTPClientError.getConnectionFromPoolTimeout, cancelTimeout: false)) + XCTAssertEqual( + timeoutAction.request, + .failRequest(request2, HTTPClientError.getConnectionFromPoolTimeout, cancelTimeout: false) + ) XCTAssertEqual(timeoutAction.connection, .none) } } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest+XCTest.swift deleted file mode 100644 index 95cade669..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest+XCTest.swift +++ /dev/null @@ -1,49 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+HTTP2ConnectionsTest+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_HTTP2ConnectionsTests { - static var allTests: [(String, (HTTPConnectionPool_HTTP2ConnectionsTests) -> () throws -> Void)] { - return [ - ("testCreatingConnections", testCreatingConnections), - ("testCreatingConnectionAndFailing", testCreatingConnectionAndFailing), - ("testFailConnectionRace", testFailConnectionRace), - ("testLeaseConnectionOfPreferredButUnavailableEL", testLeaseConnectionOfPreferredButUnavailableEL), - ("testLeaseConnectionOnRequiredButUnavailableEL", testLeaseConnectionOnRequiredButUnavailableEL), - ("testCloseConnectionIfIdle", testCloseConnectionIfIdle), - ("testCloseConnectionIfIdleButLeasedRaceCondition", testCloseConnectionIfIdleButLeasedRaceCondition), - ("testCloseConnectionIfIdleButClosedRaceCondition", testCloseConnectionIfIdleButClosedRaceCondition), - ("testCloseConnectionIfIdleRace", testCloseConnectionIfIdleRace), - ("testShutdown", testShutdown), - ("testLeasingAllConnections", testLeasingAllConnections), - ("testGoAway", testGoAway), - ("testNewMaxConcurrentStreamsSetting", testNewMaxConcurrentStreamsSetting), - ("testEventsAfterConnectionIsClosed", testEventsAfterConnectionIsClosed), - ("testLeaseOnPreferredEventLoopWithoutAnyAvailable", testLeaseOnPreferredEventLoopWithoutAnyAvailable), - ("testMigrationFromHTTP1", testMigrationFromHTTP1), - ("testMigrationToHTTP1", testMigrationToHTTP1), - ("testMigrationFromHTTP1WithPendingRequestsWithRequiredEventLoop", testMigrationFromHTTP1WithPendingRequestsWithRequiredEventLoop), - ("testMigrationFromHTTP1WithAlreadyEstablishedHTTP2Connection", testMigrationFromHTTP1WithAlreadyEstablishedHTTP2Connection), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift index 9e9ca1df6..dd56a9102 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift @@ -12,15 +12,16 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testCreatingConnections() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() @@ -32,7 +33,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el1)) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) XCTAssertEqual(conn1CreatedContext.isIdle, true) XCTAssert(conn1CreatedContext.eventLoop === el1) @@ -46,7 +50,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let conn2ID = connections.createNewConnection(on: el2) XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el2)) let conn2: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn2ID, eventLoop: el2) - let (conn2Index, conn2CreatedContext) = connections.newHTTP2ConnectionEstablished(conn2, maxConcurrentStreams: 100) + let (conn2Index, conn2CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn2, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) XCTAssertTrue(conn1CreatedContext.isIdle) XCTAssert(conn2CreatedContext.eventLoop === el2) @@ -59,7 +66,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testCreatingConnectionAndFailing() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() @@ -83,7 +90,9 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssert(conn1FailContext.eventLoop === el1) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el1)) - let (replaceConn1ID, replaceConn1EL) = connections.createNewConnectionByReplacingClosedConnection(at: conn1FailIndex) + let (replaceConn1ID, replaceConn1EL) = connections.createNewConnectionByReplacingClosedConnection( + at: conn1FailIndex + ) XCTAssert(replaceConn1EL === el1) XCTAssertEqual(replaceConn1ID, 1) XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) @@ -108,7 +117,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el1 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) // connection is idle let conn1ID = connections.createNewConnection(on: el1) @@ -130,7 +139,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el4 = elg.next() let el5 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) for el in [el1, el2, el3, el4] { XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el)) @@ -155,7 +164,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el4 = elg.next() let el5 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) for el in [el1, el2, el3, el4] { XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el)) @@ -177,7 +186,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el1 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) // connection is idle let conn1ID = connections.createNewConnection(on: el1) @@ -201,7 +210,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el1 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) // connection is idle let conn1ID = connections.createNewConnection(on: el1) @@ -224,7 +233,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el1 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) // connection is idle let conn1ID = connections.createNewConnection(on: el1) @@ -241,7 +250,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el1 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) // connection is idle let conn1ID = connections.createNewConnection(on: el1) @@ -268,7 +277,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el5 = elg.next() let el6 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) for el in [el1, el2, el3, el4] { XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el)) @@ -331,18 +340,24 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testLeasingAllConnections() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 100) XCTAssertEqual(leasedConn1, conn1) XCTAssertEqual(leasdConnContext1.wasIdle, true) - XCTAssertNil(connections.leaseStream(onRequired: el1), "should not be able to lease stream because they are all already leased") + XCTAssertNil( + connections.leaseStream(onRequired: el1), + "should not be able to lease stream because they are all already leased" + ) let (_, releaseContext) = connections.releaseStream(conn1ID) XCTAssertFalse(releaseContext.isIdle) @@ -354,17 +369,23 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertEqual(leasedConn, conn1) XCTAssertEqual(leaseContext.wasIdle, false) - XCTAssertNil(connections.leaseStream(onRequired: el1), "should not be able to lease stream because they are all already leased") + XCTAssertNil( + connections.leaseStream(onRequired: el1), + "should not be able to lease stream because they are all already leased" + ) } func testGoAway() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 10) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 10 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 10) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 2) @@ -386,7 +407,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { ) ) - XCTAssertNil(connections.leaseStream(onRequired: el1), "we should not be able to lease a stream because the connection is draining") + XCTAssertNil( + connections.leaseStream(onRequired: el1), + "we should not be able to lease a stream because the connection is draining" + ) // a server can potentially send more than one connection go away and we should not crash XCTAssertTrue(connections.goAwayReceived(conn1ID)?.eventLoop === el1) @@ -440,12 +464,15 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testNewMaxConcurrentStreamsSetting() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 1) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 1 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 1) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 1) @@ -454,7 +481,8 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertNil(connections.leaseStream(onRequired: el1), "all streams are in use") - guard let (_, newSettingsContext1) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 2) else { + guard let (_, newSettingsContext1) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 2) + else { return XCTFail("Expected to get a new settings context") } XCTAssertEqual(newSettingsContext1.availableStreams, 1) @@ -467,7 +495,8 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertEqual(leasedConn2, conn1) XCTAssertEqual(leaseContext2.wasIdle, false) - guard let (_, newSettingsContext2) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 1) else { + guard let (_, newSettingsContext2) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 1) + else { return XCTFail("Expected to get a new settings context") } XCTAssertEqual(newSettingsContext2.availableStreams, 0) @@ -495,12 +524,15 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testEventsAfterConnectionIsClosed() { let elg = EmbeddedEventLoopGroup(loops: 2) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 1) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 1 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 1) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 1) @@ -530,12 +562,15 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testLeaseOnPreferredEventLoopWithoutAnyAvailable() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 1) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 1 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 1) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 1) XCTAssertEqual(leasedConn1, conn1) @@ -546,7 +581,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testMigrationFromHTTP1() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() let conn1ID: HTTPConnectionPool.Connection.ID = 1 @@ -556,9 +591,11 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { starting: [(conn1ID, el1)], backingOff: [(conn2ID, el2)] ) - XCTAssertTrue(connections.createConnectionsAfterMigrationIfNeeded( - requiredEventLoopsOfPendingRequests: [el1, el2] - ).isEmpty) + XCTAssertTrue( + connections.createConnectionsAfterMigrationIfNeeded( + requiredEventLoopsOfPendingRequests: [el1, el2] + ).isEmpty + ) XCTAssertEqual( connections.stats, @@ -574,7 +611,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { ) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 2) @@ -598,7 +638,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testMigrationToHTTP1() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP2Connections(generator: generator) + var connections = HTTPConnectionPool.HTTP2Connections(generator: generator, maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() let el3 = elg.next() @@ -615,7 +655,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { ) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 2) @@ -663,7 +706,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testMigrationFromHTTP1WithPendingRequestsWithRequiredEventLoop() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP2Connections(generator: generator) + var connections = HTTPConnectionPool.HTTP2Connections(generator: generator, maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() let el3 = elg.next() @@ -696,7 +739,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testMigrationFromHTTP1WithAlreadyEstablishedHTTP2Connection() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP2Connections(generator: generator) + var connections = HTTPConnectionPool.HTTP2Connections(generator: generator, maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() let el3 = elg.next() @@ -714,9 +757,12 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { backingOff: [(conn3ID, el3)] ) - XCTAssertTrue(connections.createConnectionsAfterMigrationIfNeeded( - requiredEventLoopsOfPendingRequests: [el1, el2, el3] - ).isEmpty, "we still have an active connection for el1 and should not create a new one") + XCTAssertTrue( + connections.createConnectionsAfterMigrationIfNeeded( + requiredEventLoopsOfPendingRequests: [el1, el2, el3] + ).isEmpty, + "we still have an active connection for el1 and should not create a new one" + ) guard let (leasedConn, _) = connections.leaseStream(onRequired: el1) else { return XCTFail("could not lease stream on el1") diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests+XCTest.swift deleted file mode 100644 index 9dca0c934..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests+XCTest.swift +++ /dev/null @@ -1,49 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+HTTP2StateMachineTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_HTTP2StateMachineTests { - static var allTests: [(String, (HTTPConnectionPool_HTTP2StateMachineTests) -> () throws -> Void)] { - return [ - ("testCreatingOfConnection", testCreatingOfConnection), - ("testConnectionFailureBackoff", testConnectionFailureBackoff), - ("testCancelRequestWorks", testCancelRequestWorks), - ("testExecuteOnShuttingDownPool", testExecuteOnShuttingDownPool), - ("testHTTP1ToHTTP2MigrationAndShutdownIfFirstConnectionIsHTTP1", testHTTP1ToHTTP2MigrationAndShutdownIfFirstConnectionIsHTTP1), - ("testSchedulingAndCancelingOfIdleTimeout", testSchedulingAndCancelingOfIdleTimeout), - ("testConnectionTimeout", testConnectionTimeout), - ("testConnectionEstablishmentFailure", testConnectionEstablishmentFailure), - ("testGoAwayOnIdleConnection", testGoAwayOnIdleConnection), - ("testGoAwayWithLeasedStream", testGoAwayWithLeasedStream), - ("testGoAwayWithPendingRequestsStartsNewConnection", testGoAwayWithPendingRequestsStartsNewConnection), - ("testMigrationFromHTTP1ToHTTP2", testMigrationFromHTTP1ToHTTP2), - ("testMigrationFromHTTP1ToHTTP2WhileShuttingDown", testMigrationFromHTTP1ToHTTP2WhileShuttingDown), - ("testMigrationFromHTTP1ToHTTP2WithAlreadyStartedHTTP1Connections", testMigrationFromHTTP1ToHTTP2WithAlreadyStartedHTTP1Connections), - ("testHTTP2toHTTP1Migration", testHTTP2toHTTP1Migration), - ("testHTTP2toHTTP1MigrationDuringShutdown", testHTTP2toHTTP1MigrationDuringShutdown), - ("testConnectionIsImmediatelyCreatedAfterBackoffTimerFires", testConnectionIsImmediatelyCreatedAfterBackoffTimerFires), - ("testMaxConcurrentStreamsIsRespected", testMaxConcurrentStreamsIsRespected), - ("testEventsAfterConnectionIsClosed", testEventsAfterConnectionIsClosed), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift index 825ffc9b3..e64fd5e71 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + private typealias Action = HTTPConnectionPool.StateMachine.Action private typealias ConnectionAction = HTTPConnectionPool.StateMachine.ConnectionAction private typealias RequestAction = HTTPConnectionPool.StateMachine.RequestAction @@ -29,10 +30,15 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el1 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: .init(), lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: .init(), + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) /// first request should create a new connection - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -48,7 +54,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { /// subsequent requests should not create a connection for _ in 0..<9 { - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -99,7 +105,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { /// 4 streams are available and therefore request should be executed immediately for _ in 0..<4 { - let mockRequest = MockHTTPRequest(eventLoop: el1, requiresEventLoopForChannel: true) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1, requiresEventLoopForChannel: true) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -122,14 +128,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { /// shutdown should only close one connection let shutdownAction = state.shutdown() XCTAssertEqual(shutdownAction.request, .none) - XCTAssertEqual(shutdownAction.connection, .cleanupConnections( - .init( - close: [conn], - cancel: [], - connectBackoff: [] - ), - isShutdown: .yes(unclean: false) - )) + XCTAssertEqual( + shutdownAction.connection, + .cleanupConnections( + .init( + close: [conn], + cancel: [], + connectBackoff: [] + ), + isShutdown: .yes(unclean: false) + ) + ) } func testConnectionFailureBackoff() { @@ -138,10 +147,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: .init(), - lifecycleState: .running + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -151,9 +162,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { return XCTFail("Unexpected connection action: \(action.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux - let failedConnect1 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: connectionID) + let failedConnect1 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: connectionID + ) XCTAssertEqual(failedConnect1.request, .none) guard case .scheduleBackoffTimer(connectionID, let backoffTimeAmount1, _) = failedConnect1.connection else { return XCTFail("Unexpected connection action: \(failedConnect1.connection)") @@ -166,9 +180,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { return XCTFail("Unexpected connection action: \(backoffDoneAction.connection)") } XCTAssertGreaterThan(newConnectionID, connectionID) - XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux - let failedConnect2 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: newConnectionID) + let failedConnect2 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: newConnectionID + ) XCTAssertEqual(failedConnect2.request, .none) guard case .scheduleBackoffTimer(newConnectionID, let backoffTimeAmount2, _) = failedConnect2.connection else { return XCTFail("Unexpected connection action: \(failedConnect2.connection)") @@ -181,7 +198,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .failRequest(let requestToFail, let requestError, cancelTimeout: false) = failRequest.request else { return XCTFail("Unexpected request action: \(action.request)") } - XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + // XCTAssertIdentical not available on Linux + XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) XCTAssertEqual(requestError as? HTTPClientError, .connectTimeout) XCTAssertEqual(failRequest.connection, .none) @@ -189,16 +207,91 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertEqual(state.connectionCreationBackoffDone(newConnectionID), .none) } + func testConnectionFailureWhileShuttingDown() { + struct SomeError: Error, Equatable {} + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: .init(), + retryConnectionEstablishment: false, + lifecycleState: .running, + maximumConnectionUses: nil + ) + + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + + let action = state.executeRequest(request) + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + + // 1. connection attempt + guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { + return XCTFail("Unexpected connection action: \(action.connection)") + } + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + + // 2. initialise shutdown + let shutdownAction = state.shutdown() + XCTAssertEqual(shutdownAction.connection, .cleanupConnections(.init(), isShutdown: .no)) + guard case .failRequestsAndCancelTimeouts(let requestsToFail, let requestError) = shutdownAction.request else { + return XCTFail("Unexpected request action: \(action.request)") + } + XCTAssertEqualTypeAndValue(requestError, HTTPClientError.cancelled) + XCTAssertEqualTypeAndValue(requestsToFail, [request]) + + // 3. connection attempt fails + let failedConnectAction = state.failedToCreateNewConnection(SomeError(), connectionID: connectionID) + XCTAssertEqual(failedConnectAction.request, .none) + XCTAssertEqual(failedConnectAction.connection, .cleanupConnections(.init(), isShutdown: .yes(unclean: true))) + } + + func testConnectionFailureWithoutRetry() { + struct SomeError: Error, Equatable {} + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: .init(), + retryConnectionEstablishment: false, + lifecycleState: .running, + maximumConnectionUses: nil + ) + + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + + let action = state.executeRequest(request) + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + + // 1. connection attempt + guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { + return XCTFail("Unexpected connection action: \(action.connection)") + } + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + + let failedConnectAction = state.failedToCreateNewConnection(SomeError(), connectionID: connectionID) + XCTAssertEqual(failedConnectAction.connection, .none) + guard case .failRequestsAndCancelTimeouts(let requestsToFail, let requestError) = failedConnectAction.request + else { + return XCTFail("Unexpected request action: \(action.request)") + } + XCTAssertEqualTypeAndValue(requestError, SomeError()) + XCTAssertEqualTypeAndValue(requestsToFail, [request]) + } + func testCancelRequestWorks() { let elg = EmbeddedEventLoopGroup(loops: 4) defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: .init(), - lifecycleState: .running + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -208,11 +301,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. cancel request let cancelAction = state.cancelRequest(request.id) - XCTAssertEqual(cancelAction.request, .cancelRequestTimeout(request.id)) + XCTAssertEqual(cancelAction.request, .failRequest(request, HTTPClientError.cancelled, cancelTimeout: true)) XCTAssertEqual(cancelAction.connection, .none) // 3. request timeout triggers to late @@ -233,10 +326,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: .init(), - lifecycleState: .running + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -246,15 +341,18 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. connection succeeds - let connection: HTTPConnectionPool.Connection = .__testOnly_connection(id: connectionID, eventLoop: connectionEL) + let connection: HTTPConnectionPool.Connection = .__testOnly_connection( + id: connectionID, + eventLoop: connectionEL + ) let connectedAction = state.newHTTP2ConnectionEstablished(connection, maxConcurrentStreams: 100) guard case .executeRequestsAndCancelTimeouts([request], connection) = connectedAction.request else { return XCTFail("Unexpected request action: \(connectedAction.request)") } - XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux XCTAssertEqual(connectedAction.connection, .none) // 3. shutdown @@ -270,11 +368,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertEqual(cleanupContext.connectBackoff, []) // 4. execute another request - let finalMockRequest = MockHTTPRequest(eventLoop: elg.next()) + let finalMockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let finalRequest = HTTPConnectionPool.Request(finalMockRequest) let failAction = state.executeRequest(finalRequest) XCTAssertEqual(failAction.connection, .none) - XCTAssertEqual(failAction.request, .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false)) + XCTAssertEqual( + failAction.request, + .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false) + ) // 5. close open connection let closeAction = state.http2ConnectionClosed(connectionID) @@ -287,11 +388,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el1 = elg.next() let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1State = HTTPConnectionPool.HTTP1StateMachine(idGenerator: idGenerator, maximumConcurrentConnections: 8, lifecycleState: .running) + var http1State = HTTPConnectionPool.HTTP1StateMachine( + idGenerator: idGenerator, + maximumConcurrentConnections: 8, + retryConnectionEstablishment: true, + maximumConnectionUses: nil, + lifecycleState: .running + ) - let mockRequest1 = MockHTTPRequest(eventLoop: el1) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest1) - let mockRequest2 = MockHTTPRequest(eventLoop: el1) + let mockRequest2 = MockHTTPScheduableRequest(eventLoop: el1) let request2 = HTTPConnectionPool.Request(mockRequest2) let executeAction1 = http1State.executeRequest(request1) @@ -313,7 +420,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // second connection is a HTTP2 connection and we need to migrate let conn2: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn2ID, eventLoop: el1) - var http2State = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var http2State = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let http2ConnectAction = http2State.migrateFromHTTP1( http1Connections: http1State.connections, @@ -322,7 +434,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { newHTTP2Connection: conn2, maxConcurrentStreams: 100 ) - XCTAssertEqual(http2ConnectAction.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual( + http2ConnectAction.connection, + .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil) + ) guard case .executeRequestsAndCancelTimeouts([request2], conn2) = http2ConnectAction.request else { return XCTFail("Unexpected request action \(http2ConnectAction.request)") } @@ -334,11 +449,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let shutdownAction = http2State.shutdown() XCTAssertEqual(shutdownAction.request, .none) - XCTAssertEqual(shutdownAction.connection, .cleanupConnections(.init( - close: [conn2], - cancel: [], - connectBackoff: [] - ), isShutdown: .no)) + XCTAssertEqual( + shutdownAction.connection, + .cleanupConnections( + .init( + close: [conn2], + cancel: [], + connectBackoff: [] + ), + isShutdown: .no + ) + ) let releaseAction = http2State.http1ConnectionReleased(conn1ID) XCTAssertEqual(releaseAction.request, .none) @@ -351,22 +472,39 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) - let connectAction = state.migrateFromHTTP1(http1Connections: http1Conns, requests: .init(), newHTTP2Connection: conn1, maxConcurrentStreams: 100) + let connectAction = state.migrateFromHTTP1( + http1Connections: http1Conns, + requests: .init(), + newHTTP2Connection: conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // execute request on idle connection - let mockRequest1 = MockHTTPRequest(eventLoop: el1) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest1) let request1Action = state.executeRequest(request1) XCTAssertEqual(request1Action.request, .executeRequest(request1, conn1, cancelTimeout: false)) @@ -378,7 +516,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertEqual(closeStream1Action.connection, .scheduleTimeoutTimer(conn1ID, on: el1)) // execute request on idle connection with required event loop - let mockRequest2 = MockHTTPRequest(eventLoop: el1, requiresEventLoopForChannel: true) + let mockRequest2 = MockHTTPScheduableRequest(eventLoop: el1, requiresEventLoopForChannel: true) let request2 = HTTPConnectionPool.Request(mockRequest2) let request2Action = state.executeRequest(request2) XCTAssertEqual(request2Action.request, .executeRequest(request2, conn1, cancelTimeout: false)) @@ -396,18 +534,35 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) - let connectAction = state.migrateFromHTTP1(http1Connections: http1Conns, requests: .init(), newHTTP2Connection: conn1, maxConcurrentStreams: 100) + let connectAction = state.migrateFromHTTP1( + http1Connections: http1Conns, + requests: .init(), + newHTTP2Connection: conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // let the connection timeout let timeoutAction = state.connectionIdleTimeout(conn1ID) @@ -424,20 +579,37 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) - let connectAction = state.migrateFromHTTP1(http1Connections: http1Conns, requests: .init(), newHTTP2Connection: conn1, maxConcurrentStreams: 100) + let connectAction = state.migrateFromHTTP1( + http1Connections: http1Conns, + requests: .init(), + newHTTP2Connection: conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // create new http2 connection - let mockRequest1 = MockHTTPRequest(eventLoop: el2, requiresEventLoopForChannel: true) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el2, requiresEventLoopForChannel: true) let request1 = HTTPConnectionPool.Request(mockRequest1) let executeAction = state.executeRequest(request1) XCTAssertEqual(executeAction.request, .scheduleRequestTimeout(for: request1, on: el2)) @@ -459,9 +631,18 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) @@ -472,11 +653,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { maxConcurrentStreams: 100 ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) let goAwayAction = state.http2ConnectionGoAwayReceived(conn1ID) XCTAssertEqual(goAwayAction.request, .none) @@ -489,9 +673,18 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) let connectAction = state.migrateFromHTTP1( @@ -501,14 +694,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { maxConcurrentStreams: 100 ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // execute request on idle connection - let mockRequest1 = MockHTTPRequest(eventLoop: el1) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest1) let request1Action = state.executeRequest(request1) XCTAssertEqual(request1Action.request, .executeRequest(request1, conn1, cancelTimeout: false)) @@ -530,9 +726,18 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) let connectAction1 = state.migrateFromHTTP1( @@ -542,21 +747,24 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { maxConcurrentStreams: 1 ) XCTAssertEqual(connectAction1.request, .none) - XCTAssertEqual(connectAction1.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction1.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // execute request - let mockRequest1 = MockHTTPRequest(eventLoop: el1) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest1) let request1Action = state.executeRequest(request1) XCTAssertEqual(request1Action.request, .executeRequest(request1, conn1, cancelTimeout: false)) XCTAssertEqual(request1Action.connection, .cancelTimeoutTimer(conn1ID)) // queue request - let mockRequest2 = MockHTTPRequest(eventLoop: el1) + let mockRequest2 = MockHTTPScheduableRequest(eventLoop: el1) let request2 = HTTPConnectionPool.Request(mockRequest2) let request2Action = state.executeRequest(request2) XCTAssertEqual(request2Action.request, .scheduleRequestTimeout(for: request2, on: el1)) @@ -592,12 +800,18 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el1 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil + ) /// first 8 request should create a new connection var connectionIDs: [HTTPConnectionPool.Connection.ID] = [] for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .createConnection(let connID, let eventLoop) = action.connection else { @@ -616,7 +830,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { /// after we reached the `maximumConcurrentHTTP1Connections`, we will not create new connections for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) XCTAssertEqual(action.connection, .none) @@ -640,11 +854,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.execute(request.__testOnly_wrapped_request(), on: conn1)) } - XCTAssertEqual(migrationAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: nil - )) + XCTAssertEqual( + migrationAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: nil + ) + ) /// remaining connections should be closed immediately without executing any request for connID in connectionIDs.dropFirst() { @@ -678,10 +895,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el1 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: false, + maximumConnectionUses: nil + ) /// create a new connection - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .createConnection(let conn1ID, let eventLoop) = action.connection else { @@ -720,12 +943,18 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el1 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil + ) /// first 8 request should create a new connection var connectionIDs: [HTTPConnectionPool.Connection.ID] = [] for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .createConnection(let connID, let eventLoop) = action.connection else { @@ -740,7 +969,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { /// after we reached the `maximumConcurrentHTTP1Connections`, we will not create new connections for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) XCTAssertEqual(action.connection, .none) @@ -791,7 +1020,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .executeRequestsAndCancelTimeouts(let requests, let conn) = migrationAction.request else { return XCTFail("unexpected request action \(migrationAction.request)") } - XCTAssertEqual(migrationAction.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual( + migrationAction.connection, + .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil) + ) XCTAssertEqual(conn, http2Conn) XCTAssertEqual(requests.count, 10) @@ -855,10 +1087,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el2 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: false, + maximumConnectionUses: nil + ) // create http2 connection - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest) let action1 = state.executeRequest(request1) guard case .createConnection(let http2ConnID, let http2EventLoop) = action1.connection else { @@ -870,11 +1108,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try queuer.queue(mockRequest, id: request1.id)) let http2Conn: HTTPConnectionPool.Connection = .__testOnly_connection(id: http2ConnID, eventLoop: el1) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP2(http2ConnID, maxConcurrentStreams: 10)) - let migrationAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) - guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = migrationAction1.request else { - return XCTFail("unexpected request action \(migrationAction1.request)") + let executeAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) + guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = executeAction1.request else { + return XCTFail("unexpected request action \(executeAction1.request)") } - XCTAssertEqual(migrationAction1.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual(requests.count, 1) for request in requests { XCTAssertNoThrow(try queuer.get(request.id, request: request.__testOnly_wrapped_request())) @@ -882,14 +1120,20 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } // a request with new required event loop should create a new connection - let mockRequestWithRequiredEventLoop = MockHTTPRequest(eventLoop: el2, requiresEventLoopForChannel: true) + let mockRequestWithRequiredEventLoop = MockHTTPScheduableRequest( + eventLoop: el2, + requiresEventLoopForChannel: true + ) let requestWithRequiredEventLoop = HTTPConnectionPool.Request(mockRequestWithRequiredEventLoop) let action2 = state.executeRequest(requestWithRequiredEventLoop) guard case .createConnection(let http1ConnId, let http1EventLoop) = action2.connection else { return XCTFail("Unexpected connection action \(action2.connection)") } XCTAssertTrue(http1EventLoop === el2) - XCTAssertEqual(action2.request, .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop)) + XCTAssertEqual( + action2.request, + .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop) + ) XCTAssertNoThrow(try connections.createConnection(http1ConnId, on: el2)) XCTAssertNoThrow(try queuer.queue(mockRequestWithRequiredEventLoop, id: requestWithRequiredEventLoop.id)) @@ -900,7 +1144,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .executeRequest(let request2, http1Conn, cancelTimeout: true) = migrationAction2.request else { return XCTFail("unexpected request action \(migrationAction2.request)") } - guard case .migration(let createConnections, closeConnections: [], scheduleTimeout: nil) = migrationAction2.connection else { + guard + case .migration(let createConnections, closeConnections: [], scheduleTimeout: nil) = migrationAction2 + .connection + else { return XCTFail("unexpected connection action \(migrationAction2.connection)") } XCTAssertEqual(createConnections.map { $0.1.id }, [el2.id]) @@ -921,10 +1168,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el2 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: false, + maximumConnectionUses: nil + ) // create http2 connection - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest) let action1 = state.executeRequest(request1) guard case .createConnection(let http2ConnID, let http2EventLoop) = action1.connection else { @@ -936,11 +1189,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try queuer.queue(mockRequest, id: request1.id)) let http2Conn: HTTPConnectionPool.Connection = .__testOnly_connection(id: http2ConnID, eventLoop: el1) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP2(http2ConnID, maxConcurrentStreams: 10)) - let migrationAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) - guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = migrationAction1.request else { - return XCTFail("unexpected request action \(migrationAction1.request)") + let executeAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) + guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = executeAction1.request else { + return XCTFail("unexpected request action \(executeAction1.request)") } - XCTAssertEqual(migrationAction1.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual(requests.count, 1) for request in requests { XCTAssertNoThrow(try queuer.get(request.id, request: request.__testOnly_wrapped_request())) @@ -948,14 +1201,20 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } // a request with new required event loop should create a new connection - let mockRequestWithRequiredEventLoop = MockHTTPRequest(eventLoop: el2, requiresEventLoopForChannel: true) + let mockRequestWithRequiredEventLoop = MockHTTPScheduableRequest( + eventLoop: el2, + requiresEventLoopForChannel: true + ) let requestWithRequiredEventLoop = HTTPConnectionPool.Request(mockRequestWithRequiredEventLoop) let action2 = state.executeRequest(requestWithRequiredEventLoop) guard case .createConnection(let http1ConnId, let http1EventLoop) = action2.connection else { return XCTFail("Unexpected connection action \(action2.connection)") } XCTAssertTrue(http1EventLoop === el2) - XCTAssertEqual(action2.request, .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop)) + XCTAssertEqual( + action2.request, + .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop) + ) XCTAssertNoThrow(try connections.createConnection(http1ConnId, on: el2)) XCTAssertNoThrow(try queuer.queue(mockRequestWithRequiredEventLoop, id: requestWithRequiredEventLoop.id)) @@ -971,13 +1230,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } XCTAssertTrue(queuer.isEmpty) - // if we established a new http/1 connection we should migrate back to http/1, + // if we established a new http/1 connection we should migrate to http/1, // close the connection and shutdown the pool let http1Conn: HTTPConnectionPool.Connection = .__testOnly_connection(id: http1ConnId, eventLoop: el2) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP1(http1ConnId)) let migrationAction2 = state.newHTTP1ConnectionCreated(http1Conn) XCTAssertEqual(migrationAction2.request, .none) - XCTAssertEqual(migrationAction2.connection, .migration(createConnections: [], closeConnections: [http1Conn], scheduleTimeout: nil)) + XCTAssertEqual( + migrationAction2.connection, + .migration(createConnections: [], closeConnections: [http1Conn], scheduleTimeout: nil) + ) // in http/1 state, we should close idle http2 connections XCTAssertNoThrow(try connections.finishExecution(http2Conn.id)) @@ -993,11 +1255,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el2 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: false, + maximumConnectionUses: nil + ) var connectionIDs: [HTTPConnectionPool.Connection.ID] = [] - for el in [el1, el2, el2] { - let mockRequest = MockHTTPRequest(eventLoop: el, requiresEventLoopForChannel: true) + for el in [el1, el2] { + let mockRequest = MockHTTPScheduableRequest(eventLoop: el, requiresEventLoopForChannel: true) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .createConnection(let connID, let eventLoop) = action.connection else { @@ -1010,7 +1278,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try queuer.queue(mockRequest, id: request.id)) } - // fail the two connections for el2 + // fail the connection for el2 for connectionID in connectionIDs.dropFirst() { struct SomeError: Error {} XCTAssertNoThrow(try connections.failConnectionCreation(connectionID)) @@ -1023,16 +1291,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } let http2ConnID1 = connectionIDs[0] let http2ConnID2 = connectionIDs[1] - let http2ConnID3 = connectionIDs[2] // let the first connection on el1 succeed as a http2 connection let http2Conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: http2ConnID1, eventLoop: el1) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP2(http2ConnID1, maxConcurrentStreams: 10)) - let migrationAction1 = state.newHTTP2ConnectionCreated(http2Conn1, maxConcurrentStreams: 10) - guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn1) = migrationAction1.request else { - return XCTFail("unexpected request action \(migrationAction1.request)") + let connectionAction = state.newHTTP2ConnectionCreated(http2Conn1, maxConcurrentStreams: 10) + guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn1) = connectionAction.request else { + return XCTFail("unexpected request action \(connectionAction.request)") } - XCTAssertEqual(migrationAction1.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) XCTAssertEqual(requests.count, 1) for request in requests { XCTAssertNoThrow(try queuer.get(request.id, request: request.__testOnly_wrapped_request())) @@ -1051,14 +1317,6 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } XCTAssertTrue(eventLoop2 === el2) XCTAssertNoThrow(try connections.createConnection(newHttp2ConnID2, on: el2)) - - // we now have a starting connection for el2 and another one backing off - - // if the backoff timer fires now for a connection on el2, we should *not* start a new connection - XCTAssertNoThrow(try connections.connectionBackoffTimerDone(http2ConnID3)) - let action3 = state.connectionCreationBackoffDone(http2ConnID3) - XCTAssertEqual(action3.request, .none) - XCTAssertEqual(action3.connection, .none) } func testMaxConcurrentStreamsIsRespected() { @@ -1076,7 +1334,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // shall be queued. for i in 0..<1000 { let requestEL = elg.next() - let mockRequest = MockHTTPRequest(eventLoop: requestEL) + let mockRequest = MockHTTPScheduableRequest(eventLoop: requestEL) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -1084,10 +1342,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { case 0: XCTAssertEqual(executeAction.connection, .cancelTimeoutTimer(generalPurposeConnection.id)) XCTAssertNoThrow(try connections.activateConnection(generalPurposeConnection.id)) - XCTAssertEqual(executeAction.request, .executeRequest(request, generalPurposeConnection, cancelTimeout: false)) + XCTAssertEqual( + executeAction.request, + .executeRequest(request, generalPurposeConnection, cancelTimeout: false) + ) XCTAssertNoThrow(try connections.execute(mockRequest, on: generalPurposeConnection)) case 1..<100: - XCTAssertEqual(executeAction.request, .executeRequest(request, generalPurposeConnection, cancelTimeout: false)) + XCTAssertEqual( + executeAction.request, + .executeRequest(request, generalPurposeConnection, cancelTimeout: false) + ) XCTAssertEqual(executeAction.connection, .none) XCTAssertNoThrow(try connections.execute(mockRequest, on: generalPurposeConnection)) case 100..<1000: @@ -1105,7 +1369,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.finishExecution(generalPurposeConnection.id)) let finishAction = state.http2ConnectionStreamClosed(generalPurposeConnection.id) XCTAssertEqual(finishAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request else { + guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request + else { return XCTFail("Unexpected request action: \(finishAction.request)") } guard requests.count == 1, let request = requests.first else { @@ -1120,11 +1385,23 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // Next the server allows for more concurrent streams let newMaxStreams = 200 - XCTAssertNoThrow(try connections.newHTTP2ConnectionSettingsReceived(generalPurposeConnection.id, maxConcurrentStreams: newMaxStreams)) - let newMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived(generalPurposeConnection.id, newMaxStreams: newMaxStreams) + XCTAssertNoThrow( + try connections.newHTTP2ConnectionSettingsReceived( + generalPurposeConnection.id, + maxConcurrentStreams: newMaxStreams + ) + ) + let newMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived( + generalPurposeConnection.id, + newMaxStreams: newMaxStreams + ) XCTAssertEqual(newMaxStreamsAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = newMaxStreamsAction.request else { - return XCTFail("Unexpected request action after new max concurrent stream setting: \(newMaxStreamsAction.request)") + guard + case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = newMaxStreamsAction.request + else { + return XCTFail( + "Unexpected request action after new max concurrent stream setting: \(newMaxStreamsAction.request)" + ) } XCTAssertEqual(requests.count, 100, "Expected to execute 100 more requests") for request in requests { @@ -1141,7 +1418,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.finishExecution(generalPurposeConnection.id)) let finishAction = state.http2ConnectionStreamClosed(generalPurposeConnection.id) XCTAssertEqual(finishAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request else { + guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request + else { return XCTFail("Unexpected request action: \(finishAction.request)") } guard requests.count == 1, let request = requests.first else { @@ -1154,8 +1432,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // Next the server allows for fewer concurrent streams let fewerMaxStreams = 50 - XCTAssertNoThrow(try connections.newHTTP2ConnectionSettingsReceived(generalPurposeConnection.id, maxConcurrentStreams: fewerMaxStreams)) - let fewerMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived(generalPurposeConnection.id, newMaxStreams: fewerMaxStreams) + XCTAssertNoThrow( + try connections.newHTTP2ConnectionSettingsReceived( + generalPurposeConnection.id, + maxConcurrentStreams: fewerMaxStreams + ) + ) + let fewerMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived( + generalPurposeConnection.id, + newMaxStreams: fewerMaxStreams + ) XCTAssertEqual(fewerMaxStreamsAction.connection, .none) XCTAssertEqual(fewerMaxStreamsAction.request, .none) @@ -1173,7 +1459,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.finishExecution(generalPurposeConnection.id)) let finishAction = state.http2ConnectionStreamClosed(generalPurposeConnection.id) XCTAssertEqual(finishAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request else { + guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request + else { return XCTFail("Unexpected request action: \(finishAction.request)") } guard requests.count == 1, let request = requests.first else { @@ -1193,7 +1480,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { switch remaining { case 1: timeoutTimerScheduled = true - XCTAssertEqual(finishAction.connection, .scheduleTimeoutTimer(generalPurposeConnection.id, on: generalPurposeConnection.eventLoop)) + XCTAssertEqual( + finishAction.connection, + .scheduleTimeoutTimer(generalPurposeConnection.id, on: generalPurposeConnection.eventLoop) + ) XCTAssertNoThrow(try connections.parkConnection(generalPurposeConnection.id)) case 2...50: XCTAssertEqual(finishAction.connection, .none) @@ -1235,16 +1525,20 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { func XCTAssertEqualTypeAndValue( _ lhs: @autoclosure () throws -> Left, _ rhs: @autoclosure () throws -> Right, - file: StaticString = #file, + file: StaticString = #filePath, line: UInt = #line ) { - XCTAssertNoThrow(try { - let lhs = try lhs() - let rhs = try rhs() - guard let lhsAsRhs = lhs as? Right else { - XCTFail("could not cast \(lhs) of type \(Right.self) to \(Left.self)") - return - } - XCTAssertEqual(lhsAsRhs, rhs) - }(), file: file, line: line) + XCTAssertNoThrow( + try { + let lhs = try lhs() + let rhs = try rhs() + guard let lhsAsRhs = lhs as? Right else { + XCTFail("could not cast \(lhs) of type \(type(of: lhs)) to \(type(of: rhs))", file: file, line: line) + return + } + XCTAssertEqual(lhsAsRhs, rhs, file: file, line: line) + }(), + file: file, + line: line + ) } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests+XCTest.swift deleted file mode 100644 index 93945f63c..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests+XCTest.swift +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+ManagerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_ManagerTests { - static var allTests: [(String, (HTTPConnectionPool_ManagerTests) -> () throws -> Void)] { - return [ - ("testManagerHappyPath", testManagerHappyPath), - ("testShutdownManagerThatHasSeenNoConnections", testShutdownManagerThatHasSeenNoConnections), - ("testExecutingARequestOnAShutdownPoolManager", testExecutingARequestOnAShutdownPoolManager), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift index d84e7f442..724c00b1f 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift @@ -12,12 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient +import Logging import NIOCore import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_ManagerTests: XCTestCase { func testManagerHappyPath() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 4) @@ -49,15 +51,17 @@ class HTTPConnectionPool_ManagerTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -105,15 +109,17 @@ class HTTPConnectionPool_ManagerTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests+XCTest.swift deleted file mode 100644 index 2511ba267..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests+XCTest.swift +++ /dev/null @@ -1,31 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+RequestQueueTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_RequestQueueTests { - static var allTests: [(String, (HTTPConnectionPool_RequestQueueTests) -> () throws -> Void)] { - return [ - ("testCountAndIsEmptyWorks", testCountAndIsEmptyWorks), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift index f8d6044cd..d792895d3 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOEmbedded @@ -20,6 +19,8 @@ import NIOHTTP1 import NIOSSL import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_RequestQueueTests: XCTestCase { func testCountAndIsEmptyWorks() { var queue = HTTPConnectionPool.RequestQueue() diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift index cb67837d7..bd9752d5d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift @@ -12,28 +12,30 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient +import Atomics import Dispatch import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +@testable import AsyncHTTPClient + /// An `EventLoopGroup` of `EmbeddedEventLoop`s. final class EmbeddedEventLoopGroup: EventLoopGroup { private let loops: [EmbeddedEventLoop] - private let index = NIOAtomic.makeAtomic(value: 0) + private let index = ManagedAtomic(0) internal init(loops: Int) { self.loops = (0.. EventLoop { - let index: Int = self.index.add(1) + let index: Int = self.index.loadThenWrappingIncrement(ordering: .relaxed) return self.loops[index % self.loops.count] } internal func makeIterator() -> EventLoopIterator { - return EventLoopIterator(self.loops) + EventLoopIterator(self.loops) } internal func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { @@ -55,7 +57,7 @@ final class EmbeddedEventLoopGroup: EventLoopGroup { extension HTTPConnectionPool.Request: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { - return lhs.id == rhs.id + lhs.id == rhs.id } } @@ -77,15 +79,24 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction: Equatable { switch (lhs, rhs) { case (.createConnection(let lhsConnID, on: let lhsEL), .createConnection(let rhsConnID, on: let rhsEL)): return lhsConnID == rhsConnID && lhsEL === rhsEL - case (.scheduleBackoffTimer(let lhsConnID, let lhsBackoff, on: let lhsEL), .scheduleBackoffTimer(let rhsConnID, let rhsBackoff, on: let rhsEL)): + case ( + .scheduleBackoffTimer(let lhsConnID, let lhsBackoff, on: let lhsEL), + .scheduleBackoffTimer(let rhsConnID, let rhsBackoff, on: let rhsEL) + ): return lhsConnID == rhsConnID && lhsBackoff == rhsBackoff && lhsEL === rhsEL case (.scheduleTimeoutTimer(let lhsConnID, on: let lhsEL), .scheduleTimeoutTimer(let rhsConnID, on: let rhsEL)): return lhsConnID == rhsConnID && lhsEL === rhsEL case (.cancelTimeoutTimer(let lhsConnID), .cancelTimeoutTimer(let rhsConnID)): return lhsConnID == rhsConnID - case (.closeConnection(let lhsConn, isShutdown: let lhsShut), .closeConnection(let rhsConn, isShutdown: let rhsShut)): + case ( + .closeConnection(let lhsConn, isShutdown: let lhsShut), + .closeConnection(let rhsConn, isShutdown: let rhsShut) + ): return lhsConn == rhsConn && lhsShut == rhsShut - case (.cleanupConnections(let lhsContext, isShutdown: let lhsShut), .cleanupConnections(let rhsContext, isShutdown: let rhsShut)): + case ( + .cleanupConnections(let lhsContext, isShutdown: let lhsShut), + .cleanupConnections(let rhsContext, isShutdown: let rhsShut) + ): return lhsContext == rhsContext && lhsShut == rhsShut case ( .migration( @@ -99,12 +110,13 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction: Equatable { let rhsScheduleTimeout ) ): - return lhsCreateConnections.elementsEqual(rhsCreateConnections, by: { - $0.0 == $1.0 && $0.1 === $1.1 - }) && - lhsCloseConnections == rhsCloseConnections && - lhsScheduleTimeout?.0 == rhsScheduleTimeout?.0 && - lhsScheduleTimeout?.1 === rhsScheduleTimeout?.1 + return lhsCreateConnections.elementsEqual( + rhsCreateConnections, + by: { + $0.0 == $1.0 && $0.1 === $1.1 + } + ) && lhsCloseConnections == rhsCloseConnections && lhsScheduleTimeout?.0 == rhsScheduleTimeout?.0 + && lhsScheduleTimeout?.1 === rhsScheduleTimeout?.1 case (.none, .none): return true default: @@ -116,18 +128,28 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction: Equatable { extension HTTPConnectionPool.StateMachine.RequestAction: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { - case (.executeRequest(let lhsReq, let lhsConn, let lhsReqID), .executeRequest(let rhsReq, let rhsConn, let rhsReqID)): + case ( + .executeRequest(let lhsReq, let lhsConn, let lhsReqID), + .executeRequest(let rhsReq, let rhsConn, let rhsReqID) + ): return lhsReq == rhsReq && lhsConn == rhsConn && lhsReqID == rhsReqID - case (.executeRequestsAndCancelTimeouts(let lhsReqs, let lhsConn), .executeRequestsAndCancelTimeouts(let rhsReqs, let rhsConn)): + case ( + .executeRequestsAndCancelTimeouts(let lhsReqs, let lhsConn), + .executeRequestsAndCancelTimeouts(let rhsReqs, let rhsConn) + ): return lhsReqs.elementsEqual(rhsReqs, by: { $0 == $1 }) && lhsConn == rhsConn - case (.failRequest(let lhsReq, _, cancelTimeout: let lhsReqID), .failRequest(let rhsReq, _, cancelTimeout: let rhsReqID)): + case ( + .failRequest(let lhsReq, _, cancelTimeout: let lhsReqID), + .failRequest(let rhsReq, _, cancelTimeout: let rhsReqID) + ): return lhsReq == rhsReq && lhsReqID == rhsReqID case (.failRequestsAndCancelTimeouts(let lhsReqs, _), .failRequestsAndCancelTimeouts(let rhsReqs, _)): return lhsReqs.elementsEqual(rhsReqs, by: { $0 == $1 }) - case (.scheduleRequestTimeout(for: let lhsReq, on: let lhsEL), .scheduleRequestTimeout(for: let rhsReq, on: let rhsEL)): + case ( + .scheduleRequestTimeout(for: let lhsReq, on: let lhsEL), + .scheduleRequestTimeout(for: let rhsReq, on: let rhsEL) + ): return lhsReq == rhsReq && lhsEL === rhsEL - case (.cancelRequestTimeout(let lhsReqID), .cancelRequestTimeout(let rhsReqID)): - return lhsReqID == rhsReqID case (.none, .none): return true default: @@ -147,7 +169,10 @@ extension HTTPConnectionPool.HTTP2StateMachine.EstablishedConnectionAction: Equa switch (lhs, rhs) { case (.scheduleTimeoutTimer(let lhsConnID, on: let lhsEL), .scheduleTimeoutTimer(let rhsConnID, on: let rhsEL)): return lhsConnID == rhsConnID && lhsEL === rhsEL - case (.closeConnection(let lhsConn, isShutdown: let lhsShut), .closeConnection(let rhsConn, isShutdown: let rhsShut)): + case ( + .closeConnection(let lhsConn, isShutdown: let lhsShut), + .closeConnection(let rhsConn, isShutdown: let rhsShut) + ): return lhsConn == rhsConn && lhsShut == rhsShut case (.none, .none): return true diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests+XCTest.swift deleted file mode 100644 index acdc0ab26..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests+XCTest.swift +++ /dev/null @@ -1,39 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPoolTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPoolTests { - static var allTests: [(String, (HTTPConnectionPoolTests) -> () throws -> Void)] { - return [ - ("testOnlyOneConnectionIsUsedForSubSequentRequests", testOnlyOneConnectionIsUsedForSubSequentRequests), - ("testConnectionsForEventLoopRequirementsAreClosed", testConnectionsForEventLoopRequirementsAreClosed), - ("testConnectionPoolGrowsToMaxConcurrentConnections", testConnectionPoolGrowsToMaxConcurrentConnections), - ("testConnectionCreationIsRetriedUntilRequestIsFailed", testConnectionCreationIsRetriedUntilRequestIsFailed), - ("testConnectionCreationIsRetriedUntilPoolIsShutdown", testConnectionCreationIsRetriedUntilPoolIsShutdown), - ("testConnectionCreationIsRetriedUntilRequestIsCancelled", testConnectionCreationIsRetriedUntilRequestIsCancelled), - ("testConnectionShutdownIsCalledOnActiveConnections", testConnectionShutdownIsCalledOnActiveConnections), - ("testConnectionPoolStressResistanceHTTP1", testConnectionPoolStressResistanceHTTP1), - ("testBackoffBehavesSensibly", testBackoffBehavesSensibly), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift index 60e5077ee..a40703456 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPoolTests: XCTestCase { func testOnlyOneConnectionIsUsedForSubSequentRequests() { let httpBin = HTTPBin() @@ -53,15 +54,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -82,7 +85,6 @@ class HTTPConnectionPoolTests: XCTestCase { let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)") let poolDelegate = TestDelegate(eventLoop: eventLoop) - let pool = HTTPConnectionPool( eventLoopGroup: eventLoopGroup, sslContextCache: .init(), @@ -93,6 +95,74 @@ class HTTPConnectionPoolTests: XCTestCase { idGenerator: .init(), backgroundActivityLogger: .init(label: "test") ) + defer { + pool.shutdown() + XCTAssertNoThrow(try poolDelegate.future.wait()) + XCTAssertNoThrow(try eventLoop.scheduleTask(in: .milliseconds(100)) {}.futureResult.wait()) + XCTAssertEqual(httpBin.activeConnections, 0) + // Since we would migrate from h2 -> h1, which creates a general purpose connection + // for every connection in .starting state, after the first request which will + // be serviced by an overflow connection, the rest of requests will use the general + // purpose connection since they are all on the same event loop. + // Hence we will only create 1 overflow connection and 1 general purpose connection. + XCTAssertEqual(httpBin.createdConnections, 2) + } + + XCTAssertEqual(httpBin.createdConnections, 0) + + for _ in 0..<10 { + var maybeRequest: HTTPClient.Request? + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .init( + .testOnly_exact(channelOn: eventLoopGroup.next(), delegateOn: eventLoopGroup.next()) + ), + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) + + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } + + pool.executeRequest(requestBag) + XCTAssertNoThrow(try requestBag.task.futureResult.wait()) + + // Flakiness Alert: We check <= and >= instead of == + // While migration from h2 -> h1, one general purpose and one over flow connection + // will be created, there's no guarantee as to whether the request is executed + // after both are created. + XCTAssertGreaterThanOrEqual(httpBin.createdConnections, 1) + XCTAssertLessThanOrEqual(httpBin.createdConnections, 2) + } + } + + func testConnectionsForEventLoopRequirementsAreClosedH1Only() { + let httpBin = HTTPBin() + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + let eventLoop = eventLoopGroup.next() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)") + let poolDelegate = TestDelegate(eventLoop: eventLoop) + var configuration = HTTPClient.Configuration() + configuration.httpVersion = .http1Only + let pool = HTTPConnectionPool( + eventLoopGroup: eventLoopGroup, + sslContextCache: .init(), + tlsConfiguration: .none, + clientConfiguration: configuration, + key: .init(request), + delegate: poolDelegate, + idGenerator: .init(), + backgroundActivityLogger: .init(label: "test") + ) defer { pool.shutdown() XCTAssertNoThrow(try poolDelegate.future.wait()) @@ -107,15 +177,19 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .init(.testOnly_exact(channelOn: eventLoopGroup.next(), delegateOn: eventLoopGroup.next())), - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .init( + .testOnly_exact(channelOn: eventLoopGroup.next(), delegateOn: eventLoopGroup.next()) + ), + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -162,15 +236,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -216,15 +292,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -264,15 +342,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -320,21 +400,23 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } pool.executeRequest(requestBag) XCTAssertNoThrow(try eventLoop.scheduleTask(in: .seconds(1)) {}.futureResult.wait()) - requestBag.cancel() + requestBag.fail(HTTPClientError.cancelled) XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { XCTAssertEqual($0 as? HTTPClientError, .cancelled) @@ -366,15 +448,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/wait")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -419,22 +503,24 @@ class HTTPConnectionPoolTests: XCTestCase { let dispatchGroup = DispatchGroup() for workerID in 0..? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: url)) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } pool.executeRequest(requestBag) @@ -458,7 +544,10 @@ class HTTPConnectionPoolTests: XCTestCase { var backoff = HTTPConnectionPool.calculateBackoff(failedAttempt: 1) // The value should be 100ms±3ms - XCTAssertLessThanOrEqual((backoff - .milliseconds(100)).nanoseconds.magnitude, TimeAmount.milliseconds(3).nanoseconds.magnitude) + XCTAssertLessThanOrEqual( + (backoff - .milliseconds(100)).nanoseconds.magnitude, + TimeAmount.milliseconds(3).nanoseconds.magnitude + ) // Should always increase // We stop when we get within the jitter of 60s, which is 1.8s @@ -474,7 +563,8 @@ class HTTPConnectionPoolTests: XCTestCase { // Ok, now we should be able to do a hundred increments, and always hit 60s, plus or minus 1.8s of jitter. for offset in 0..<100 { XCTAssertLessThanOrEqual( - (HTTPConnectionPool.calculateBackoff(failedAttempt: attempt + offset) - .seconds(60)).nanoseconds.magnitude, + (HTTPConnectionPool.calculateBackoff(failedAttempt: attempt + offset) - .seconds(60)).nanoseconds + .magnitude, TimeAmount.milliseconds(1800).nanoseconds.magnitude ) } diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift deleted file mode 100644 index b54865fd8..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift +++ /dev/null @@ -1,66 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPRequestStateMachineTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPRequestStateMachineTests { - static var allTests: [(String, (HTTPRequestStateMachineTests) -> () throws -> Void)] { - return [ - ("testSimpleGETRequest", testSimpleGETRequest), - ("testPOSTRequestWithWriterBackpressure", testPOSTRequestWithWriterBackpressure), - ("testPOSTContentLengthIsTooLong", testPOSTContentLengthIsTooLong), - ("testPOSTContentLengthIsTooShort", testPOSTContentLengthIsTooShort), - ("testRequestBodyStreamIsCancelledIfServerRespondsWith301", testRequestBodyStreamIsCancelledIfServerRespondsWith301), - ("testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure", testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure), - ("testRequestBodyStreamIsContinuedIfServerRespondsWith200", testRequestBodyStreamIsContinuedIfServerRespondsWith200), - ("testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200", testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200), - ("testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200", testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200), - ("testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerSendHeadWithStatus200", testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerSendHeadWithStatus200), - ("testRequestIsNotSendUntilChannelIsWritable", testRequestIsNotSendUntilChannelIsWritable), - ("testConnectionBecomesInactiveWhileWaitingForWritable", testConnectionBecomesInactiveWhileWaitingForWritable), - ("testResponseReadingWithBackpressure", testResponseReadingWithBackpressure), - ("testChannelReadCompleteTriggersButNoBodyDataWasReceivedSoFar", testChannelReadCompleteTriggersButNoBodyDataWasReceivedSoFar), - ("testResponseReadingWithBackpressureEndOfResponseAllowsReadEventsToTriggerDirectly", testResponseReadingWithBackpressureEndOfResponseAllowsReadEventsToTriggerDirectly), - ("testCancellingARequestInStateInitializedKeepsTheConnectionAlive", testCancellingARequestInStateInitializedKeepsTheConnectionAlive), - ("testCancellingARequestBeforeBeingSendKeepsTheConnectionAlive", testCancellingARequestBeforeBeingSendKeepsTheConnectionAlive), - ("testConnectionBecomesWritableBeforeFirstRequest", testConnectionBecomesWritableBeforeFirstRequest), - ("testCancellingARequestThatIsSent", testCancellingARequestThatIsSent), - ("testRemoteSuddenlyClosesTheConnection", testRemoteSuddenlyClosesTheConnection), - ("testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnored", testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnored), - ("testResponseWithStatus1XXAreIgnored", testResponseWithStatus1XXAreIgnored), - ("testReadTimeoutThatFiresToLateIsIgnored", testReadTimeoutThatFiresToLateIsIgnored), - ("testCancellationThatIsInvokedToLateIsIgnored", testCancellationThatIsInvokedToLateIsIgnored), - ("testErrorWhileRunningARequestClosesTheStream", testErrorWhileRunningARequestClosesTheStream), - ("testCanReadHTTP1_0ResponseWithoutBody", testCanReadHTTP1_0ResponseWithoutBody), - ("testCanReadHTTP1_0ResponseWithBody", testCanReadHTTP1_0ResponseWithBody), - ("testFailHTTP1_0RequestThatIsStillUploading", testFailHTTP1_0RequestThatIsStillUploading), - ("testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown", testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown), - ("testNIOSSLErrorUncleanShutdownShouldBeTreatedAsRemoteConnectionCloseWhileInWaitingForHeadState", testNIOSSLErrorUncleanShutdownShouldBeTreatedAsRemoteConnectionCloseWhileInWaitingForHeadState), - ("testArbitraryErrorShouldBeTreatedAsARequestFailureWhileInWaitingForHeadState", testArbitraryErrorShouldBeTreatedAsARequestFailureWhileInWaitingForHeadState), - ("testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt", testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt), - ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand), - ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead), - ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand), - ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index ab55345c9..8fe879745 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -12,21 +12,29 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore +import NIOEmbedded import NIOHTTP1 import NIOSSL import XCTest +@testable import AsyncHTTPClient + class HTTPRequestStateMachineTests: XCTestCase { func testSimpleGETRequest() { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -35,32 +43,47 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTRequestWithWriterBackpressure() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "4")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) let part2 = IOData.byteBuffer(ByteBuffer(bytes: [2])) let part3 = IOData.byteBuffer(ByteBuffer(bytes: [3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) // oh the channel reports... we should slow down producing... XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) // but we issued a .produceMoreRequestBodyData before... Thus, we must accept more produced // data - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) // however when we have put the data on the channel, we should not issue further // .produceMoreRequestBodyData events // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) - XCTAssertEqual(state.requestStreamPartReceived(part3), .sendBodyPart(part3)) - XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd) + XCTAssertEqual(state.requestStreamPartReceived(part3, promise: nil), .sendBodyPart(part3, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -69,14 +92,25 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTContentLengthIsTooLong() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "4")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) - state.requestStreamPartReceived(part1).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + state.requestStreamPartReceived(part1, promise: nil).assertFailRequest( + HTTPClientError.bodyLengthMismatch, + .close(nil) + ) // if another error happens the new one is ignored XCTAssertEqual(state.errorHappened(HTTPClientError.remoteConnectionClosed), .wait) @@ -84,140 +118,257 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTContentLengthIsTooShort() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "8")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "8")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(8)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) - state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) } func testRequestBodyStreamIsCancelledIfServerRespondsWith301() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part), .sendBodyPart(part)) + XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .sendBodyPart(part, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .movedPermanently) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: true)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: true) + ) XCTAssertEqual(state.writabilityChanged(writable: false), .wait) XCTAssertEqual(state.writabilityChanged(writable: true), .wait) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init())) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) + + XCTAssertEqual( + state.requestStreamFinished(promise: nil), + .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) + } - XCTAssertEqual(state.requestStreamFinished(), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") + func testStreamPartReceived_whenCancelled() { + var state = HTTPRequestStateMachine(isChannelWritable: false) + let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) + + XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .none)) + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.cancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) } func testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part), .sendBodyPart(part)) + XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .sendBodyPart(part, nil)) XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .movedPermanently) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.writabilityChanged(writable: true), .wait) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init())) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") - - XCTAssertEqual(state.requestStreamFinished(), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) + + XCTAssertEqual( + state.requestStreamFinished(promise: nil), + .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) } func testRequestBodyStreamIsContinuedIfServerRespondsWith200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseBodyParts(.init())) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) - XCTAssertEqual(state.requestStreamFinished(), .succeedRequest(.sendRequestEnd, .init())) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .succeedRequest(.sendRequestEnd(nil), .init())) + + XCTAssertEqual( + state.requestStreamPartReceived(part2, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil) + ) } func testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) - XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) } func testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseBodyParts(.init())) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) - state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) + state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) } func testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerSendHeadWithStatus200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) - state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) + state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) XCTAssertEqual(state.channelRead(.end(nil)), .wait) } @@ -227,10 +378,13 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .wait) XCTAssertEqual(state.read(), .read) - XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -249,10 +403,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -276,10 +440,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -303,10 +477,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -321,7 +505,11 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelRead(.body(part2)), .wait) XCTAssertEqual(state.read(), .read, "Calling `read` while we wait for a channelReadComplete doesn't crash") - XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait, "Calling `demandMoreResponseBodyParts` while we wait for a channelReadComplete doesn't crash") + XCTAssertEqual( + state.demandMoreResponseBodyParts(), + .wait, + "Calling `demandMoreResponseBodyParts` while we wait for a channelReadComplete doesn't crash" + ) XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts(.init([part2]))) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) @@ -350,11 +538,17 @@ class HTTPRequestStateMachineTests: XCTestCase { // --- sending request let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) // --- receiving response let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "4"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -365,30 +559,54 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close(nil)) } func testRemoteSuddenlyClosesTheConnection() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: .init([("content-length", "4")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: .init([("content-length", "4")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) - state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close) - XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3))), .wait) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) + state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close(nil)) + XCTAssertEqual( + state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3)), promise: nil), + .failSendBodyPart(HTTPClientError.cancelled, nil) + ) } func testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnored() { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) XCTAssertEqual(state.channelRead(.body(part0)), .wait) - state.idleReadTimeoutTriggered().assertFailRequest(HTTPClientError.readTimeout, .close) + state.idleReadTimeoutTriggered().assertFailRequest(HTTPClientError.readTimeout, .close(nil)) XCTAssertEqual(state.channelRead(.body(ByteBuffer(bytes: 4...7))), .wait) XCTAssertEqual(state.channelRead(.body(ByteBuffer(bytes: 8...11))), .wait) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) @@ -399,13 +617,19 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let continueHead = HTTPResponseHead(version: .http1_1, status: .continue) XCTAssertEqual(state.channelRead(.head(continueHead)), .wait) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -415,10 +639,16 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) XCTAssertEqual(state.idleReadTimeoutTriggered(), .wait, "A read timeout that fires to late must be ignored") } @@ -427,10 +657,16 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") } @@ -439,9 +675,15 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - - state.errorHappened(HTTPParserError.invalidChunkSize).assertFailRequest(HTTPParserError.invalidChunkSize, .close) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + state.errorHappened(HTTPParserError.invalidChunkSize).assertFailRequest( + HTTPParserError.invalidChunkSize, + .close(nil) + ) XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") } @@ -449,10 +691,16 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .internalServerError) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -465,11 +713,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .internalServerError) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -483,19 +737,28 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .stream) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part1: ByteBuffer = .init(string: "foo") - XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(part1)), .sendBodyPart(.byteBuffer(part1))) + XCTAssertEqual( + state.requestStreamPartReceived(.byteBuffer(part1), promise: nil), + .sendBodyPart(.byteBuffer(part1), nil) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .ok) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.channelRead(.body(body)), .wait) - state.channelRead(.end(nil)).assertFailRequest(HTTPClientError.remoteConnectionClosed, .close) + state.channelRead(.end(nil)).assertFailRequest(HTTPClientError.remoteConnectionClosed, .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -503,14 +766,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelRead(.body(body)), .wait) - state.errorHappened(NIOSSLError.uncleanShutdown).assertFailRequest(NIOSSLError.uncleanShutdown, .close) + state.errorHappened(NIOSSLError.uncleanShutdown).assertFailRequest(NIOSSLError.uncleanShutdown, .close(nil)) XCTAssertEqual(state.channelRead(.end(nil)), .wait) XCTAssertEqual(state.channelInactive(), .wait) } @@ -519,7 +788,10 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait) state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none) @@ -530,9 +802,12 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) - state.errorHappened(ArbitraryError()).assertFailRequest(ArbitraryError(), .close) + state.errorHappened(ArbitraryError()).assertFailRequest(ArbitraryError(), .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -540,17 +815,26 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "30"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelRead(.body(body)), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body])) XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait) - state.errorHappened(HTTPParserError.invalidEOFState).assertFailRequest(HTTPParserError.invalidEOFState, .close) + state.errorHappened(HTTPParserError.invalidEOFState).assertFailRequest( + HTTPParserError.invalidEOFState, + .close(nil) + ) XCTAssertEqual(state.channelInactive(), .wait) } @@ -558,11 +842,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -579,11 +869,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -600,11 +896,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -620,11 +922,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -656,24 +964,36 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.sendRequestHead(let lhsHead, let lhsStartBody), .sendRequestHead(let rhsHead, let rhsStartBody)): return lhsHead == rhsHead && lhsStartBody == rhsStartBody - case (.sendBodyPart(let lhsData), .sendBodyPart(let rhsData)): - return lhsData == rhsData + case ( + .notifyRequestHeadSendSuccessfully(let lhsResumeRequestBodyStream, let lhsStartIdleTimer), + .notifyRequestHeadSendSuccessfully(let rhsResumeRequestBodyStream, let rhsStartIdleTimer) + ): + return lhsResumeRequestBodyStream == rhsResumeRequestBodyStream && lhsStartIdleTimer == rhsStartIdleTimer - case (.sendRequestEnd, .sendRequestEnd): - return true + case (.sendBodyPart(let lhsData, let lhsPromise), .sendBodyPart(let rhsData, let rhsPromise)): + return lhsData == rhsData && lhsPromise?.futureResult == rhsPromise?.futureResult + + case (.sendRequestEnd(let lhsPromise), .sendRequestEnd(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult case (.pauseRequestBodyStream, .pauseRequestBodyStream): return true case (.resumeRequestBodyStream, .resumeRequestBodyStream): return true - case (.forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream)): + case ( + .forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), + .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream) + ): return lhsHead == rhsHead && lhsPauseRequestBodyStream == rhsPauseRequestBodyStream case (.forwardResponseBodyParts(let lhsData), .forwardResponseBodyParts(let rhsData)): return lhsData == rhsData - case (.succeedRequest(let lhsFinalAction, let lhsFinalBuffer), .succeedRequest(let rhsFinalAction, let rhsFinalBuffer)): + case ( + .succeedRequest(let lhsFinalAction, let lhsFinalBuffer), + .succeedRequest(let rhsFinalAction, let rhsFinalBuffer) + ): return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer case (.failRequest(_, let lhsFinalAction), .failRequest(_, let rhsFinalAction)): @@ -685,6 +1005,57 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.wait, .wait): return true + case ( + .failSendBodyPart(let lhsError as HTTPClientError, let lhsPromise), + .failSendBodyPart(let rhsError as HTTPClientError, let rhsPromise) + ): + return lhsError == rhsError && lhsPromise?.futureResult == rhsPromise?.futureResult + + case ( + .failSendStreamFinished(let lhsError as HTTPClientError, let lhsPromise), + .failSendStreamFinished(let rhsError as HTTPClientError, let rhsPromise) + ): + return lhsError == rhsError && lhsPromise?.futureResult == rhsPromise?.futureResult + + default: + return false + } + } +} + +extension HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction: Equatable { + public static func == ( + lhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, + rhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction + ) -> Bool { + switch (lhs, rhs) { + case (.close, close): + return true + + case (.sendRequestEnd(let lhsPromise), .sendRequestEnd(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + + case (.none, .none): + return true + + default: + return false + } + } +} + +extension HTTPRequestStateMachine.Action.FinalFailedRequestAction: Equatable { + public static func == ( + lhs: HTTPRequestStateMachine.Action.FinalFailedRequestAction, + rhs: HTTPRequestStateMachine.Action.FinalFailedRequestAction + ) -> Bool { + switch (lhs, rhs) { + case (.close(let lhsPromise), close(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + + case (.none, .none): + return true + default: return false } @@ -694,12 +1065,16 @@ extension HTTPRequestStateMachine.Action: Equatable { extension HTTPRequestStateMachine.Action { fileprivate func assertFailRequest( _ expectedError: Error, - _ expectedFinalStreamAction: HTTPRequestStateMachine.Action.FinalStreamAction, - file: StaticString = #file, + _ expectedFinalStreamAction: HTTPRequestStateMachine.Action.FinalFailedRequestAction, + file: StaticString = #filePath, line: UInt = #line ) where Error: Swift.Error & Equatable { guard case .failRequest(let actualError, let actualFinalStreamAction) = self else { - return XCTFail("expected .failRequest(\(expectedError), \(expectedFinalStreamAction)) but got \(self)", file: file, line: line) + return XCTFail( + "expected .failRequest(\(expectedError), \(expectedFinalStreamAction)) but got \(self)", + file: file, + line: line + ) } if let actualError = actualError as? Error { XCTAssertEqual(actualError, expectedError, file: file, line: line) diff --git a/Tests/AsyncHTTPClientTests/IdleTimeoutNoReuseTests.swift b/Tests/AsyncHTTPClientTests/IdleTimeoutNoReuseTests.swift new file mode 100644 index 000000000..e9a0d46dc --- /dev/null +++ b/Tests/AsyncHTTPClientTests/IdleTimeoutNoReuseTests.swift @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class TestIdleTimeoutNoReuse: XCTestCaseHTTPClientTestsBaseClass { + func testIdleTimeoutNoReuse() throws { + var req = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .GET) + XCTAssertNoThrow(try self.defaultClient.execute(request: req, deadline: .now() + .seconds(2)).wait()) + req.headers.add(name: "X-internal-delay", value: "2500") + try self.defaultClient.eventLoopGroup.next().scheduleTask(in: .milliseconds(250)) {}.futureResult.wait() + XCTAssertNoThrow(try self.defaultClient.execute(request: req).timeout(after: .seconds(10)).wait()) + } +} diff --git a/Tests/AsyncHTTPClientTests/LRUCacheTests+XCTest.swift b/Tests/AsyncHTTPClientTests/LRUCacheTests+XCTest.swift deleted file mode 100644 index a0231bf0d..000000000 --- a/Tests/AsyncHTTPClientTests/LRUCacheTests+XCTest.swift +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// LRUCacheTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension LRUCacheTests { - static var allTests: [(String, (LRUCacheTests) -> () throws -> Void)] { - return [ - ("testBasicsWork", testBasicsWork), - ("testCachesTheRightThings", testCachesTheRightThings), - ("testAppendingTheSameDoesNotEvictButUpdates", testAppendingTheSameDoesNotEvictButUpdates), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/LRUCacheTests.swift b/Tests/AsyncHTTPClientTests/LRUCacheTests.swift index 6392bcebe..6173c34eb 100644 --- a/Tests/AsyncHTTPClientTests/LRUCacheTests.swift +++ b/Tests/AsyncHTTPClientTests/LRUCacheTests.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import XCTest +@testable import AsyncHTTPClient + class LRUCacheTests: XCTestCase { func testBasicsWork() { var cache = LRUCache(capacity: 1) diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift b/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift index eedc499ad..e49c67f19 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOHTTP1 import NIOSSL +@testable import AsyncHTTPClient + /// A mock connection pool (not creating any actual connections) that is used to validate /// connection actions returned by the `HTTPConnectionPool.StateMachine`. struct MockConnectionPool { @@ -541,17 +542,22 @@ extension MockConnectionPool { ) throws -> (Self, HTTPConnectionPool.StateMachine) { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: maxNumberOfConnections + maximumConcurrentHTTP1Connections: maxNumberOfConnections, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil ) var connections = MockConnectionPool() var queuer = MockRequestQueuer() for _ in 0.. (Self, HTTPConnectionPool.StateMachine) { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 8 + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: false, + maximumConnectionUses: nil ) var connections = MockConnectionPool() var queuer = MockRequestQueuer() // 1. Schedule one request to create a connection - let mockRequest = MockHTTPRequest(eventLoop: eventLoop ?? elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop ?? elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) - guard case .scheduleRequestTimeout(request, on: let waitEL) = executeAction.request, mockRequest.eventLoop === waitEL else { + guard case .scheduleRequestTimeout(request, on: let waitEL) = executeAction.request, + mockRequest.eventLoop === waitEL + else { throw SetupError.expectedRequestToBeAddedToQueue } @@ -628,17 +639,16 @@ extension MockConnectionPool { // 2. the connection becomes available - let newConnection = try connections.succeedConnectionCreationHTTP2(connectionID, maxConcurrentStreams: maxConcurrentStreams) + let newConnection = try connections.succeedConnectionCreationHTTP2( + connectionID, + maxConcurrentStreams: maxConcurrentStreams + ) let action = state.newHTTP2ConnectionCreated(newConnection, maxConcurrentStreams: maxConcurrentStreams) guard case .executeRequestsAndCancelTimeouts([request], newConnection) = action.request else { throw SetupError.expectedPreviouslyQueuedRequestToBeRunNow } - guard case .migration(createConnections: let create, closeConnections: [], scheduleTimeout: nil) = action.connection, create.isEmpty else { - throw SetupError.expectedNoConnectionAction - } - guard try queuer.get(request.id, request: request.__testOnly_wrapped_request()) === mockRequest else { throw SetupError.expectedPreviouslyQueuedRequestToBeRunNow } @@ -664,7 +674,7 @@ extension MockConnectionPool { /// A request that can be used when testing the `HTTPConnectionPool.StateMachine` /// with the `MockConnectionPool`. -class MockHTTPRequest: HTTPSchedulableRequest { +final class MockHTTPScheduableRequest: HTTPSchedulableRequest { let logger: Logger let connectionDeadline: NIODeadline let requestOptions: RequestOptions @@ -672,10 +682,12 @@ class MockHTTPRequest: HTTPSchedulableRequest { let preferredEventLoop: EventLoop let requiredEventLoop: EventLoop? - init(eventLoop: EventLoop, - logger: Logger = Logger(label: "mock"), - connectionTimeout: TimeAmount = .seconds(60), - requiresEventLoopForChannel: Bool = false) { + init( + eventLoop: EventLoop, + logger: Logger = Logger(label: "mock"), + connectionTimeout: TimeAmount = .seconds(60), + requiresEventLoopForChannel: Bool = false + ) { self.logger = logger self.connectionDeadline = .now() + connectionTimeout @@ -690,7 +702,7 @@ class MockHTTPRequest: HTTPSchedulableRequest { } var eventLoop: EventLoop { - return self.preferredEventLoop + self.preferredEventLoop } // MARK: HTTPSchedulableRequest diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift new file mode 100644 index 000000000..021c69731 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift @@ -0,0 +1,164 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOHTTP1 +import XCTest + +@testable import AsyncHTTPClient + +final class MockHTTPExecutableRequest: HTTPExecutableRequest { + enum Event { + /// ``Event`` without associated values + enum Kind: Hashable { + case willExecuteRequest + case requestHeadSent + case resumeRequestBodyStream + case pauseRequestBodyStream + case receiveResponseHead + case receiveResponseBodyParts + case succeedRequest + case fail + } + + case willExecuteRequest(HTTPRequestExecutor) + case requestHeadSent + case resumeRequestBodyStream + case pauseRequestBodyStream + case receiveResponseHead(HTTPResponseHead) + case receiveResponseBodyParts(CircularBuffer) + case succeedRequest(CircularBuffer?) + case fail(Error) + + var kind: Kind { + switch self { + case .willExecuteRequest: return .willExecuteRequest + case .requestHeadSent: return .requestHeadSent + case .resumeRequestBodyStream: return .resumeRequestBodyStream + case .pauseRequestBodyStream: return .pauseRequestBodyStream + case .receiveResponseHead: return .receiveResponseHead + case .receiveResponseBodyParts: return .receiveResponseBodyParts + case .succeedRequest: return .succeedRequest + case .fail: return .fail + } + } + } + + var logger: Logging.Logger = Logger(label: "request") + var requestHead: NIOHTTP1.HTTPRequestHead + var requestFramingMetadata: RequestFramingMetadata + var requestOptions: RequestOptions = .forTests() + + /// if true and ``HTTPExecutableRequest`` method is called without setting a corresponding callback on `self` e.g. + /// If ``HTTPExecutableRequest\.willExecuteRequest(_:)`` is called but ``willExecuteRequestCallback`` is not set, + /// ``XCTestFail(_:)`` will be called to fail the current test. + var raiseErrorIfUnimplementedMethodIsCalled: Bool = true + private var file: StaticString + private var line: UInt + + var willExecuteRequestCallback: ((HTTPRequestExecutor) -> Void)? + var requestHeadSentCallback: (() -> Void)? + var resumeRequestBodyStreamCallback: (() -> Void)? + var pauseRequestBodyStreamCallback: (() -> Void)? + var receiveResponseHeadCallback: ((HTTPResponseHead) -> Void)? + var receiveResponseBodyPartsCallback: ((CircularBuffer) -> Void)? + var succeedRequestCallback: ((CircularBuffer?) -> Void)? + var failCallback: ((Error) -> Void)? + + /// captures all ``HTTPExecutableRequest`` method calls in the order of occurrence, including arguments. + /// If you are not interested in the arguments you can use `events.map(\.kind)` to get all events without arguments. + private(set) var events: [Event] = [] + + init( + head: NIOHTTP1.HTTPRequestHead = .init(version: .http1_1, method: .GET, uri: "http://localhost/"), + framingMetadata: RequestFramingMetadata = .init(connectionClose: false, body: .fixedSize(0)), + file: StaticString = #file, + line: UInt = #line + ) { + self.requestHead = head + self.requestFramingMetadata = framingMetadata + self.file = file + self.line = line + } + + private func calledUnimplementedMethod(_ name: String) { + guard self.raiseErrorIfUnimplementedMethodIsCalled else { return } + XCTFail("\(name) invoked but it is not implemented", file: self.file, line: self.line) + } + + func willExecuteRequest(_ executor: HTTPRequestExecutor) { + self.events.append(.willExecuteRequest(executor)) + guard let willExecuteRequestCallback = willExecuteRequestCallback else { + return self.calledUnimplementedMethod(#function) + } + willExecuteRequestCallback(executor) + } + + func requestHeadSent() { + self.events.append(.requestHeadSent) + guard let requestHeadSentCallback = requestHeadSentCallback else { + return self.calledUnimplementedMethod(#function) + } + requestHeadSentCallback() + } + + func resumeRequestBodyStream() { + self.events.append(.resumeRequestBodyStream) + guard let resumeRequestBodyStreamCallback = resumeRequestBodyStreamCallback else { + return self.calledUnimplementedMethod(#function) + } + resumeRequestBodyStreamCallback() + } + + func pauseRequestBodyStream() { + self.events.append(.pauseRequestBodyStream) + guard let pauseRequestBodyStreamCallback = pauseRequestBodyStreamCallback else { + return self.calledUnimplementedMethod(#function) + } + pauseRequestBodyStreamCallback() + } + + func receiveResponseHead(_ head: HTTPResponseHead) { + self.events.append(.receiveResponseHead(head)) + guard let receiveResponseHeadCallback = receiveResponseHeadCallback else { + return self.calledUnimplementedMethod(#function) + } + receiveResponseHeadCallback(head) + } + + func receiveResponseBodyParts(_ buffer: CircularBuffer) { + self.events.append(.receiveResponseBodyParts(buffer)) + guard let receiveResponseBodyPartsCallback = receiveResponseBodyPartsCallback else { + return self.calledUnimplementedMethod(#function) + } + receiveResponseBodyPartsCallback(buffer) + } + + func succeedRequest(_ buffer: CircularBuffer?) { + self.events.append(.succeedRequest(buffer)) + guard let succeedRequestCallback = succeedRequestCallback else { + return self.calledUnimplementedMethod(#function) + } + succeedRequestCallback(buffer) + } + + func fail(_ error: Error) { + self.events.append(.fail(error)) + guard let failCallback = failCallback else { + return self.calledUnimplementedMethod(#function) + } + failCallback(error) + } +} diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift index b5b67c809..f85c75ce5 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift @@ -12,10 +12,11 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOConcurrencyHelpers import NIOCore +@testable import AsyncHTTPClient + // This is a MockRequestExecutor, that is synchronized on its EventLoop. final class MockRequestExecutor { enum Errors: Error { @@ -47,7 +48,7 @@ final class MockRequestExecutor { } var requestBodyPartsCount: Int { - return self.blockingQueue.count + self.blockingQueue.count } let eventLoop: EventLoop @@ -82,7 +83,8 @@ final class MockRequestExecutor { request.requestHeadSent() } - func receiveRequestBody(deadline: NIODeadline = .now() + .seconds(5), _ verify: (ByteBuffer) throws -> Void) throws { + func receiveRequestBody(deadline: NIODeadline = .now() + .seconds(5), _ verify: (ByteBuffer) throws -> Void) throws + { enum ReceiveAction { case value(RequestParts) case future(EventLoopFuture) @@ -155,10 +157,11 @@ final class MockRequestExecutor { func receiveResponseDemand(deadline: NIODeadline = .now() + .seconds(5)) throws { let secondsUntilDeath = deadline - NIODeadline.now() - guard self.responseBodyDemandLock.lock( - whenValue: true, - timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) - ) + guard + self.responseBodyDemandLock.lock( + whenValue: true, + timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) + ) else { throw TimeoutError() } @@ -168,10 +171,11 @@ final class MockRequestExecutor { func receiveCancellation(deadline: NIODeadline = .now() + .seconds(5)) throws { let secondsUntilDeath = deadline - NIODeadline.now() - guard self.cancellationLock.lock( - whenValue: true, - timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) - ) + guard + self.cancellationLock.lock( + whenValue: true, + timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) + ) else { throw TimeoutError() } @@ -184,12 +188,14 @@ extension MockRequestExecutor: HTTPRequestExecutor { // this should always be called twice. When we receive the first call, the next call to produce // data is already scheduled. If we call pause here, once, after the second call new subsequent // calls should not be scheduled. - func writeRequestBodyPart(_ part: IOData, request: HTTPExecutableRequest) { + func writeRequestBodyPart(_ part: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { self.writeNextRequestPart(.body(part), request: request) + promise?.succeed(()) } - func finishRequestBodyStream(_ request: HTTPExecutableRequest) { + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { self.writeNextRequestPart(.endOfStream, request: request) + promise?.succeed(()) } private func writeNextRequestPart(_ part: RequestParts, request: HTTPExecutableRequest) { @@ -263,8 +269,12 @@ extension MockRequestExecutor { internal func popFirst(deadline: NIODeadline) throws -> Element { let secondsUntilDeath = deadline - NIODeadline.now() - guard self.condition.lock(whenValue: true, - timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000)) else { + guard + self.condition.lock( + whenValue: true, + timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) + ) + else { throw TimeoutError() } let first = self.buffer.removeFirst() diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift b/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift index e81f1ed0a..44e820444 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOHTTP1 +@testable import AsyncHTTPClient + /// A mock request queue (not creating any timers) that is used to validate /// request actions returned by the `HTTPConnectionPool.StateMachine`. struct MockRequestQueuer { @@ -82,11 +83,11 @@ struct MockRequestQueuer { return waiter.request } - mutating func timeoutRandomRequest() -> RequestID? { - guard let waiterID = self.waiters.randomElement().map(\.0) else { + mutating func timeoutRandomRequest() -> (RequestID, HTTPSchedulableRequest)? { + guard let waiter = self.waiters.randomElement() else { return nil } - self.waiters.removeValue(forKey: waiterID) - return waiterID + self.waiters.removeValue(forKey: waiter.key) + return (waiter.key, waiter.value.request) } } diff --git a/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift b/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift new file mode 100644 index 000000000..033214ffe --- /dev/null +++ b/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift @@ -0,0 +1,87 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if canImport(Network) +@testable import AsyncHTTPClient +import Network +import NIOCore +import NIOEmbedded +import NIOSSL +import NIOTransportServices +import XCTest + +@available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) +class NWWaitingHandlerTests: XCTestCase { + class MockRequester: HTTPConnectionRequester { + var waitingForConnectivityCalled = false + var connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID? + var transientError: NWError? + + func http1ConnectionCreated(_: AsyncHTTPClient.HTTP1Connection) {} + + func http2ConnectionCreated(_: AsyncHTTPClient.HTTP2Connection, maximumStreams: Int) {} + + func failedToCreateHTTPConnection(_: AsyncHTTPClient.HTTPConnectionPool.Connection.ID, error: Error) {} + + func waitingForConnectivity(_ connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID, error: Error) { + self.waitingForConnectivityCalled = true + self.connectionID = connectionID + self.transientError = error as? NWError + } + } + + func testWaitingHandlerInvokesWaitingForConnectivity() { + let requester = MockRequester() + let connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID = 1 + let waitingEventHandler = NWWaitingHandler(requester: requester, connectionID: connectionID) + let embedded = EmbeddedChannel(handlers: [waitingEventHandler]) + + embedded.pipeline.fireUserInboundEventTriggered( + NIOTSNetworkEvents.WaitingForConnectivity(transientError: .dns(1)) + ) + + XCTAssertTrue( + requester.waitingForConnectivityCalled, + "Expected the handler to invoke .waitingForConnectivity on the requester" + ) + XCTAssertEqual(requester.connectionID, connectionID, "Expected the handler to pass connectionID to requester") + XCTAssertEqual(requester.transientError, NWError.dns(1)) + } + + func testWaitingHandlerDoesNotInvokeWaitingForConnectionOnUnrelatedErrors() { + let requester = MockRequester() + let waitingEventHandler = NWWaitingHandler(requester: requester, connectionID: 1) + let embedded = EmbeddedChannel(handlers: [waitingEventHandler]) + embedded.pipeline.fireUserInboundEventTriggered(NIOTSNetworkEvents.BetterPathAvailable()) + + XCTAssertFalse( + requester.waitingForConnectivityCalled, + "Should not call .waitingForConnectivity on unrelated events" + ) + } + + func testWaitingHandlerPassesTheEventDownTheContext() { + let requester = MockRequester() + let waitingEventHandler = NWWaitingHandler(requester: requester, connectionID: 1) + let tlsEventsHandler = TLSEventsHandler(deadline: nil) + let embedded = EmbeddedChannel(handlers: [waitingEventHandler, tlsEventsHandler]) + + embedded.pipeline.fireErrorCaught(NIOSSLError.handshakeFailed(BoringSSLError.wantConnect)) + XCTAssertThrowsError(try XCTUnwrap(tlsEventsHandler.tlsEstablishedFuture).wait()) { + XCTAssertEqualTypeAndValue($0, NIOSSLError.handshakeFailed(BoringSSLError.wantConnect)) + } + } +} + +#endif diff --git a/Tests/AsyncHTTPClientTests/NoBytesSentOverBodyLimitTests.swift b/Tests/AsyncHTTPClientTests/NoBytesSentOverBodyLimitTests.swift new file mode 100644 index 000000000..026a45d4c --- /dev/null +++ b/Tests/AsyncHTTPClientTests/NoBytesSentOverBodyLimitTests.swift @@ -0,0 +1,81 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class NoBytesSentOverBodyLimitTests: XCTestCaseHTTPClientTestsBaseClass { + func testNoBytesSentOverBodyLimit() throws { + let server = NIOHTTP1TestServer(group: self.serverGroup) + defer { + XCTAssertNoThrow(try server.stop()) + } + + let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n" + + let request = try Request( + url: "http://localhost:\(server.serverPort)", + body: .stream(contentLength: 1) { streamWriter in + streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong))) + } + ) + + let future = self.defaultClient.execute(request: request) + + // Okay, what happens here needs an explanation: + // + // In the request state machine, we should start the request, which will lead to an + // invocation of `context.write(HTTPRequestHead)`. Since we will receive a streamed request + // body a `context.flush()` will be issued. Further the request stream will be started. + // Since the request stream immediately produces to much data, the request will be failed + // and the connection will be closed. + // + // Even though a flush was issued after the request head, there is no guarantee that the + // request head was written to the network. For this reason we must accept not receiving a + // request and receiving a request head. + + do { + _ = try server.receiveHead() + + // A request head was sent. We expect the request now to fail with a parsing error, + // since the client ended the connection to early (from the server's point of view.) + XCTAssertThrowsError(try server.readInbound()) { + XCTAssertEqual($0 as? HTTPParserError, HTTPParserError.invalidEOFState) + } + } catch { + // TBD: We sadly can't verify the error type, since it is private in `NIOTestUtils`: + // NIOTestUtils.BlockingQueue.TimeoutError + } + + // request must always be failed with this error + XCTAssertThrowsError(try future.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .bodyLengthMismatch) + } + } +} diff --git a/Tests/AsyncHTTPClientTests/RacePoolIdleConnectionsAndGetTests.swift b/Tests/AsyncHTTPClientTests/RacePoolIdleConnectionsAndGetTests.swift new file mode 100644 index 000000000..35a09c421 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/RacePoolIdleConnectionsAndGetTests.swift @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class RacePoolIdleConnectionsAndGetTests: XCTestCaseHTTPClientTestsBaseClass { + func testRacePoolIdleConnectionsAndGet() { + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(connectionPool: .init(idleTimeout: .milliseconds(10))) + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + } + for _ in 1...200 { + XCTAssertNoThrow(try localClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()) + Thread.sleep(forTimeInterval: 0.01 + .random(in: -0.01...0.01)) + } + } +} diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift deleted file mode 100644 index b2919081c..000000000 --- a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift +++ /dev/null @@ -1,38 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// RequestBagTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension RequestBagTests { - static var allTests: [(String, (RequestBagTests) -> () throws -> Void)] { - return [ - ("testWriteBackpressureWorks", testWriteBackpressureWorks), - ("testTaskIsFailedIfWritingFails", testTaskIsFailedIfWritingFails), - ("testCancelFailsTaskBeforeRequestIsSent", testCancelFailsTaskBeforeRequestIsSent), - ("testCancelFailsTaskAfterRequestIsSent", testCancelFailsTaskAfterRequestIsSent), - ("testCancelFailsTaskWhenTaskIsQueued", testCancelFailsTaskWhenTaskIsQueued), - ("testFailsTaskWhenTaskIsWaitingForMoreFromServer", testFailsTaskWhenTaskIsWaitingForMoreFromServer), - ("testHTTPUploadIsCancelledEvenThoughRequestSucceeds", testHTTPUploadIsCancelledEvenThoughRequestSucceeds), - ("testRaceBetweenConnectionCloseAndDemandMoreData", testRaceBetweenConnectionCloseAndDemandMoreData), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index 2f2eb8cf5..9aa595224 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -12,37 +12,54 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient +import Atomics import Logging +import NIOConcurrencyHelpers import NIOCore import NIOEmbedded import NIOHTTP1 +import NIOPosix import XCTest +@testable import AsyncHTTPClient + final class RequestBagTests: XCTestCase { func testWriteBackpressureWorks() { let embeddedEventLoop = EmbeddedEventLoop() defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } let logger = Logger(label: "test") - var writtenBytes = 0 - var writes = 0 + struct TestState { + var writtenBytes: Int = 0 + var writes: Int = 0 + var streamIsAllowedToWrite: Bool = false + } + + let testState = NIOLockedValueBox(TestState()) + let bytesToSent = (3000...10000).randomElement()! let expectedWrites = bytesToSent / 100 + ((bytesToSent % 100 > 0) ? 1 : 0) - var streamIsAllowedToWrite = false let writeDonePromise = embeddedEventLoop.makePromise(of: Void.self) - let requestBody: HTTPClient.Body = .stream(length: bytesToSent) { writer -> EventLoopFuture in - func write(donePromise: EventLoopPromise) { - XCTAssertTrue(streamIsAllowedToWrite) - guard writtenBytes < bytesToSent else { - return donePromise.succeed(()) + let requestBody: HTTPClient.Body = .stream(contentLength: Int64(bytesToSent)) { + writer -> EventLoopFuture in + @Sendable func write(donePromise: EventLoopPromise) { + let futureWrite: EventLoopFuture? = testState.withLockedValue { state in + XCTAssertTrue(state.streamIsAllowedToWrite) + guard state.writtenBytes < bytesToSent else { + donePromise.succeed(()) + return nil + } + let byteCount = min(bytesToSent - state.writtenBytes, 100) + let buffer = ByteBuffer(bytes: [UInt8](repeating: 1, count: byteCount)) + state.writes += 1 + return writer.write(.byteBuffer(buffer)) } - let byteCount = min(bytesToSent - writtenBytes, 100) - let buffer = ByteBuffer(bytes: [UInt8](repeating: 1, count: byteCount)) - writes += 1 - writer.write(.byteBuffer(buffer)).whenSuccess { _ in - writtenBytes += 100 + + futureWrite?.whenSuccess { _ in + testState.withLockedValue { state in + state.writtenBytes += 100 + } write(donePromise: donePromise) } } @@ -53,20 +70,24 @@ final class RequestBagTests: XCTestCase { } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody)) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.task.eventLoop === embeddedEventLoop) @@ -80,24 +101,26 @@ final class RequestBagTests: XCTestCase { executor.runRequest(bag) XCTAssertEqual(delegate.hitDidSendRequestHead, 1) - streamIsAllowedToWrite = true + testState.withLockedValue { $0.streamIsAllowedToWrite = true } bag.resumeRequestBodyStream() - streamIsAllowedToWrite = false + testState.withLockedValue { $0.streamIsAllowedToWrite = false } // after starting the body stream we should have received two writes var receivedBytes = 0 for i in 0.. EventLoopFuture in + let requestBody: HTTPClient.Body = .stream(contentLength: 12) { writer -> EventLoopFuture in writer.write(.byteBuffer(ByteBuffer(bytes: 0...3))).flatMap { _ -> EventLoopFuture in embeddedEventLoop.makeFailedFuture(TestError()) @@ -160,20 +187,24 @@ final class RequestBagTests: XCTestCase { } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody)) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.task.eventLoop === embeddedEventLoop) @@ -206,20 +237,22 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.eventLoop === embeddedEventLoop) let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) - bag.cancel() + bag.fail(HTTPClientError.cancelled) bag.willExecuteRequest(executor) XCTAssertTrue(executor.isCancelled, "The request bag, should call cancel immediately on the executor") @@ -228,6 +261,90 @@ final class RequestBagTests: XCTestCase { } } + func testDeadlineExceededFailsTaskEvenIfRaceBetweenCancelingSchedulerAndRequestStart() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + XCTAssert(bag.eventLoop === embeddedEventLoop) + + let queuer = MockTaskQueuer() + bag.requestWasQueued(queuer) + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + XCTAssertEqual(queuer.hitCancelCount, 0) + bag.deadlineExceeded() + XCTAssertEqual(queuer.hitCancelCount, 1) + + bag.willExecuteRequest(executor) + XCTAssertTrue(executor.isCancelled, "The request bag, should call cancel immediately on the executor") + XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .deadlineExceeded) + } + } + + func testCancelHasNoEffectAfterDeadlineExceededFailsTask() { + struct MyError: Error, Equatable {} + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + XCTAssert(bag.eventLoop === embeddedEventLoop) + + let queuer = MockTaskQueuer() + bag.requestWasQueued(queuer) + + XCTAssertEqual(queuer.hitCancelCount, 0) + bag.deadlineExceeded() + XCTAssertEqual(queuer.hitCancelCount, 1) + XCTAssertEqual(delegate.hitDidReceiveError, 0) + bag.fail(MyError()) + XCTAssertEqual(delegate.hitDidReceiveError, 1) + + bag.fail(HTTPClientError.cancelled) + XCTAssertEqual(delegate.hitDidReceiveError, 1) + + XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertEqualTypeAndValue($0, MyError()) + } + } + func testCancelFailsTaskAfterRequestIsSent() { let embeddedEventLoop = EmbeddedEventLoop() defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } @@ -239,15 +356,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.eventLoop === embeddedEventLoop) @@ -261,10 +380,10 @@ final class RequestBagTests: XCTestCase { XCTAssertEqual(delegate.hitDidSendRequestHead, 1) XCTAssertEqual(delegate.hitDidSendRequest, 1) - bag.cancel() + bag.fail(HTTPClientError.cancelled) XCTAssertTrue(executor.isCancelled, "The request bag, should call cancel immediately on the executor") - XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertThrowsError(try bag.task.futureResult.timeout(after: .seconds(10)).wait()) { XCTAssertEqual($0 as? HTTPClientError, .cancelled) } } @@ -280,22 +399,24 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let queuer = MockTaskQueuer() bag.requestWasQueued(queuer) XCTAssertEqual(queuer.hitCancelCount, 0) - bag.cancel() + bag.fail(HTTPClientError.cancelled) XCTAssertEqual(queuer.hitCancelCount, 1) XCTAssertThrowsError(try bag.task.futureResult.wait()) { @@ -314,15 +435,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) @@ -336,6 +459,120 @@ final class RequestBagTests: XCTestCase { } } + func testChannelBecomingWritableDoesntCrashCancelledTask() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://swift.org", + body: .bytes([1, 2, 3, 4, 5]) + ) + ) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + executor.runRequest(bag) + + // This simulates a race between the user cancelling the task (which invokes `RequestBag.fail(_:)`) and the + // call to `resumeRequestBodyStream` (which comes from the `Channel` event loop and so may have to hop. + bag.fail(HTTPClientError.cancelled) + bag.resumeRequestBodyStream() + + XCTAssertEqual(executor.isCancelled, true) + XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + } + + func testDidReceiveBodyPartFailedPromise() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://swift.org", + method: .POST, + body: .byteBuffer(.init(bytes: [1])) + ) + ) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + struct MyError: Error, Equatable {} + final class Delegate: HTTPClientResponseDelegate { + typealias Response = Void + let didFinishPromise: EventLoopPromise + init(didFinishPromise: EventLoopPromise) { + self.didFinishPromise = didFinishPromise + } + + func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { + task.eventLoop.makeFailedFuture(MyError()) + } + + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + self.didFinishPromise.fail(error) + } + + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws { + XCTFail("\(#function) should not be called") + self.didFinishPromise.succeed(()) + } + } + let delegate = Delegate(didFinishPromise: embeddedEventLoop.makePromise()) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + + executor.runRequest(bag) + + bag.resumeRequestBodyStream() + XCTAssertNoThrow(try executor.receiveRequestBody { XCTAssertEqual($0, ByteBuffer(bytes: [1])) }) + + bag.receiveResponseHead(.init(version: .http1_1, status: .ok)) + + bag.succeedRequest([ByteBuffer([1])]) + + XCTAssertThrowsError(try delegate.didFinishPromise.futureResult.wait()) { error in + XCTAssertEqualTypeAndValue(error, MyError()) + } + XCTAssertThrowsError(try bag.task.futureResult.wait()) { error in + XCTAssertEqualTypeAndValue(error, MyError()) + } + } + func testHTTPUploadIsCancelledEvenThoughRequestSucceeds() { let embeddedEventLoop = EmbeddedEventLoop() defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } @@ -343,42 +580,46 @@ final class RequestBagTests: XCTestCase { var maybeRequest: HTTPClient.Request? let writeSecondPartPromise = embeddedEventLoop.makePromise(of: Void.self) - - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "https://swift.org", - method: .POST, - headers: ["content-length": "12"], - body: .stream(length: 12) { writer -> EventLoopFuture in - var firstWriteSuccess = false - return writer.write(.byteBuffer(.init(bytes: 0...3))).flatMap { _ in - firstWriteSuccess = true - - return writeSecondPartPromise.futureResult - }.flatMap { - return writer.write(.byteBuffer(.init(bytes: 4...7))) - }.always { result in - XCTAssertTrue(firstWriteSuccess) - - guard case .failure(let error) = result else { - return XCTFail("Expected the second write to fail") + let firstWriteSuccess: NIOLockedValueBox = .init(false) + + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://swift.org", + method: .POST, + headers: ["content-length": "12"], + body: .stream(contentLength: 12) { writer -> EventLoopFuture in + writer.write(.byteBuffer(.init(bytes: 0...3))).flatMap { _ in + firstWriteSuccess.withLockedValue { $0 = true } + + return writeSecondPartPromise.futureResult + }.flatMap { + writer.write(.byteBuffer(.init(bytes: 4...7))) + }.always { result in + XCTAssertTrue(firstWriteSuccess.withLockedValue { $0 }) + + guard case .failure(let error) = result else { + return XCTFail("Expected the second write to fail") + } + XCTAssertEqual(error as? HTTPClientError, .requestStreamCancelled) } - XCTAssertEqual(error as? HTTPClientError, .requestStreamCancelled) } - } - )) + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) @@ -414,15 +655,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) @@ -457,6 +700,293 @@ final class RequestBagTests: XCTestCase { XCTAssertNoThrow(try XCTUnwrap(delegate.backpressurePromise).succeed(())) XCTAssertEqual(delegate.hitDidReceiveResponse, 1) } + + func testRedirectWith3KBBody() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + var redirectTriggered = false + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task( + eventLoop: embeddedEventLoop, + logger: logger + ) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + executor.runRequest(bag) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseHead( + .init( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"] + ) + ) + XCTAssertNil(delegate.backpressurePromise) + XCTAssertTrue(executor.signalledDemandForResponseBody) + executor.resetResponseStreamDemandSignal() + + // "foo" is forwarded for consumption. We expect the RequestBag to consume "foo" with the + // delegate and call demandMoreBody afterwards. + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseBodyParts([ByteBuffer(repeating: 0, count: 1024)]) + XCTAssertTrue(executor.signalledDemandForResponseBody) + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertNil(delegate.backpressurePromise) + executor.resetResponseStreamDemandSignal() + + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseBodyParts([ByteBuffer(repeating: 1, count: 1024)]) + XCTAssertTrue(executor.signalledDemandForResponseBody) + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertNil(delegate.backpressurePromise) + executor.resetResponseStreamDemandSignal() + + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.succeedRequest([ByteBuffer(repeating: 2, count: 1024)]) + XCTAssertFalse(executor.signalledDemandForResponseBody) + XCTAssertEqual(delegate.hitDidReceiveResponse, 0) + XCTAssertNil(delegate.backpressurePromise) + executor.resetResponseStreamDemandSignal() + + XCTAssertTrue(redirectTriggered) + } + + func testRedirectWith4KBBodyAnnouncedInResponseHead() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + var redirectTriggered = false + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task( + eventLoop: embeddedEventLoop, + logger: logger + ) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + executor.runRequest(bag) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseHead( + .init( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(4 * 1024)", "location": "https://swift.org/sswg"] + ) + ) + XCTAssertNil(delegate.backpressurePromise) + XCTAssertFalse(executor.signalledDemandForResponseBody) + XCTAssertTrue(executor.isCancelled) + + XCTAssertTrue(redirectTriggered) + } + + func testRedirectWith4KBBodyNotAnnouncedInResponseHead() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + var redirectTriggered = false + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task( + eventLoop: embeddedEventLoop, + logger: logger + ) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + executor.runRequest(bag) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseHead( + .init( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"] + ) + ) + XCTAssertNil(delegate.backpressurePromise) + XCTAssertTrue(executor.signalledDemandForResponseBody) + executor.resetResponseStreamDemandSignal() + + // "foo" is forwarded for consumption. We expect the RequestBag to consume "foo" with the + // delegate and call demandMoreBody afterwards. + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseBodyParts([ByteBuffer(repeating: 0, count: 2024)]) + XCTAssertTrue(executor.signalledDemandForResponseBody) + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertNil(delegate.backpressurePromise) + executor.resetResponseStreamDemandSignal() + + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertFalse(executor.isCancelled) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseBodyParts([ByteBuffer(repeating: 1, count: 2024)]) + XCTAssertFalse(executor.signalledDemandForResponseBody) + XCTAssertTrue(executor.isCancelled) + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertNil(delegate.backpressurePromise) + executor.resetResponseStreamDemandSignal() + + XCTAssertTrue(redirectTriggered) + } + + func testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise() { + final class LeakDetector {} + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group)) + defer { XCTAssertNoThrow(try httpClient.shutdown().wait()) } + + let httpBin = HTTPBin() + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + var leakDetector = LeakDetector() + + do { + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/", method: .POST) + ) + guard var request = maybeRequest else { return XCTFail("Expected to have a request here") } + + let writerPromise = group.any().makePromise(of: HTTPClient.Body.StreamWriter.self) + let donePromise = group.any().makePromise(of: Void.self) + request.body = .stream { [leakDetector] writer in + _ = leakDetector + writerPromise.succeed(writer) + return donePromise.futureResult + } + + let resultFuture = httpClient.execute(request: request) + request.body = nil + writerPromise.futureResult.whenSuccess { writer in + writer.write(.byteBuffer(ByteBuffer(string: "hello"))).map { + print("written") + }.cascade(to: donePromise) + } + XCTAssertNoThrow(try donePromise.futureResult.wait()) + print("HTTP sent") + + var result: HTTPClient.Response? + XCTAssertNoThrow(result = try resultFuture.wait()) + + XCTAssertEqual(.ok, result?.status) + let body = result?.body.map { String(buffer: $0) } + XCTAssertNotNil(body) + print("HTTP done") + } + XCTAssertTrue(isKnownUniquelyReferenced(&leakDetector)) + } +} + +extension HTTPClient.Task { + convenience init( + eventLoop: EventLoop, + logger: Logger + ) { + self.init(eventLoop: eventLoop, logger: logger) { + preconditionFailure("thread pool not needed in tests") + } + } } class UploadCountingDelegate: HTTPClientResponseDelegate { @@ -522,20 +1052,35 @@ class UploadCountingDelegate: HTTPClientResponseDelegate { } } -class MockTaskQueuer: HTTPRequestScheduler { - private(set) var hitCancelCount = 0 +final class MockTaskQueuer: HTTPRequestScheduler { + private let _hitCancelCount = ManagedAtomic(0) - init() {} + var hitCancelCount: Int { + self._hitCancelCount.load(ordering: .sequentiallyConsistent) + } + + let onCancelRequest: (@Sendable (HTTPSchedulableRequest) -> Void)? + + init(onCancelRequest: (@Sendable (HTTPSchedulableRequest) -> Void)? = nil) { + self.onCancelRequest = onCancelRequest + } - func cancelRequest(_: HTTPSchedulableRequest) { - self.hitCancelCount += 1 + func cancelRequest(_ request: HTTPSchedulableRequest) { + self._hitCancelCount.wrappingIncrement(ordering: .sequentiallyConsistent) + self.onCancelRequest?(request) } } extension RequestOptions { - static func forTests(idleReadTimeout: TimeAmount? = nil) -> Self { + static func forTests( + idleReadTimeout: TimeAmount? = nil, + idleWriteTimeout: TimeAmount? = nil, + dnsOverride: [String: String] = [:] + ) -> Self { RequestOptions( - idleReadTimeout: idleReadTimeout + idleReadTimeout: idleReadTimeout, + idleWriteTimeout: idleWriteTimeout, + dnsOverride: dnsOverride ) } } diff --git a/Tests/AsyncHTTPClientTests/RequestValidationTests+XCTest.swift b/Tests/AsyncHTTPClientTests/RequestValidationTests+XCTest.swift deleted file mode 100644 index 3a93d70ec..000000000 --- a/Tests/AsyncHTTPClientTests/RequestValidationTests+XCTest.swift +++ /dev/null @@ -1,52 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// RequestValidationTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension RequestValidationTests { - static var allTests: [(String, (RequestValidationTests) -> () throws -> Void)] { - return [ - ("testContentLengthHeaderIsRemovedFromGETIfNoBody", testContentLengthHeaderIsRemovedFromGETIfNoBody), - ("testContentLengthHeaderIsAddedToPOSTAndPUTWithNoBody", testContentLengthHeaderIsAddedToPOSTAndPUTWithNoBody), - ("testContentLengthHeaderIsChangedIfBodyHasDifferentLength", testContentLengthHeaderIsChangedIfBodyHasDifferentLength), - ("testTRACERequestMustNotHaveBody", testTRACERequestMustNotHaveBody), - ("testGET_HEAD_DELETE_CONNECTRequestCanHaveBody", testGET_HEAD_DELETE_CONNECTRequestCanHaveBody), - ("testInvalidHeaderFieldNames", testInvalidHeaderFieldNames), - ("testValidHeaderFieldNames", testValidHeaderFieldNames), - ("testMetadataDetectConnectionClose", testMetadataDetectConnectionClose), - ("testMetadataDefaultIsConnectionCloseIsFalse", testMetadataDefaultIsConnectionCloseIsFalse), - ("testNoHeadersNoBody", testNoHeadersNoBody), - ("testNoHeadersHasBody", testNoHeadersHasBody), - ("testContentLengthHeaderNoBody", testContentLengthHeaderNoBody), - ("testContentLengthHeaderHasBody", testContentLengthHeaderHasBody), - ("testTransferEncodingHeaderNoBody", testTransferEncodingHeaderNoBody), - ("testTransferEncodingHeaderHasBody", testTransferEncodingHeaderHasBody), - ("testBothHeadersNoBody", testBothHeadersNoBody), - ("testBothHeadersHasBody", testBothHeadersHasBody), - ("testHostHeaderIsSetCorrectlyInCreateRequestHead", testHostHeaderIsSetCorrectlyInCreateRequestHead), - ("testTraceMethodIsNotAllowedToHaveAFixedLengthBody", testTraceMethodIsNotAllowedToHaveAFixedLengthBody), - ("testTraceMethodIsNotAllowedToHaveADynamicLengthBody", testTraceMethodIsNotAllowedToHaveADynamicLengthBody), - ("testTransferEncodingsAreOverwrittenIfBodyLengthIsFixed", testTransferEncodingsAreOverwrittenIfBodyLengthIsFixed), - ("testTransferEncodingsAreOverwrittenIfBodyLengthIsDynamic", testTransferEncodingsAreOverwrittenIfBodyLengthIsDynamic), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/RequestValidationTests.swift b/Tests/AsyncHTTPClientTests/RequestValidationTests.swift index c50d3afd1..ea5a6bd66 100644 --- a/Tests/AsyncHTTPClientTests/RequestValidationTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestValidationTests.swift @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class RequestValidationTests: XCTestCase { func testContentLengthHeaderIsRemovedFromGETIfNoBody() { var headers = HTTPHeaders([("Content-Length", "0")]) @@ -29,13 +30,17 @@ class RequestValidationTests: XCTestCase { func testContentLengthHeaderIsAddedToPOSTAndPUTWithNoBody() { var putHeaders = HTTPHeaders() var putMetadata: RequestFramingMetadata? - XCTAssertNoThrow(putMetadata = try putHeaders.validateAndSetTransportFraming(method: .PUT, bodyLength: .known(0))) + XCTAssertNoThrow( + putMetadata = try putHeaders.validateAndSetTransportFraming(method: .PUT, bodyLength: .known(0)) + ) XCTAssertEqual(putHeaders.first(name: "Content-Length"), "0") XCTAssertEqual(putMetadata?.body, .fixedSize(0)) var postHeaders = HTTPHeaders() var postMetadata: RequestFramingMetadata? - XCTAssertNoThrow(postMetadata = try postHeaders.validateAndSetTransportFraming(method: .POST, bodyLength: .known(0))) + XCTAssertNoThrow( + postMetadata = try postHeaders.validateAndSetTransportFraming(method: .POST, bodyLength: .known(0)) + ) XCTAssertEqual(postHeaders.first(name: "Content-Length"), "0") XCTAssertEqual(postMetadata?.body, .fixedSize(0)) } @@ -90,7 +95,7 @@ class RequestValidationTests: XCTestCase { func testMetadataDetectConnectionClose() { var headers = HTTPHeaders([ - ("Connection", "close"), + ("Connection", "close") ]) var metadata: RequestFramingMetadata? XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: .GET, bodyLength: .known(0))) @@ -114,7 +119,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT, .TRACE] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -123,7 +130,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertEqual(headers["content-length"].first, "0") XCTAssertFalse(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -139,7 +148,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -149,7 +160,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown)) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .stream) @@ -159,7 +172,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -169,7 +184,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown)) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .stream) @@ -184,7 +201,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT, .TRACE] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -193,7 +212,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertEqual(headers["content-length"].first, "0") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -208,7 +229,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -217,7 +240,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -232,7 +257,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT, .TRACE] { var headers: HTTPHeaders = .init([("Transfer-Encoding", "chunked")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertFalse(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -241,7 +268,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init([("Transfer-Encoding", "chunked")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertEqual(headers["content-length"].first, "0") XCTAssertFalse(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -337,21 +366,27 @@ class RequestValidationTests: XCTestCase { func testTransferEncodingsAreOverwrittenIfBodyLengthIsFixed() { var headers: HTTPHeaders = [ - "Transfer-Encoding": "gzip, chunked", + "Transfer-Encoding": "gzip, chunked" ] XCTAssertNoThrow(try headers.validateAndSetTransportFraming(method: .POST, bodyLength: .known(1))) - XCTAssertEqual(headers, [ - "Content-Length": "1", - ]) + XCTAssertEqual( + headers, + [ + "Content-Length": "1" + ] + ) } func testTransferEncodingsAreOverwrittenIfBodyLengthIsDynamic() { var headers: HTTPHeaders = [ - "Transfer-Encoding": "gzip, chunked", + "Transfer-Encoding": "gzip, chunked" ] XCTAssertNoThrow(try headers.validateAndSetTransportFraming(method: .POST, bodyLength: .unknown)) - XCTAssertEqual(headers, [ - "Transfer-Encoding": "chunked", - ]) + XCTAssertEqual( + headers, + [ + "Transfer-Encoding": "chunked" + ] + ) } } diff --git a/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem b/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem new file mode 100644 index 000000000..f6314d47a --- /dev/null +++ b/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBxDCCAUmgAwIBAgIVAPY31L1kyEnjO1E4inpE7+SYRO9mMAoGCCqGSM49BAMD +MCoxFDASBgNVBAoMC1NlbGYgU2lnbmVkMRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcN +MjQwMzI4MjI0MDUyWhcNMjUwMzI4MjI0MDUyWjAqMRQwEgYDVQQKDAtTZWxmIFNp +Z25lZDESMBAGA1UEAwwJbG9jYWxob3N0MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE +o2i+uiLtMu0Jzsk3oEUnfoM9n44/aV9UeOXxyDs57i2E13HrJeWIXACetybkB+Q8 +Poab6ohbskTwrS7WN3tFgoGdRBCKQow/rTECdezR/fdz2cGADaBN+CNMuFSnFSr5 +oy8wLTAWBgNVHREEDzANggtleGFtcGxlLmNvbTATBgNVHSUEDDAKBggrBgEFBQcD +ATAKBggqhkjOPQQDAwNpADBmAjEAwF5OlUBOloDTIAxgaSSvHBMSVOE1rY5hUlkT +kQ+dQFeUe3Fn+Er5ohvkt+qVOQ5yAjEAt9s5b/Iz+JmWxKKUyExHob6QHEuuHmJy +AKdrn20Ply60bb8qxGYHhwhoyV2MZYVV +-----END CERTIFICATE----- \ No newline at end of file diff --git a/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem b/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem new file mode 100644 index 000000000..7cf27cc35 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem @@ -0,0 +1,6 @@ +-----BEGIN PRIVATE KEY----- +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDhC5OSjPQeYRm4irIH +z4EyM/NbJsX39SlI6J4/q0Syt0BwojgJKhCWfeveanbIjbWhZANiAASjaL66Iu0y +7QnOyTegRSd+gz2fjj9pX1R45fHIOznuLYTXcesl5YhcAJ63JuQH5Dw+hpvqiFuy +RPCtLtY3e0WCgZ1EEIpCjD+tMQJ17NH993PZwYANoE34I0y4VKcVKvk= +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/Tests/AsyncHTTPClientTests/Resources/self_signed_cert.pem b/Tests/AsyncHTTPClientTests/Resources/self_signed_cert.pem new file mode 100644 index 000000000..20b46f355 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/Resources/self_signed_cert.pem @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIEpjCCAo4CCQCeTmfiTQcJrzANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls +b2NhbGhvc3QwIBcNMjIwNjE0MTI1NDQ4WhgPMjI5NjAzMjgxMjU0NDhaMBQxEjAQ +BgNVBAMMCWxvY2FsaG9zdDCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIB +AK16gPDwP/Xbaf36x5BNd6yHDxCPIIJP4JLfMEuozwLE0YRqwmZOuklb4jUbAXf7 +u9u24ANrC4XS6VVWkfPdugokAUkaKPpwkV4GOiMCXeSDjDiLt1dYxlbp+MLV78a5 +oUDbCAqfFKebIgv1oiK+L6/p818eAHSBWEXXMhTeBDEQAIpJLTG88iVu6r3fMJeH +FbMWuPmAajmx2AEGmwD1x6+NHZLJv1zaufa7j0sHADraagXnfKn6rkLn1is6QFu4 +v7xaNlEwsRCYbh0nrtCtEdJIqnEHc0GCu/gnw5GE3CuRG3FYBTZStIF7d9h+XZQB +ky/YEWSGw9DXFBbebOZugopvl91qaZLqo6Wg0J8qCodgFtJHOSVMq/SAOBmKyw+b +7FYZbj4tQKpuuhwCN+gwEveTy+BK+zGY/sVzPwR8PNjpCgT/HiOBM7dNt4+2r9pY +Ld/mcMvakgRzM4Iqqntem9ltuckZev0TRjdrIylVWsAlNYVXm4ncMLkbzxFkv5Gb +AlhAuTwxyFkIo0M7+GS4lXCZ2bX2umJ0DTl3/NGJserFdkOhvHZSHHC9BzDBysmc +SejX/cGOFQ8O3sFeJdVMGlO64dU482O0FbBcLHmTLXWR4t8dlhrzJuXZ4X6WtHqY +83RwyD1gacYRZnT0eL+Z7XGrO1/qypji1RNaFIaGUt7DAgMBAAEwDQYJKoZIhvcN +AQELBQADggIBAIigOuEVirgqXoUMStTwYObs/DcNIPEugn9gAq9Lt1cr6fm7CvhG +AupxoJTbKLHQX6FegvFSA+4Kt3KYXX9Qi9SJF3Vr4zOhV0q203d4Aui6Lamo5Yye +nhbzzXuDSIyxpaWPFRC2RqCA6+hV8/Ar9Bx0TCI4NQxWxQEPerwqzqWCuTbViccw +WzlwRD2AHibaQaCbpzXg9lOX0fRJHoSM3exYQd91pDoSoL3f/EV3I/czssq+10M8 +F4GhE4bQjaKD7jL5U59dlvfy73nLAzzxzsxsFuYTAgzZwDg586sdbrqqFjzjoZ9A +dF8NuVYkHyFDQkpe66e1isNZi7eFdSjeVmj8llp4b6in59ik7ZS7arzGOxhZZzmv +Jf3nfE4hJzMS/4GJsKMdtcI+6K+hMi6Yt9OoPh82SQ2q8gK4QSWWrwAKuQ4F4UeO +pgiWBryKrkOXlGARBbsR/ZDhlqyAskeGuhIpEY5NLCByFfQ5KlcrX+n4TVLRZMvb +/7PZqboGgU+CUVawm/suPAs8jOlFQOzrxWQPRfWVvFII62ABgozS8N/xZ/WbgTVj +kOtWj85NpaBSCUliIY/7z1FkjpMZO8Kds45WQzAq4YChDLZGbgV0MkyXqO/LEYFJ +zqGOP1yGxVcKxu6t8Xh0hL6JPFmKWiMEWVrd1wut6NAIu6WNftmWZX6J +-----END CERTIFICATE----- diff --git a/Tests/AsyncHTTPClientTests/Resources/self_signed_key.pem b/Tests/AsyncHTTPClientTests/Resources/self_signed_key.pem new file mode 100644 index 000000000..8811c2d81 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/Resources/self_signed_key.pem @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCteoDw8D/122n9 ++seQTXeshw8QjyCCT+CS3zBLqM8CxNGEasJmTrpJW+I1GwF3+7vbtuADawuF0ulV +VpHz3boKJAFJGij6cJFeBjojAl3kg4w4i7dXWMZW6fjC1e/GuaFA2wgKnxSnmyIL +9aIivi+v6fNfHgB0gVhF1zIU3gQxEACKSS0xvPIlbuq93zCXhxWzFrj5gGo5sdgB +BpsA9cevjR2Syb9c2rn2u49LBwA62moF53yp+q5C59YrOkBbuL+8WjZRMLEQmG4d +J67QrRHSSKpxB3NBgrv4J8ORhNwrkRtxWAU2UrSBe3fYfl2UAZMv2BFkhsPQ1xQW +3mzmboKKb5fdammS6qOloNCfKgqHYBbSRzklTKv0gDgZissPm+xWGW4+LUCqbroc +AjfoMBL3k8vgSvsxmP7Fcz8EfDzY6QoE/x4jgTO3TbePtq/aWC3f5nDL2pIEczOC +Kqp7XpvZbbnJGXr9E0Y3ayMpVVrAJTWFV5uJ3DC5G88RZL+RmwJYQLk8MchZCKND +O/hkuJVwmdm19rpidA05d/zRibHqxXZDobx2UhxwvQcwwcrJnEno1/3BjhUPDt7B +XiXVTBpTuuHVOPNjtBWwXCx5ky11keLfHZYa8ybl2eF+lrR6mPN0cMg9YGnGEWZ0 +9Hi/me1xqztf6sqY4tUTWhSGhlLewwIDAQABAoICAApRcP3jrEo5RLKgieIhWW7f +kZvQh4R4r8jMkZjOb5Gglz2jA/EF2bqnRmsWMh4q0N+envBVG5hYFRzIS2IP3BLi +VVk9vxY2P88x259dcqw2zs5GMR923kUpIWylQN+3BspOvMm08IuPhJTlhUE/wqJZ +7enIZQqI7vEofYgUNHeelgmjlJaSwGxNjpTAg6lflYDTZykf5DGOTGSzOeDyvW/J +muqyKTmioND2Eu3JetAFUa0MObP6fwbntytXCaDq+ix/yR9HICD2kAYX6CPtR1QU +kl6qrMZGultmMhGjr1zAArvZGmZCwQ26hERSL8qv1UtRNKegBGGViVJa5GtIQ2dT +UmTWmWu/5gyxKvvjuqYl8Dub2/ZT0iGAsA6hGyUr+vpgjcNEZqsYhiEiQPi0g1sM +XyszytqG1F7JzXYgVzcdFA9L+eLD+i4nKD18TYTYHFGRmxwQ+HzHnetgDQ2gqRbB +XwT4lp643oNLMGyL+T0cQ7i1Hpq7Ko0S2FeXzzFe9B33uXbDvc0usier5qx2tgxc +zfgSqJjahfo4LCxhxvBWOup3U/sXNgyMCctr1qjpwGwLek+H1keOyv7FO9O6OgI1 +v5ZPFsJV7mK1fDLM/8QLDpUcUNnhPUfzsBdxKrjLfnZ8MPNczgv1GPzb4jsLvewf +g6ps8oBwnZDQVa6dMuyRAoIBAQDnTKRUsTMmFo01o0k90C8SwwE2x7Wry8r6vIIf +PMni3ZAS+zWFnu1zg82+83QpdvskntWM2iXS7nimmkXClCCFMDU/hYA9EsZtGIv6 ++xA6gYF0Xd3Qf9QrvhixOxHj3ixNyCeee3/9XUYln3ZfEx8cgCwHjPSIm3rOKI2M +PFnuG9xJ513sy6YCDrCdtb661E6bmsaMcIhu6S7At0njwnoL9aB617TSds5tFEr8 +74EW3D9epN01uUQ9MgZSXbzdQ82IswLps4a/k4wfDFp4qKpx7zOsoTSjA9il3fgW +QLhBXxnzTYYTvwxIgaW//fyqEL3p6t9zuYcjbORcrj7v8xIvAoIBAQDAASGjsSCA +hn03DXrI/atoXEC0htVwPwp4HTI0Z1/rOS0IrFBcX3CWx90Dr/clePHQGPk1yOO7 +oM83zumwggIOymtDhlTcCa77yN9x9AZMW3qPMF+mvAouUzItnlMrOjvfEnIWziWC +UsylBiV4/I6tf0zpH8zFYPNXq98fpv+UXyJDTW+YGBc2b2BwZZA6RdtFalqvunM7 +M8FIH8vSYEMR0YC47L2ceBJY/U9EQpsc6vuS7+CoXOH/WRb5v1z+a5O9sHWp8Rdc +Oh67B6v2feUT9TwhGUVF0L+ktW389e3N+VzPvbvICvRsOvo6+bceCJTszhNno00s +87bPyelaHXutAoIBAFtJ6onqri9YMz96RMv6wLl88Zu3UsKNWn1/rTO7AEtj+xsi +vssQINO4r5mv6Kb86L5ZWhuPdeI8cK4AsYvMftFSZ5G8lRKFuH8Scx0Jviv5NSjC +a2uBKDJjgsdgcv0mkQHZ/5kTUT6kc60htMxtdZgAFmCch17rTprTcppor23E3Trl +8DInZkvllFuKgc6nQKc1fSustoxfyC4TqTwVY6oYtdAGFr4CWhK/MaGGvcJSB0jJ +dO1hQ8eLWOdlS8dgnVxYmsu2KXavO1x9ua9pkmwJZrG5pla4i+dbJjFSNebHLCzU +6hgdDTIIyWxvSCuvE+Wg57R7AxU+Qxs5Qmnd280CggEAex4+m+BwnvmeQTb7jPZc +e0bsltX+90L1S6AtGT1QXF0Fa5JS1Wi9oXH3Xu3u5LBxHqdk5gAzR5UOSxL69pvn +BeT2cw4oTBBJjFp6LW/0ufHO3RJ/w0LApIPkoSvs2MM2sQv67HSzyKWfZBJU5QfN +1aLTholFnStV3tnu8TT8nf+C0PVOoZCREe7JQElf+n3g5NoV3KkKSuQdBEqfP/9K +Apr8l5f23eaAnV+Q/IxZOmnTd50pycwFft95xBvZXatNyUzlpltaR2FdY0DAHAcO +ZYXTUMYLjYEV4mAUbyijnHhR80QOrW+Y2+3VlwuZSEDofhCGkOY+Dp0YlJU8dPSC +4QKCAQEA3qlwsjJT8Vv+Sx2sZySDNtey/00yjxe4TckdH8dWTC22G38/ppbazlp/ +YVqFUgteoo5FI60HKa0+jLKQZpCIH6JgpL3yq81Obl0wRkrSNePRTL1Ikiff8P2j +bowFpbIdJLgDco1opJpDgTOz2mB7HlHu6RyoKjiVrNA/EOoks1Uljxdth6h/5ctr +rLn8dnz2sTtwxcUsOpyFcFQ2qaWJvSg+bF7JPPzMrpQfCR1qVWa43Kl8KlcWSKaq +ITpglIBY+h3F2GygAAcnpfkXde381Iw89y7TFd2LxWQR98zhnbJWF2JmuuPDtVRv ++HYZkcyQcpDwfC+2NOWOU7NQj+IDIA== +-----END PRIVATE KEY----- diff --git a/Tests/AsyncHTTPClientTests/ResponseDelayGetTests.swift b/Tests/AsyncHTTPClientTests/ResponseDelayGetTests.swift new file mode 100644 index 000000000..5fd1d6720 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/ResponseDelayGetTests.swift @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class ResponseDelayGetTests: XCTestCaseHTTPClientTestsBaseClass { + func testResponseDelayGet() throws { + let req = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .GET, + headers: ["X-internal-delay": "2000"], + body: nil + ) + let start = NIODeadline.now() + let response = try self.defaultClient.execute(request: req).wait() + XCTAssertGreaterThanOrEqual(.now() - start, .milliseconds(1_900)) + XCTAssertEqual(response.status, .ok) + } +} diff --git a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests+XCTest.swift deleted file mode 100644 index 0338adf3c..000000000 --- a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests+XCTest.swift +++ /dev/null @@ -1,35 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// SOCKSEventsHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension SOCKSEventsHandlerTests { - static var allTests: [(String, (SOCKSEventsHandlerTests) -> () throws -> Void)] { - return [ - ("testHandlerHappyPath", testHandlerHappyPath), - ("testHandlerFailsFutureWhenRemovedWithoutEvent", testHandlerFailsFutureWhenRemovedWithoutEvent), - ("testHandlerFailsFutureWhenHandshakeFails", testHandlerFailsFutureWhenHandshakeFails), - ("testHandlerClosesConnectionIfHandshakeTimesout", testHandlerClosesConnectionIfHandshakeTimesout), - ("testHandlerWorksIfDeadlineIsInPast", testHandlerWorksIfDeadlineIsInPast), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift b/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift index 066a631a5..1170aa444 100644 --- a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOSOCKS import XCTest +@testable import AsyncHTTPClient + class SOCKSEventsHandlerTests: XCTestCase { func testHandlerHappyPath() { let socksEventsHandler = SOCKSEventsHandler(deadline: .now() + .seconds(10)) diff --git a/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift index d7c97e6fe..ebff55a6d 100644 --- a/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift @@ -40,7 +40,13 @@ class MockSOCKSServer { self.channel.localAddress!.port! } - init(expectedURL: String, expectedResponse: String, misbehave: Bool = false, file: String = #file, line: UInt = #line) throws { + init( + expectedURL: String, + expectedResponse: String, + misbehave: Bool = false, + file: String = #filePath, + line: UInt = #line + ) throws { let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1) let bootstrap: ServerBootstrap if misbehave { @@ -53,12 +59,19 @@ class MockSOCKSServer { bootstrap = ServerBootstrap(group: elg) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelInitializer { channel in - let handshakeHandler = SOCKSServerHandshakeHandler() - return channel.pipeline.addHandlers([ - handshakeHandler, - SOCKSTestHandler(handshakeHandler: handshakeHandler), - TestHTTPServer(expectedURL: expectedURL, expectedResponse: expectedResponse, file: file, line: line), - ]) + channel.eventLoop.makeCompletedFuture { + let handshakeHandler = SOCKSServerHandshakeHandler() + try channel.pipeline.syncOperations.addHandlers([ + handshakeHandler, + SOCKSTestHandler(handshakeHandler: handshakeHandler), + TestHTTPServer( + expectedURL: expectedURL, + expectedResponse: expectedResponse, + file: file, + line: line + ), + ]) + } } } self.channel = try bootstrap.bind(host: "localhost", port: 0).wait() @@ -86,19 +99,34 @@ class SOCKSTestHandler: ChannelInboundHandler, RemovableChannelHandler { let message = self.unwrapInboundIn(data) switch message { case .greeting: - context.writeAndFlush(.init( - ServerMessage.selectedAuthenticationMethod(.init(method: .noneRequired))), promise: nil) + context.writeAndFlush( + .init( + ServerMessage.selectedAuthenticationMethod(.init(method: .noneRequired)) + ), + promise: nil + ) case .authenticationData: context.fireErrorCaught(MockSOCKSError(description: "Received authentication data but didn't receive any.")) case .request(let request): - context.writeAndFlush(.init( - ServerMessage.response(.init(reply: .succeeded, boundAddress: request.addressType))), promise: nil) - context.channel.pipeline.addHandlers([ - ByteToMessageHandler(HTTPRequestDecoder()), - HTTPResponseEncoder(), - ], position: .after(self)).whenSuccess { - context.channel.pipeline.removeHandler(self, promise: nil) - context.channel.pipeline.removeHandler(self.handshakeHandler, promise: nil) + context.writeAndFlush( + .init( + ServerMessage.response(.init(reply: .succeeded, boundAddress: request.addressType)) + ), + promise: nil + ) + + do { + try context.channel.pipeline.syncOperations.addHandlers( + [ + ByteToMessageHandler(HTTPRequestDecoder()), + HTTPResponseEncoder(), + ], + position: .after(self) + ) + context.channel.pipeline.syncOperations.removeHandler(self, promise: nil) + context.channel.pipeline.syncOperations.removeHandler(self.handshakeHandler, promise: nil) + } catch { + context.fireErrorCaught(error) } } } @@ -134,7 +162,12 @@ class TestHTTPServer: ChannelInboundHandler { break case .end: context.write(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok))), promise: nil) - context.write(self.wrapOutboundOut(.body(.byteBuffer(context.channel.allocator.buffer(string: self.expectedResponse)))), promise: nil) + context.write( + self.wrapOutboundOut( + .body(.byteBuffer(context.channel.allocator.buffer(string: self.expectedResponse))) + ), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } } diff --git a/Tests/AsyncHTTPClientTests/SSLContextCacheTests+XCTest.swift b/Tests/AsyncHTTPClientTests/SSLContextCacheTests+XCTest.swift deleted file mode 100644 index d98f5a853..000000000 --- a/Tests/AsyncHTTPClientTests/SSLContextCacheTests+XCTest.swift +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// SSLContextCacheTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension SSLContextCacheTests { - static var allTests: [(String, (SSLContextCacheTests) -> () throws -> Void)] { - return [ - ("testRequestingSSLContextWorks", testRequestingSSLContextWorks), - ("testCacheWorks", testCacheWorks), - ("testCacheDoesNotReturnWrongEntry", testCacheDoesNotReturnWrongEntry), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift b/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift index 438c643d7..c7588cc7d 100644 --- a/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift +++ b/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOPosix import NIOSSL import XCTest +@testable import AsyncHTTPClient + final class SSLContextCacheTests: XCTestCase { func testRequestingSSLContextWorks() { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) @@ -27,9 +28,13 @@ final class SSLContextCacheTests: XCTestCase { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - XCTAssertNoThrow(try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) } func testCacheWorks() { @@ -43,12 +48,20 @@ final class SSLContextCacheTests: XCTestCase { var firstContext: NIOSSLContext? var secondContext: NIOSSLContext? - XCTAssertNoThrow(firstContext = try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) - XCTAssertNoThrow(secondContext = try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + firstContext = try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) + XCTAssertNoThrow( + secondContext = try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) XCTAssertNotNil(firstContext) XCTAssertNotNil(secondContext) XCTAssert(firstContext === secondContext) @@ -65,16 +78,24 @@ final class SSLContextCacheTests: XCTestCase { var firstContext: NIOSSLContext? var secondContext: NIOSSLContext? - XCTAssertNoThrow(firstContext = try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + firstContext = try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) // Second one has a _different_ TLSConfiguration. var testTLSConfig = TLSConfiguration.makeClientConfiguration() testTLSConfig.certificateVerification = .none - XCTAssertNoThrow(secondContext = try cache.sslContext(tlsConfiguration: testTLSConfig, - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + secondContext = try cache.sslContext( + tlsConfiguration: testTLSConfig, + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) XCTAssertNotNil(firstContext) XCTAssertNotNil(secondContext) XCTAssert(firstContext !== secondContext) diff --git a/Tests/AsyncHTTPClientTests/StressGetHttpsTests.swift b/Tests/AsyncHTTPClientTests/StressGetHttpsTests.swift new file mode 100644 index 000000000..587e6c64c --- /dev/null +++ b/Tests/AsyncHTTPClientTests/StressGetHttpsTests.swift @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class StressGetHttpsTests: XCTestCaseHTTPClientTestsBaseClass { + func testStressGetHttps() throws { + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + + let eventLoop = localClient.eventLoopGroup.next() + let requestCount = 200 + var futureResults = [EventLoopFuture]() + for _ in 1...requestCount { + let req = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + headers: ["X-internal-delay": "100"] + ) + futureResults.append(localClient.execute(request: req)) + } + XCTAssertNoThrow(try EventLoopFuture.andAllSucceed(futureResults, on: eventLoop).wait()) + } +} diff --git a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests+XCTest.swift deleted file mode 100644 index 062132f4e..000000000 --- a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests+XCTest.swift +++ /dev/null @@ -1,34 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// TLSEventsHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension TLSEventsHandlerTests { - static var allTests: [(String, (TLSEventsHandlerTests) -> () throws -> Void)] { - return [ - ("testHandlerHappyPath", testHandlerHappyPath), - ("testHandlerFailsFutureWhenRemovedWithoutEvent", testHandlerFailsFutureWhenRemovedWithoutEvent), - ("testHandlerFailsFutureWhenHandshakeFails", testHandlerFailsFutureWhenHandshakeFails), - ("testHandlerIgnoresShutdownCompletedEvent", testHandlerIgnoresShutdownCompletedEvent), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift b/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift index c119c7e50..96cdf68f6 100644 --- a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOSSL import NIOTLS import XCTest +@testable import AsyncHTTPClient + class TLSEventsHandlerTests: XCTestCase { func testHandlerHappyPath() { let tlsEventsHandler = TLSEventsHandler(deadline: nil) diff --git a/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests+XCTest.swift deleted file mode 100644 index a46c7dfc0..000000000 --- a/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests+XCTest.swift +++ /dev/null @@ -1,35 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// Transaction+StateMachineTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension Transaction_StateMachineTests { - static var allTests: [(String, (Transaction_StateMachineTests) -> () throws -> Void)] { - return [ - ("testRequestWasQueuedAfterWillExecuteRequestWasCalled", testRequestWasQueuedAfterWillExecuteRequestWasCalled), - ("testRequestBodyStreamWasPaused", testRequestBodyStreamWasPaused), - ("testQueuedRequestGetsRemovedWhenDeadlineExceeded", testQueuedRequestGetsRemovedWhenDeadlineExceeded), - ("testScheduledRequestGetsRemovedWhenDeadlineExceeded", testScheduledRequestGetsRemovedWhenDeadlineExceeded), - ("testRequestWithHeadReceivedGetNotCancelledWhenDeadlineExceeded", testRequestWithHeadReceivedGetNotCancelledWhenDeadlineExceeded), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift b/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift index ff1972330..a631e9a93 100644 --- a/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift @@ -12,16 +12,21 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + +struct NoOpAsyncSequenceProducerDelegate: NIOAsyncSequenceProducerDelegate { + func produceMore() {} + func didTerminate() {} +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) final class Transaction_StateMachineTests: XCTestCase { func testRequestWasQueuedAfterWillExecuteRequestWasCalled() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let eventLoop = EmbeddedEventLoop() XCTAsyncTest { func workaround(_ continuation: CheckedContinuation) { @@ -33,7 +38,10 @@ final class Transaction_StateMachineTests: XCTestCase { state.requestWasQueued(queuer) let failAction = state.fail(HTTPClientError.cancelled) - guard case .failResponseHead(_, let error, let scheduler, let rexecutor, let bodyStreamContinuation) = failAction else { + guard + case .failResponseHead(_, let error, let scheduler, let rexecutor, let bodyStreamContinuation) = + failAction + else { return XCTFail("Unexpected fail action: \(failAction)") } XCTAssertEqual(error as? HTTPClientError, .cancelled) @@ -46,12 +54,9 @@ final class Transaction_StateMachineTests: XCTestCase { await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) } - #endif } func testRequestBodyStreamWasPaused() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let eventLoop = EmbeddedEventLoop() XCTAsyncTest { func workaround(_ continuation: CheckedContinuation) { @@ -69,12 +74,10 @@ final class Transaction_StateMachineTests: XCTestCase { await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) } - #endif } func testQueuedRequestGetsRemovedWhenDeadlineExceeded() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + struct MyError: Error, Equatable {} XCTAsyncTest { func workaround(_ continuation: CheckedContinuation) { var state = Transaction.StateMachine(continuation) @@ -82,23 +85,69 @@ final class Transaction_StateMachineTests: XCTestCase { state.requestWasQueued(queuer) - let failAction = state.deadlineExceeded() - guard case .cancel(let continuation, let scheduler, nil, nil) = failAction else { + let deadlineExceededAction = state.deadlineExceeded() + guard case .cancelSchedulerOnly(let scheduler) = deadlineExceededAction else { + return XCTFail("Unexpected fail action: \(deadlineExceededAction)") + } + XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) + + let failAction = state.fail(MyError()) + guard + case .failResponseHead(let continuation, let error, nil, nil, bodyStreamContinuation: nil) = + failAction + else { return XCTFail("Unexpected fail action: \(failAction)") } XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) - continuation.resume(throwing: HTTPClientError.deadlineExceeded) + continuation.resume(throwing: error) } - await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) + await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) { + XCTAssertEqualTypeAndValue($0, MyError()) + } + } + } + + func testDeadlineExceededAndFullyFailedRequestCanBeCanceledWithNoEffect() { + struct MyError: Error, Equatable {} + XCTAsyncTest { + func workaround(_ continuation: CheckedContinuation) { + var state = Transaction.StateMachine(continuation) + let queuer = MockTaskQueuer() + + state.requestWasQueued(queuer) + + let deadlineExceededAction = state.deadlineExceeded() + guard case .cancelSchedulerOnly(let scheduler) = deadlineExceededAction else { + return XCTFail("Unexpected fail action: \(deadlineExceededAction)") + } + XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) + + let failAction = state.fail(MyError()) + guard + case .failResponseHead(let continuation, let error, nil, nil, bodyStreamContinuation: nil) = + failAction + else { + return XCTFail("Unexpected fail action: \(failAction)") + } + XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) + + let secondFailAction = state.fail(HTTPClientError.cancelled) + guard case .none = secondFailAction else { + return XCTFail("Unexpected fail action: \(secondFailAction)") + } + + continuation.resume(throwing: error) + } + + await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) { + XCTAssertEqualTypeAndValue($0, MyError()) + } } - #endif } func testScheduledRequestGetsRemovedWhenDeadlineExceeded() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let eventLoop = EmbeddedEventLoop() XCTAsyncTest { func workaround(_ continuation: CheckedContinuation) { @@ -120,12 +169,40 @@ final class Transaction_StateMachineTests: XCTestCase { await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) } - #endif + } + + func testDeadlineExceededRaceWithRequestWillExecute() { + let eventLoop = EmbeddedEventLoop() + XCTAsyncTest { + func workaround(_ continuation: CheckedContinuation) { + var state = Transaction.StateMachine(continuation) + let expectedExecutor = MockRequestExecutor(eventLoop: eventLoop) + let queuer = MockTaskQueuer() + + state.requestWasQueued(queuer) + + let deadlineExceededAction = state.deadlineExceeded() + guard case .cancelSchedulerOnly(let scheduler) = deadlineExceededAction else { + return XCTFail("Unexpected fail action: \(deadlineExceededAction)") + } + XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) + + let failAction = state.willExecuteRequest(expectedExecutor) + guard case .cancelAndFail(let returnedExecutor, let continuation, with: let error) = failAction else { + return XCTFail("Unexpected fail action: \(failAction)") + } + XCTAssertIdentical(returnedExecutor as? MockRequestExecutor, expectedExecutor) + + continuation.resume(throwing: error) + } + + await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) { + XCTAssertEqualTypeAndValue($0, HTTPClientError.deadlineExceeded) + } + } } func testRequestWithHeadReceivedGetNotCancelledWhenDeadlineExceeded() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let eventLoop = EmbeddedEventLoop() XCTAsyncTest { func workaround(_ continuation: CheckedContinuation) { @@ -136,8 +213,11 @@ final class Transaction_StateMachineTests: XCTestCase { XCTAssertEqual(state.willExecuteRequest(executor), .none) state.requestWasQueued(queuer) let head = HTTPResponseHead(version: .http1_1, status: .ok) - let receiveResponseHeadAction = state.receiveResponseHead(head) - guard case .succeedResponseHead(head, let continuation) = receiveResponseHeadAction else { + let receiveResponseHeadAction = state.receiveResponseHead( + head, + delegate: NoOpAsyncSequenceProducerDelegate() + ) + guard case .succeedResponseHead(_, let continuation) = receiveResponseHeadAction else { return XCTFail("Unexpected action: \(receiveResponseHeadAction)") } @@ -150,11 +230,9 @@ final class Transaction_StateMachineTests: XCTestCase { await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) } - #endif } } -#if compiler(>=5.5.2) && canImport(_Concurrency) @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension Transaction.StateMachine.StartExecutionAction: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { @@ -193,7 +271,7 @@ extension Transaction.StateMachine.NextWriteAction: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { case (.writeAndWait(let lhsEx), .writeAndWait(let rhsEx)), - (.writeAndContinue(let lhsEx), .writeAndContinue(let rhsEx)): + (.writeAndContinue(let lhsEx), .writeAndContinue(let rhsEx)): if let lhsMock = lhsEx as? MockRequestExecutor, let rhsMock = rhsEx as? MockRequestExecutor { return lhsMock === rhsMock } @@ -205,4 +283,3 @@ extension Transaction.StateMachine.NextWriteAction: Equatable { } } } -#endif diff --git a/Tests/AsyncHTTPClientTests/TransactionTests+XCTest.swift b/Tests/AsyncHTTPClientTests/TransactionTests+XCTest.swift deleted file mode 100644 index 190260647..000000000 --- a/Tests/AsyncHTTPClientTests/TransactionTests+XCTest.swift +++ /dev/null @@ -1,39 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// TransactionTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension TransactionTests { - static var allTests: [(String, (TransactionTests) -> () throws -> Void)] { - return [ - ("testCancelAsyncRequest", testCancelAsyncRequest), - ("testResponseStreamingWorks", testResponseStreamingWorks), - ("testIgnoringResponseBodyWorks", testIgnoringResponseBodyWorks), - ("testWriteBackpressureWorks", testWriteBackpressureWorks), - ("testSimpleGetRequest", testSimpleGetRequest), - ("testSimplePostRequest", testSimplePostRequest), - ("testPostStreamFails", testPostStreamFails), - ("testResponseStreamFails", testResponseStreamFails), - ("testBiDirectionalStreamingHTTP2", testBiDirectionalStreamingHTTP2), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/TransactionTests.swift b/Tests/AsyncHTTPClientTests/TransactionTests.swift index 7e2c62a0d..8609597b3 100644 --- a/Tests/AsyncHTTPClientTests/TransactionTests.swift +++ b/Tests/AsyncHTTPClientTests/TransactionTests.swift @@ -12,27 +12,29 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +import NIOFoundationCompat import NIOHTTP1 import NIOPosix import XCTest -#if compiler(>=5.5.2) && canImport(_Concurrency) +@testable import AsyncHTTPClient + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) typealias PreparedRequest = HTTPClientRequest.Prepared -#endif +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) final class TransactionTests: XCTestCase { func testCancelAsyncRequest() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + // creating the `XCTestExpectation` off the main thread crashes on Linux with Swift 5.6 + // therefore we create it here as a workaround which works fine + let scheduledRequestCanceled = self.expectation(description: "scheduled request canceled") XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -41,34 +43,92 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return XCTFail("Expected to have a request here.") } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let queuer = MockTaskQueuer() + let queuer = MockTaskQueuer { _ in + scheduledRequestCanceled.fulfill() + } transaction.requestWasQueued(queuer) + XCTAssertEqual(queuer.hitCancelCount, 0) Task.detached { try await Task.sleep(nanoseconds: 5 * 1000 * 1000) transaction.cancel() } - XCTAssertEqual(queuer.hitCancelCount, 0) - await XCTAssertThrowsError(try await responseTask.value) { - XCTAssertEqual($0 as? HTTPClientError, .cancelled) + await XCTAssertThrowsError(try await responseTask.value) { error in + XCTAssertTrue(error is CancellationError, "unexpected error \(error)") + } + + // self.fulfillment(of:) is not available on Linux + _ = { + self.wait(for: [scheduledRequestCanceled], timeout: 1) + }() + } + } + + func testDeadlineExceededWhileQueuedAndExecutorImmediatelyCancelsTask() { + XCTAsyncTest { + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } + + var request = HTTPClientRequest(url: "https://localhost/") + request.method = .GET + var maybePreparedRequest: PreparedRequest? + XCTAssertNoThrow(maybePreparedRequest = try PreparedRequest(request)) + guard let preparedRequest = maybePreparedRequest else { + return XCTFail("Expected to have a request here.") + } + let (transaction, responseTask) = await Transaction.makeWithResultTask( + request: preparedRequest, + preferredEventLoop: loop + ) + + let queuer = MockTaskQueuer() + transaction.requestWasQueued(queuer) + + transaction.deadlineExceeded() + + struct Executor: HTTPRequestExecutor { + func writeRequestBodyPart( + _: NIOCore.IOData, + request: AsyncHTTPClient.HTTPExecutableRequest, + promise: NIOCore.EventLoopPromise? + ) { + XCTFail() + } + + func finishRequestBodyStream( + _ task: AsyncHTTPClient.HTTPExecutableRequest, + promise: NIOCore.EventLoopPromise? + ) { + XCTFail() + } + + func demandResponseBodyStream(_: AsyncHTTPClient.HTTPExecutableRequest) { + XCTFail() + } + + func cancelRequest(_ task: AsyncHTTPClient.HTTPExecutableRequest) { + task.fail(HTTPClientError.cancelled) + } + } + + transaction.willExecuteRequest(Executor()) + + await XCTAssertThrowsError(try await responseTask.value) { error in + XCTAssertEqualTypeAndValue(error, HTTPClientError.deadlineExceeded) } - XCTAssertEqual(queuer.hitCancelCount, 1) } - #endif } func testResponseStreamingWorks() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -78,14 +138,14 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) transaction.willExecuteRequest(executor) @@ -100,11 +160,11 @@ final class TransactionTests: XCTestCase { XCTAssertEqual(response.headers, responseHead.headers) XCTAssertEqual(response.version, responseHead.version) - let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }.makeAsyncIterator()) + let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }) - for i in 0..<100 { - XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") + XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") + for i in 0..<100 { async let part = iterator.next() XCTAssertNoThrow(try executor.receiveResponseDemand()) @@ -115,7 +175,6 @@ final class TransactionTests: XCTestCase { XCTAssertEqual(result, ByteBuffer(integer: i)) } - XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") async let part = iterator.next() XCTAssertNoThrow(try executor.receiveResponseDemand()) executor.resetResponseStreamDemandSignal() @@ -123,15 +182,12 @@ final class TransactionTests: XCTestCase { let result = try await part XCTAssertNil(result) } - #endif } func testIgnoringResponseBodyWorks() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -141,9 +197,9 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return } - var tuple: (Transaction, Task)! = Transaction.makeWithResultTask( + var tuple: (Transaction, Task)! = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let transaction = tuple.0 @@ -152,9 +208,10 @@ final class TransactionTests: XCTestCase { let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) executor.runRequest(transaction) + await loop.run() let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["foo": "bar"]) XCTAssertFalse(executor.signalledDemandForResponseBody) @@ -174,15 +231,12 @@ final class TransactionTests: XCTestCase { transaction.receiveResponseBodyParts([ByteBuffer(string: "foo bar")]) transaction.succeedRequest(nil) } - #endif } func testWriteBackpressureWorks() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } let streamWriter = AsyncSequenceWriter() XCTAssertFalse(streamWriter.hasDemand, "Did not expect to have a demand at this point") @@ -196,28 +250,31 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return XCTFail("Expected to have a request here.") } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() for i in 0..<100 { XCTAssertFalse(streamWriter.hasDemand, "Did not expect to have demand yet") transaction.resumeRequestBodyStream() - await streamWriter.demand() // wait's for the stream writer to signal demand + await streamWriter.demand() // wait's for the stream writer to signal demand transaction.pauseRequestBodyStream() let part = ByteBuffer(integer: i) streamWriter.write(part) - XCTAssertNoThrow(try executor.receiveRequestBody { - XCTAssertEqual($0, part) - }) + XCTAssertNoThrow( + try executor.receiveRequestBody { + XCTAssertEqual($0, part) + } + ) } transaction.resumeRequestBodyStream() @@ -237,7 +294,7 @@ final class TransactionTests: XCTestCase { XCTAssertEqual(response.headers, responseHead.headers) XCTAssertEqual(response.version, responseHead.version) - let iterator = SharedIterator(response.body.makeAsyncIterator()) + let iterator = SharedIterator(response.body) XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") async let part = iterator.next() @@ -248,12 +305,9 @@ final class TransactionTests: XCTestCase { let result = try await part XCTAssertNil(result) } - #endif } func testSimpleGetRequest() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) let eventLoop = eventLoopGroup.next() @@ -265,11 +319,13 @@ final class TransactionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - )) + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") } @@ -282,7 +338,7 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return XCTFail("Expected to have a request here.") } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, preferredEventLoop: eventLoopGroup.next() ) @@ -306,15 +362,12 @@ final class TransactionTests: XCTestCase { RequestInfo(data: "", requestNumber: 1, connectionNumber: 0) ) } - #endif } func testSimplePostRequest() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .POST @@ -324,17 +377,20 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return XCTFail("Expected to have a request here.") } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() executor.resumeRequestBodyStream() - XCTAssertNoThrow(try executor.receiveRequestBody { - XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") - }) + XCTAssertNoThrow( + try executor.receiveRequestBody { + XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") + } + ) XCTAssertNoThrow(try executor.receiveEndOfStream()) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["foo": "bar"]) @@ -346,15 +402,12 @@ final class TransactionTests: XCTestCase { XCTAssertEqual(response.version, .http1_1) XCTAssertEqual(response.headers, ["foo": "bar"]) } - #endif } func testPostStreamFails() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } let writer = AsyncSequenceWriter() @@ -366,21 +419,24 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return XCTFail("Expected to have a request here.") } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() executor.resumeRequestBodyStream() await writer.demand() writer.write(.init(string: "Hello world!")) - XCTAssertNoThrow(try executor.receiveRequestBody { - XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") - }) + XCTAssertNoThrow( + try executor.receiveRequestBody { + XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") + } + ) XCTAssertFalse(executor.isCancelled) struct WriteError: Error, Equatable {} @@ -391,15 +447,12 @@ final class TransactionTests: XCTestCase { } XCTAssertNoThrow(try executor.receiveCancellation()) } - #endif } func testResponseStreamFails() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } - XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + XCTAsyncTest(timeout: 30) { + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -409,14 +462,14 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) transaction.willExecuteRequest(executor) @@ -427,17 +480,19 @@ final class TransactionTests: XCTestCase { transaction.receiveResponseHead(responseHead) let response = try await responseTask.value + XCTAssertEqual(response.status, responseHead.status) XCTAssertEqual(response.headers, responseHead.headers) XCTAssertEqual(response.version, responseHead.version) XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") - let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }.makeAsyncIterator()) + let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }) async let part1 = iterator.next() XCTAssertNoThrow(try executor.receiveResponseDemand()) executor.resetResponseStreamDemandSignal() transaction.receiveResponseBodyParts([ByteBuffer(integer: 123)]) + let result = try await part1 XCTAssertEqual(result, ByteBuffer(integer: 123)) @@ -454,12 +509,9 @@ final class TransactionTests: XCTestCase { XCTAssertEqual($0 as? HTTPClientError, .readTimeout) } } - #endif } func testBiDirectionalStreamingHTTP2() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) let eventLoop = eventLoopGroup.next() @@ -471,11 +523,13 @@ final class TransactionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - )) + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") } @@ -486,14 +540,14 @@ final class TransactionTests: XCTestCase { var request = HTTPClientRequest(url: "https://localhost:\(httpBin.port)/") request.method = .POST request.headers = ["host": "localhost:\(httpBin.port)"] - request.body = .stream(streamWriter, length: .known(800)) + request.body = .stream(streamWriter, length: .known(Int64(800))) var maybePreparedRequest: PreparedRequest? XCTAssertNoThrow(maybePreparedRequest = try PreparedRequest(request)) guard let preparedRequest = maybePreparedRequest else { return } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, preferredEventLoop: eventLoopGroup.next() ) @@ -508,7 +562,7 @@ final class TransactionTests: XCTestCase { XCTAssertEqual(response.version, .http2) XCTAssertEqual(delegate.hitStreamClosed, 0) - let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }.makeAsyncIterator()) + let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }) // at this point we can start to write to the stream and wait for the results @@ -529,30 +583,73 @@ final class TransactionTests: XCTestCase { XCTAssertNil(final) XCTAssertEqual(delegate.hitStreamClosed, 1) } - #endif } } -#if compiler(>=5.5.2) && canImport(_Concurrency) - // This needs a small explanation. If an iterator is a struct, it can't be used across multiple // tasks. Since we want to wait for things to happen in tests, we need to `async let`, which creates // implicit tasks. Therefore we need to wrap our iterator struct. @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -actor SharedIterator { - private var iterator: Iterator +actor SharedIterator where Wrapped.Element: Sendable { + private var wrappedIterator: Wrapped.AsyncIterator + private var nextCallInProgress: Bool = false - init(_ iterator: Iterator) { - self.iterator = iterator + init(_ sequence: Wrapped) { + self.wrappedIterator = sequence.makeAsyncIterator() } - func next() async throws -> Iterator.Element? { - var iter = self.iterator - defer { self.iterator = iter } + func next() async throws -> Wrapped.Element? { + precondition(self.nextCallInProgress == false) + self.nextCallInProgress = true + var iter = self.wrappedIterator + defer { + precondition(self.nextCallInProgress == true) + self.nextCallInProgress = false + self.wrappedIterator = iter + } return try await iter.next() } } +/// non fail-able promise that only supports one observer +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +private actor Promise { + private enum State { + case initialised + case fulfilled(Value) + } + + private var state: State = .initialised + + private var observer: CheckedContinuation? + + init() {} + + func fulfil(_ value: Value) { + switch self.state { + case .initialised: + self.state = .fulfilled(value) + self.observer?.resume(returning: value) + case .fulfilled: + preconditionFailure("\(Self.self) over fulfilled") + } + } + + var value: Value { + get async { + switch self.state { + case .initialised: + return await withCheckedContinuation { (continuation: CheckedContinuation) in + precondition(self.observer == nil, "\(Self.self) supports only one observer") + self.observer = continuation + } + case .fulfilled(let value): + return value + } + } + } +} + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension Transaction { fileprivate static func makeWithResultTask( @@ -561,10 +658,11 @@ extension Transaction { logger: Logger = Logger(label: "test"), connectionDeadline: NIODeadline = .distantFuture, preferredEventLoop: EventLoop - ) -> (Transaction, _Concurrency.Task) { - let transactionPromise = preferredEventLoop.makePromise(of: Transaction.self) - let result = Task { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + ) async -> (Transaction, _Concurrency.Task) { + let transactionPromise = Promise() + let task = Task { + try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in let transaction = Transaction( request: request, requestOptions: requestOptions, @@ -573,13 +671,12 @@ extension Transaction { preferredEventLoop: preferredEventLoop, responseContinuation: continuation ) - transactionPromise.succeed(transaction) + Task { + await transactionPromise.fulfil(transaction) + } } } - // the promise can never fail and it is therefore safe to force unwrap - let transaction = try! transactionPromise.futureResult.wait() - return (transaction, result) + return (await transactionPromise.value, task) } } -#endif diff --git a/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift b/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift index fbc429b10..6cdcf4f8a 100644 --- a/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift +++ b/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift @@ -11,26 +11,25 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -/* - * Copyright 2021, gRPC Authors All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#if compiler(>=5.5.2) && canImport(_Concurrency) +// +// Copyright 2021, gRPC Authors All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + import XCTest extension XCTestCase { - @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) /// Cross-platform XCTest support for async-await tests. /// /// Currently the Linux implementation of XCTest doesn't have async-await support. @@ -39,6 +38,7 @@ extension XCTestCase { /// /// - NOTE: Support for Linux is tracked by https://bugs.swift.org/browse/SR-14403. /// - NOTE: Implementation currently in progress: https://github.com/apple/swift-corelibs-xctest/pull/326 + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) func XCTAsyncTest( expectationDescription: String = "Async operation", timeout: TimeInterval = 30, @@ -53,7 +53,7 @@ extension XCTestCase { try await operation() } catch { XCTFail("Error thrown while executing \(function): \(error)", file: file, line: line) - Thread.callStackSymbols.forEach { print($0) } + for symbol in Thread.callStackSymbols { print(symbol) } } expectation.fulfill() } @@ -65,7 +65,7 @@ extension XCTestCase { internal func XCTAssertThrowsError( _ expression: @autoclosure () async throws -> T, verify: (Error) -> Void = { _ in }, - file: StaticString = #file, + file: StaticString = #filePath, line: UInt = #line ) async { do { @@ -79,7 +79,7 @@ internal func XCTAssertThrowsError( @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) internal func XCTAssertNoThrowWithResult( _ expression: @autoclosure () async throws -> Result, - file: StaticString = #file, + file: StaticString = #filePath, line: UInt = #line ) async -> Result? { do { @@ -89,5 +89,3 @@ internal func XCTAssertNoThrowWithResult( } return nil } - -#endif diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift deleted file mode 100644 index cebced614..000000000 --- a/Tests/LinuxMain.swift +++ /dev/null @@ -1,64 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// LinuxMain.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -#if os(Linux) || os(FreeBSD) -@testable import AsyncHTTPClientTests - -XCTMain([ - testCase(AsyncAwaitEndToEndTests.allTests), - testCase(HTTP1ClientChannelHandlerTests.allTests), - testCase(HTTP1ConnectionStateMachineTests.allTests), - testCase(HTTP1ConnectionTests.allTests), - testCase(HTTP1ProxyConnectHandlerTests.allTests), - testCase(HTTP2ClientRequestHandlerTests.allTests), - testCase(HTTP2ClientTests.allTests), - testCase(HTTP2ConnectionTests.allTests), - testCase(HTTP2IdleHandlerTests.allTests), - testCase(HTTPClientCookieTests.allTests), - testCase(HTTPClientInternalTests.allTests), - testCase(HTTPClientNIOTSTests.allTests), - testCase(HTTPClientReproTests.allTests), - testCase(HTTPClientRequestTests.allTests), - testCase(HTTPClientSOCKSTests.allTests), - testCase(HTTPClientTests.allTests), - testCase(HTTPClientUncleanSSLConnectionShutdownTests.allTests), - testCase(HTTPConnectionPoolTests.allTests), - testCase(HTTPConnectionPool_FactoryTests.allTests), - testCase(HTTPConnectionPool_HTTP1ConnectionsTests.allTests), - testCase(HTTPConnectionPool_HTTP1StateMachineTests.allTests), - testCase(HTTPConnectionPool_HTTP2ConnectionsTests.allTests), - testCase(HTTPConnectionPool_HTTP2StateMachineTests.allTests), - testCase(HTTPConnectionPool_ManagerTests.allTests), - testCase(HTTPConnectionPool_RequestQueueTests.allTests), - testCase(HTTPRequestStateMachineTests.allTests), - testCase(LRUCacheTests.allTests), - testCase(RequestBagTests.allTests), - testCase(RequestValidationTests.allTests), - testCase(SOCKSEventsHandlerTests.allTests), - testCase(SSLContextCacheTests.allTests), - testCase(TLSEventsHandlerTests.allTests), - testCase(TransactionTests.allTests), - testCase(Transaction_StateMachineTests.allTests), -]) -#endif diff --git a/docker/Dockerfile b/docker/Dockerfile deleted file mode 100644 index 6395405c1..000000000 --- a/docker/Dockerfile +++ /dev/null @@ -1,34 +0,0 @@ -ARG swift_version=5.2 -ARG ubuntu_version=bionic -ARG base_image=swift:$swift_version-$ubuntu_version -FROM $base_image -# needed to do again after FROM due to docker limitation -ARG swift_version -ARG ubuntu_version - -# set as UTF-8 -RUN apt-get update && apt-get install -y locales locales-all -ENV LC_ALL en_US.UTF-8 -ENV LANG en_US.UTF-8 -ENV LANGUAGE en_US.UTF-8 - -# dependencies -RUN apt-get update && apt-get install -y wget -RUN apt-get update && apt-get install -y lsof dnsutils netcat-openbsd net-tools libz-dev curl jq # used by integration tests - -# ruby and jazzy for docs generation -RUN apt-get update && apt-get install -y ruby ruby-dev libsqlite3-dev build-essential -# jazzy no longer works on xenial as ruby is too old. -RUN if [ "${ubuntu_version}" = "focal" ] ; then echo "gem: --no-document" > ~/.gemrc; fi -RUN if [ "${ubuntu_version}" = "focal" ] ; then gem install jazzy; fi - -# tools -RUN mkdir -p $HOME/.tools -RUN echo 'export PATH="$HOME/.tools:$PATH"' >> $HOME/.profile - -# swiftformat (until part of the toolchain) - -ARG swiftformat_version=0.48.8 -RUN git clone --branch $swiftformat_version --depth 1 https://github.com/nicklockwood/SwiftFormat $HOME/.tools/swift-format -RUN cd $HOME/.tools/swift-format && swift build -c release -RUN ln -s $HOME/.tools/swift-format/.build/release/swiftformat $HOME/.tools/swiftformat diff --git a/docker/docker-compose.1604.52.yaml b/docker/docker-compose.1604.52.yaml deleted file mode 100644 index 4f74bca9f..000000000 --- a/docker/docker-compose.1604.52.yaml +++ /dev/null @@ -1,18 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:16.04-5.2 - build: - args: - ubuntu_version: "xenial" - swift_version: "5.2" - - test: - image: async-http-client:16.04-5.2 - environment: - - SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:16.04-5.2 diff --git a/docker/docker-compose.1804.53.yaml b/docker/docker-compose.1804.53.yaml deleted file mode 100644 index e9e4e53dc..000000000 --- a/docker/docker-compose.1804.53.yaml +++ /dev/null @@ -1,18 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:18.04-5.3 - build: - args: - ubuntu_version: "bionic" - swift_version: "5.3" - - test: - image: async-http-client:18.04-5.3 - environment: [] - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:18.04-5.3 diff --git a/docker/docker-compose.2004.54.yaml b/docker/docker-compose.2004.54.yaml deleted file mode 100644 index 154540ccb..000000000 --- a/docker/docker-compose.2004.54.yaml +++ /dev/null @@ -1,18 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:20.04-5.4 - build: - args: - ubuntu_version: "focal" - swift_version: "5.4" - - test: - image: async-http-client:20.04-5.4 - environment: [] - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:20.04-5.4 diff --git a/docker/docker-compose.2004.55.yaml b/docker/docker-compose.2004.55.yaml deleted file mode 100644 index 4d0a12ee7..000000000 --- a/docker/docker-compose.2004.55.yaml +++ /dev/null @@ -1,18 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:20.04-5.5 - build: - args: - ubuntu_version: "focal" - swift_version: "5.5" - - test: - image: async-http-client:20.04-5.5 - environment: [] - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:20.04-5.5 diff --git a/docker/docker-compose.2004.56.yaml b/docker/docker-compose.2004.56.yaml deleted file mode 100644 index ed61267a9..000000000 --- a/docker/docker-compose.2004.56.yaml +++ /dev/null @@ -1,18 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:20.04-5.6 - build: - args: - ubuntu_version: "focal" - swift_version: "5.6" - - test: - image: async-http-client:20.04-5.6 - environment: [] - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:20.04-5.6 diff --git a/docker/docker-compose.2004.57.yaml b/docker/docker-compose.2004.57.yaml deleted file mode 100644 index 16c564482..000000000 --- a/docker/docker-compose.2004.57.yaml +++ /dev/null @@ -1,17 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:20.04-5.7 - build: - args: - base_image: "swiftlang/swift:nightly-main-focal" - - test: - image: async-http-client:20.04-5.7 - environment: [] - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:20.04-5.7 diff --git a/docker/docker-compose.2004.main.yaml b/docker/docker-compose.2004.main.yaml deleted file mode 100644 index 11c7517ba..000000000 --- a/docker/docker-compose.2004.main.yaml +++ /dev/null @@ -1,17 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:20.04-main - build: - args: - base_image: "swiftlang/swift:nightly-main-focal" - - test: - image: async-http-client:20.04-main - environment: [] - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:20.04-main diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml deleted file mode 100644 index 6269e953b..000000000 --- a/docker/docker-compose.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# this file is not designed to be run directly -# instead, use the docker-compose.. files -# eg docker-compose -f docker/docker-compose.yaml -f docker/docker-compose.1804.50.yaml run test -version: "3" - -services: - - runtime-setup: - image: async-http-client:default - build: - context: . - dockerfile: Dockerfile - - common: &common - image: async-http-client:default - depends_on: [runtime-setup] - volumes: - - ~/.ssh:/root/.ssh - - ..:/code:z - working_dir: /code - cap_drop: - - CAP_NET_RAW - - CAP_NET_BIND_SERVICE - - soundness: - <<: *common - command: /bin/bash -xcl "./scripts/soundness.sh" - - test: - <<: *common - command: /bin/bash -xcl "swift test --parallel -Xswiftc -warnings-as-errors $${SANITIZER_ARG-}" - - # util - - shell: - <<: *common - entrypoint: /bin/bash - - docs: - <<: *common - command: /bin/bash -cl "./scripts/generate_docs.sh" diff --git a/scripts/check_no_api_breakages.sh b/scripts/check_no_api_breakages.sh deleted file mode 100755 index f380e2423..000000000 --- a/scripts/check_no_api_breakages.sh +++ /dev/null @@ -1,136 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -##===----------------------------------------------------------------------===## -## -## This source file is part of the SwiftNIO open source project -## -## Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of SwiftNIO project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu - -# repodir -function all_modules() { - local repodir="$1" - ( - set -eu - cd "$repodir" - swift package dump-package | jq '.products | - map(select(.type | has("library") )) | - map(.name) | .[]' | tr -d '"' - ) -} - -# repodir tag output -function build_and_do() { - local repodir=$1 - local tag=$2 - local output=$3 - - ( - cd "$repodir" - git checkout -q "$tag" - swift build - while read -r module; do - swift api-digester -sdk "$sdk" -dump-sdk -module "$module" \ - -o "$output/$module.json" -I "$repodir/.build/debug" - done < <(all_modules "$repodir") - ) -} - -function usage() { - echo >&2 "Usage: $0 REPO-GITHUB-URL NEW-VERSION OLD-VERSIONS..." - echo >&2 - echo >&2 "This script requires a Swift 5.1+ toolchain." - echo >&2 - echo >&2 "Examples:" - echo >&2 - echo >&2 "Check between master and tag 2.1.1 of swift-nio:" - echo >&2 " $0 https://github.com/apple/swift-nio master 2.1.1" - echo >&2 - echo >&2 "Check between HEAD and commit 64cf63d7 using the provided toolchain:" - echo >&2 " xcrun --toolchain org.swift.5120190702a $0 ../some-local-repo HEAD 64cf63d7" -} - -if [[ $# -lt 3 ]]; then - usage - exit 1 -fi - -sdk=/ -if [[ "$(uname -s)" == Darwin ]]; then - sdk=$(xcrun --show-sdk-path) -fi - -hash jq 2> /dev/null || { echo >&2 "ERROR: jq must be installed"; exit 1; } -tmpdir=$(mktemp -d /tmp/.check-api_XXXXXX) -repo_url=$1 -new_tag=$2 -shift 2 - -repodir="$tmpdir/repo" -git clone "$repo_url" "$repodir" -git -C "$repodir" fetch -q origin '+refs/pull/*:refs/remotes/origin/pr/*' -errors=0 - -for old_tag in "$@"; do - mkdir "$tmpdir/api-old" - mkdir "$tmpdir/api-new" - - echo "Checking public API breakages from $old_tag to $new_tag" - - build_and_do "$repodir" "$new_tag" "$tmpdir/api-new/" - build_and_do "$repodir" "$old_tag" "$tmpdir/api-old/" - - for f in "$tmpdir/api-new"/*; do - f=$(basename "$f") - report="$tmpdir/$f.report" - if [[ ! -f "$tmpdir/api-old/$f" ]]; then - echo "NOTICE: NEW MODULE $f" - continue - fi - - echo -n "Checking $f... " - swift api-digester -sdk "$sdk" -diagnose-sdk \ - --input-paths "$tmpdir/api-old/$f" -input-paths "$tmpdir/api-new/$f" 2>&1 \ - > "$report" 2>&1 - - if ! shasum "$report" | grep -q afd2a1b542b33273920d65821deddc653063c700; then - echo ERROR - echo >&2 "==============================" - echo >&2 "ERROR: public API change in $f" - echo >&2 "==============================" - cat >&2 "$report" - errors=$(( errors + 1 )) - else - echo OK - fi - done - rm -rf "$tmpdir/api-new" "$tmpdir/api-old" -done - -if [[ "$errors" == 0 ]]; then - echo "OK, all seems good" -fi -echo done -exit "$errors" diff --git a/scripts/generate_contributors_list.sh b/scripts/generate_contributors_list.sh deleted file mode 100755 index 00c162638..000000000 --- a/scripts/generate_contributors_list.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu -here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -contributors=$( cd "$here"/.. && git shortlog -es | cut -f2 | sed 's/^/- /' ) - -cat > "$here/../CONTRIBUTORS.txt" <<- EOF - For the purpose of tracking copyright, this is the list of individuals and - organizations who have contributed source code to the AsyncHTTPClient. - - For employees of an organization/company where the copyright of work done - by employees of that company is held by the company itself, only the company - needs to be listed here. - - ## COPYRIGHT HOLDERS - - - Apple Inc. (all contributors with '@apple.com') - - ### Contributors - - $contributors - - **Updating this list** - - Please do not edit this file manually. It is generated using \`./scripts/generate_contributors_list.sh\`. If a name is misspelled or appearing multiple times: add an entry in \`./.mailmap\` -EOF diff --git a/scripts/generate_docs.sh b/scripts/generate_docs.sh deleted file mode 100755 index 82da814d3..000000000 --- a/scripts/generate_docs.sh +++ /dev/null @@ -1,114 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -e - -my_path="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -root_path="$my_path/.." -version=$(git describe --abbrev=0 --tags || echo "main") -modules=(AsyncHTTPClient) - -if [[ "$(uname -s)" == "Linux" ]]; then - # build code if required - if [[ ! -d "$root_path/.build/x86_64-unknown-linux" ]]; then - swift build - fi - # setup source-kitten if required - mkdir -p "$root_path/.build/sourcekitten" - source_kitten_source_path="$root_path/.build/sourcekitten/source" - if [[ ! -d "$source_kitten_source_path" ]]; then - git clone https://github.com/jpsim/SourceKitten.git "$source_kitten_source_path" - fi - source_kitten_path="$source_kitten_source_path/.build/debug" - if [[ ! -d "$source_kitten_path" ]]; then - rm -rf "$source_kitten_source_path/.swift-version" - cd "$source_kitten_source_path" && swift build && cd "$root_path" - fi - # generate - for module in "${modules[@]}"; do - if [[ ! -f "$root_path/.build/sourcekitten/$module.json" ]]; then - "$source_kitten_path/sourcekitten" doc --spm --module-name $module > "$root_path/.build/sourcekitten/$module.json" - fi - done -fi - -[[ -d docs/$version ]] || mkdir -p docs/$version -[[ -d async-http-client.xcodeproj ]] || swift package generate-xcodeproj - -# run jazzy -if ! command -v jazzy > /dev/null; then - gem install jazzy --no-ri --no-rdoc -fi - -jazzy_dir="$root_path/.build/jazzy" -rm -rf "$jazzy_dir" -mkdir -p "$jazzy_dir" - -module_switcher="$jazzy_dir/README.md" -jazzy_args=(--clean - --author 'AsyncHTTPClient team' - --readme "$module_switcher" - --author_url https://github.com/swift-server/async-http-client - --github_url https://github.com/swift-server/async-http-client - --github-file-prefix "https://github.com/swift-server/async-http-client/tree/$version" - --theme fullwidth - --xcodebuild-arguments -scheme,async-http-client-Package) -cat > "$module_switcher" <<"EOF" -# AsyncHTTPClient Docs - -AsyncHTTPClient is a Swift HTTP Client package. - -To get started with AsyncHTTPClient, [`import AsyncHTTPClient`](../AsyncHTTPClient/index.html). The -most important type is [`HTTPClient`](https://swift-server.github.io/async-http-client/docs/current/AsyncHTTPClient/Classes/HTTPClient.html) -which you can use to emit log messages. - -EOF - -tmp=`mktemp -d` -for module in "${modules[@]}"; do - args=("${jazzy_args[@]}" --output "$jazzy_dir/docs/$version/$module" --docset-path "$jazzy_dir/docset/$version/$module" - --module "$module" --module-version $version - --root-url "https://swift-server.github.io/async-http-client/docs/$version/$module/") - if [[ -f "$root_path/.build/sourcekitten/$module.json" ]]; then - args+=(--sourcekitten-sourcefile "$root_path/.build/sourcekitten/$module.json") - fi - jazzy "${args[@]}" -done - -# push to github pages -if [[ $PUSH == true ]]; then - BRANCH_NAME=$(git rev-parse --abbrev-ref HEAD) - GIT_AUTHOR=$(git --no-pager show -s --format='%an <%ae>' HEAD) - git fetch origin +gh-pages:gh-pages - git checkout gh-pages - rm -rf "docs/$version" - rm -rf "docs/current" - cp -r "$jazzy_dir/docs/$version" docs/ - cp -r "docs/$version" docs/current - git add --all docs - echo '' > index.html - git add index.html - touch .nojekyll - git add .nojekyll - changes=$(git diff-index --name-only HEAD) - if [[ -n "$changes" ]]; then - echo -e "changes detected\n$changes" - git commit --author="$GIT_AUTHOR" -m "publish $version docs" - git push origin gh-pages - else - echo "no changes detected" - fi - git checkout -f $BRANCH_NAME -fi diff --git a/scripts/generate_linux_tests.rb b/scripts/generate_linux_tests.rb deleted file mode 100755 index ed887f83c..000000000 --- a/scripts/generate_linux_tests.rb +++ /dev/null @@ -1,231 +0,0 @@ -#!/usr/bin/env ruby - -# -# process_test_files.rb -# -# Copyright 2016 Tony Stone -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Created by Tony Stone on 5/4/16. -# -require 'getoptlong' -require 'fileutils' -require 'pathname' - -include FileUtils - -# -# This ruby script will auto generate LinuxMain.swift and the +XCTest.swift extension files for Swift Package Manager on Linux platforms. -# -# See https://github.com/apple/swift-corelibs-xctest/blob/master/Documentation/Linux.md -# -def header(fileName) - string = <<-eos -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - eos - - string - .sub('', File.basename(fileName)) - .sub('', Time.now.to_s) -end - -def createExtensionFile(fileName, classes) - extensionFile = fileName.sub! '.swift', '+XCTest.swift' - print 'Creating file: ' + extensionFile + "\n" - - File.open(extensionFile, 'w') do |file| - file.write header(extensionFile) - file.write "\n" - - for classArray in classes - file.write 'extension ' + classArray[0] + " {\n" - file.write ' static var allTests: [(String, (' + classArray[0] + ") -> () throws -> Void)] {\n" - file.write " return [\n" - - for funcName in classArray[1] - file.write ' ("' + funcName + '", ' + funcName + "),\n" - end - - file.write " ]\n" - file.write " }\n" - file.write "}\n" - end - end -end - -def createLinuxMain(testsDirectory, allTestSubDirectories, files) - fileName = testsDirectory + '/LinuxMain.swift' - print 'Creating file: ' + fileName + "\n" - - File.open(fileName, 'w') do |file| - file.write header(fileName) - file.write "\n" - - file.write "#if os(Linux) || os(FreeBSD)\n" - for testSubDirectory in allTestSubDirectories.sort { |x, y| x <=> y } - file.write '@testable import ' + testSubDirectory + "\n" - end - file.write "\n" - file.write "XCTMain([\n" - - testCases = [] - for classes in files - for classArray in classes - testCases << classArray[0] - end - end - - for testCase in testCases.sort { |x, y| x <=> y } - file.write ' testCase(' + testCase + ".allTests),\n" - end - file.write "])\n" - file.write "#endif\n" - end -end - -def parseSourceFile(fileName) - puts 'Parsing file: ' + fileName + "\n" - - classes = [] - currentClass = nil - inIfLinux = false - inElse = false - ignore = false - - # - # Read the file line by line - # and parse to find the class - # names and func names - # - File.readlines(fileName).each do |line| - if inIfLinux - if /\#else/.match(line) - inElse = true - ignore = true - else - if /\#end/.match(line) - inElse = false - inIfLinux = false - ignore = false - end - end - else - if /\#if[ \t]+os\(Linux\)/.match(line) - inIfLinux = true - ignore = false - end - end - - next if ignore - # Match class or func - match = line[/class[ \t]+[a-zA-Z0-9_]*(?=[ \t]*:[ \t]*XCTestCase)|func[ \t]+test[a-zA-Z0-9_]*(?=[ \t]*\(\))/, 0] - if match - - if match[/class/, 0] == 'class' - className = match.sub(/^class[ \t]+/, '') - # - # Create a new class / func structure - # and add it to the classes array. - # - currentClass = [className, []] - classes << currentClass - else # Must be a func - funcName = match.sub(/^func[ \t]+/, '') - # - # Add each func name the the class / func - # structure created above. - # - currentClass[1] << funcName - end - end - end - classes -end - -# -# Main routine -# -# - -testsDirectory = 'Tests' - -options = GetoptLong.new(['--tests-dir', GetoptLong::OPTIONAL_ARGUMENT]) -options.quiet = true - -begin - options.each do |option, value| - case option - when '--tests-dir' - testsDirectory = value - end - end -rescue GetoptLong::InvalidOption -end - -allTestSubDirectories = [] -allFiles = [] - -Dir[testsDirectory + '/*'].each do |subDirectory| - next unless File.directory?(subDirectory) - directoryHasClasses = false - Dir[subDirectory + '/*Test{s,}.swift'].each do |fileName| - next unless File.file? fileName - fileClasses = parseSourceFile(fileName) - - # - # If there are classes in the - # test source file, create an extension - # file for it. - # - next unless fileClasses.count > 0 - createExtensionFile(fileName, fileClasses) - directoryHasClasses = true - allFiles << fileClasses - end - - if directoryHasClasses - allTestSubDirectories << Pathname.new(subDirectory).split.last.to_s - end -end - -# -# Last step is the create a LinuxMain.swift file that -# references all the classes and funcs in the source files. -# -if allFiles.count > 0 - createLinuxMain(testsDirectory, allTestSubDirectories, allFiles) -end -# eof diff --git a/scripts/soundness.sh b/scripts/soundness.sh deleted file mode 100755 index da9a91d24..000000000 --- a/scripts/soundness.sh +++ /dev/null @@ -1,164 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2022 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu -here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -function replace_acceptable_years() { - # this needs to replace all acceptable forms with 'YEARS' - sed -e 's/20[12][0-9]-20[12][0-9]/YEARS/' -e 's/2019/YEARS/' -e 's/2020/YEARS/' -e 's/2021/YEARS/' -e 's/2022/YEARS/' -} - -printf "=> Checking linux tests... " -FIRST_OUT="$(git status --porcelain)" -ruby "$here/../scripts/generate_linux_tests.rb" > /dev/null -SECOND_OUT="$(git status --porcelain)" -if [[ "$FIRST_OUT" != "$SECOND_OUT" ]]; then - printf "\033[0;31mmissing changes!\033[0m\n" - git --no-pager diff - exit 1 -else - printf "\033[0;32mokay.\033[0m\n" -fi - -printf "=> Checking for unacceptable language... " -# This greps for unacceptable terminology. The square bracket[s] are so that -# "git grep" doesn't find the lines that greps :). -unacceptable_terms=( - -e blacklis[t] - -e whitelis[t] - -e slav[e] - -e sanit[y] -) -if git grep --color=never -i "${unacceptable_terms[@]}" > /dev/null; then - printf "\033[0;31mUnacceptable language found.\033[0m\n" - git grep -i "${unacceptable_terms[@]}" - exit 1 -fi -printf "\033[0;32mokay.\033[0m\n" - -printf "=> Checking format... " -FIRST_OUT="$(git status --porcelain)" -swiftformat . > /dev/null 2>&1 -SECOND_OUT="$(git status --porcelain)" -if [[ "$FIRST_OUT" != "$SECOND_OUT" ]]; then - printf "\033[0;31mformatting issues!\033[0m\n" - git --no-pager diff - exit 1 -else - printf "\033[0;32mokay.\033[0m\n" -fi - -printf "=> Checking license headers\n" -tmp=$(mktemp /tmp/.async-http-client-soundness_XXXXXX) - -for language in swift-or-c bash dtrace; do - printf " * $language... " - declare -a matching_files - declare -a exceptions - expections=( ) - matching_files=( -name '*' ) - case "$language" in - swift-or-c) - exceptions=( -name c_nio_http_parser.c -o -name c_nio_http_parser.h -o -name cpp_magic.h -o -name Package.swift -o -name CNIOSHA1.h -o -name c_nio_sha1.c -o -name ifaddrs-android.c -o -name ifaddrs-android.h) - matching_files=( -name '*.swift' -o -name '*.c' -o -name '*.h' ) - cat > "$tmp" <<"EOF" -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) YEARS Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -EOF - ;; - bash) - matching_files=( -name '*.sh' ) - cat > "$tmp" <<"EOF" -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) YEARS Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## -EOF - ;; - dtrace) - matching_files=( -name '*.d' ) - cat > "$tmp" <<"EOF" -#!/usr/sbin/dtrace -q -s -/*===----------------------------------------------------------------------===* - * - * This source file is part of the AsyncHTTPClient open source project - * - * Copyright (c) YEARS Apple Inc. and the AsyncHTTPClient project authors - * Licensed under Apache License v2.0 - * - * See LICENSE.txt for license information - * See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors - * - * SPDX-License-Identifier: Apache-2.0 - * - *===----------------------------------------------------------------------===*/ -EOF - ;; - *) - echo >&2 "ERROR: unknown language '$language'" - ;; - esac - - expected_lines=$(cat "$tmp" | wc -l) - expected_sha=$(cat "$tmp" | shasum) - - ( - cd "$here/.." - find . \ - \( \! -path './.build/*' -a \ - \( "${matching_files[@]}" \) -a \ - \( \! \( "${exceptions[@]}" \) \) \) | while read line; do - if [[ "$(cat "$line" | replace_acceptable_years | head -n $expected_lines | shasum)" != "$expected_sha" ]]; then - printf "\033[0;31mmissing headers in file '$line'!\033[0m\n" - diff -u <(cat "$line" | replace_acceptable_years | head -n $expected_lines) "$tmp" - exit 1 - fi - done - printf "\033[0;32mokay.\033[0m\n" - ) -done - -rm "$tmp" - -# This checks for the umbrella NIO module. -printf "=> Checking for imports of umbrella NIO module... " -if git grep --color=never -i "^[ \t]*import \+NIO[ \t]*$" > /dev/null; then - printf "\033[0;31mUmbrella imports found.\033[0m\n" - git grep -i "^[ \t]*import \+NIO[ \t]*$" - exit 1 -fi -printf "\033[0;32mokay.\033[0m\n"