Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Sources/NIOHTTP1/HTTPHeaderValidator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,16 @@ public final class NIOHTTPResponseHeadersValidator: ChannelOutboundHandler, Remo
}

private var state: State
private let sendResponseOnInvalidHeader: Bool

public init() {
self.state = .validating
self.sendResponseOnInvalidHeader = false
}

public init(pipelineConfiguration: ChannelPipeline.SynchronousOperations.Configuration) {
self.state = .validating
self.sendResponseOnInvalidHeader = pipelineConfiguration.headerValidationResponse
}

public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
Expand All @@ -82,6 +89,14 @@ public final class NIOHTTPResponseHeadersValidator: ChannelOutboundHandler, Remo
if head.headers.areValidToSend {
context.write(data, promise: promise)
} else {
// We won't write another header since we drop them going forward to write
// out a response if configured to do so
if self.sendResponseOnInvalidHeader {
let headers = HTTPHeaders([("Connection", "close"), ("Content-Length", "0")])
let head = HTTPResponseHead(version: .http1_1, status: .internalServerError, headers: headers)
context.write(Self.wrapOutboundOut(.head(head)), promise: nil)
context.writeAndFlush(Self.wrapOutboundOut(.end(nil)), promise: nil)
}
self.state = .dropping
promise?.fail(HTTPParserError.invalidHeaderToken)
context.fireErrorCaught(HTTPParserError.invalidHeaderToken)
Expand Down
68 changes: 66 additions & 2 deletions Sources/NIOHTTP1/HTTPPipelineSetup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -913,13 +913,67 @@ extension ChannelPipeline.SynchronousOperations {
)
}

/// Configure a `ChannelPipeline` for use as a HTTP server.
///
/// This function knows how to set up all first-party HTTP channel handlers appropriately
/// for server use. It supports the following features:
///
/// 1. Providing assistance handling clients that pipeline HTTP requests, using the
/// `HTTPServerPipelineHandler`.
/// 2. Supporting HTTP upgrade, using the `HTTPServerUpgradeHandler`.
/// 3. Providing assistance handling protocol errors.
/// 4. Validating outbound header fields to protect against response splitting attacks.
/// 5. Specifying whether the header validation should return a response
///
/// This method will likely be extended in future with more support for other first-party
/// features.
///
/// - important: This **must** be called on the Channel's event loop.
/// - Parameters:
/// - position: Where in the pipeline to add the HTTP server handlers, defaults to `.last`.
/// - pipelining: Whether to provide assistance handling HTTP clients that pipeline
/// their requests. Defaults to `true`. If `false`, users will need to handle
/// clients that pipeline themselves.
/// - upgrade: Whether to add a `HTTPServerUpgradeHandler` to the pipeline, configured for
/// HTTP upgrade. Defaults to `nil`, which will not add the handler to the pipeline. If
/// provided should be a tuple of an array of `HTTPServerProtocolUpgrader` and the upgrade
/// completion handler. See the documentation on `HTTPServerUpgradeHandler` for more
/// details.
/// - errorHandling: Whether to provide assistance handling protocol errors (e.g.
/// failure to parse the HTTP request) by sending 400 errors. Defaults to `true`.
/// - headerValidation: Whether to validate outbound request headers to confirm that they meet
/// spec compliance. Defaults to `true`.
/// - encoderConfiguration: The configuration for the ``HTTPRequestEncoder``.
/// - configuration: Confguration for setting up for the pipeline. Provides additional options
/// for configuring the pipeline.
/// - Throws: If the pipeline could not be configured.
public func configureHTTPServerPipeline(
position: ChannelPipeline.SynchronousOperations.Position = .last,
withPipeliningAssistance pipelining: Bool = true,
withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil,
withErrorHandling errorHandling: Bool = true,
withOutboundHeaderValidation headerValidation: Bool = true,
withEncoderConfiguration encoderConfiguration: HTTPResponseEncoder.Configuration = .init(),
withConfiguration configuration: Configuration
) throws {
try self._configureHTTPServerPipeline(
position: position,
withPipeliningAssistance: pipelining,
withServerUpgrade: upgrade,
withErrorHandling: errorHandling,
withOutboundHeaderValidation: headerValidation,
configuration: configuration
)
}

private func _configureHTTPServerPipeline(
position: ChannelPipeline.SynchronousOperations.Position = .last,
withPipeliningAssistance pipelining: Bool = true,
withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil,
withErrorHandling errorHandling: Bool = true,
withOutboundHeaderValidation headerValidation: Bool = true,
withEncoderConfiguration encoderConfiguration: HTTPResponseEncoder.Configuration = .init()
withEncoderConfiguration encoderConfiguration: HTTPResponseEncoder.Configuration = .init(),
configuration: Configuration = .init(),
) throws {
self.eventLoop.assertInEventLoop()

Expand All @@ -933,7 +987,7 @@ extension ChannelPipeline.SynchronousOperations {
}

if headerValidation {
handlers.append(NIOHTTPResponseHeadersValidator())
handlers.append(NIOHTTPResponseHeadersValidator(pipelineConfiguration: configuration))
}

if errorHandling {
Expand All @@ -952,4 +1006,14 @@ extension ChannelPipeline.SynchronousOperations {

try self.addHandlers(handlers, position: position)
}

/// Configuration for setting up an HTTP client pipeline.
public struct Configuration {
/// Whether or not a response is returned when the header validation fails.
public var headerValidationResponse: Bool

public init() {
self.headerValidationResponse = false
}
}
}
55 changes: 55 additions & 0 deletions Tests/NIOHTTP1Tests/HTTPHeaderValidationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,61 @@ final class HTTPHeaderValidationTests: XCTestCase {
XCTAssertEqual(maybeReceivedHeadBytes, toleratedRequestBytes)
XCTAssertEqual(maybeReceivedTrailerBytes, toleratedTrailerBytes)
}

func testBadRequestResponseIsReturnedIfHeadersInvalidAndConfiguredToDoSo() throws {
let channel = EmbeddedChannel()
var pipelineConfig = ChannelPipeline.SynchronousOperations.Configuration()
pipelineConfig.headerValidationResponse = true
try channel.pipeline.syncOperations.configureHTTPServerPipeline(withConfiguration: pipelineConfig)
try channel.primeForResponse()

func assertReadHead(from channel: EmbeddedChannel) throws {
if case .head = try channel.readInbound(as: HTTPServerRequestPart.self) {
()
} else {
XCTFail("Expected 'head'")
}
}

func assertReadEnd(from channel: EmbeddedChannel) throws {
if case .end = try channel.readInbound(as: HTTPServerRequestPart.self) {
()
} else {
XCTFail("Expected 'end'")
}
}

// Read the first request.
try assertReadHead(from: channel)
try assertReadEnd(from: channel)
XCTAssertNil(try channel.readInbound(as: HTTPServerRequestPart.self))

// Respond with bad headers; they should cause an error and result in the rest of the
// response being dropped, but a fallback response being sent
let head = HTTPResponseHead(version: .http1_1, status: .ok, headers: [":pseudo-header": "not-here"])
XCTAssertThrowsError(try channel.writeOutbound(HTTPServerResponsePart.head(head)))

// We expect exactly one ByteBuffer in the output.
guard var written = try channel.readOutbound(as: ByteBuffer.self) else {
XCTFail("No writes")
return
}

XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound()))

// Check the response.
assertResponseIs(
response: written.readString(length: written.readableBytes)!,
expectedResponseLine: "HTTP/1.1 500 Internal Server Error",
expectedResponseHeaders: ["Connection: close", "Content-Length: 0"]
)
XCTAssertThrowsError(try channel.writeOutbound(HTTPServerResponsePart.body(.byteBuffer(ByteBuffer()))))
XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self))
XCTAssertThrowsError(try channel.writeOutbound(HTTPServerResponsePart.end(nil)))
XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self))

_ = try? channel.finish()
}
}

extension EmbeddedChannel {
Expand Down