Skip to content

Commit 2166cbe

Browse files
authored
Add support for proxying in WebsocketClient (vapor#130)
* WebsocketClient supports proxying Added support for TLS and plain text proxying of Websocket traffic. In the TLS case a CONNECT header is first sent establishing the proxied traffic. In the plain text case the modified URI in the initial upgrade request header indicates to the proxy server that the traffic is to be proxied. Use `NIOWebSocketFrameAggregator` to handle aggregating frame fragments. This brings with it more protections e.g. against memory exhaustion. Accompanying config has been added to support this change. * Reduce allocations and copies in WebSocket.send Reduce allocation and copies necessary to send `ByteBuffer` and `ByteBufferView` through `WebSocket.send`. In fact sending `ByteBuffer` or `ByteBufferView` doesn’t require any allocation or copy of the data. Sending a `String` now correctly pre allocates the `ByteBuffer` if multibyte characters are present in the `String`. Remove custom random websocket mask generation which would only generate bytes between `UInt8.min..<UInt8.max`, therefore excluding `UInt8.max` aka `255`. * add DocC comments * DocC comments for new APIs
1 parent 2b88859 commit 2166cbe

File tree

8 files changed

+911
-65
lines changed

8 files changed

+911
-65
lines changed

NOTICES.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,12 @@ This product contains a derivation of `NIOSSLTestHelpers.swift` from SwiftNIO SS
1717
* https://www.apache.org/licenses/LICENSE-2.0
1818
* HOMEPAGE:
1919
* https://github.com/apple/swift-nio-ssl
20+
21+
---
22+
23+
This product contains derivations of "HTTPProxySimulator" and "HTTPBin" test utils from AsyncHTTPClient.
24+
25+
* LICENSE (Apache License 2.0):
26+
* https://www.apache.org/licenses/LICENSE-2.0
27+
* HOMEPAGE:
28+
* https://github.com/swift-server/async-http-client

Package.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ let package = Package(
1313
],
1414
dependencies: [
1515
.package(url: "https://github.com/apple/swift-nio.git", from: "2.33.0"),
16+
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.16.0"),
1617
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.0"),
1718
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"),
1819
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"),
@@ -22,6 +23,7 @@ let package = Package(
2223
.product(name: "NIO", package: "swift-nio"),
2324
.product(name: "NIOCore", package: "swift-nio"),
2425
.product(name: "NIOConcurrencyHelpers", package: "swift-nio"),
26+
.product(name: "NIOExtras", package: "swift-nio-extras"),
2527
.product(name: "NIOFoundationCompat", package: "swift-nio"),
2628
.product(name: "NIOHTTP1", package: "swift-nio"),
2729
.product(name: "NIOSSL", package: "swift-nio-ssl"),

Sources/WebSocketKit/HTTPInitialRequestHandler.swift renamed to Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import NIO
22
import NIOHTTP1
33

4-
final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHandler {
4+
final class HTTPUpgradeRequestHandler: ChannelInboundHandler, RemovableChannelHandler {
55
typealias InboundIn = HTTPClientResponsePart
66
typealias OutboundOut = HTTPClientRequestPart
77

@@ -11,6 +11,8 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa
1111
let headers: HTTPHeaders
1212
let upgradePromise: EventLoopPromise<Void>
1313

14+
private var requestSent = false
15+
1416
init(host: String, path: String, query: String?, headers: HTTPHeaders, upgradePromise: EventLoopPromise<Void>) {
1517
self.host = host
1618
self.path = path
@@ -20,10 +22,33 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa
2022
}
2123

2224
func channelActive(context: ChannelHandlerContext) {
25+
self.sendRequest(context: context)
26+
context.fireChannelActive()
27+
}
28+
29+
func handlerAdded(context: ChannelHandlerContext) {
30+
if context.channel.isActive {
31+
self.sendRequest(context: context)
32+
}
33+
}
34+
35+
private func sendRequest(context: ChannelHandlerContext) {
36+
if self.requestSent {
37+
// we might run into this handler twice, once in handlerAdded and once in channelActive.
38+
return
39+
}
40+
self.requestSent = true
41+
2342
var headers = self.headers
2443
headers.add(name: "Host", value: self.host)
2544

26-
var uri = self.path.hasPrefix("/") ? self.path : "/" + self.path
45+
var uri: String
46+
if self.path.hasPrefix("/") || self.path.hasPrefix("ws://") || self.path.hasPrefix("wss://") {
47+
uri = self.path
48+
} else {
49+
uri = "/" + self.path
50+
}
51+
2752
if let query = self.query {
2853
uri += "?\(query)"
2954
}
@@ -43,10 +68,13 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa
4368
}
4469

4570
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
71+
// `NIOHTTPClientUpgradeHandler` should consume the first response in the success case,
72+
// any response we see here indicates a failure. Report the failure and tidy up at the end of the response.
4673
let clientResponse = self.unwrapInboundIn(data)
4774
switch clientResponse {
4875
case .head(let responseHead):
49-
self.upgradePromise.fail(WebSocketClient.Error.invalidResponseStatus(responseHead))
76+
let error = WebSocketClient.Error.invalidResponseStatus(responseHead)
77+
self.upgradePromise.fail(error)
5078
case .body: break
5179
case .end:
5280
context.close(promise: nil)

Sources/WebSocketKit/WebSocket+Connect.swift

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@ import NIOHTTP1
33
import Foundation
44

55
extension WebSocket {
6+
/// Establish a WebSocket connection.
7+
///
8+
/// - Parameters:
9+
/// - url: URL for the WebSocket server.
10+
/// - headers: Headers to send to the WebSocket server.
11+
/// - configuration: Configuration for the WebSocket client.
12+
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
13+
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
14+
/// - Returns: An future which completes when the connection to the WebSocket server is established.
615
public static func connect(
716
to url: String,
817
headers: HTTPHeaders = [:],
@@ -22,6 +31,15 @@ extension WebSocket {
2231
)
2332
}
2433

34+
/// Establish a WebSocket connection.
35+
///
36+
/// - Parameters:
37+
/// - url: URL for the WebSocket server.
38+
/// - headers: Headers to send to the WebSocket server.
39+
/// - configuration: Configuration for the WebSocket client.
40+
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
41+
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
42+
/// - Returns: An future which completes when the connection to the WebSocket server is established.
2543
public static func connect(
2644
to url: URL,
2745
headers: HTTPHeaders = [:],
@@ -43,6 +61,19 @@ extension WebSocket {
4361
)
4462
}
4563

64+
/// Establish a WebSocket connection.
65+
///
66+
/// - Parameters:
67+
/// - scheme: Scheme component of the URI for the WebSocket server.
68+
/// - host: Host component of the URI for the WebSocket server.
69+
/// - port: Port on which to connect to the WebSocket server.
70+
/// - path: Path component of the URI for the WebSocket server.
71+
/// - query: Query component of the URI for the WebSocket server.
72+
/// - headers: Headers to send to the WebSocket server.
73+
/// - configuration: Configuration for the WebSocket client.
74+
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
75+
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
76+
/// - Returns: An future which completes when the connection to the WebSocket server is established.
4677
public static func connect(
4778
scheme: String = "ws",
4879
host: String,
@@ -67,4 +98,98 @@ extension WebSocket {
6798
onUpgrade: onUpgrade
6899
)
69100
}
101+
102+
/// Establish a WebSocket connection via a proxy server.
103+
///
104+
/// - Parameters:
105+
/// - scheme: Scheme component of the URI for the origin server.
106+
/// - host: Host component of the URI for the origin server.
107+
/// - port: Port on which to connect to the origin server.
108+
/// - path: Path component of the URI for the origin server.
109+
/// - query: Query component of the URI for the origin server.
110+
/// - headers: Headers to send to the origin server.
111+
/// - proxy: Host component of the URI for the proxy server.
112+
/// - proxyPort: Port on which to connect to the proxy server.
113+
/// - proxyHeaders: Headers to send to the proxy server.
114+
/// - proxyConnectDeadline: Deadline for establishing the proxy connection.
115+
/// - configuration: Configuration for the WebSocket client.
116+
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
117+
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
118+
/// - Returns: An future which completes when the connection to the origin server is established.
119+
public static func connect(
120+
scheme: String = "ws",
121+
host: String,
122+
port: Int = 80,
123+
path: String = "/",
124+
query: String? = nil,
125+
headers: HTTPHeaders = [:],
126+
proxy: String?,
127+
proxyPort: Int? = nil,
128+
proxyHeaders: HTTPHeaders = [:],
129+
proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture,
130+
configuration: WebSocketClient.Configuration = .init(),
131+
on eventLoopGroup: EventLoopGroup,
132+
onUpgrade: @escaping (WebSocket) -> ()
133+
) -> EventLoopFuture<Void> {
134+
return WebSocketClient(
135+
eventLoopGroupProvider: .shared(eventLoopGroup),
136+
configuration: configuration
137+
).connect(
138+
scheme: scheme,
139+
host: host,
140+
port: port,
141+
path: path,
142+
query: query,
143+
headers: headers,
144+
proxy: proxy,
145+
proxyPort: proxyPort,
146+
proxyHeaders: proxyHeaders,
147+
proxyConnectDeadline: proxyConnectDeadline,
148+
onUpgrade: onUpgrade
149+
)
150+
}
151+
152+
153+
/// Description
154+
/// - Parameters:
155+
/// - url: URL for the origin server.
156+
/// - headers: Headers to send to the origin server.
157+
/// - proxy: Host component of the URI for the proxy server.
158+
/// - proxyPort: Port on which to connect to the proxy server.
159+
/// - proxyHeaders: Headers to send to the proxy server.
160+
/// - proxyConnectDeadline: Deadline for establishing the proxy connection.
161+
/// - configuration: Configuration for the WebSocket client.
162+
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
163+
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
164+
/// - Returns: An future which completes when the connection to the origin server is established.
165+
public static func connect(
166+
to url: String,
167+
headers: HTTPHeaders = [:],
168+
proxy: String?,
169+
proxyPort: Int? = nil,
170+
proxyHeaders: HTTPHeaders = [:],
171+
proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture,
172+
configuration: WebSocketClient.Configuration = .init(),
173+
on eventLoopGroup: EventLoopGroup,
174+
onUpgrade: @escaping (WebSocket) -> ()
175+
) -> EventLoopFuture<Void> {
176+
guard let url = URL(string: url) else {
177+
return eventLoopGroup.next().makeFailedFuture(WebSocketClient.Error.invalidURL)
178+
}
179+
let scheme = url.scheme ?? "ws"
180+
return self.connect(
181+
scheme: scheme,
182+
host: url.host ?? "localhost",
183+
port: url.port ?? (scheme == "wss" ? 443 : 80),
184+
path: url.path,
185+
query: url.query,
186+
headers: headers,
187+
proxy: proxy,
188+
proxyPort: proxyPort,
189+
proxyHeaders: proxyHeaders,
190+
proxyConnectDeadline: proxyConnectDeadline,
191+
on: eventLoopGroup,
192+
onUpgrade: onUpgrade
193+
)
194+
}
70195
}

Sources/WebSocketKit/WebSocket.swift

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ public final class WebSocket {
2424
self.channel.closeFuture
2525
}
2626

27-
private let channel: Channel
27+
@usableFromInline
28+
/* private but @usableFromInline */
29+
internal let channel: Channel
2830
private var onTextCallback: (WebSocket, String) -> ()
2931
private var onBinaryCallback: (WebSocket, ByteBuffer) -> ()
3032
private var onPongCallback: (WebSocket) -> ()
@@ -64,10 +66,10 @@ public final class WebSocket {
6466
}
6567

6668
/// If set, this will trigger automatic pings on the connection. If ping is not answered before
67-
/// the next ping is sent, then the WebSocket will be presumed innactive and will be closed
69+
/// the next ping is sent, then the WebSocket will be presumed inactive and will be closed
6870
/// automatically.
6971
/// These pings can also be used to keep the WebSocket alive if there is some other timeout
70-
/// mechanism shutting down innactive connections, such as a Load Balancer deployed in
72+
/// mechanism shutting down inactive connections, such as a Load Balancer deployed in
7173
/// front of the server.
7274
public var pingInterval: TimeAmount? {
7375
didSet {
@@ -82,13 +84,13 @@ public final class WebSocket {
8284
}
8385
}
8486

87+
@inlinable
8588
public func send<S>(_ text: S, promise: EventLoopPromise<Void>? = nil)
8689
where S: Collection, S.Element == Character
8790
{
8891
let string = String(text)
89-
var buffer = channel.allocator.buffer(capacity: text.count)
90-
buffer.writeString(string)
91-
self.send(raw: buffer.readableBytesView, opcode: .text, fin: true, promise: promise)
92+
let buffer = channel.allocator.buffer(string: string)
93+
self.send(buffer, opcode: .text, fin: true, promise: promise)
9294

9395
}
9496

@@ -105,6 +107,7 @@ public final class WebSocket {
105107
)
106108
}
107109

110+
@inlinable
108111
public func send<Data>(
109112
raw data: Data,
110113
opcode: WebSocketOpcode,
@@ -113,13 +116,32 @@ public final class WebSocket {
113116
)
114117
where Data: DataProtocol
115118
{
116-
var buffer = channel.allocator.buffer(capacity: data.count)
117-
buffer.writeBytes(data)
119+
if let byteBufferView = data as? ByteBufferView {
120+
// optimisation: converting from `ByteBufferView` to `ByteBuffer` doesn't allocate or copy any data
121+
send(ByteBuffer(byteBufferView), opcode: opcode, fin: fin, promise: promise)
122+
} else {
123+
let buffer = channel.allocator.buffer(bytes: data)
124+
send(buffer, opcode: opcode, fin: fin, promise: promise)
125+
}
126+
}
127+
128+
/// Send the provided data in a WebSocket frame.
129+
/// - Parameters:
130+
/// - data: Data to be sent.
131+
/// - opcode: Frame opcode.
132+
/// - fin: The value of the fin bit.
133+
/// - promise: A promise to be completed when the write is complete.
134+
public func send(
135+
_ data: ByteBuffer,
136+
opcode: WebSocketOpcode = .binary,
137+
fin: Bool = true,
138+
promise: EventLoopPromise<Void>? = nil
139+
) {
118140
let frame = WebSocketFrame(
119141
fin: fin,
120142
opcode: opcode,
121143
maskKey: self.makeMaskKey(),
122-
data: buffer
144+
data: data
123145
)
124146
self.channel.writeAndFlush(frame, promise: promise)
125147
}
@@ -164,11 +186,7 @@ public final class WebSocket {
164186
func makeMaskKey() -> WebSocketMaskingKey? {
165187
switch type {
166188
case .client:
167-
var bytes: [UInt8] = []
168-
for _ in 0..<4 {
169-
bytes.append(.random(in: .min ..< .max))
170-
}
171-
return WebSocketMaskingKey(bytes)
189+
return WebSocketMaskingKey.random()
172190
case .server:
173191
return nil
174192
}
@@ -237,14 +255,8 @@ public final class WebSocket {
237255
frameSequence.append(frame)
238256
self.frameSequence = frameSequence
239257
case .continuation:
240-
// we must have an existing sequence
241-
if var frameSequence = self.frameSequence {
242-
// append this frame and update
243-
frameSequence.append(frame)
244-
self.frameSequence = frameSequence
245-
} else {
246-
self.close(code: .protocolError, promise: nil)
247-
}
258+
/// continuations are filtered by ``NIOWebSocketFrameAggregator``
259+
preconditionFailure("We will never receive a continuation frame")
248260
default:
249261
// We ignore all other frames.
250262
break

0 commit comments

Comments
 (0)