Skip to content

Commit ff51e9c

Browse files
authored
Merge pull request #35 from danschultzer/allow-to-for-websocket-init
Allow `:to` option for `TestServer.websocket_init/3`
2 parents 9bbb266 + be4fcd1 commit ff51e9c

File tree

5 files changed

+88
-33
lines changed

5 files changed

+88
-33
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## v0.1.19 (TBA)
4+
5+
- Allow `:to` plug to be set for `TestServer.websocket_init/3` for handshake
6+
37
## v0.1.18 (2024-12-29)
48

59
- Limit number of Bandit acceptors to 1 to improve performance

lib/test_server.ex

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -353,17 +353,21 @@ defmodule TestServer do
353353
"""
354354
@spec add(pid(), binary(), keyword()) :: :ok
355355
def add(instance, uri, options) when is_pid(instance) and is_binary(uri) and is_list(options) do
356-
instance_alive!(instance)
357-
358-
[_first_module_entry | stacktrace] = get_stacktrace()
359-
360356
options = Keyword.put_new(options, :to, &default_response_handler/1)
361357

362-
{:ok, _route} = Instance.register(instance, {:plug_router_to, {uri, options, stacktrace}})
358+
{:ok, _route} = register_route(instance, uri, options)
363359

364360
:ok
365361
end
366362

363+
defp register_route(instance, uri, options) do
364+
instance_alive!(instance)
365+
366+
[_register_route, _first_module_entry | stacktrace] = get_stacktrace()
367+
368+
Instance.register(instance, {:plug_router_to, {uri, options, stacktrace}})
369+
end
370+
367371
defp get_stacktrace do
368372
{:current_stacktrace, [{Process, :info, _, _} | stacktrace]} =
369373
Process.info(self(), :current_stacktrace)
@@ -480,6 +484,10 @@ defmodule TestServer do
480484
@doc """
481485
Adds a websocket route to current test server.
482486
487+
The `:to` option can be overridden the same way as for `add/2`, and will be
488+
called during the HTTP handshake. If the `conn.state` is `:unset` the
489+
websocket will be initiated otherwise response is returned as-is.
490+
483491
## Options
484492
485493
Takes the same options as `add/2`, except `:to`.
@@ -499,6 +507,17 @@ defmodule TestServer do
499507
end)
500508
501509
assert {:ok, _client} = WebSocketClient.start_link(TestServer.url("/ws?token=secret"))
510+
511+
`:to` option is also called during the HTTP handshake:
512+
513+
TestServer.websocket_init("/ws",
514+
to: fn conn ->
515+
Plug.Conn.send_resp(conn, 403, "Forbidden")
516+
end
517+
)
518+
519+
assert {:error, %WebSockex.RequestError{code: 403}} =
520+
WebSocketClient.start_link(TestServer.url("/ws"))
502521
"""
503522
@spec websocket_init(binary(), keyword()) :: {:ok, websocket_socket()}
504523
def websocket_init(uri, options) when is_binary(uri) do
@@ -519,16 +538,12 @@ defmodule TestServer do
519538
"""
520539
@spec websocket_init(pid(), binary(), keyword()) :: {:ok, websocket_socket()}
521540
def websocket_init(instance, uri, options) do
522-
instance_alive!(instance)
523-
524-
if Keyword.has_key?(options, :to), do: raise(ArgumentError, "`:to` is an invalid option")
525-
526-
[_first_module_entry | stacktrace] = get_stacktrace()
527-
528-
options = Keyword.put(options, :to, :websocket)
541+
options =
542+
options
543+
|> Keyword.put(:websocket, true)
544+
|> Keyword.put_new(:to, & &1)
529545

530-
{:ok, %{ref: ref}} =
531-
Instance.register(instance, {:plug_router_to, {uri, options, stacktrace}})
546+
{:ok, %{ref: ref}} = register_route(instance, uri, options)
532547

533548
{:ok, {instance, ref}}
534549
end

lib/test_server/instance.ex

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ defmodule TestServer.Instance do
1616
@spec register(pid(), {:plug_router_to, {binary(), keyword(), TestServer.stacktrace()}}) ::
1717
{:ok, %{ref: reference()}}
1818
def register(instance, {:plug_router_to, {uri, options, stacktrace}}) do
19-
options[:to] != :websocket && ensure_plug!(options[:to])
19+
ensure_plug!(options[:to])
2020
options[:match] && ensure_function!(options[:match])
2121

2222
GenServer.call(instance, {:register, {:plug_router_to, {uri, options, stacktrace}}})
@@ -319,16 +319,16 @@ defmodule TestServer.Instance do
319319
plugs -> plugs
320320
end
321321
|> Enum.reduce_while(conn, fn %{plug: plug, stacktrace: stacktrace}, conn ->
322-
case try_run_plug(plug, stacktrace, conn) do
322+
case try_run_plug(conn, plug, stacktrace) do
323323
{:ok, conn} -> {:cont, conn}
324324
{:error, error} -> {:halt, {:error, error}}
325325
end
326326
end)
327327
end
328328

329-
defp try_run_plug(plug, stacktrace, conn) do
330-
plug
331-
|> run_plug(conn)
329+
defp try_run_plug(conn, plug, stacktrace) do
330+
conn
331+
|> run_plug(plug)
332332
|> check_halted!(plug, stacktrace)
333333
rescue
334334
error -> {:error, {error, __STACKTRACE__}}
@@ -345,11 +345,11 @@ defmodule TestServer.Instance do
345345

346346
defp check_halted!(conn, _plug, _stacktrace), do: {:ok, conn}
347347

348-
defp run_plug(plug, conn) when is_function(plug) do
348+
defp run_plug(conn, plug) when is_function(plug) do
349349
plug.(conn)
350350
end
351351

352-
defp run_plug(plug, conn) when is_atom(plug) do
352+
defp run_plug(conn, plug) when is_atom(plug) do
353353
options = plug.init([])
354354
plug.call(conn, options)
355355
end
@@ -367,16 +367,12 @@ defmodule TestServer.Instance do
367367
{{:error, {:not_found, conn}}, state}
368368

369369
index ->
370+
%{to: plug, stacktrace: stacktrace} = route = Enum.at(state.routes, index)
371+
370372
result =
371-
case Enum.at(state.routes, index) do
372-
%{to: :websocket, options: options} = route ->
373-
websocket = {{self(), route.ref}, Keyword.get(options, :init_state)}
374-
conn = Map.put(conn, :private, %{websocket: websocket})
375-
{:ok, conn}
376-
377-
%{to: plug, stacktrace: stacktrace} ->
378-
try_run_plug(plug, stacktrace, conn)
379-
end
373+
conn
374+
|> maybe_put_websocket(route)
375+
|> try_run_plug(plug, stacktrace)
380376

381377
routes =
382378
List.update_at(state.routes, index, fn route ->
@@ -393,6 +389,17 @@ defmodule TestServer.Instance do
393389
error -> {:error, {error, __STACKTRACE__}}
394390
end
395391

392+
def maybe_put_websocket(conn, route) do
393+
case route.options[:websocket] do
394+
true ->
395+
websocket = {{self(), route.ref}, Keyword.get(route.options, :init_state)}
396+
Map.put(conn, :private, %{websocket: websocket})
397+
398+
_false ->
399+
conn
400+
end
401+
end
402+
396403
defp run_websocket_handlers({_instance, route_ref}, frame, websocket_state, state) do
397404
state.websocket_handlers
398405
|> Enum.map(&{&1.route_ref == route_ref, &1})

lib/test_server/plug.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ defmodule TestServer.Plug do
88

99
def call(conn, {http_server, args, instance}) do
1010
case Instance.dispatch(instance, {:plug, conn}) do
11-
{:ok, %{private: %{websocket: {socket, state}}} = conn} ->
11+
{:ok, %{state: :unset, private: %{websocket: {socket, state}}} = conn} ->
1212
:ok = Instance.put_websocket_connection(socket, http_server.get_socket_pid(conn))
1313
Plug.Conn.upgrade_adapter(conn, :websocket, {http_server, {socket, state}, args})
1414

test/test_server_test.exs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,8 +558,12 @@ defmodule TestServerTest do
558558
end
559559

560560
test "invalid options" do
561-
assert_raise ArgumentError, "`:to` is an invalid option", fn ->
562-
TestServer.websocket_init("/", to: MyPlug)
561+
assert_raise BadFunctionError, ~r/expected a function, got: :invalid/, fn ->
562+
TestServer.websocket_init("/", to: :invalid)
563+
end
564+
565+
assert_raise BadFunctionError, ~r/expected a function, got: :invalid/, fn ->
566+
TestServer.websocket_init("/", match: :invalid)
563567
end
564568

565569
assert_raise BadFunctionError, ~r/expected a function, got: :invalid/, fn ->
@@ -577,6 +581,31 @@ defmodule TestServerTest do
577581
TestServer.websocket_init("/ws")
578582
end
579583
end
584+
585+
test "with handshake callback function with set conn" do
586+
assert {:ok, _socket} =
587+
TestServer.websocket_init("/ws",
588+
to: fn conn ->
589+
Plug.Conn.resp(conn, 403, "Forbidden")
590+
end
591+
)
592+
593+
assert {:error, %WebSockex.RequestError{code: 403}} =
594+
WebSocketClient.start_link(TestServer.url("/ws"))
595+
end
596+
597+
test "with handshake callback function with unset conn" do
598+
assert {:ok, _socket} =
599+
TestServer.websocket_init("/ws",
600+
to: fn conn ->
601+
assert Plug.Conn.get_req_header(conn, "upgrade") == ["websocket"]
602+
603+
conn
604+
end
605+
)
606+
607+
assert {:ok, _client} = WebSocketClient.start_link(TestServer.url("/ws"))
608+
end
580609
end
581610

582611
describe "websocket_handle/3" do

0 commit comments

Comments
 (0)