88from pytest_lazy_fixtures import lf
99
1010from litestar import Controller , Litestar , Request , WebSocket
11+ from litestar .connection import ASGIConnection
1112from litestar .datastructures import State
1213from litestar .di import Provide
1314from litestar .dto import DataclassDTO , dto_field
1415from litestar .exceptions import ImproperlyConfiguredException
1516from litestar .handlers import WebsocketListenerRouteHandler
17+ from litestar .handlers .base import BaseRouteHandler
1618from litestar .handlers .websocket_handlers import WebsocketListener , websocket_listener
1719from litestar .routes import WebSocketRoute
1820from litestar .testing import create_test_client
@@ -76,8 +78,7 @@ def test_listener_receive_bytes(receive_mode: WebSocketMode, mock: MagicMock) ->
7678 def handler (data : bytes ) -> None :
7779 mock (data )
7880
79- client = create_test_client ([handler ])
80- with client .websocket_connect ("/" ) as ws :
81+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
8182 ws .send ("foo" , mode = receive_mode )
8283
8384 mock .assert_called_once_with (b"foo" )
@@ -89,8 +90,7 @@ def test_listener_receive_string(receive_mode: WebSocketMode, mock: MagicMock) -
8990 def handler (data : str ) -> None :
9091 mock (data )
9192
92- client = create_test_client ([handler ])
93- with client .websocket_connect ("/" ) as ws :
93+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
9494 ws .send ("foo" , mode = receive_mode )
9595
9696 mock .assert_called_once_with ("foo" )
@@ -102,8 +102,7 @@ def test_listener_receive_json(receive_mode: WebSocketMode, mock: MagicMock) ->
102102 def handler (data : list [str ]) -> None :
103103 mock (data )
104104
105- client = create_test_client ([handler ])
106- with client .websocket_connect ("/" ) as ws :
105+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
107106 ws .send_json (["foo" , "bar" ], mode = receive_mode )
108107
109108 mock .assert_called_once_with (["foo" , "bar" ])
@@ -140,8 +139,7 @@ def test_listener_return_bytes(send_mode: WebSocketMode) -> None:
140139 def handler (data : str ) -> bytes :
141140 return data .encode ("utf-8" )
142141
143- client = create_test_client ([handler ])
144- with client .websocket_connect ("/" ) as ws :
142+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
145143 ws .send_text ("foo" )
146144 if send_mode == "text" :
147145 assert ws .receive_text () == "foo"
@@ -155,8 +153,7 @@ def test_listener_send_json(send_mode: WebSocketMode) -> None:
155153 def handler (data : str ) -> dict [str , str ]:
156154 return {"data" : data }
157155
158- client = create_test_client ([handler ])
159- with client .websocket_connect ("/" ) as ws :
156+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
160157 ws .send_text ("foo" )
161158 assert ws .receive_json (mode = send_mode ) == {"data" : "foo" }
162159
@@ -174,8 +171,7 @@ class User:
174171 def handler (data : User ) -> User :
175172 return data
176173
177- client = create_test_client ([handler ])
178- with client .websocket_connect ("/" ) as ws :
174+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
179175 ws .send_json ({"name" : "litestar user" })
180176 assert ws .receive_json (mode = send_mode ) == {"name" : "litestar user" }
181177
@@ -185,8 +181,7 @@ def test_listener_return_none() -> None:
185181 def handler (data : str ) -> None :
186182 return data # type: ignore[return-value]
187183
188- client = create_test_client ([handler ])
189- with client .websocket_connect ("/" ) as ws :
184+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
190185 ws .send_text ("foo" )
191186
192187
@@ -195,8 +190,7 @@ def test_listener_return_optional_none() -> None:
195190 def handler (data : str ) -> Optional [str ]:
196191 return "world" if data == "hello" else None
197192
198- client = create_test_client ([handler ])
199- with client .websocket_connect ("/" ) as ws :
193+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
200194 ws .send_text ("hello" )
201195 assert ws .receive_text () == "world"
202196 ws .send_text ("goodbye" )
@@ -208,8 +202,7 @@ def handler(data: str, socket: WebSocket) -> dict[str, str]:
208202 mock (socket = socket )
209203 return {"data" : data }
210204
211- client = create_test_client ([handler ])
212- with client .websocket_connect ("/" ) as ws :
205+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
213206 ws .send_text ("foo" )
214207 assert ws .receive_json () == {"data" : "foo" }
215208
@@ -227,8 +220,7 @@ async def foo_dependency(state: State) -> int:
227220 def handler (data : str , foo : int ) -> dict [str , Union [str , int ]]:
228221 return {"data" : data , "foo" : foo }
229222
230- client = create_test_client ([handler ])
231- with client .websocket_connect ("/" ) as ws :
223+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
232224 ws .send_text ("something" )
233225 ws .send_text ("something" )
234226 assert ws .receive_json () == {"data" : "something" , "foo" : 1 }
@@ -267,8 +259,7 @@ async def accept_connection(socket: WebSocket) -> None:
267259 def handler (data : bytes ) -> None :
268260 return None
269261
270- client = create_test_client ([handler ])
271- with client .websocket_connect ("/" ) as ws :
262+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
272263 assert ws .extra_headers == [(b"cookie" , b"custom-cookie" )]
273264
274265
@@ -441,3 +432,21 @@ async def lifespan() -> AsyncGenerator[None, None]:
441432 @websocket_listener ("/" , ** {hook_name : hook_callback }, connection_lifespan = lifespan ) # pyright: ignore
442433 def handler (data : bytes ) -> None :
443434 pass
435+
436+
437+ def test_websocket_listener_applies_guards () -> None :
438+ guard_called = False
439+
440+ async def custom_guard (connection : ASGIConnection , _ : BaseRouteHandler ) -> None :
441+ nonlocal guard_called
442+ guard_called = True
443+
444+ @websocket_listener ("/" , guards = [custom_guard ])
445+ async def handler (data : str ) -> str :
446+ return data
447+
448+ with create_test_client ([handler ]) as client , client .websocket_connect ("/" ) as ws :
449+ ws .send_text ("test" )
450+
451+ assert ws .receive_text () == "test"
452+ assert guard_called is True
0 commit comments