@@ -17,54 +17,110 @@ import NIOCore
1717import NIOHTTP1
1818
1919/// Handler that manages the CORS protocol for requests incoming from the browser.
20- internal class WebCORSHandler {
21- var requestMethod : HTTPMethod ?
20+ internal final class WebCORSHandler {
21+ let configuration : Server . Configuration . CORS
22+
23+ private var state : State = . idle
24+ private enum State : Equatable {
25+ /// Starting state.
26+ case idle
27+ /// CORS preflight request is in progress.
28+ case processingPreflightRequest
29+ /// "Real" request is in progress.
30+ case processingRequest( origin: String ? )
31+ }
32+
33+ init ( configuration: Server . Configuration . CORS ) {
34+ self . configuration = configuration
35+ }
2236}
2337
2438extension WebCORSHandler : ChannelInboundHandler {
2539 typealias InboundIn = HTTPServerRequestPart
40+ typealias InboundOut = HTTPServerRequestPart
2641 typealias OutboundOut = HTTPServerResponsePart
2742
2843 func channelRead( context: ChannelHandlerContext , data: NIOAny ) {
29- // If the request is OPTIONS, the request is not propagated further.
3044 switch self . unwrapInboundIn ( data) {
31- case let . head( requestHead) :
32- self . requestMethod = requestHead. method
33- if self . requestMethod == . OPTIONS {
34- var headers = HTTPHeaders ( )
35- headers. add ( name: " Access-Control-Allow-Origin " , value: " * " )
36- headers. add ( name: " Access-Control-Allow-Methods " , value: " POST " )
37- headers. add (
38- name: " Access-Control-Allow-Headers " ,
39- value: " content-type,x-grpc-web,x-user-agent "
40- )
41- headers. add ( name: " Access-Control-Max-Age " , value: " 86400 " )
42- context. write (
43- self . wrapOutboundOut ( . head( HTTPResponseHead (
44- version: requestHead. version,
45- status: . ok,
46- headers: headers
47- ) ) ) ,
48- promise: nil
49- )
50- return
45+ case let . head( head) :
46+ self . receivedRequestHead ( context: context, head)
47+
48+ case let . body( body) :
49+ self . receivedRequestBody ( context: context, body)
50+
51+ case let . end( trailers) :
52+ self . receivedRequestEnd ( context: context, trailers)
53+ }
54+ }
55+
56+ private func receivedRequestHead( context: ChannelHandlerContext , _ head: HTTPRequestHead ) {
57+ if head. method == . OPTIONS,
58+ head. headers. contains ( . accessControlRequestMethod) ,
59+ let origin = head. headers. first ( name: " origin " ) {
60+ // If the request is OPTIONS with a access-control-request-method header it's a CORS
61+ // preflight request and is not propagated further.
62+ self . state = . processingPreflightRequest
63+ self . handlePreflightRequest ( context: context, head: head, origin: origin)
64+ } else {
65+ self . state = . processingRequest( origin: head. headers. first ( name: " origin " ) )
66+ context. fireChannelRead ( self . wrapInboundOut ( . head( head) ) )
67+ }
68+ }
69+
70+ private func receivedRequestBody( context: ChannelHandlerContext , _ body: ByteBuffer ) {
71+ // OPTIONS requests do not have a body, but still handle this case to be
72+ // cautious.
73+ if self . state == . processingPreflightRequest {
74+ return
75+ }
76+
77+ context. fireChannelRead ( self . wrapInboundOut ( . body( body) ) )
78+ }
79+
80+ private func receivedRequestEnd( context: ChannelHandlerContext , _ trailers: HTTPHeaders ? ) {
81+ if self . state == . processingPreflightRequest {
82+ // End of OPTIONS request; reset state and finish the response.
83+ self . state = . idle
84+ context. writeAndFlush ( self . wrapOutboundOut ( . end( nil ) ) , promise: nil )
85+ } else {
86+ context. fireChannelRead ( self . wrapInboundOut ( . end( trailers) ) )
87+ }
88+ }
89+
90+ private func handlePreflightRequest(
91+ context: ChannelHandlerContext ,
92+ head: HTTPRequestHead ,
93+ origin: String
94+ ) {
95+ let responseHead : HTTPResponseHead
96+
97+ if let allowedOrigin = self . configuration. allowedOrigins. header ( origin) {
98+ var headers = HTTPHeaders ( )
99+ headers. reserveCapacity ( 4 + self . configuration. allowedHeaders. count)
100+ headers. add ( name: . accessControlAllowOrigin, value: allowedOrigin)
101+ headers. add ( name: . accessControlAllowMethods, value: " POST " )
102+
103+ for value in self . configuration. allowedHeaders {
104+ headers. add ( name: . accessControlAllowHeaders, value: value)
51105 }
52- case . body:
53- if self . requestMethod == . OPTIONS {
54- // OPTIONS requests do not have a body, but still handle this case to be
55- // cautious.
56- return
106+
107+ if self . configuration. allowCredentialedRequests {
108+ headers. add ( name: . accessControlAllowCredentials, value: " true " )
57109 }
58110
59- case . end :
60- if self . requestMethod == . OPTIONS {
61- context . writeAndFlush ( self . wrapOutboundOut ( . end ( nil ) ) , promise : nil )
62- self . requestMethod = nil
63- return
111+ if self . configuration . preflightCacheExpiration > 0 {
112+ headers . add (
113+ name : . accessControlMaxAge ,
114+ value : " \( self . configuration . preflightCacheExpiration ) "
115+ )
64116 }
117+ responseHead = HTTPResponseHead ( version: head. version, status: . ok, headers: headers)
118+ } else {
119+ // Not allowed; respond with 403. This is okay in a pre-flight request.
120+ responseHead = HTTPResponseHead ( version: head. version, status: . forbidden)
65121 }
66- // The OPTIONS request should be fully handled at this point.
67- context. fireChannelRead ( data )
122+
123+ context. write ( self . wrapOutboundOut ( . head ( responseHead ) ) , promise : nil )
68124 }
69125}
70126
@@ -74,25 +130,76 @@ extension WebCORSHandler: ChannelOutboundHandler {
74130 func write( context: ChannelHandlerContext , data: NIOAny , promise: EventLoopPromise < Void > ? ) {
75131 let responsePart = self . unwrapOutboundIn ( data)
76132 switch responsePart {
77- case let . head( responseHead) :
78- var headers = responseHead. headers
79- // CORS requires all requests to have an Allow-Origin header.
80- headers. add ( name: " Access-Control-Allow-Origin " , value: " * " )
81- //! FIXME: Check whether we can let browsers keep connections alive. It's not possible
82- // now as the channel has a state that can't be reused since the pipeline is modified to
83- // inject the gRPC call handler.
84- headers. add ( name: " Connection " , value: " close " )
85-
86- context. write (
87- self . wrapOutboundOut ( . head( HTTPResponseHead (
88- version: responseHead. version,
89- status: responseHead. status,
90- headers: headers
91- ) ) ) ,
92- promise: promise
93- )
94- default :
133+ case var . head( responseHead) :
134+ switch self . state {
135+ case let . processingRequest( origin) :
136+ self . prepareCORSResponseHead ( & responseHead, origin: origin)
137+ context. write ( self . wrapOutboundOut ( . head( responseHead) ) , promise: promise)
138+
139+ case . idle, . processingPreflightRequest:
140+ assertionFailure ( " Writing response head when no request is in progress " )
141+ context. close ( promise: nil )
142+ }
143+
144+ case . body:
145+ context. write ( data, promise: promise)
146+
147+ case . end:
148+ self . state = . idle
95149 context. write ( data, promise: promise)
96150 }
97151 }
152+
153+ private func prepareCORSResponseHead( _ head: inout HTTPResponseHead , origin: String ? ) {
154+ guard let header = origin. flatMap ( { self . configuration. allowedOrigins. header ( $0) } ) else {
155+ // No origin or the origin is not allowed; don't treat it as a CORS request.
156+ return
157+ }
158+
159+ head. headers. replaceOrAdd ( name: . accessControlAllowOrigin, value: header)
160+
161+ if self . configuration. allowCredentialedRequests {
162+ head. headers. add ( name: . accessControlAllowCredentials, value: " true " )
163+ }
164+
165+ //! FIXME: Check whether we can let browsers keep connections alive. It's not possible
166+ // now as the channel has a state that can't be reused since the pipeline is modified to
167+ // inject the gRPC call handler.
168+ head. headers. replaceOrAdd ( name: " Connection " , value: " close " )
169+ }
170+ }
171+
172+ extension HTTPHeaders {
173+ fileprivate enum CORSHeader : String {
174+ case accessControlRequestMethod = " access-control-request-method "
175+ case accessControlRequestHeaders = " access-control-request-headers "
176+ case accessControlAllowOrigin = " access-control-allow-origin "
177+ case accessControlAllowMethods = " access-control-allow-methods "
178+ case accessControlAllowHeaders = " access-control-allow-headers "
179+ case accessControlAllowCredentials = " access-control-allow-credentials "
180+ case accessControlMaxAge = " access-control-max-age "
181+ }
182+
183+ fileprivate func contains( _ name: CORSHeader ) -> Bool {
184+ return self . contains ( name: name. rawValue)
185+ }
186+
187+ fileprivate mutating func add( name: CORSHeader , value: String ) {
188+ self . add ( name: name. rawValue, value: value)
189+ }
190+
191+ fileprivate mutating func replaceOrAdd( name: CORSHeader , value: String ) {
192+ self . replaceOrAdd ( name: name. rawValue, value: value)
193+ }
194+ }
195+
196+ extension Server . Configuration . CORS . AllowedOrigins {
197+ internal func header( _ origin: String ) -> String ? {
198+ switch self . wrapped {
199+ case . all:
200+ return " * "
201+ case let . only( allowed) :
202+ return allowed. contains ( origin) ? origin : nil
203+ }
204+ }
98205}
0 commit comments