@@ -11,22 +11,6 @@ import NIOHTTP1
1111import Logging
1212import Atomics
1313
14- enum HttpProtocol {
15- case HTTP
16- case HTTPS
17- }
18-
19- func convertToClientRequestPart( _ reqPart: HTTPServerRequestPart ) -> HTTPClientRequestPart {
20- switch reqPart {
21- case . head( let head) :
22- return . head( head)
23- case . body( let buffer) :
24- return . body( . byteBuffer( buffer) )
25- case . end( let headers) :
26- return . end( headers)
27- }
28- }
29-
3014final class ProxyHandler : ChannelInboundHandler , RemovableChannelHandler , Equatable {
3115 static func == ( lhs: ProxyHandler , rhs: ProxyHandler ) -> Bool {
3216 return false
@@ -40,6 +24,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
4024 private var logger : Logger
4125 private var targetHost : String ?
4226 private var targetPort : Int ?
27+ private var targetProtocol : HttpProtocol ?
4328 private static let globalRequestID = ManagedAtomic < Int > ( 0 ) // FIXME: should initialize with latest saved ID
4429 public var requestParts : [ HTTPClientRequestPart ] = [ ]
4530 private var waitingContext : ChannelHandlerContext ?
@@ -62,11 +47,20 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
6247 case connectRequested
6348 }
6449
65- public func forwardRequestForProtocol( httpProtocol: HttpProtocol , context: ChannelHandlerContext , requestParts: [ HTTPClientRequestPart ] ) {
50+ enum HttpProtocol {
51+ case HTTP
52+ case HTTPS
53+ }
54+
55+ public func forwardRequestForProtocol( context: ChannelHandlerContext , requestParts: [ HTTPClientRequestPart ] ) {
6656 guard let port = self . targetPort else {
6757 fatalError ( " Port was not passed " )
6858 }
6959
60+ guard let httpProtocol = self . targetProtocol else {
61+ fatalError ( " Targer protocol was not passed " )
62+ }
63+
7064 if httpProtocol == . HTTPS {
7165 forwardRequestForHttps ( context: context, port: port, requestParts: requestParts)
7266 } else {
@@ -81,18 +75,18 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
8175 switch upgradeState {
8276 case . idle:
8377 let channelFuture = clientBootstrap. channelInitializer { channel in
84- let sslContext = try ! NIOSSLContext ( configuration: . makeClientConfiguration( ) )
85- let sslHandler = try ! NIOSSLClientHandler ( context: sslContext, serverHostname: self . targetHost!)
86-
87- return channel. pipeline. addHandler ( sslHandler) . flatMap {
88- channel. pipeline. addHandler ( HTTPRequestEncoder ( ) ) . flatMap {
89- channel. pipeline. addHandler ( ByteToMessageHandler ( HTTPResponseDecoder ( leftOverBytesStrategy: . forwardBytes) ) )
90- . flatMap {
91- channel. pipeline. addHandler ( ResponseHandler ( context: context, preUpgradedRequest: self . preUpgradedRequest, request: self . loggedRequest!) )
92- }
93- }
78+ let sslContext = try ! NIOSSLContext ( configuration: . makeClientConfiguration( ) )
79+ let sslHandler = try ! NIOSSLClientHandler ( context: sslContext, serverHostname: self . targetHost!)
80+
81+ return channel. pipeline. addHandler ( sslHandler) . flatMap {
82+ channel. pipeline. addHandler ( HTTPRequestEncoder ( ) ) . flatMap {
83+ channel. pipeline. addHandler ( ByteToMessageHandler ( HTTPResponseDecoder ( leftOverBytesStrategy: . forwardBytes) ) )
84+ . flatMap {
85+ channel. pipeline. addHandler ( ResponseHandler ( context: context, preUpgradedRequest: self . preUpgradedRequest, request: self . loggedRequest!) )
86+ }
9487 }
9588 }
89+ }
9690 . connect ( host: targetHost!, port: port)
9791
9892 channelFuture. whenSuccess { channel in
@@ -117,13 +111,13 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
117111 switch upgradeState {
118112 case . idle:
119113 let channelFuture = clientBootstrap. channelInitializer { channel in
120- channel. pipeline. addHandler ( HTTPRequestEncoder ( ) ) . flatMap {
121- channel. pipeline. addHandler ( ByteToMessageHandler ( HTTPResponseDecoder ( leftOverBytesStrategy: . forwardBytes) ) )
122- . flatMap {
123- return channel. pipeline. addHandler ( ResponseHandler ( context: context, preUpgradedRequest: self . preUpgradedRequest, request: self . loggedRequest!) )
124- }
125- }
114+ channel. pipeline. addHandler ( HTTPRequestEncoder ( ) ) . flatMap {
115+ channel. pipeline. addHandler ( ByteToMessageHandler ( HTTPResponseDecoder ( leftOverBytesStrategy: . forwardBytes) ) )
116+ . flatMap {
117+ return channel. pipeline. addHandler ( ResponseHandler ( context: context, preUpgradedRequest: self . preUpgradedRequest, request: self . loggedRequest!) )
118+ }
126119 }
120+ }
127121 . connect ( host: targetHost!, port: port)
128122
129123 channelFuture. whenSuccess { channel in
@@ -206,7 +200,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
206200 }
207201 }
208202
209- func handleRequest( context: ChannelHandlerContext , reqPart: HTTPServerRequestPart , httpProtocol: HttpProtocol , port : Int ) {
203+ func handleRequest( context: ChannelHandlerContext , reqPart: HTTPServerRequestPart , httpProtocol: HttpProtocol ) {
210204 switch ( reqPart) {
211205 case . head( var head) :
212206 // Remove the Accept-Encoding header
@@ -223,7 +217,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
223217 requestParts. append ( clientReqPart)
224218 case . end( let headers) :
225219 requestParts. append ( HTTPClientRequestPart . end ( headers) )
226-
220+
227221 if AppState . shared. isInterceptEnabled {
228222 waitingContext = context
229223
@@ -238,20 +232,20 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
238232 }
239233 }
240234 } else {
241- forwardRequestForProtocol ( httpProtocol : httpProtocol , context: context, requestParts: requestParts)
235+ forwardRequestForProtocol ( context: context, requestParts: requestParts)
242236 }
243237 }
244238 }
245239
246240 func channelReadForHttps( context: ChannelHandlerContext , data: NIOAny ) {
247241 let reqPart = self . unwrapInboundIn ( data)
248- handleRequest ( context: context, reqPart: reqPart, httpProtocol: . HTTPS, port : 443 )
242+ handleRequest ( context: context, reqPart: reqPart, httpProtocol: . HTTPS)
249243 }
250244
251245 func channelRead( context: ChannelHandlerContext , data: NIOAny ) {
252246 let reqPart = self . unwrapInboundIn ( data)
253247 print ( " Invoked channel read " )
254-
248+
255249 switch ( upgradeState) {
256250 case . connectRequested:
257251 print ( " Connect return " )
@@ -279,38 +273,40 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
279273 }
280274 self . targetHost = newHost
281275 self . targetPort = originalURI. port ?? 80
276+ self . targetProtocol = . HTTP
282277
283278 head. headers. replaceOrAdd ( name: " Host " , value: newHost)
284279 head. uri = originalURI. relativePath
285280
286- handleRequest ( context: context, reqPart: reqPart, httpProtocol: . HTTP, port : 80 )
281+ handleRequest ( context: context, reqPart: reqPart, httpProtocol: . HTTP)
287282 }
288283 default :
289- handleRequest ( context: context, reqPart: reqPart, httpProtocol: . HTTP, port : 80 )
284+ handleRequest ( context: context, reqPart: reqPart, httpProtocol: . HTTP)
290285 }
291286 }
292287
293288 private func handleConnectRequest( context: ChannelHandlerContext , head: inout HTTPRequestHead ) {
294289 let uriComponents = head. uri. split ( separator: " : " , maxSplits: 1 , omittingEmptySubsequences: false )
295290 self . targetHost = String ( uriComponents. first!)
296291 self . targetPort = Int ( uriComponents. last!)
292+ self . targetProtocol = . HTTPS
297293
298294 guard let targetHost = self . targetHost,
299295 let _ = self . targetPort else {
300296 sendHttpResponse ( ctx: context, status: . badRequest)
301297 context. close ( promise: nil )
302298 return
303299 }
304-
300+
305301 let selfSignedCertAndKey = CertificateManager . shared. certificateForDomain ( String ( targetHost) )
306302 let selfSignedRootCa = try ! CertificateManager . shared. loadRootCAFromKeychain ( )
307303
308304 var serializer = DER . Serializer ( )
309305 try ! serializer. serialize ( selfSignedCertAndKey!. certificate)
310-
306+
311307 var selfSignedRootCaSerializer = DER . Serializer ( )
312308 try ! selfSignedRootCaSerializer. serialize ( selfSignedRootCa. rootCertificate)
313-
309+
314310
315311 let certificate = try ! NIOSSLCertificate ( bytes: serializer. serializedBytes, format: . der)
316312 let privateKey = try ! NIOSSLPrivateKey ( bytes: [ UInt8] ( selfSignedCertAndKey!. privateKey. derRepresentation) , format: . der)
@@ -319,13 +315,13 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
319315 let rootCert = NIOSSLCertificateSource . certificate ( rootCertificate)
320316 let tlsConfiguration = TLSConfiguration . makeServerConfiguration ( certificateChain: [ serverCert, rootCert] , privateKey: . privateKey( privateKey) )
321317 let sslContext = try ! NIOSSLContext ( configuration: tlsConfiguration)
322-
318+
323319 let sslHandler = NIOSSLServerHandler ( context: sslContext)
324320
325321 let responseHead = HTTPResponseHead ( version: . http1_1, status: . ok)
326322 let responsePart = HTTPServerResponsePart . head ( responseHead)
327323 self . preUpgradedRequest = context
328-
324+
329325 context. writeAndFlush ( self . wrapOutboundOut ( responsePart) )
330326 . flatMap { _ in
331327 context. pipeline. removeHandler ( name: " HTTPResponseEncoder " )
@@ -386,8 +382,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
386382 let updatedRequestParts = parseRawRequest ( rawRequest: rawRequest)
387383
388384 // Forward the updated request
389- // FIXME: don't hardcode used protocol
390- forwardRequestForProtocol ( httpProtocol: . HTTP, context: waitingContext!, requestParts: updatedRequestParts)
385+ forwardRequestForProtocol ( context: waitingContext!, requestParts: updatedRequestParts)
391386 }
392387
393388 private func parseRawRequest( rawRequest: String ) -> [ HTTPClientRequestPart ] {
0 commit comments