From d60bb5eab707422c1ad3e62ca9398dd6e0bb4ef8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Attilio=20Don=C3=A0?= Date: Thu, 21 Nov 2024 10:39:30 +0100 Subject: [PATCH] Fix WebSocket concurrents send --- src/WebSockets.jl | 16 ++++++----- test/runtests.jl | 1 + test/websockets/multiple_writers.jl | 43 +++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 test/websockets/multiple_writers.jl diff --git a/src/WebSockets.jl b/src/WebSockets.jl index 0e2ed56e4..1afb5f761 100644 --- a/src/WebSockets.jl +++ b/src/WebSockets.jl @@ -182,22 +182,24 @@ end # writing a single frame function writeframe(io::IO, x::Frame) - n = write(io.io, hton(uint16(x.flags))) + buff = IOBuffer() + n = write(buff, hton(uint16(x.flags))) if x.extendedlen !== nothing - n += write(io.io, hton(x.extendedlen)) + n += write(buff, hton(x.extendedlen)) end if x.mask != EMPTY_MASK - n += write(io.io, UInt32(x.mask)) + n += write(buff, UInt32(x.mask)) end pl = x.payload # manually unroll a few known type cases to help the compiler if pl isa Vector{UInt8} - n += write(io.io, pl) - elseif pl isa Base.CodeUnits{UInt8, String} - n += write(io.io, pl) + n += write(buff, pl) + elseif pl isa Base.CodeUnits{UInt8,String} + n += write(buff, pl) else - n += write(io.io, pl) + n += write(buff, pl) end + write(io.io, take!(buff)) return n end diff --git a/test/runtests.jl b/test/runtests.jl index c89ccdedc..d8cd6c2c7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ isok(r) = r.status == 200 "mwe.jl", "httpversion.jl", "websockets/autobahn.jl", + "websockets/multiple_writers.jl", ] # ARGS can be most easily passed like this: # import Pkg; Pkg.test("HTTP"; test_args=`ascii.jl parser.jl`) diff --git a/test/websockets/multiple_writers.jl b/test/websockets/multiple_writers.jl new file mode 100644 index 000000000..0662b7665 --- /dev/null +++ b/test/websockets/multiple_writers.jl @@ -0,0 +1,43 @@ +using Test +using HTTP.WebSockets + +function write_message(ws, msg) + send(ws, msg) +end + +function client_twin(ws) + for count in 1:10 + @async write_message(ws, count) + end +end + +function serve(ch) + WebSockets.listen!("127.0.0.1", 8081) do ws + client_twin(ws) + response = receive(ws) + put!(ch, response) + end +end + +ch = Channel(1) +srvtask = @async serve(ch) + +WebSockets.open("ws://127.0.0.1:8081") do ws + try + while true + s = receive(ws) + if s == "10" + send(ws, "ok") + end + end + catch e + if e.message.status !== 1000 + @error "Ws client: $e" + !ws.writeclosed && send(ws, "error") + end + end +end; + +@testset "WebSocket multiple writes" begin + @test take!(ch) == "ok" +end