11
11
//
12
12
13
13
import NIOCore
14
+ import NIOFoundationCompat
14
15
import NIOWebSocket
15
16
16
17
final class WebSocketHandler : ChannelInboundHandler {
17
18
typealias InboundIn = WebSocketFrame
18
19
19
20
private let websocket : WebSocket
20
- init ( websocket: WebSocket ) {
21
+ private var firstFrame : WebSocketFrame ?
22
+ private var bufferedFrameData : ByteBuffer
23
+ private var totalBufferedFrameCount : Int
24
+ private let configuration : LCLWebSocket . Configuration
25
+ init ( websocket: WebSocket , configuration: LCLWebSocket . Configuration ) {
21
26
self . websocket = websocket
27
+ self . firstFrame = nil
28
+ self . bufferedFrameData = ByteBuffer ( )
29
+ self . totalBufferedFrameCount = 0
30
+ self . configuration = configuration
22
31
}
23
32
24
33
#if DEBUG
@@ -32,8 +41,61 @@ final class WebSocketHandler: ChannelInboundHandler {
32
41
#endif // DEBUG
33
42
34
43
func channelRead( context: ChannelHandlerContext , data: NIOAny ) {
35
- let frame = self . unwrapInboundIn ( data)
36
- self . websocket. handleFrame ( frame)
44
+ var frame = self . unwrapInboundIn ( data)
45
+ if let maskKey = frame. maskKey {
46
+ frame. data. webSocketUnmask ( maskKey)
47
+ }
48
+
49
+ do {
50
+ switch frame. opcode {
51
+ case . continuation:
52
+ guard let firstFrame = self . firstFrame else {
53
+ // close channel due to policy violation
54
+ throw LCLWebSocketError . receivedContinuationFrameWithoutPreviousFragmentFrame
55
+ }
56
+
57
+ // buffer the frame
58
+ try self . bufferFrame ( frame)
59
+
60
+ guard frame. fin else {
61
+ // continuation frame
62
+ break
63
+ }
64
+
65
+ // final frame is received
66
+ // combine frame
67
+ // clear buffer
68
+ let combinedFrame = self . combineFrames ( firstFrame: firstFrame, allocator: context. channel. allocator)
69
+ try validateUFT8Encoding ( of: combinedFrame. data)
70
+ self . websocket. handleFrame ( combinedFrame)
71
+ self . clearBufferedFrames ( )
72
+
73
+ case . binary, . text:
74
+ if frame. fin {
75
+ // unfragmented frame
76
+ guard self . firstFrame == nil else {
77
+ // close channel due to policy violatioin
78
+ throw LCLWebSocketError . receivedNewFrameWithoutFinishingPreviousOne
79
+ }
80
+ try validateUFT8Encoding ( of: frame. data)
81
+
82
+ self . websocket. handleFrame ( frame)
83
+ } else {
84
+ // fragmented frame
85
+ try self . bufferFrame ( frame)
86
+ return
87
+ }
88
+ default :
89
+ self . websocket. handleFrame ( frame)
90
+ }
91
+ } catch LCLWebSocketError . invalidUTF8String , is ByteBuffer . ReadUTF8ValidationError {
92
+ self . websocket. close ( code: . dataInconsistentWithMessage, promise: nil )
93
+ context. close ( mode: . all, promise: nil )
94
+ } catch {
95
+ let reason = ( error as? LCLWebSocketError ) ? . description
96
+ self . websocket. close ( code: . protocolError, reason: reason, promise: nil )
97
+ context. close ( mode: . all, promise: nil )
98
+ }
37
99
}
38
100
39
101
func errorCaught( context: ChannelHandlerContext , error: any Error ) {
@@ -46,6 +108,70 @@ final class WebSocketHandler: ChannelInboundHandler {
46
108
}
47
109
context. close ( mode: . all, promise: nil )
48
110
}
111
+
112
+ private func bufferFrame( _ frame: WebSocketFrame ) throws {
113
+ guard self . firstFrame == nil || frame. opcode == . continuation else {
114
+ throw LCLWebSocketError . receivedNewFrameWithoutFinishingPreviousOne
115
+ }
116
+
117
+ guard frame. fin || frame. length >= self . configuration. minNonFinalFragmentSize else {
118
+ throw LCLWebSocketError . nonFinalFragmentSizeIsTooSmall
119
+ }
120
+
121
+ guard self . totalBufferedFrameCount < self . configuration. maxFrameSize else {
122
+ throw LCLWebSocketError . tooManyFrameFragments
123
+ }
124
+
125
+ guard frame. fin || ( self . totalBufferedFrameCount + 1 ) < self . configuration. maxAccumulatedFrameCount else {
126
+ throw LCLWebSocketError . tooManyFrameFragments
127
+ }
128
+
129
+ if self . firstFrame == nil {
130
+ self . firstFrame = frame
131
+ }
132
+ self . totalBufferedFrameCount += 1
133
+ var frame = frame
134
+ self . bufferedFrameData. writeBuffer ( & frame. data)
135
+
136
+ guard self . bufferedFrameData. readableBytes <= self . configuration. maxAccumulatedFrameSize else {
137
+ throw LCLWebSocketError . accumulatedFrameSizeIsTooLarge
138
+ }
139
+ }
140
+
141
+ private func validateUFT8Encoding( of data: ByteBuffer ) throws {
142
+ if data. readableBytes == 0 {
143
+ return
144
+ }
145
+
146
+ if #available( macOS 15 , iOS 18 , tvOS 18 , watchOS 11 , * ) {
147
+ _ = try self . bufferedFrameData. getUTF8ValidatedString ( at: data. readableBytes, length: data. readableBytes)
148
+ } else {
149
+ guard let bytes = self . bufferedFrameData. getData ( at: data. readerIndex, length: data. readableBytes) ,
150
+ String ( data: bytes, encoding: . utf8) != nil
151
+ else {
152
+ throw LCLWebSocketError . invalidUTF8String
153
+ }
154
+ }
155
+ }
156
+
157
+ private func combineFrames( firstFrame: WebSocketFrame , allocator: ByteBufferAllocator ) -> WebSocketFrame {
158
+ WebSocketFrame (
159
+ fin: firstFrame. fin,
160
+ rsv1: firstFrame. rsv1,
161
+ rsv2: firstFrame. rsv2,
162
+ rsv3: firstFrame. rsv3,
163
+ opcode: firstFrame. opcode,
164
+ maskKey: firstFrame. maskKey,
165
+ data: self . bufferedFrameData,
166
+ extensionData: firstFrame. extensionData
167
+ )
168
+ }
169
+
170
+ private func clearBufferedFrames( ) {
171
+ self . firstFrame = nil
172
+ self . bufferedFrameData. clear ( )
173
+ self . totalBufferedFrameCount = 0
174
+ }
49
175
}
50
176
51
177
extension WebSocketErrorCode {
0 commit comments