Skip to content

Commit 957895c

Browse files
authored
Fix rate limiting test (#268)
* fix rate limiting test * simplify client args * ignore type warnings
1 parent 8b5cc86 commit 957895c

File tree

2 files changed

+32
-17
lines changed

2 files changed

+32
-17
lines changed

piccolo_api/openapi/endpoints.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ class OAuthRedirectEndpoint(HTTPEndpoint):
7979
def get(self, request: Request):
8080
return get_swagger_ui_oauth2_redirect_html()
8181

82-
router.add_route("/", endpoint=DocsEndpoint)
83-
router.add_route("/oauth2-redirect/", endpoint=OAuthRedirectEndpoint)
82+
router.add_route("/", endpoint=DocsEndpoint) # type: ignore
83+
router.add_route(
84+
"/oauth2-redirect/",
85+
endpoint=OAuthRedirectEndpoint, # type: ignore
86+
)
8487

8588
return router

tests/rate_limiting/test_rate_middleware.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import asyncio
12
from time import sleep
23
from unittest import TestCase
34

5+
from httpx import ASGITransport, AsyncClient
46
from starlette.endpoints import HTTPEndpoint
57
from starlette.responses import JSONResponse
68
from starlette.routing import Route, Router
7-
from starlette.testclient import TestClient
89

910
from piccolo_api.rate_limiting.middleware import (
1011
InMemoryLimitProvider,
@@ -27,23 +28,34 @@ def test_limit(self):
2728
InMemoryLimitProvider(limit=5, timespan=1, block_duration=1),
2829
)
2930

30-
client = TestClient(app)
31+
# We have to use `httpx.AsyncClient` directly, because `TestClient`
32+
# was broken in this PR:
33+
# https://github.com/encode/starlette/pull/2377
34+
# `TestClient` no longer sends the client IP and port.
35+
# If a fix is released, we can go back to using `TestClient` directly.
36+
client = AsyncClient(
37+
transport=ASGITransport(app=app),
38+
base_url="http://testserver",
39+
)
40+
41+
async def run_test():
42+
successful = 0
43+
for _ in range(20):
44+
response = await client.get("/")
45+
if response.status_code == 429:
46+
break
47+
else:
48+
successful += 1
3149

32-
successful = 0
33-
for i in range(20):
34-
response = client.get("/")
35-
if response.status_code == 429:
36-
break
37-
else:
38-
successful += 1
50+
self.assertEqual(successful, 5)
3951

40-
self.assertEqual(successful, 5)
52+
# After the 'block_duration' has expired, requests should be
53+
# allowed again.
54+
sleep(1.1)
55+
response = await client.get("/")
56+
self.assertEqual(response.status_code, 200)
4157

42-
# After the 'block_duration' has expired, requests should be allowed
43-
# again.
44-
sleep(1.1)
45-
response = client.get("/")
46-
self.assertEqual(response.status_code, 200)
58+
asyncio.run(run_test())
4759

4860
def test_memory_usage(self):
4961
"""

0 commit comments

Comments
 (0)