From ae654261b86fc89bf9acaade833e52e1e0eab238 Mon Sep 17 00:00:00 2001 From: Alessandro Genova Date: Fri, 22 Mar 2024 13:25:16 -0400 Subject: [PATCH 1/4] feat(msgpack): use msgpack to serialize/deserialize messages BREAKING CHANGE: replace json serialization with msgpack --- js/package-lock.json | 11 +- js/package.json | 2 +- js/src/WebsocketConnection/session.js | 194 ++++++++------------------ python/requirements.txt | 1 + python/src/wslink/protocol.py | 175 ++++++----------------- python/src/wslink/publish.py | 39 +----- 6 files changed, 124 insertions(+), 298 deletions(-) diff --git a/js/package-lock.json b/js/package-lock.json index 49dd3665..2793014a 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -9,7 +9,7 @@ "version": "0.0.0-semantically-release", "license": "BSD-3-Clause", "dependencies": { - "json5": "2.2.3" + "@msgpack/msgpack": "^2.8.0" }, "devDependencies": { "@babel/core": "7.20.12", @@ -2008,6 +2008,14 @@ "@jridgewell/sourcemap-codec": "1.4.14" } }, + "node_modules/@msgpack/msgpack": { + "version": "2.8.0", + "resolved": "https://registry.npmjs.org/@msgpack/msgpack/-/msgpack-2.8.0.tgz", + "integrity": "sha512-h9u4u/jiIRKbq25PM+zymTyW6bhTzELvOoUd+AvYriWOAKpLGnIamaET3pnHYoI5iYphAHBI4ayx0MehR+VVPQ==", + "engines": { + "node": ">= 10" + } + }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", @@ -6291,6 +6299,7 @@ "version": "2.2.3", "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", "integrity": "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==", + "dev": true, "bin": { "json5": "lib/cli.js" }, diff --git a/js/package.json b/js/package.json index eff7564c..fcc1595c 100644 --- a/js/package.json +++ b/js/package.json @@ -48,7 +48,7 @@ } }, "dependencies": { - "json5": "2.2.3" + "@msgpack/msgpack": "^2.8.0" }, "release": { "plugins": [ diff --git a/js/src/WebsocketConnection/session.js b/js/src/WebsocketConnection/session.js index 748e0328..c647fb07 100644 --- a/js/src/WebsocketConnection/session.js +++ b/js/src/WebsocketConnection/session.js @@ -1,6 +1,6 @@ // Helper borrowed from paraviewweb/src/Common/Core import CompositeClosureHelper from '../CompositeClosureHelper'; -import JSON5 from 'json5'; +import { Encoder, Decoder } from "@msgpack/msgpack"; function defer() { const deferred = {}; @@ -17,10 +17,6 @@ function Session(publicAPI, model) { const CLIENT_ERROR = -32099; let msgCount = 0; const inFlightRpc = {}; - const attachments = []; - const attachmentsToSend = {}; - let attachmentId = 1; - const regexAttach = /^wslink_bin[\d]+$/; // matches 'rpc:client3:21' // client may be dot-separated and include '_' // number is message count - unique. @@ -28,76 +24,20 @@ function Session(publicAPI, model) { const regexRPC = /^(rpc|publish|system):(\w+(?:\.\w+)*):(?:\d+)$/; const subscriptions = {}; let clientID = null; + const encoder = CustomEncoder(); + const decoder = Decoder(); // -------------------------------------------------------------------------- // Private helpers // -------------------------------------------------------------------------- - function sendBinary(key) { - if (key in attachmentsToSend) { - // binary header - model.ws.send( - JSON.stringify({ - wslink: '1.0', - method: 'wslink.binary.attachment', - args: [key], - }) - ); - - // send binary - model.ws.send(attachmentsToSend[key], { binary: true }); - delete attachmentsToSend[key]; - } - } - - // -------------------------------------------------------------------------- - - function findBinary(o) { - if (o) { - if (Array.isArray(o)) { - o.forEach((v) => findBinary(v)); - } else if (o.constructor === Object) { - Object.keys(o).forEach((k) => findBinary(o[k])); - } else if (regexAttach.test(o)) { - sendBinary(o); - } - } - } - - // -------------------------------------------------------------------------- - // split out to support a message with a bare binary attachment. - // -------------------------------------------------------------------------- - - function getAttachment(binaryKey) { - // console.log('Adding binary attachment', binaryKey); - const index = attachments.findIndex((att) => att.key === binaryKey); - if (index !== -1) { - const result = attachments[index].data; - // TODO if attachment is sent mulitple times, we shouldn't remove it yet. - attachments.splice(index, 1); - return result; - } - console.error('Binary attachment key found without matching attachment'); - return null; - } - - // -------------------------------------------------------------------------- - // To do a full traversal of nested objects/lists, we need recursion. - // -------------------------------------------------------------------------- - - function addAttachment(obj_list) { - for (let key in obj_list) { - if ( - typeof obj_list[key] === 'string' && - regexAttach.test(obj_list[key]) - ) { - const binaryKey = obj_list[key]; - const replacement = getAttachment(binaryKey); - if (replacement !== null) obj_list[key] = replacement; - } else if (typeof obj_list[key] === 'object') { - // arrays are also 'object' with this test. - addAttachment(obj_list[key]); - } + async function decodeFromBlob(blob) { + if (blob.stream) { + // Blob#stream(): ReadableStream (recommended) + return await decoder.decodeAsync(blob.stream()); + } else { + // Blob#arrayBuffer(): Promise (if stream() is not available) + return decoder.decode(await blob.arrayBuffer()); } } @@ -110,15 +50,18 @@ function Session(publicAPI, model) { const deferred = defer(); const id = 'system:c0:0'; inFlightRpc[id] = deferred; - model.ws.send( - JSON.stringify({ - wslink: '1.0', - id, - method: 'wslink.hello', - args: [{ secret: model.secret }], - kwargs: {}, - }) - ); + + const wrapper = { + wslink: '1.0', + id, + method: 'wslink.hello', + args: [{ secret: model.secret }], + kwargs: {}, + } + + const packedWrapper = encoder.encode(wrapper); + + model.ws.send(packedWrapper, { binary: true }); return deferred.promise; }; @@ -131,16 +74,11 @@ function Session(publicAPI, model) { if (model.ws && clientID && model.ws.readyState === 1) { const id = `rpc:${clientID}:${msgCount++}`; inFlightRpc[id] = deferred; - const msg = JSON.stringify({ wslink: '1.0', id, method, args, kwargs }); - if (Object.keys(attachmentsToSend).length) { - findBinary(args); - findBinary(kwargs); - } + const wrapper = { wslink: '1.0', id, method, args, kwargs }; + const packedWrapper = encoder.encode(wrapper); - model.ws.send( - JSON.stringify({ wslink: '1.0', id, method, args, kwargs }) - ); + model.ws.send(packedWrapper, { binary: true }); } else { deferred.reject({ code: CLIENT_ERROR, @@ -212,45 +150,18 @@ function Session(publicAPI, model) { // -------------------------------------------------------------------------- - publicAPI.onmessage = (event) => { - if (event.data instanceof ArrayBuffer || event.data instanceof Blob) { - // we've gotten a header with the keys for this binary data. - // we will soon receive a json message with embedded ids of the binary objects. - // Save with it's key, in order. - // console.log('Saving binary attachment'); - let foundIt = false; - for (let i = 0; i < attachments.length; i++) { - if (attachments[i].data === null) { - attachments[i].data = event.data; - foundIt = true; - break; - } - } - if (!foundIt) { - console.error('Missing header for received binary message'); - } - } else { + publicAPI.onmessage = async (event) => { + { let payload; try { - payload = JSON5.parse(event.data); + payload = await decodeFromBlob(event.data); } catch (e) { console.error('Malformed message: ', event.data); // debugger; } if (!payload) return; - if (!payload.id) { - // Notification-only message from the server - should be binary attachment header - // console.log('Notify', payload); - if (payload.method === 'wslink.binary.attachment') { - payload.args.forEach((key) => { - attachments.push({ key, data: null }); - }); - } - return; - } + if (!payload.id) return; if (payload.error) { - // kill any attachments - attachments.length = 0; const deferred = inFlightRpc[payload.id]; if (deferred) { deferred.reject(payload.error); @@ -258,17 +169,6 @@ function Session(publicAPI, model) { console.error('Server error:', payload.error); } } else { - if (payload.result && attachments.length > 0) { - if ( - typeof payload.result === 'string' && - regexAttach.test(payload.result) - ) { - const replacement = getAttachment(payload.result); - if (replacement !== null) payload.result = replacement; - } else { - addAttachment(payload.result); - } - } const match = regexRPC.exec(payload.id); if (match) { const type = match[1]; @@ -322,10 +222,18 @@ function Session(publicAPI, model) { // -------------------------------------------------------------------------- publicAPI.addAttachment = (payload) => { - const binaryId = `wslink_bin${attachmentId}`; - attachmentsToSend[binaryId] = payload; - attachmentId++; - return binaryId; + // Deprecated method, keeping it to avoid breaking compatibility + // Now that we use msgpack to pack/unpack messages, + // We can have binary data directly in the object itself, + // without needing to transfer it separately from the rest. + // + // If an ArrayBuffer is passed, ensure it gets wrapped in + // a DataView (which is what the encoder expects). + if (payload instanceof ArrayBuffer) { + return new DataView(payload); + } + + return payload; }; } @@ -350,3 +258,23 @@ export const newInstance = CompositeClosureHelper.newInstance(extend); // ---------------------------------------------------------------------------- export default { newInstance, extend }; + +class CustomEncoder extends Encoder { + // Unfortunately @msgpack/msgpack only supports + // views of an ArrayBuffer (DataView, Uint8Array,..), + // but not an ArrayBuffer itself. + // They suggest using custom type extensions to support it, + // but that would yield a different packed payload + // (1 byte larger, but most importantly it would require + // dealing with the custom type when unpacking on the server). + // Since this type is too trivial to be treated differently, + // and since I don't want to rely on the users always wrapping + // their ArrayBuffers in a view, I'm subclassing the encoder. + encodeObject(object, depth) { + if (object instanceof ArrayBuffer) { + object = new DataView(object); + } + + return super.encodeObject(object, depth); + } +} diff --git a/python/requirements.txt b/python/requirements.txt index 31f000eb..d3b73084 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,4 +1,5 @@ aiohttp>=3.7.4,<4 +msgpack>=1.0.8,<2 # platform specific pypiwin32==223; sys_platform == 'win32' diff --git a/python/src/wslink/protocol.py b/python/src/wslink/protocol.py index e2642d95..fdd50ffa 100644 --- a/python/src/wslink/protocol.py +++ b/python/src/wslink/protocol.py @@ -3,6 +3,7 @@ import inspect import json import logging +import msgpack import re import traceback @@ -235,34 +236,9 @@ async def handleSystemMessage(self, rpcid, methodName, args, client_id): return False async def onMessage(self, is_binary, msg, client_id): - payload = msg.data - - if is_binary: - if self.isClientAuthenticated(client_id): - # assume all binary messages are attachments - try: - key = self.attachmentsRecvQueue.pop(0) - self.attachmentsReceived[key] = payload - except: - pass - return - - # handles issue https://bugs.python.org/issue10976 - # `payload` is type bytes in Python 3. Unfortunately, json.loads - # doesn't support taking bytes until Python 3.6. - if type(payload) is bytes: - payload = payload.decode("utf-8") - - rpc = json.loads(payload) + rpc = msgpack.unpackb(msg.data) logger.debug("wslink incoming msg %s", self.payloadWithSecretStripped(rpc)) if "id" not in rpc: - # should be a binary attachment header - if rpc.get("method") == "wslink.binary.attachment": - keys = rpc.get("args", []) - if isinstance(keys, list): - for k in keys: - # wait for an attachment by it's order - self.attachmentsRecvQueue.append(k) return # TODO validate @@ -303,73 +279,37 @@ async def onMessage(self, is_binary, msg, client_id): return obj, func = self.functionMap[methodName] + args.insert(0, obj) + try: - # get any attachments - def findAttachments(o): - if ( - isinstance(o, str) - and re.match(r"^wslink_bin\d+$", o) - and o in self.attachmentsReceived - ): - attachment = self.attachmentsReceived[o] - del self.attachmentsReceived[o] - return attachment - elif isinstance(o, list): - for i, v in enumerate(o): - o[i] = findAttachments(v) - elif isinstance(o, dict): - for k in o: - o[k] = findAttachments(o[k]) - return o - - args = findAttachments(args) - kwargs = findAttachments(kwargs) - - args.insert(0, obj) - - try: - self.web_app.last_active_client_id = client_id - results = func(*args, **kwargs) - if inspect.isawaitable(results): - results = await results - - if self.connections[client_id].closed: - # Connection was closed during RPC call. - return - - await self.sendWrappedMessage( - rpcid, results, method=methodName, client_id=client_id - ) - except Exception as e_inst: - captured_trace = traceback.format_exc() - logger.error("Exception raised") - logger.error(repr(e_inst)) - logger.error(captured_trace) - await self.sendWrappedError( - rpcid, - EXCEPTION_ERROR, - "Exception raised", - { - "method": methodName, - "exception": repr(e_inst), - "trace": captured_trace, - }, - client_id=client_id, - ) + self.web_app.last_active_client_id = client_id + results = func(*args, **kwargs) + if inspect.isawaitable(results): + results = await results - except Exception as e: + if self.connections[client_id].closed: + # Connection was closed during RPC call. + return + + await self.sendWrappedMessage( + rpcid, results, method=methodName, client_id=client_id + ) + except Exception as e_inst: + captured_trace = traceback.format_exc() + logger.error("Exception raised") + logger.error(repr(e_inst)) + logger.error(captured_trace) await self.sendWrappedError( rpcid, EXCEPTION_ERROR, "Exception raised", { "method": methodName, - "exception": repr(e), - "trace": traceback.format_exc(), + "exception": repr(e_inst), + "trace": captured_trace, }, client_id=client_id, ) - return def payloadWithSecretStripped(self, payload): payload = copy.deepcopy(payload) @@ -428,9 +368,10 @@ async def sendWrappedMessage( "id": rpcid, "result": content, } + try: - encMsg = json.dumps(wrapper, ensure_ascii=False) - except TypeError as e: + packed_wrapper = msgpack.packb(wrapper) + except Exception: # the content which is not serializable might be arbitrarily large, don't include. # repr(content) would do that... await self.sendWrappedError( @@ -444,47 +385,13 @@ async def sendWrappedMessage( websockets = self.getAuthenticatedWebsockets(client_id, skip_last_active_client) - # Check if any attachments in the map go with this message - attachments = self.pub_manager.getAttachmentMap() - found_keys = [] - if attachments: - for key in attachments: - # string match the encoded attachment key - if key in encMsg: - if key not in found_keys: - found_keys.append(key) - # increment for key - self.pub_manager.registerAttachment(key) - - for key in found_keys: - # send header - header = { - "wslink": "1.0", - "method": "wslink.binary.attachment", - "args": [key], - } - json_header = json.dumps(header, ensure_ascii=False) - - # aiohttp can not handle pending ws.send_bytes() - # tried with semaphore but got exception with >1 - # https://github.com/aio-libs/aiohttp/issues/2934 - async with self.attachment_atomic: - for ws in websockets: - if ws is not None: - # Send binary header - await ws.send_str(json_header) - # Send binary message - await ws.send_bytes(attachments[key]) - - # decrement for key - self.pub_manager.unregisterAttachment(key) - - for ws in websockets: - if ws is not None: - await ws.send_str(encMsg) - - loop = asyncio.get_event_loop() - loop.call_soon(self.pub_manager.freeAttachments, found_keys) + # aiohttp can not handle pending ws.send_bytes() + # tried with semaphore but got exception with >1 + # https://github.com/aio-libs/aiohttp/issues/2934 + async with self.attachment_atomic: + for ws in websockets: + if ws is not None: + await ws.send_bytes(packed_wrapper) async def sendWrappedError(self, rpcid, code, message, data=None, client_id=None): wrapper = { @@ -497,15 +404,25 @@ async def sendWrappedError(self, rpcid, code, message, data=None, client_id=None } if data: wrapper["error"]["data"] = data - encMsg = json.dumps(wrapper, ensure_ascii=False) + + try: + packed_wrapper = msgpack.packb(wrapper) + except Exception: + del wrapper["error"]["data"] + packed_wrapper = msgpack.packb(wrapper) + websockets = ( [self.connections[client_id]] if client_id else [self.connections[c] for c in self.connections] ) - for ws in websockets: - if ws is not None: - await ws.send_str(encMsg) + # aiohttp can not handle pending ws.send_bytes() + # tried with semaphore but got exception with >1 + # https://github.com/aio-libs/aiohttp/issues/2934 + async with self.attachment_atomic: + for ws in websockets: + if ws is not None: + await ws.send_bytes(packed_wrapper) def publish(self, topic, data, client_id=None, skip_last_active_client=False): client_list = [client_id] if client_id else [c_id for c_id in self.connections] diff --git a/python/src/wslink/publish.py b/python/src/wslink/publish.py index ee002920..894ee225 100644 --- a/python/src/wslink/publish.py +++ b/python/src/wslink/publish.py @@ -7,9 +7,6 @@ class PublishManager(object): def __init__(self): self.protocols = [] - self.attachmentMap = {} - self.attachmentRefCounts = {} # keyed same as attachment map - self.attachmentId = 0 self.publishCount = 0 def registerProtocol(self, protocol): @@ -19,38 +16,12 @@ def unregisterProtocol(self, protocol): if protocol in self.protocols: self.protocols.remove(protocol) - def getAttachmentMap(self): - return self.attachmentMap - - def clearAttachmentMap(self): - self.attachmentMap.clear() - - def registerAttachment(self, attachKey): - self.attachmentRefCounts[attachKey] += 1 - - def unregisterAttachment(self, attachKey): - self.attachmentRefCounts[attachKey] -= 1 - - def freeAttachments(self, keys=None): - keys_to_delete = [] - keys_to_check = keys if keys is not None else [k for k in self.attachmentMap] - - for key in keys_to_check: - if self.attachmentRefCounts.get(key) == 0: - keys_to_delete.append(key) - - for key in keys_to_delete: - self.attachmentMap.pop(key) - self.attachmentRefCounts.pop(key) - def addAttachment(self, payload): - # print("attachment", self, self.attachmentId) - # use a string flag in place of the binary attachment. - binaryId = "wslink_bin{0}".format(self.attachmentId) - self.attachmentMap[binaryId] = payload - self.attachmentRefCounts[binaryId] = 0 - self.attachmentId += 1 - return binaryId + """Deprecated method, keeping it to avoid breaking compatibility + Now that we use msgpack to pack/unpack messages, + We can have binary data directly in the object itself, + without needing to transfer it separately from the rest.""" + return payload def publish(self, topic, data, client_id=None, skip_last_active_client=False): for protocol in self.protocols: From 9bb8bc460e3d4c6d70f1e93a7c07ba86e0e1e5cd Mon Sep 17 00:00:00 2001 From: Alessandro Genova Date: Sun, 7 Apr 2024 23:23:40 -0400 Subject: [PATCH 2/4] feat(chunking): implement chunking of client/server messages BREAKING CHANGE: each message has a header and is possibly chunked --- js/src/WebsocketConnection/chunking.js | 211 ++++++++++++++++++ js/src/WebsocketConnection/chunking.ts | 296 +++++++++++++++++++++++++ js/src/WebsocketConnection/session.js | 160 ++++++------- python/src/wslink/chunking.py | 209 +++++++++++++++++ python/src/wslink/protocol.py | 44 +++- 5 files changed, 832 insertions(+), 88 deletions(-) create mode 100644 js/src/WebsocketConnection/chunking.js create mode 100644 js/src/WebsocketConnection/chunking.ts create mode 100644 python/src/wslink/chunking.py diff --git a/js/src/WebsocketConnection/chunking.js b/js/src/WebsocketConnection/chunking.js new file mode 100644 index 00000000..41fdb029 --- /dev/null +++ b/js/src/WebsocketConnection/chunking.js @@ -0,0 +1,211 @@ +// Project not setup for typescript, manually compiling this file to chunker.js +// npx tsc chunking.ts --target esnext +const UINT32_LENGTH = 4; +const ID_LOCATION = 0; +const ID_LENGTH = UINT32_LENGTH; +const MESSAGE_OFFSET_LOCATION = ID_LOCATION + ID_LENGTH; +const MESSAGE_OFFSET_LENGTH = UINT32_LENGTH; +const MESSAGE_SIZE_LOCATION = MESSAGE_OFFSET_LOCATION + MESSAGE_OFFSET_LENGTH; +const MESSAGE_SIZE_LENGTH = UINT32_LENGTH; +const HEADER_LENGTH = ID_LENGTH + MESSAGE_OFFSET_LENGTH + MESSAGE_SIZE_LENGTH; +function encodeHeader(id, offset, size) { + const buffer = new ArrayBuffer(HEADER_LENGTH); + const header = new Uint8Array(buffer); + const view = new DataView(buffer); + view.setUint32(ID_LOCATION, id, true); + view.setUint32(MESSAGE_OFFSET_LOCATION, offset, true); + view.setUint32(MESSAGE_SIZE_LOCATION, size, true); + return header; +} +function decodeHeader(header) { + const view = new DataView(header.buffer); + const id = view.getUint32(ID_LOCATION, true); + const offset = view.getUint32(MESSAGE_OFFSET_LOCATION, true); + const size = view.getUint32(MESSAGE_SIZE_LOCATION, true); + return { id, offset, size }; +} +function* generateChunks(message, maxSize) { + const totalSize = message.byteLength; + let maxContentSize; + if (maxSize === 0) { + maxContentSize = totalSize; + } else { + maxContentSize = Math.max(maxSize - HEADER_LENGTH, 1); + } + const id = new Uint32Array(1); + crypto.getRandomValues(id); + let offset = 0; + while (offset < totalSize) { + const contentSize = Math.min(maxContentSize, totalSize - offset); + const chunk = new Uint8Array(new ArrayBuffer(HEADER_LENGTH + contentSize)); + const header = encodeHeader(id[0], offset, totalSize); + chunk.set(new Uint8Array(header.buffer), 0); + chunk.set(message.subarray(offset, offset + contentSize), HEADER_LENGTH); + yield chunk; + offset += contentSize; + } + return; +} +/* + This un-chunker is vulnerable to DOS. + If it receives a message with a header claiming a large incoming message + it will allocate the memory blindly even without actually receiving the content + Chunks for a given message can come in any order + Chunks across messages can be interleaved. +*/ +class UnChunker { + pendingMessages; + constructor() { + this.pendingMessages = {}; + } + releasePendingMessages() { + this.pendingMessages = {}; + } + async processChunk(chunk, decoderFactory) { + const headerBlob = chunk.slice(0, HEADER_LENGTH); + const contentBlob = chunk.slice(HEADER_LENGTH); + const header = new Uint8Array(await headerBlob.arrayBuffer()); + const { id, offset, size: totalSize } = decodeHeader(header); + let pendingMessage = this.pendingMessages[id]; + if (!pendingMessage) { + pendingMessage = { + receivedSize: 0, + content: new Uint8Array(totalSize), + decoder: decoderFactory(), + }; + this.pendingMessages[id] = pendingMessage; + } + // This should never happen, but still check it + if (totalSize !== pendingMessage.content.byteLength) { + delete this.pendingMessages[id]; + throw new Error( + `Total size in chunk header for message ${id} does not match total size declared by previous chunk.` + ); + } + const chunkContent = new Uint8Array(await contentBlob.arrayBuffer()); + const content = pendingMessage.content; + content.set(chunkContent, offset); + pendingMessage.receivedSize += chunkContent.byteLength; + if (pendingMessage.receivedSize >= totalSize) { + delete this.pendingMessages[id]; + try { + return pendingMessage['decoder'].decode(content); + } catch (e) { + console.error('Malformed message: ', content.slice(0, 100)); + // debugger; + } + } + return undefined; + } +} +// Makes sure messages are processed in order of arrival, +export class SequentialTaskQueue { + taskId; + pendingTaskId; + tasks; + constructor() { + this.taskId = 0; + this.pendingTaskId = -1; + this.tasks = {}; + } + enqueue(fn, ...args) { + return new Promise((resolve, reject) => { + const taskId = this.taskId++; + this.tasks[taskId] = { fn, args, resolve, reject }; + this._maybeExecuteNext(); + }); + } + _maybeExecuteNext() { + let pendingTask = this.tasks[this.pendingTaskId]; + if (pendingTask) { + return; + } + const nextPendingTaskId = this.pendingTaskId + 1; + pendingTask = this.tasks[nextPendingTaskId]; + if (!pendingTask) { + return; + } + this.pendingTaskId = nextPendingTaskId; + const { fn, args, resolve, reject } = pendingTask; + fn(...args) + .then((result) => { + resolve(result); + delete this.tasks[nextPendingTaskId]; + this._maybeExecuteNext(); + }) + .catch((err) => { + reject(err); + delete this.tasks[nextPendingTaskId]; + this._maybeExecuteNext(); + }); + } +} +/* + This un-chunker is more memory efficient + (each chunk is passed immediately to msgpack) + and it will only allocate memory when it receives content. + Chunks for a given message are expected to come sequentially + Chunks across messages can be interleaved. +*/ +class StreamUnChunker { + pendingMessages; + constructor() { + this.pendingMessages = {}; + } + processChunk = async (chunk, decoderFactory) => { + const headerBlob = chunk.slice(0, HEADER_LENGTH); + const header = new Uint8Array(await headerBlob.arrayBuffer()); + const { id, offset, size: totalSize } = decodeHeader(header); + const contentBlob = chunk.slice(HEADER_LENGTH); + let pendingMessage = this.pendingMessages[id]; + if (!pendingMessage) { + pendingMessage = { + receivedSize: 0, + totalSize: totalSize, + decoder: decoderFactory(), + }; + this.pendingMessages[id] = pendingMessage; + } + // This should never happen, but still check it + if (totalSize !== pendingMessage.totalSize) { + delete this.pendingMessages[id]; + throw new Error( + `Total size in chunk header for message ${id} does not match total size declared by previous chunk.` + ); + } + // This should never happen, but still check it + if (offset !== pendingMessage.receivedSize) { + delete this.pendingMessages[id]; + throw new Error(`Received an unexpected chunk for message ${id}. + Expected offset = ${pendingMessage.receivedSize}, + Received offset = ${offset}.`); + } + let result; + try { + result = await pendingMessage.decoder.decodeAsync(contentBlob.stream()); + } catch (e) { + if (e instanceof RangeError) { + // More data is needed, it should come in the next chunk + result = undefined; + } + } + pendingMessage.receivedSize += contentBlob.size; + /* + In principle feeding a stream to the unpacker could yield multiple outputs + for example unpacker.feed(b'0123') would yield b'0', b'1', ect + or concatenated packed payloads would yield two or more unpacked objects + but in our use case we expect a full message to be mapped to a single object + */ + if (result && pendingMessage.receivedSize < totalSize) { + delete this.pendingMessages[id]; + throw new Error(`Received a parsable payload shorter than expected for message ${id}. + Expected size = ${totalSize}, + Received size = ${pendingMessage.receivedSize}.`); + } + if (pendingMessage.receivedSize >= totalSize) { + delete this.pendingMessages[id]; + } + return result; + }; +} +export { UnChunker, StreamUnChunker, generateChunks }; diff --git a/js/src/WebsocketConnection/chunking.ts b/js/src/WebsocketConnection/chunking.ts new file mode 100644 index 00000000..42d6b9eb --- /dev/null +++ b/js/src/WebsocketConnection/chunking.ts @@ -0,0 +1,296 @@ +// Project not setup for typescript, manually compiling this file to chunker.js +// npx tsc chunking.ts --target esnext + +const UINT32_LENGTH = 4; +const ID_LOCATION = 0; +const ID_LENGTH = UINT32_LENGTH; +const MESSAGE_OFFSET_LOCATION = ID_LOCATION + ID_LENGTH; +const MESSAGE_OFFSET_LENGTH = UINT32_LENGTH; +const MESSAGE_SIZE_LOCATION = MESSAGE_OFFSET_LOCATION + MESSAGE_OFFSET_LENGTH; +const MESSAGE_SIZE_LENGTH = UINT32_LENGTH; + +const HEADER_LENGTH = ID_LENGTH + MESSAGE_OFFSET_LENGTH + MESSAGE_SIZE_LENGTH; + +function encodeHeader(id: number, offset: number, size: number): Uint8Array { + const buffer = new ArrayBuffer(HEADER_LENGTH); + const header = new Uint8Array(buffer); + const view = new DataView(buffer); + view.setUint32(ID_LOCATION, id, true); + view.setUint32(MESSAGE_OFFSET_LOCATION, offset, true); + view.setUint32(MESSAGE_SIZE_LOCATION, size, true); + + return header; +} + +function decodeHeader(header: Uint8Array) { + const view = new DataView(header.buffer); + const id = view.getUint32(ID_LOCATION, true); + const offset = view.getUint32(MESSAGE_OFFSET_LOCATION, true); + const size = view.getUint32(MESSAGE_SIZE_LOCATION, true); + + return { id, offset, size }; +} + +function* generateChunks(message: Uint8Array, maxSize: number) { + const totalSize = message.byteLength; + let maxContentSize: number; + + if (maxSize === 0) { + maxContentSize = totalSize; + } else { + maxContentSize = Math.max(maxSize - HEADER_LENGTH, 1); + } + + const id = new Uint32Array(1); + crypto.getRandomValues(id); + + let offset = 0; + + while (offset < totalSize) { + const contentSize = Math.min(maxContentSize, totalSize - offset); + const chunk = new Uint8Array(new ArrayBuffer(HEADER_LENGTH + contentSize)); + const header = encodeHeader(id[0], offset, totalSize); + chunk.set(new Uint8Array(header.buffer), 0); + chunk.set(message.subarray(offset, offset + contentSize), HEADER_LENGTH); + + yield chunk; + + offset += contentSize; + } + + return; +} + +type PendingMessage = { + receivedSize: number; + content: Uint8Array; + decoder: any; +}; + +/* + This un-chunker is vulnerable to DOS. + If it receives a message with a header claiming a large incoming message + it will allocate the memory blindly even without actually receiving the content + Chunks for a given message can come in any order + Chunks across messages can be interleaved. +*/ +class UnChunker { + private pendingMessages: { [key: number]: PendingMessage }; + + constructor() { + this.pendingMessages = {}; + } + + releasePendingMessages() { + this.pendingMessages = {}; + } + + async processChunk( + chunk: Blob, + decoderFactory: () => any + ): Promise { + const headerBlob = chunk.slice(0, HEADER_LENGTH); + const contentBlob = chunk.slice(HEADER_LENGTH); + + const header = new Uint8Array(await headerBlob.arrayBuffer()); + const { id, offset, size: totalSize } = decodeHeader(header); + + let pendingMessage = this.pendingMessages[id]; + + if (!pendingMessage) { + pendingMessage = { + receivedSize: 0, + content: new Uint8Array(totalSize), + decoder: decoderFactory(), + }; + + this.pendingMessages[id] = pendingMessage; + } + + // This should never happen, but still check it + if (totalSize !== pendingMessage.content.byteLength) { + delete this.pendingMessages[id]; + throw new Error( + `Total size in chunk header for message ${id} does not match total size declared by previous chunk.` + ); + } + + const chunkContent = new Uint8Array(await contentBlob.arrayBuffer()); + const content = pendingMessage.content; + content.set(chunkContent, offset); + pendingMessage.receivedSize += chunkContent.byteLength; + + if (pendingMessage.receivedSize >= totalSize) { + delete this.pendingMessages[id]; + + try { + return pendingMessage['decoder'].decode(content); + } catch (e) { + console.error('Malformed message: ', content.slice(0, 100)); + // debugger; + } + } + + return undefined; + } +} + +type StreamPendingMessage = { + receivedSize: number; + totalSize: number; + decoder: any; +}; + +// Makes sure messages are processed in order of arrival, +export class SequentialTaskQueue { + taskId: number; + pendingTaskId: number; + tasks: { + [id: number]: { + fn: (...args: any) => Promise; + args: any[]; + resolve: (value: any) => void; + reject: (err: any) => void; + }; + }; + + constructor() { + this.taskId = 0; + this.pendingTaskId = -1; + this.tasks = {}; + } + + enqueue(fn: (...args: any) => Promise, ...args: any[]) { + return new Promise((resolve, reject) => { + const taskId = this.taskId++; + this.tasks[taskId] = { fn, args, resolve, reject }; + this._maybeExecuteNext(); + }); + } + + _maybeExecuteNext() { + let pendingTask = this.tasks[this.pendingTaskId]; + + if (pendingTask) { + return; + } + + const nextPendingTaskId = this.pendingTaskId + 1; + + pendingTask = this.tasks[nextPendingTaskId]; + + if (!pendingTask) { + return; + } + + this.pendingTaskId = nextPendingTaskId; + + const { fn, args, resolve, reject } = pendingTask; + + fn(...args) + .then((result) => { + resolve(result); + delete this.tasks[nextPendingTaskId]; + this._maybeExecuteNext(); + }) + .catch((err) => { + reject(err); + delete this.tasks[nextPendingTaskId]; + this._maybeExecuteNext(); + }); + } +} + +/* + This un-chunker is more memory efficient + (each chunk is passed immediately to msgpack) + and it will only allocate memory when it receives content. + Chunks for a given message are expected to come sequentially + Chunks across messages can be interleaved. +*/ +class StreamUnChunker { + private pendingMessages: { [key: number]: StreamPendingMessage }; + + constructor() { + this.pendingMessages = {}; + } + + processChunk = async ( + chunk: Blob, + decoderFactory: () => any + ): Promise => { + const headerBlob = chunk.slice(0, HEADER_LENGTH); + + const header = new Uint8Array(await headerBlob.arrayBuffer()); + const { id, offset, size: totalSize } = decodeHeader(header); + + const contentBlob = chunk.slice(HEADER_LENGTH); + + let pendingMessage = this.pendingMessages[id]; + + if (!pendingMessage) { + pendingMessage = { + receivedSize: 0, + totalSize: totalSize, + decoder: decoderFactory(), + }; + + this.pendingMessages[id] = pendingMessage; + } + + // This should never happen, but still check it + if (totalSize !== pendingMessage.totalSize) { + delete this.pendingMessages[id]; + throw new Error( + `Total size in chunk header for message ${id} does not match total size declared by previous chunk.` + ); + } + + // This should never happen, but still check it + if (offset !== pendingMessage.receivedSize) { + delete this.pendingMessages[id]; + throw new Error( + `Received an unexpected chunk for message ${id}. + Expected offset = ${pendingMessage.receivedSize}, + Received offset = ${offset}.` + ); + } + + let result: unknown; + try { + result = await pendingMessage.decoder.decodeAsync( + contentBlob.stream() as any + ); + } catch (e) { + if (e instanceof RangeError) { + // More data is needed, it should come in the next chunk + result = undefined; + } + } + + pendingMessage.receivedSize += contentBlob.size; + + /* + In principle feeding a stream to the unpacker could yield multiple outputs + for example unpacker.feed(b'0123') would yield b'0', b'1', ect + or concatenated packed payloads would yield two or more unpacked objects + but in our use case we expect a full message to be mapped to a single object + */ + if (result && pendingMessage.receivedSize < totalSize) { + delete this.pendingMessages[id]; + throw new Error( + `Received a parsable payload shorter than expected for message ${id}. + Expected size = ${totalSize}, + Received size = ${pendingMessage.receivedSize}.` + ); + } + + if (pendingMessage.receivedSize >= totalSize) { + delete this.pendingMessages[id]; + } + + return result; + }; +} + +export { UnChunker, StreamUnChunker, generateChunks }; diff --git a/js/src/WebsocketConnection/session.js b/js/src/WebsocketConnection/session.js index c647fb07..4517103c 100644 --- a/js/src/WebsocketConnection/session.js +++ b/js/src/WebsocketConnection/session.js @@ -1,6 +1,7 @@ // Helper borrowed from paraviewweb/src/Common/Core import CompositeClosureHelper from '../CompositeClosureHelper'; -import { Encoder, Decoder } from "@msgpack/msgpack"; +import { UnChunker, generateChunks } from './chunking'; +import { Encoder, Decoder } from '@msgpack/msgpack'; function defer() { const deferred = {}; @@ -24,21 +25,72 @@ function Session(publicAPI, model) { const regexRPC = /^(rpc|publish|system):(\w+(?:\.\w+)*):(?:\d+)$/; const subscriptions = {}; let clientID = null; - const encoder = CustomEncoder(); - const decoder = Decoder(); + let MAX_MSG_SIZE = 512 * 1024; + const unchunker = new UnChunker(); // -------------------------------------------------------------------------- // Private helpers // -------------------------------------------------------------------------- - async function decodeFromBlob(blob) { - if (blob.stream) { - // Blob#stream(): ReadableStream (recommended) - return await decoder.decodeAsync(blob.stream()); + function onCompleteMessage(payload) { + if (!payload) return; + if (!payload.id) return; + if (payload.error) { + const deferred = inFlightRpc[payload.id]; + if (deferred) { + deferred.reject(payload.error); + } else { + console.error('Server error:', payload.error); + } } else { - // Blob#arrayBuffer(): Promise (if stream() is not available) - return decoder.decode(await blob.arrayBuffer()); + const match = regexRPC.exec(payload.id); + if (match) { + const type = match[1]; + if (type === 'rpc') { + const deferred = inFlightRpc[payload.id]; + if (!deferred) { + console.log( + 'session message id without matching call, dropped', + payload + ); + return; + } + deferred.resolve(payload.result); + } else if (type == 'publish') { + console.assert( + inFlightRpc[payload.id] === undefined, + 'publish message received matching in-flight rpc call' + ); + // regex extracts the topic for us. + const topic = match[2]; + if (!subscriptions[topic]) { + return; + } + // for each callback, provide the message data. Wrap in an array, for back-compatibility with WAMP + subscriptions[topic].forEach((callback) => + callback([payload.result]) + ); + } else if (type == 'system') { + // console.log('DBG system:', payload.id, payload.result); + const deferred = inFlightRpc[payload.id]; + if (payload.id === 'system:c0:0') { + clientID = payload.result.clientID; + MAX_MSG_SIZE = payload.result.maxMsgSize || MAX_MSG_SIZE; + if (deferred) deferred.resolve(clientID); + } else { + console.error('Unknown system message', payload.id); + if (deferred) + deferred.reject({ + code: CLIENT_ERROR, + message: `Unknown system message ${payload.id}`, + }); + } + } else { + console.error('Unknown rpc id format', payload.id); + } + } } + delete inFlightRpc[payload.id]; } // -------------------------------------------------------------------------- @@ -57,11 +109,15 @@ function Session(publicAPI, model) { method: 'wslink.hello', args: [{ secret: model.secret }], kwargs: {}, - } + }; + const encoder = new CustomEncoder(); const packedWrapper = encoder.encode(wrapper); - model.ws.send(packedWrapper, { binary: true }); + for (let chunk of generateChunks(packedWrapper, MAX_MSG_SIZE)) { + model.ws.send(chunk, { binary: true }); + } + return deferred.promise; }; @@ -76,9 +132,13 @@ function Session(publicAPI, model) { inFlightRpc[id] = deferred; const wrapper = { wslink: '1.0', id, method, args, kwargs }; + + const encoder = new CustomEncoder(); const packedWrapper = encoder.encode(wrapper); - model.ws.send(packedWrapper, { binary: true }); + for (let chunk of generateChunks(packedWrapper, MAX_MSG_SIZE)) { + model.ws.send(chunk, { binary: true }); + } } else { deferred.reject({ code: CLIENT_ERROR, @@ -144,78 +204,22 @@ function Session(publicAPI, model) { const deferred = defer(); // some transports might be able to close the session without closing the connection. Not true for websocket... model.ws.close(); + unchunker.releasePendingMessages(); deferred.resolve(); return deferred.promise; }; // -------------------------------------------------------------------------- + function createDecoder() { + return new Decoder(); + } + publicAPI.onmessage = async (event) => { - { - let payload; - try { - payload = await decodeFromBlob(event.data); - } catch (e) { - console.error('Malformed message: ', event.data); - // debugger; - } - if (!payload) return; - if (!payload.id) return; - if (payload.error) { - const deferred = inFlightRpc[payload.id]; - if (deferred) { - deferred.reject(payload.error); - } else { - console.error('Server error:', payload.error); - } - } else { - const match = regexRPC.exec(payload.id); - if (match) { - const type = match[1]; - if (type === 'rpc') { - const deferred = inFlightRpc[payload.id]; - if (!deferred) { - console.log( - 'session message id without matching call, dropped', - payload - ); - return; - } - deferred.resolve(payload.result); - } else if (type == 'publish') { - console.assert( - inFlightRpc[payload.id] === undefined, - 'publish message received matching in-flight rpc call' - ); - // regex extracts the topic for us. - const topic = match[2]; - if (!subscriptions[topic]) { - return; - } - // for each callback, provide the message data. Wrap in an array, for back-compatibility with WAMP - subscriptions[topic].forEach((callback) => - callback([payload.result]) - ); - } else if (type == 'system') { - // console.log('DBG system:', payload.id, payload.result); - const deferred = inFlightRpc[payload.id]; - if (payload.id === 'system:c0:0') { - clientID = payload.result.clientID; - if (deferred) deferred.resolve(clientID); - } else { - console.error('Unknown system message', payload.id); - if (deferred) - deferred.reject({ - code: CLIENT_ERROR, - message: `Unknown system message ${payload.id}`, - }); - } - } else { - console.error('Unknown rpc id format', payload.id); - } - } - } - delete inFlightRpc[payload.id]; + const message = await unchunker.processChunk(event.data, createDecoder); + + if (message) { + onCompleteMessage(message); } }; @@ -275,6 +279,6 @@ class CustomEncoder extends Encoder { object = new DataView(object); } - return super.encodeObject(object, depth); + return super.encodeObject.call(this, object, depth); } } diff --git a/python/src/wslink/chunking.py b/python/src/wslink/chunking.py new file mode 100644 index 00000000..940c057b --- /dev/null +++ b/python/src/wslink/chunking.py @@ -0,0 +1,209 @@ +import random +from typing import TypedDict, Dict, Tuple, Union +import msgpack + +UINT32_LENGTH = 4 +ID_LOCATION = 0 +ID_LENGTH = UINT32_LENGTH +MESSAGE_OFFSET_LOCATION = ID_LOCATION + ID_LENGTH +MESSAGE_OFFSET_LENGTH = UINT32_LENGTH +MESSAGE_SIZE_LOCATION = MESSAGE_OFFSET_LOCATION + MESSAGE_OFFSET_LENGTH +MESSAGE_SIZE_LENGTH = UINT32_LENGTH + +HEADER_LENGTH = ID_LENGTH + MESSAGE_OFFSET_LENGTH + MESSAGE_SIZE_LENGTH + + +def _encode_header(id: bytes, offset: int, size: int) -> bytes: + return ( + id + + offset.to_bytes(MESSAGE_OFFSET_LENGTH, "little", signed=False) + + size.to_bytes(MESSAGE_SIZE_LENGTH, "little", signed=False) + ) + + +def _decode_header(header: bytes) -> Tuple[bytes, int, int]: + id = header[ID_LOCATION:ID_LENGTH] + offset = int.from_bytes( + header[ + MESSAGE_OFFSET_LOCATION : MESSAGE_OFFSET_LOCATION + MESSAGE_OFFSET_LENGTH + ], + "little", + signed=False, + ) + size = int.from_bytes( + header[MESSAGE_SIZE_LOCATION : MESSAGE_SIZE_LOCATION + MESSAGE_SIZE_LENGTH], + "little", + signed=False, + ) + return id, offset, size + + +def generate_chunks(message: bytes, max_size: int): + total_size = len(message) + + if max_size == 0: + max_content_size = total_size + else: + max_content_size = max(max_size - HEADER_LENGTH, 1) + + id = random.randbytes(ID_LENGTH) + + offset = 0 + + while offset < total_size: + header = _encode_header(id, offset, total_size) + chunk_content = message[offset : offset + max_content_size] + + yield header + chunk_content + + offset += max_content_size + + return + + +class PendingMessage(TypedDict): + received_size: int + content: bytearray + + +# This un-chunker is vulnerable to DOS. +# If it receives a message with a header claiming a large incoming message +# it will allocate the memory blindly even without actually receiving the content +# Chunks for a given message can come in any order +# Chunks across messages can be interleaved. +class UnChunker: + pending_messages: Dict[bytes, PendingMessage] + max_message_size: int + + def __init__(self): + self.pending_messages = {} + self.max_message_size = 512 + + def set_max_message_size(self, size): + self.max_message_size = size + + def release_pending_messages(self): + self.pending_messages = {} + + def process_chunk(self, chunk: bytes) -> Union[bytes, None]: + header, chunk_content = chunk[:HEADER_LENGTH], chunk[HEADER_LENGTH:] + id, offset, total_size = _decode_header(header) + + pending_message = self.pending_messages.get(id, None) + + if pending_message is None: + if total_size > self.max_message_size: + raise ValueError( + f"""Total size for message {id} exceeds the allocation limit allowed. + Maximum size = {self.max_message_size}, + Received size = {total_size}.""" + ) + + pending_message = PendingMessage( + received_size=0, content=bytearray(total_size) + ) + self.pending_messages[id] = pending_message + + # This should never happen, but still check it + if total_size != len(pending_message["content"]): + del self.pending_messages[id] + raise ValueError( + f"Total size in chunk header for message {id} does not match total size declared by previous chunk." + ) + + content_size = len(chunk_content) + content_view = memoryview(pending_message["content"]) + content_view[offset : offset + content_size] = chunk_content + pending_message["received_size"] += content_size + + if pending_message["received_size"] >= total_size: + full_message = pending_message["content"] + del self.pending_messages[id] + return msgpack.unpackb(bytes(full_message)) + + return None + + +class StreamPendingMessage(TypedDict): + received_size: int + total_size: int + unpacker: msgpack.Unpacker + + +# This un-chunker is more memory efficient +# (each chunk is passed immediately to msgpack) +# and it will only allocate memory when it receives content. +# Chunks for a given message are expected to come sequentially +# Chunks across messages can be interleaved. +class StreamUnChunker: + pending_messages: Dict[bytes, StreamPendingMessage] + + def __init__(self): + self.pending_messages = {} + + def set_max_message_size(self, _size): + pass + + def release_pending_messages(self): + self.pending_messages = {} + + def process_chunk(self, chunk: bytes) -> Union[bytes, None]: + header, chunk_content = chunk[:HEADER_LENGTH], chunk[HEADER_LENGTH:] + id, offset, total_size = _decode_header(header) + + pending_message = self.pending_messages.get(id, None) + + if pending_message is None: + pending_message = StreamPendingMessage( + received_size=0, + total_size=total_size, + unpacker=msgpack.Unpacker(max_buffer_size=total_size), + ) + self.pending_messages[id] = pending_message + + # This should never happen, but still check it + if offset != pending_message["received_size"]: + del self.pending_messages[id] + raise ValueError( + f"""Received an unexpected chunk for message {id}. + Expected offset = {pending_message['received_size']}, + Received offset = {offset}.""" + ) + + # This should never happen, but still check it + if total_size != pending_message["total_size"]: + del self.pending_messages[id] + raise ValueError( + f"""Received an unexpected total size in chunk header for message {id}. + Expected size = {pending_message['total_size']}, + Received size = {total_size}.""" + ) + + content_size = len(chunk_content) + pending_message["received_size"] += content_size + + unpacker = pending_message["unpacker"] + unpacker.feed(chunk_content) + + full_message = None + + try: + full_message = unpacker.unpack() + except msgpack.OutOfData: + pass # message is incomplete, keep ingesting chunks + + if full_message is not None: + del self.pending_messages[id] + + if pending_message["received_size"] < total_size: + # In principle feeding a stream to the unpacker could yield multiple outputs + # for example unpacker.feed(b'0123') would yield b'0', b'1', ect + # or concatenated packed payloads would yield two or more unpacked objects + # but in our use case we expect a full message to be mapped to a single object + raise ValueError( + f"""Received a parsable payload shorter than expected for message {id}. + Expected size = {total_size}, + Received size = {pending_message['received_size']}.""" + ) + + return full_message diff --git a/python/src/wslink/protocol.py b/python/src/wslink/protocol.py index fdd50ffa..6ffaef04 100644 --- a/python/src/wslink/protocol.py +++ b/python/src/wslink/protocol.py @@ -1,14 +1,14 @@ import asyncio import copy import inspect -import json import logging import msgpack -import re +import os import traceback from wslink import schedule_coroutine from wslink.publish import PublishManager +from wslink.chunking import generate_chunks, UnChunker # from http://www.jsonrpc.org/specification, section 5.1 METHOD_NOT_FOUND = -32601 @@ -18,6 +18,9 @@ # used in client JS code: CLIENT_ERROR = -32099 +# 4MB is the default inside aiohttp +MAX_MSG_SIZE = int(os.environ.get("WSLINK_MAX_MSG_SIZE", 4194304)) + logger = logging.getLogger(__name__) @@ -148,6 +151,7 @@ def __init__(self, protocol=None, web_app=None): self.authentified_client_ids = set() self.attachment_atomic = asyncio.Lock() self.pub_manager = PublishManager() + self.unchunkers = {} # Build the rpc method dictionary, assuming we were given a serverprotocol if self.getServerProtocol(): @@ -184,6 +188,8 @@ def reverse_connection_client_id(self): return "reverse_connection_client_id" async def onConnect(self, request, client_id): + self.unchunkers[client_id] = UnChunker() + if not self.serverProtocol: return if hasattr(self.serverProtocol, "onConnect"): @@ -193,6 +199,8 @@ async def onConnect(self, request, client_id): linkProtocol.onConnect(request, client_id) async def onClose(self, client_id): + del self.unchunkers[client_id] + if not self.serverProtocol: return if hasattr(self.serverProtocol, "onClose"): @@ -213,9 +221,16 @@ async def handleSystemMessage(self, rpcid, methodName, args, client_id): and await self.validateToken(args[0]["secret"], client_id) ): self.authentified_client_ids.add(client_id) + # Once a client is authenticated let the unchunker allocate memory unrestricted + self.unchunkers[client_id].set_max_message_size( + 4 * 1024 * 1024 * 1024 + ) # 4GB await self.sendWrappedMessage( rpcid, - {"clientID": "c{0}".format(client_id)}, + { + "clientID": "c{0}".format(client_id), + "maxMsgSize": MAX_MSG_SIZE, + }, client_id=client_id, ) else: @@ -236,7 +251,14 @@ async def handleSystemMessage(self, rpcid, methodName, args, client_id): return False async def onMessage(self, is_binary, msg, client_id): - rpc = msgpack.unpackb(msg.data) + if not is_binary: + return + + full_message = self.unchunkers[client_id].process_chunk(msg.data) + if full_message is not None: + await self.onCompleteMessage(full_message, client_id) + + async def onCompleteMessage(self, rpc, client_id): logger.debug("wslink incoming msg %s", self.payloadWithSecretStripped(rpc)) if "id" not in rpc: return @@ -389,9 +411,10 @@ async def sendWrappedMessage( # tried with semaphore but got exception with >1 # https://github.com/aio-libs/aiohttp/issues/2934 async with self.attachment_atomic: - for ws in websockets: - if ws is not None: - await ws.send_bytes(packed_wrapper) + for chunk in generate_chunks(packed_wrapper, MAX_MSG_SIZE): + for ws in websockets: + if ws is not None: + await ws.send_bytes(chunk) async def sendWrappedError(self, rpcid, code, message, data=None, client_id=None): wrapper = { @@ -420,9 +443,10 @@ async def sendWrappedError(self, rpcid, code, message, data=None, client_id=None # tried with semaphore but got exception with >1 # https://github.com/aio-libs/aiohttp/issues/2934 async with self.attachment_atomic: - for ws in websockets: - if ws is not None: - await ws.send_bytes(packed_wrapper) + for chunk in generate_chunks(packed_wrapper, MAX_MSG_SIZE): + for ws in websockets: + if ws is not None: + await ws.send_bytes(chunk) def publish(self, topic, data, client_id=None, skip_last_active_client=False): client_list = [client_id] if client_id else [c_id for c_id in self.connections] From 45e9bc57feda6eceab3040f8fb0334e5b4aba786 Mon Sep 17 00:00:00 2001 From: Alessandro Genova Date: Mon, 8 Apr 2024 11:29:26 -0400 Subject: [PATCH 3/4] fix(python): make msgpack/chunking compatible down to python 3.7 --- python/requirements.txt | 2 +- python/setup.py | 2 +- python/src/wslink/chunking.py | 11 ++++++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/requirements.txt b/python/requirements.txt index d3b73084..7ada13a8 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,5 +1,5 @@ aiohttp>=3.7.4,<4 -msgpack>=1.0.8,<2 +msgpack>=1,<2 # platform specific pypiwin32==223; sys_platform == 'win32' diff --git a/python/setup.py b/python/setup.py index 66618811..b715fe13 100644 --- a/python/setup.py +++ b/python/setup.py @@ -45,7 +45,7 @@ keywords="websocket javascript rpc pubsub", packages=find_packages("src", exclude=("tests.*", "tests")), package_dir={"": "src"}, - install_requires=["aiohttp<4"], + install_requires=["aiohttp<4", "msgpack>=1,<2"], extras_require={ "ssl": ["cryptography"], }, diff --git a/python/src/wslink/chunking.py b/python/src/wslink/chunking.py index 940c057b..3c0f6a92 100644 --- a/python/src/wslink/chunking.py +++ b/python/src/wslink/chunking.py @@ -1,6 +1,11 @@ -import random -from typing import TypedDict, Dict, Tuple, Union +import sys +import secrets import msgpack +from typing import Dict, Tuple, Union +if sys.version_info >= (3, 8): + from typing import TypedDict # pylint: disable=no-name-in-module +else: + from typing_extensions import TypedDict UINT32_LENGTH = 4 ID_LOCATION = 0 @@ -46,7 +51,7 @@ def generate_chunks(message: bytes, max_size: int): else: max_content_size = max(max_size - HEADER_LENGTH, 1) - id = random.randbytes(ID_LENGTH) + id = secrets.token_bytes(ID_LENGTH) offset = 0 From dafd70271147aa6d49b8f0cd8489bea208de74ba Mon Sep 17 00:00:00 2001 From: Alessandro Genova Date: Wed, 10 Apr 2024 17:06:46 -0400 Subject: [PATCH 4/4] DO NOT MERGE: enable wslink to have typescript source files --- js/package.json | 4 +++- js/tsconfig.json | 11 +++++++++++ js/webpack-test-simple.config.js | 8 ++++++++ js/webpack.config.js | 8 ++++++++ 4 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 js/tsconfig.json diff --git a/js/package.json b/js/package.json index fcc1595c..8f09977f 100644 --- a/js/package.json +++ b/js/package.json @@ -12,7 +12,7 @@ "homepage": "https://github.com/kitware/wslink#readme", "main": "dist/wslink.js", "scripts": { - "prettier": "prettier --config ./prettier.config.js --write \"src/**/*.js\" \"test/**/*.js\"", + "prettier": "prettier --config ./prettier.config.js --write \"src/**/*.js\" \"src/**/*.ts\" \"test/**/*.js\"", "test": "npm run build:test && python ../tests/simple/server/simple.py --content ../tests/simple/www --debug", "build": "webpack", "build:test": "webpack --config webpack-test-simple.config.js", @@ -39,6 +39,8 @@ "prettier": "2.8.4", "semantic-release": "22.0.5", "semantic-release-pypi": "2.5.2", + "ts-loader": "^9.5.1", + "typescript": "^5.4.3", "webpack": "^5.75.0", "webpack-cli": "4.7.2" }, diff --git a/js/tsconfig.json b/js/tsconfig.json new file mode 100644 index 00000000..e8977ec3 --- /dev/null +++ b/js/tsconfig.json @@ -0,0 +1,11 @@ +{ + "compilerOptions": { + "outDir": "./dist/", + "noImplicitAny": true, + "module": "es6", + "target": "es5", + "jsx": "react", + "allowJs": true, + "moduleResolution": "node" + } +} diff --git a/js/webpack-test-simple.config.js b/js/webpack-test-simple.config.js index da176a5a..1f923355 100644 --- a/js/webpack-test-simple.config.js +++ b/js/webpack-test-simple.config.js @@ -28,8 +28,16 @@ module.exports = { }, ], }, + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, ] }, + resolve: { + extensions: ['.tsx', '.ts', '.js'], + }, plugins: [ new HtmlWebpackPlugin(), ] diff --git a/js/webpack.config.js b/js/webpack.config.js index 2d8efed1..cd653f32 100644 --- a/js/webpack.config.js +++ b/js/webpack.config.js @@ -24,6 +24,14 @@ module.exports = { }, ], }, + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, ] }, + resolve: { + extensions: ['.tsx', '.ts', '.js'], + }, };