Skip to content

Commit c5ada48

Browse files
authored
Fix fragmented frame handling in websocket (#3)
1 parent 484e9d7 commit c5ada48

File tree

5 files changed

+172
-39
lines changed

5 files changed

+172
-39
lines changed

Sources/LCLWebSocket/Client/WebSocketClient.swift

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -203,14 +203,9 @@ public struct WebSocketClient: Sendable, LCLWebSocketListenable {
203203
)
204204
)
205205

206-
try channel.pipeline.syncOperations.addHandlers([
207-
NIOWebSocketFrameAggregator(
208-
minNonFinalFragmentSize: configuration.minNonFinalFragmentSize,
209-
maxAccumulatedFrameCount: configuration.maxAccumulatedFrameCount,
210-
maxAccumulatedFrameSize: configuration.maxAccumulatedFrameSize
211-
),
212-
WebSocketHandler(websocket: websocket),
213-
])
206+
try channel.pipeline.syncOperations.addHandlers(
207+
WebSocketHandler(websocket: websocket, configuration: configuration)
208+
)
214209
self._onOpen?(websocket)
215210
return channel.eventLoop.makeSucceededVoidFuture()
216211
} catch {
@@ -547,14 +542,9 @@ extension WebSocketClient {
547542
)
548543
)
549544

550-
try channel.pipeline.syncOperations.addHandlers([
551-
NIOWebSocketFrameAggregator(
552-
minNonFinalFragmentSize: configuration.minNonFinalFragmentSize,
553-
maxAccumulatedFrameCount: configuration.maxAccumulatedFrameCount,
554-
maxAccumulatedFrameSize: configuration.maxAccumulatedFrameSize
555-
),
556-
WebSocketHandler(websocket: websocket),
557-
])
545+
try channel.pipeline.syncOperations.addHandler(
546+
WebSocketHandler(websocket: websocket, configuration: configuration)
547+
)
558548
self._onOpen?(websocket)
559549
} catch {
560550
return channel.eventLoop.makeFailedFuture(error)

Sources/LCLWebSocket/LCLWebSocket+Error.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,24 @@ public enum LCLWebSocketError: Error {
4848

4949
/// HTTP method is not allowed during the upgrade request.
5050
case methodNotAllowed
51+
52+
/// Received a new fragment frame without finishing the previous fragment sequence.
53+
case receivedNewFrameWithoutFinishingPreviousOne
54+
55+
/// The size of the non-final fragment is too small.
56+
case nonFinalFragmentSizeIsTooSmall
57+
58+
/// There are too many fragment frames.
59+
case tooManyFrameFragments
60+
61+
/// The buffered frame sizes is too large.
62+
case accumulatedFrameSizeIsTooLarge
63+
64+
/// Received a continuation frame without a previous fragment frame.
65+
case receivedContinuationFrameWithoutPreviousFragmentFrame
66+
67+
/// Invalid UTF-8 string.
68+
case invalidUTF8String
5169
}
5270

5371
extension LCLWebSocketError: CustomStringConvertible {
@@ -75,6 +93,18 @@ extension LCLWebSocketError: CustomStringConvertible {
7593
return "Unknown opcode \(code)"
7694
case .methodNotAllowed:
7795
return "HTTP Method not allowed"
96+
case .receivedNewFrameWithoutFinishingPreviousOne:
97+
return "Received new frame without finishing previous one"
98+
case .nonFinalFragmentSizeIsTooSmall:
99+
return "Non-final fragment size is too small"
100+
case .tooManyFrameFragments:
101+
return "Too many frame fragments"
102+
case .accumulatedFrameSizeIsTooLarge:
103+
return "Accumulated frame size is too large"
104+
case .receivedContinuationFrameWithoutPreviousFragmentFrame:
105+
return "Received continuation frame without previous fragment frame"
106+
case .invalidUTF8String:
107+
return "Invalid UTF-8 string"
78108
}
79109
}
80110
}

Sources/LCLWebSocket/Server/WebSocketServer.swift

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,9 @@ public struct WebSocketServer: Sendable, LCLWebSocketListenable {
194194
high: configuration.writeBufferWaterMarkHigh
195195
)
196196
)
197-
try channel.pipeline.syncOperations.addHandlers([
198-
NIOWebSocketFrameAggregator(
199-
minNonFinalFragmentSize: configuration.minNonFinalFragmentSize,
200-
maxAccumulatedFrameCount: configuration.maxAccumulatedFrameCount,
201-
maxAccumulatedFrameSize: configuration.maxAccumulatedFrameSize
202-
),
203-
WebSocketHandler(websocket: websocket),
204-
])
197+
try channel.pipeline.syncOperations.addHandler(
198+
WebSocketHandler(websocket: websocket, configuration: configuration)
199+
)
205200
self._onOpen?(websocket)
206201
return channel.eventLoop.makeSucceededVoidFuture()
207202
} catch {
@@ -408,14 +403,9 @@ extension WebSocketServer {
408403
high: configuration.writeBufferWaterMarkHigh
409404
)
410405
)
411-
try channel.pipeline.syncOperations.addHandlers([
412-
NIOWebSocketFrameAggregator(
413-
minNonFinalFragmentSize: configuration.minNonFinalFragmentSize,
414-
maxAccumulatedFrameCount: configuration.maxAccumulatedFrameCount,
415-
maxAccumulatedFrameSize: configuration.maxAccumulatedFrameSize
416-
),
417-
WebSocketHandler(websocket: websocket),
418-
])
406+
try channel.pipeline.syncOperations.addHandler(
407+
WebSocketHandler(websocket: websocket, configuration: configuration)
408+
)
419409
return channel.eventLoop.makeSucceededFuture(UpgradeResult.websocket)
420410
} catch {
421411
return channel.eventLoop.makeFailedFuture(error)

Sources/LCLWebSocket/WebSocket.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,6 @@ public final class WebSocket: Sendable {
294294
}
295295

296296
var data = frame.data
297-
if let maskKey = frame.maskKey {
298-
data.webSocketUnmask(maskKey)
299-
}
300297
let originalDataReaderIdx = data.readerIndex
301298

302299
switch frame.opcode {
@@ -383,7 +380,7 @@ public final class WebSocket: Sendable {
383380
preconditionFailure("WebSocket connection is not established.")
384381
}
385382
case .continuation:
386-
preconditionFailure("continuation frame is filtered by swiftnio")
383+
preconditionFailure("continuation frame is filtered by WebSocketHandler")
387384
case .ping:
388385
if frame.fin {
389386
self._onPing.value?(self, data)

Sources/LCLWebSocket/WebSocketHandler.swift

Lines changed: 129 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,23 @@
1111
//
1212

1313
import NIOCore
14+
import NIOFoundationCompat
1415
import NIOWebSocket
1516

1617
final class WebSocketHandler: ChannelInboundHandler {
1718
typealias InboundIn = WebSocketFrame
1819

1920
private let websocket: WebSocket
20-
init(websocket: WebSocket) {
21+
private var firstFrame: WebSocketFrame?
22+
private var bufferedFrameData: ByteBuffer
23+
private var totalBufferedFrameCount: Int
24+
private let configuration: LCLWebSocket.Configuration
25+
init(websocket: WebSocket, configuration: LCLWebSocket.Configuration) {
2126
self.websocket = websocket
27+
self.firstFrame = nil
28+
self.bufferedFrameData = ByteBuffer()
29+
self.totalBufferedFrameCount = 0
30+
self.configuration = configuration
2231
}
2332

2433
#if DEBUG
@@ -32,8 +41,61 @@ final class WebSocketHandler: ChannelInboundHandler {
3241
#endif // DEBUG
3342

3443
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
35-
let frame = self.unwrapInboundIn(data)
36-
self.websocket.handleFrame(frame)
44+
var frame = self.unwrapInboundIn(data)
45+
if let maskKey = frame.maskKey {
46+
frame.data.webSocketUnmask(maskKey)
47+
}
48+
49+
do {
50+
switch frame.opcode {
51+
case .continuation:
52+
guard let firstFrame = self.firstFrame else {
53+
// close channel due to policy violation
54+
throw LCLWebSocketError.receivedContinuationFrameWithoutPreviousFragmentFrame
55+
}
56+
57+
// buffer the frame
58+
try self.bufferFrame(frame)
59+
60+
guard frame.fin else {
61+
// continuation frame
62+
break
63+
}
64+
65+
// final frame is received
66+
// combine frame
67+
// clear buffer
68+
let combinedFrame = self.combineFrames(firstFrame: firstFrame, allocator: context.channel.allocator)
69+
try validateUFT8Encoding(of: combinedFrame.data)
70+
self.websocket.handleFrame(combinedFrame)
71+
self.clearBufferedFrames()
72+
73+
case .binary, .text:
74+
if frame.fin {
75+
// unfragmented frame
76+
guard self.firstFrame == nil else {
77+
// close channel due to policy violatioin
78+
throw LCLWebSocketError.receivedNewFrameWithoutFinishingPreviousOne
79+
}
80+
try validateUFT8Encoding(of: frame.data)
81+
82+
self.websocket.handleFrame(frame)
83+
} else {
84+
// fragmented frame
85+
try self.bufferFrame(frame)
86+
return
87+
}
88+
default:
89+
self.websocket.handleFrame(frame)
90+
}
91+
} catch LCLWebSocketError.invalidUTF8String, is ByteBuffer.ReadUTF8ValidationError {
92+
self.websocket.close(code: .dataInconsistentWithMessage, promise: nil)
93+
context.close(mode: .all, promise: nil)
94+
} catch {
95+
let reason = (error as? LCLWebSocketError)?.description
96+
self.websocket.close(code: .protocolError, reason: reason, promise: nil)
97+
context.close(mode: .all, promise: nil)
98+
}
3799
}
38100

39101
func errorCaught(context: ChannelHandlerContext, error: any Error) {
@@ -46,6 +108,70 @@ final class WebSocketHandler: ChannelInboundHandler {
46108
}
47109
context.close(mode: .all, promise: nil)
48110
}
111+
112+
private func bufferFrame(_ frame: WebSocketFrame) throws {
113+
guard self.firstFrame == nil || frame.opcode == .continuation else {
114+
throw LCLWebSocketError.receivedNewFrameWithoutFinishingPreviousOne
115+
}
116+
117+
guard frame.fin || frame.length >= self.configuration.minNonFinalFragmentSize else {
118+
throw LCLWebSocketError.nonFinalFragmentSizeIsTooSmall
119+
}
120+
121+
guard self.totalBufferedFrameCount < self.configuration.maxFrameSize else {
122+
throw LCLWebSocketError.tooManyFrameFragments
123+
}
124+
125+
guard frame.fin || (self.totalBufferedFrameCount + 1) < self.configuration.maxAccumulatedFrameCount else {
126+
throw LCLWebSocketError.tooManyFrameFragments
127+
}
128+
129+
if self.firstFrame == nil {
130+
self.firstFrame = frame
131+
}
132+
self.totalBufferedFrameCount += 1
133+
var frame = frame
134+
self.bufferedFrameData.writeBuffer(&frame.data)
135+
136+
guard self.bufferedFrameData.readableBytes <= self.configuration.maxAccumulatedFrameSize else {
137+
throw LCLWebSocketError.accumulatedFrameSizeIsTooLarge
138+
}
139+
}
140+
141+
private func validateUFT8Encoding(of data: ByteBuffer) throws {
142+
if data.readableBytes == 0 {
143+
return
144+
}
145+
146+
if #available(macOS 15, iOS 18, tvOS 18, watchOS 11, *) {
147+
_ = try self.bufferedFrameData.getUTF8ValidatedString(at: data.readableBytes, length: data.readableBytes)
148+
} else {
149+
guard let bytes = self.bufferedFrameData.getData(at: data.readerIndex, length: data.readableBytes),
150+
String(data: bytes, encoding: .utf8) != nil
151+
else {
152+
throw LCLWebSocketError.invalidUTF8String
153+
}
154+
}
155+
}
156+
157+
private func combineFrames(firstFrame: WebSocketFrame, allocator: ByteBufferAllocator) -> WebSocketFrame {
158+
WebSocketFrame(
159+
fin: firstFrame.fin,
160+
rsv1: firstFrame.rsv1,
161+
rsv2: firstFrame.rsv2,
162+
rsv3: firstFrame.rsv3,
163+
opcode: firstFrame.opcode,
164+
maskKey: firstFrame.maskKey,
165+
data: self.bufferedFrameData,
166+
extensionData: firstFrame.extensionData
167+
)
168+
}
169+
170+
private func clearBufferedFrames() {
171+
self.firstFrame = nil
172+
self.bufferedFrameData.clear()
173+
self.totalBufferedFrameCount = 0
174+
}
49175
}
50176

51177
extension WebSocketErrorCode {

0 commit comments

Comments
 (0)