Skip to content

Commit 946d6b7

Browse files
authored
fix(websockets): Correctly pass guards to underlying handler (#4414)
1 parent 59800ae commit 946d6b7

File tree

4 files changed

+56
-24
lines changed

4 files changed

+56
-24
lines changed

litestar/handlers/websocket_handlers/listener.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def decorator(fn: AnyCallable) -> WebsocketListenerRouteHandler:
529529
dependencies=dependencies,
530530
dto=dto,
531531
exception_handlers=exception_handlers,
532-
guard=guards,
532+
guards=guards,
533533
middleware=middleware,
534534
receive_mode=receive_mode,
535535
send_mode=send_mode,

litestar/handlers/websocket_handlers/stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def decorator(fn: Callable[..., AsyncGenerator[Any, Any]]) -> WebsocketRouteHand
188188
path=path,
189189
dependencies=dependencies,
190190
exception_handlers=exception_handlers,
191-
guard=guards,
191+
guards=guards,
192192
middleware=middleware,
193193
name=name,
194194
opt=opt,

tests/unit/test_handlers/test_websocket_handlers/test_listeners.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from pytest_lazy_fixtures import lf
99

1010
from litestar import Controller, Litestar, Request, WebSocket
11+
from litestar.connection import ASGIConnection
1112
from litestar.datastructures import State
1213
from litestar.di import Provide
1314
from litestar.dto import DataclassDTO, dto_field
1415
from litestar.exceptions import ImproperlyConfiguredException
1516
from litestar.handlers import WebsocketListenerRouteHandler
17+
from litestar.handlers.base import BaseRouteHandler
1618
from litestar.handlers.websocket_handlers import WebsocketListener, websocket_listener
1719
from litestar.routes import WebSocketRoute
1820
from litestar.testing import create_test_client
@@ -76,8 +78,7 @@ def test_listener_receive_bytes(receive_mode: WebSocketMode, mock: MagicMock) ->
7678
def handler(data: bytes) -> None:
7779
mock(data)
7880

79-
client = create_test_client([handler])
80-
with client.websocket_connect("/") as ws:
81+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
8182
ws.send("foo", mode=receive_mode)
8283

8384
mock.assert_called_once_with(b"foo")
@@ -89,8 +90,7 @@ def test_listener_receive_string(receive_mode: WebSocketMode, mock: MagicMock) -
8990
def handler(data: str) -> None:
9091
mock(data)
9192

92-
client = create_test_client([handler])
93-
with client.websocket_connect("/") as ws:
93+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
9494
ws.send("foo", mode=receive_mode)
9595

9696
mock.assert_called_once_with("foo")
@@ -102,8 +102,7 @@ def test_listener_receive_json(receive_mode: WebSocketMode, mock: MagicMock) ->
102102
def handler(data: list[str]) -> None:
103103
mock(data)
104104

105-
client = create_test_client([handler])
106-
with client.websocket_connect("/") as ws:
105+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
107106
ws.send_json(["foo", "bar"], mode=receive_mode)
108107

109108
mock.assert_called_once_with(["foo", "bar"])
@@ -140,8 +139,7 @@ def test_listener_return_bytes(send_mode: WebSocketMode) -> None:
140139
def handler(data: str) -> bytes:
141140
return data.encode("utf-8")
142141

143-
client = create_test_client([handler])
144-
with client.websocket_connect("/") as ws:
142+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
145143
ws.send_text("foo")
146144
if send_mode == "text":
147145
assert ws.receive_text() == "foo"
@@ -155,8 +153,7 @@ def test_listener_send_json(send_mode: WebSocketMode) -> None:
155153
def handler(data: str) -> dict[str, str]:
156154
return {"data": data}
157155

158-
client = create_test_client([handler])
159-
with client.websocket_connect("/") as ws:
156+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
160157
ws.send_text("foo")
161158
assert ws.receive_json(mode=send_mode) == {"data": "foo"}
162159

@@ -174,8 +171,7 @@ class User:
174171
def handler(data: User) -> User:
175172
return data
176173

177-
client = create_test_client([handler])
178-
with client.websocket_connect("/") as ws:
174+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
179175
ws.send_json({"name": "litestar user"})
180176
assert ws.receive_json(mode=send_mode) == {"name": "litestar user"}
181177

@@ -185,8 +181,7 @@ def test_listener_return_none() -> None:
185181
def handler(data: str) -> None:
186182
return data # type: ignore[return-value]
187183

188-
client = create_test_client([handler])
189-
with client.websocket_connect("/") as ws:
184+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
190185
ws.send_text("foo")
191186

192187

@@ -195,8 +190,7 @@ def test_listener_return_optional_none() -> None:
195190
def handler(data: str) -> Optional[str]:
196191
return "world" if data == "hello" else None
197192

198-
client = create_test_client([handler])
199-
with client.websocket_connect("/") as ws:
193+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
200194
ws.send_text("hello")
201195
assert ws.receive_text() == "world"
202196
ws.send_text("goodbye")
@@ -208,8 +202,7 @@ def handler(data: str, socket: WebSocket) -> dict[str, str]:
208202
mock(socket=socket)
209203
return {"data": data}
210204

211-
client = create_test_client([handler])
212-
with client.websocket_connect("/") as ws:
205+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
213206
ws.send_text("foo")
214207
assert ws.receive_json() == {"data": "foo"}
215208

@@ -227,8 +220,7 @@ async def foo_dependency(state: State) -> int:
227220
def handler(data: str, foo: int) -> dict[str, Union[str, int]]:
228221
return {"data": data, "foo": foo}
229222

230-
client = create_test_client([handler])
231-
with client.websocket_connect("/") as ws:
223+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
232224
ws.send_text("something")
233225
ws.send_text("something")
234226
assert ws.receive_json() == {"data": "something", "foo": 1}
@@ -267,8 +259,7 @@ async def accept_connection(socket: WebSocket) -> None:
267259
def handler(data: bytes) -> None:
268260
return None
269261

270-
client = create_test_client([handler])
271-
with client.websocket_connect("/") as ws:
262+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
272263
assert ws.extra_headers == [(b"cookie", b"custom-cookie")]
273264

274265

@@ -441,3 +432,21 @@ async def lifespan() -> AsyncGenerator[None, None]:
441432
@websocket_listener("/", **{hook_name: hook_callback}, connection_lifespan=lifespan) # pyright: ignore
442433
def handler(data: bytes) -> None:
443434
pass
435+
436+
437+
def test_websocket_listener_applies_guards() -> None:
438+
guard_called = False
439+
440+
async def custom_guard(connection: ASGIConnection, _: BaseRouteHandler) -> None:
441+
nonlocal guard_called
442+
guard_called = True
443+
444+
@websocket_listener("/", guards=[custom_guard])
445+
async def handler(data: str) -> str:
446+
return data
447+
448+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
449+
ws.send_text("test")
450+
451+
assert ws.receive_text() == "test"
452+
assert guard_called is True

tests/unit/test_handlers/test_websocket_handlers/test_stream.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import dataclasses
55
from collections.abc import AsyncGenerator, Generator
6+
from typing import TYPE_CHECKING
67
from unittest.mock import MagicMock
78

89
import pytest
@@ -13,6 +14,10 @@
1314
from litestar.handlers.websocket_handlers import websocket_stream
1415
from litestar.testing import create_test_client
1516

17+
if TYPE_CHECKING:
18+
from litestar.connection import ASGIConnection
19+
from litestar.handlers.base import BaseRouteHandler
20+
1621

1722
def test_websocket_stream() -> None:
1823
@websocket_stream("/")
@@ -150,3 +155,21 @@ def foo() -> bytes:
150155
return b""
151156

152157
Litestar([foo])
158+
159+
160+
def test_websocket_stream_applies_guards() -> None:
161+
guard_called = False
162+
163+
async def custom_guard(connection: ASGIConnection, _: BaseRouteHandler) -> None:
164+
nonlocal guard_called
165+
guard_called = True
166+
167+
@websocket_stream("/", guards=[custom_guard])
168+
async def handler() -> AsyncGenerator[dict[str, str], None]:
169+
yield {"Urfaust": "Gespinnst"}
170+
yield {"Des": "Verderbens"}
171+
172+
with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
173+
assert ws.receive_json(timeout=0.1) == {"Urfaust": "Gespinnst"}
174+
assert ws.receive_json(timeout=0.1) == {"Des": "Verderbens"}
175+
assert guard_called is True

0 commit comments

Comments
 (0)