Skip to content

Commit bd4b48a

Browse files
committed
Merge pull request #209 from KeepSafe/middleware
Middleware, second edition
2 parents 266b386 + 32c3857 commit bd4b48a

File tree

2 files changed

+133
-3
lines changed

2 files changed

+133
-3
lines changed

aiohttp/web.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,7 @@ def __init__(self, manager, app, router, **kwargs):
11731173
self._manager = manager
11741174
self._app = app
11751175
self._router = router
1176+
self._middlewares = app.middlewares
11761177

11771178
def connection_made(self, transport):
11781179
super().connection_made(transport)
@@ -1188,7 +1189,8 @@ def connection_lost(self, exc):
11881189
def handle_request(self, message, payload):
11891190
now = self._loop.time()
11901191

1191-
request = Request(self._app, message, payload,
1192+
app = self._app
1193+
request = Request(app, message, payload,
11921194
self.transport, self.writer, self.keep_alive_timeout)
11931195
try:
11941196
match_info = yield from self._router.resolve(request)
@@ -1198,7 +1200,10 @@ def handle_request(self, message, payload):
11981200
request._match_info = match_info
11991201
handler = match_info.handler
12001202

1203+
for factory in reversed(self._middlewares):
1204+
handler = yield from factory(app, handler)
12011205
resp = yield from handler(request)
1206+
12021207
if not isinstance(resp, StreamResponse):
12031208
raise RuntimeError(
12041209
("Handler should return response instance, got {!r}")
@@ -1273,8 +1278,8 @@ def __call__(self):
12731278
class Application(dict):
12741279

12751280
def __init__(self, *, logger=web_logger, loop=None,
1276-
router=None, handler_factory=RequestHandlerFactory, **kwargs):
1277-
# TODO: explicitly accept *debug* param
1281+
router=None, handler_factory=RequestHandlerFactory,
1282+
middlewares=(), **kwargs):
12781283
if loop is None:
12791284
loop = asyncio.get_event_loop()
12801285
if router is None:
@@ -1288,6 +1293,9 @@ def __init__(self, *, logger=web_logger, loop=None,
12881293
self.logger = logger
12891294

12901295
self.update(**kwargs)
1296+
for factory in middlewares:
1297+
assert asyncio.iscoroutinefunction(factory), factory
1298+
self._middlewares = tuple(middlewares)
12911299

12921300
@property
12931301
def router(self):
@@ -1297,6 +1305,10 @@ def router(self):
12971305
def loop(self):
12981306
return self._loop
12991307

1308+
@property
1309+
def middlewares(self):
1310+
return self._middlewares
1311+
13001312
def make_handler(self, **kwargs):
13011313
return self._handler_factory(
13021314
self, self.router, loop=self.loop, **kwargs)

tests/test_web_middleware.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import asyncio
2+
import socket
3+
import unittest
4+
from aiohttp import web, request
5+
6+
7+
class TestWebFunctional(unittest.TestCase):
8+
9+
def setUp(self):
10+
self.loop = asyncio.new_event_loop()
11+
asyncio.set_event_loop(None)
12+
13+
def tearDown(self):
14+
self.loop.close()
15+
16+
def find_unused_port(self):
17+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
18+
s.bind(('127.0.0.1', 0))
19+
port = s.getsockname()[1]
20+
s.close()
21+
return port
22+
23+
@asyncio.coroutine
24+
def create_server(self, method, path, handler, *middlewares):
25+
app = web.Application(loop=self.loop, middlewares=middlewares)
26+
app.router.add_route(method, path, handler)
27+
28+
port = self.find_unused_port()
29+
srv = yield from self.loop.create_server(
30+
app.make_handler(debug=True), '127.0.0.1', port)
31+
url = "http://127.0.0.1:{}".format(port) + path
32+
self.addCleanup(srv.close)
33+
return app, srv, url
34+
35+
def test_middleware_modifies_response(self):
36+
37+
@asyncio.coroutine
38+
def handler(request):
39+
return web.Response(body=b'OK')
40+
41+
@asyncio.coroutine
42+
def middleware_factory(app, handler):
43+
def middleware(request):
44+
resp = yield from handler(request)
45+
self.assertEqual(200, resp.status)
46+
resp.set_status(201)
47+
resp.text = resp.text + '[MIDDLEWARE]'
48+
return resp
49+
return middleware
50+
51+
@asyncio.coroutine
52+
def go():
53+
_, _, url = yield from self.create_server('GET', '/', handler,
54+
middleware_factory)
55+
resp = yield from request('GET', url, loop=self.loop)
56+
self.assertEqual(201, resp.status)
57+
txt = yield from resp.text()
58+
self.assertEqual('OK[MIDDLEWARE]', txt)
59+
60+
self.loop.run_until_complete(go())
61+
62+
def test_middleware_handles_exception(self):
63+
64+
@asyncio.coroutine
65+
def handler(request):
66+
raise RuntimeError('Error text')
67+
68+
@asyncio.coroutine
69+
def middleware_factory(app, handler):
70+
def middleware(request):
71+
with self.assertRaises(RuntimeError) as ctx:
72+
yield from handler(request)
73+
return web.Response(status=501,
74+
text=str(ctx.exception) + '[MIDDLEWARE]')
75+
76+
return middleware
77+
78+
@asyncio.coroutine
79+
def go():
80+
_, _, url = yield from self.create_server('GET', '/', handler,
81+
middleware_factory)
82+
resp = yield from request('GET', url, loop=self.loop)
83+
self.assertEqual(501, resp.status)
84+
txt = yield from resp.text()
85+
self.assertEqual('Error text[MIDDLEWARE]', txt)
86+
87+
self.loop.run_until_complete(go())
88+
89+
def test_middleware_chain(self):
90+
91+
@asyncio.coroutine
92+
def handler(request):
93+
return web.Response(text='OK')
94+
95+
def make_factory(num):
96+
97+
@asyncio.coroutine
98+
def factory(app, handler):
99+
100+
def middleware(request):
101+
resp = yield from handler(request)
102+
resp.text = resp.text + '[{}]'.format(num)
103+
return resp
104+
105+
return middleware
106+
return factory
107+
108+
@asyncio.coroutine
109+
def go():
110+
_, _, url = yield from self.create_server('GET', '/', handler,
111+
make_factory(1),
112+
make_factory(2))
113+
resp = yield from request('GET', url, loop=self.loop)
114+
self.assertEqual(200, resp.status)
115+
txt = yield from resp.text()
116+
self.assertEqual('OK[2][1]', txt)
117+
118+
self.loop.run_until_complete(go())

0 commit comments

Comments
 (0)