Skip to content

Commit e627693

Browse files
committed
Don't hardcode HTTP protocol
1 parent 97c9503 commit e627693

File tree

1 file changed

+42
-47
lines changed

1 file changed

+42
-47
lines changed

Request Ranger/Networking/ProxyHandler.swift

Lines changed: 42 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,6 @@ import NIOHTTP1
1111
import Logging
1212
import Atomics
1313

14-
enum HttpProtocol {
15-
case HTTP
16-
case HTTPS
17-
}
18-
19-
func convertToClientRequestPart(_ reqPart: HTTPServerRequestPart) -> HTTPClientRequestPart {
20-
switch reqPart {
21-
case .head(let head):
22-
return .head(head)
23-
case .body(let buffer):
24-
return .body(.byteBuffer(buffer))
25-
case .end(let headers):
26-
return .end(headers)
27-
}
28-
}
29-
3014
final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equatable {
3115
static func == (lhs: ProxyHandler, rhs: ProxyHandler) -> Bool {
3216
return false
@@ -40,6 +24,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
4024
private var logger: Logger
4125
private var targetHost: String?
4226
private var targetPort: Int?
27+
private var targetProtocol: HttpProtocol?
4328
private static let globalRequestID = ManagedAtomic<Int>(0) // FIXME: should initialize with latest saved ID
4429
public var requestParts: [HTTPClientRequestPart] = []
4530
private var waitingContext: ChannelHandlerContext?
@@ -62,11 +47,20 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
6247
case connectRequested
6348
}
6449

65-
public func forwardRequestForProtocol(httpProtocol: HttpProtocol, context: ChannelHandlerContext, requestParts: [HTTPClientRequestPart]) {
50+
enum HttpProtocol {
51+
case HTTP
52+
case HTTPS
53+
}
54+
55+
public func forwardRequestForProtocol(context: ChannelHandlerContext, requestParts: [HTTPClientRequestPart]) {
6656
guard let port = self.targetPort else {
6757
fatalError("Port was not passed")
6858
}
6959

60+
guard let httpProtocol = self.targetProtocol else {
61+
fatalError("Targer protocol was not passed")
62+
}
63+
7064
if httpProtocol == .HTTPS {
7165
forwardRequestForHttps(context: context, port: port, requestParts: requestParts)
7266
} else {
@@ -81,18 +75,18 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
8175
switch upgradeState {
8276
case .idle:
8377
let channelFuture = clientBootstrap.channelInitializer { channel in
84-
let sslContext = try! NIOSSLContext(configuration: .makeClientConfiguration())
85-
let sslHandler = try! NIOSSLClientHandler(context: sslContext, serverHostname: self.targetHost!)
86-
87-
return channel.pipeline.addHandler(sslHandler).flatMap {
88-
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
89-
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
90-
.flatMap {
91-
channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!))
92-
}
93-
}
78+
let sslContext = try! NIOSSLContext(configuration: .makeClientConfiguration())
79+
let sslHandler = try! NIOSSLClientHandler(context: sslContext, serverHostname: self.targetHost!)
80+
81+
return channel.pipeline.addHandler(sslHandler).flatMap {
82+
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
83+
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
84+
.flatMap {
85+
channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!))
86+
}
9487
}
9588
}
89+
}
9690
.connect(host: targetHost!, port: port)
9791

9892
channelFuture.whenSuccess { channel in
@@ -117,13 +111,13 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
117111
switch upgradeState {
118112
case .idle:
119113
let channelFuture = clientBootstrap.channelInitializer { channel in
120-
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
121-
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
122-
.flatMap {
123-
return channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!))
124-
}
125-
}
114+
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
115+
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
116+
.flatMap {
117+
return channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!))
118+
}
126119
}
120+
}
127121
.connect(host: targetHost!, port: port)
128122

129123
channelFuture.whenSuccess { channel in
@@ -206,7 +200,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
206200
}
207201
}
208202

209-
func handleRequest(context: ChannelHandlerContext, reqPart: HTTPServerRequestPart, httpProtocol: HttpProtocol, port: Int) {
203+
func handleRequest(context: ChannelHandlerContext, reqPart: HTTPServerRequestPart, httpProtocol: HttpProtocol) {
210204
switch(reqPart) {
211205
case .head(var head):
212206
// Remove the Accept-Encoding header
@@ -223,7 +217,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
223217
requestParts.append(clientReqPart)
224218
case .end(let headers):
225219
requestParts.append(HTTPClientRequestPart.end(headers))
226-
220+
227221
if AppState.shared.isInterceptEnabled {
228222
waitingContext = context
229223

@@ -238,20 +232,20 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
238232
}
239233
}
240234
} else {
241-
forwardRequestForProtocol(httpProtocol: httpProtocol, context: context, requestParts: requestParts)
235+
forwardRequestForProtocol(context: context, requestParts: requestParts)
242236
}
243237
}
244238
}
245239

246240
func channelReadForHttps(context: ChannelHandlerContext, data: NIOAny) {
247241
let reqPart = self.unwrapInboundIn(data)
248-
handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTPS, port: 443)
242+
handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTPS)
249243
}
250244

251245
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
252246
let reqPart = self.unwrapInboundIn(data)
253247
print("Invoked channel read")
254-
248+
255249
switch(upgradeState) {
256250
case .connectRequested:
257251
print("Connect return")
@@ -279,38 +273,40 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
279273
}
280274
self.targetHost = newHost
281275
self.targetPort = originalURI.port ?? 80
276+
self.targetProtocol = .HTTP
282277

283278
head.headers.replaceOrAdd(name: "Host", value: newHost)
284279
head.uri = originalURI.relativePath
285280

286-
handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP, port: 80)
281+
handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP)
287282
}
288283
default:
289-
handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP, port: 80)
284+
handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP)
290285
}
291286
}
292287

293288
private func handleConnectRequest(context: ChannelHandlerContext, head: inout HTTPRequestHead) {
294289
let uriComponents = head.uri.split(separator: ":", maxSplits: 1, omittingEmptySubsequences: false)
295290
self.targetHost = String(uriComponents.first!)
296291
self.targetPort = Int(uriComponents.last!)
292+
self.targetProtocol = .HTTPS
297293

298294
guard let targetHost = self.targetHost,
299295
let _ = self.targetPort else {
300296
sendHttpResponse(ctx: context, status: .badRequest)
301297
context.close(promise: nil)
302298
return
303299
}
304-
300+
305301
let selfSignedCertAndKey = CertificateManager.shared.certificateForDomain(String(targetHost))
306302
let selfSignedRootCa = try! CertificateManager.shared.loadRootCAFromKeychain()
307303

308304
var serializer = DER.Serializer()
309305
try! serializer.serialize(selfSignedCertAndKey!.certificate)
310-
306+
311307
var selfSignedRootCaSerializer = DER.Serializer()
312308
try! selfSignedRootCaSerializer.serialize(selfSignedRootCa.rootCertificate)
313-
309+
314310

315311
let certificate = try! NIOSSLCertificate(bytes: serializer.serializedBytes, format: .der)
316312
let privateKey = try! NIOSSLPrivateKey(bytes: [UInt8](selfSignedCertAndKey!.privateKey.derRepresentation), format: .der)
@@ -319,13 +315,13 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
319315
let rootCert = NIOSSLCertificateSource.certificate(rootCertificate)
320316
let tlsConfiguration = TLSConfiguration.makeServerConfiguration(certificateChain: [serverCert, rootCert], privateKey: .privateKey(privateKey))
321317
let sslContext = try! NIOSSLContext(configuration: tlsConfiguration)
322-
318+
323319
let sslHandler = NIOSSLServerHandler(context: sslContext)
324320

325321
let responseHead = HTTPResponseHead(version: .http1_1, status: .ok)
326322
let responsePart = HTTPServerResponsePart.head(responseHead)
327323
self.preUpgradedRequest = context
328-
324+
329325
context.writeAndFlush(self.wrapOutboundOut(responsePart))
330326
.flatMap { _ in
331327
context.pipeline.removeHandler(name: "HTTPResponseEncoder")
@@ -386,8 +382,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
386382
let updatedRequestParts = parseRawRequest(rawRequest: rawRequest)
387383

388384
// Forward the updated request
389-
// FIXME: don't hardcode used protocol
390-
forwardRequestForProtocol(httpProtocol: .HTTP, context: waitingContext!, requestParts: updatedRequestParts)
385+
forwardRequestForProtocol(context: waitingContext!, requestParts: updatedRequestParts)
391386
}
392387

393388
private func parseRawRequest(rawRequest: String) -> [HTTPClientRequestPart] {

0 commit comments

Comments
 (0)