|
1 | 1 | import multiprocessing |
| 2 | +import os |
2 | 3 | from contextlib import asynccontextmanager |
3 | 4 | from pathlib import Path |
| 5 | +from random import randint, sample |
| 6 | +from typing import Any |
4 | 7 |
|
5 | 8 | import asyncpg |
6 | | -import os |
7 | | - |
8 | 9 | import orjson |
9 | | -from litestar import Litestar, Request, get, MediaType |
10 | | - |
11 | | -from random import randint, sample |
12 | | - |
| 10 | +from litestar import Litestar, MediaType, Request, get |
13 | 11 | from litestar.contrib.jinja import JinjaTemplateEngine |
14 | 12 | from litestar.response import Template |
15 | 13 | from litestar.template import TemplateConfig |
16 | 14 |
|
17 | 15 | READ_ROW_SQL = 'SELECT "id", "randomnumber" FROM "world" WHERE id = $1' |
18 | 16 | WRITE_ROW_SQL = 'UPDATE "world" SET "randomnumber"=$1 WHERE id=$2' |
19 | 17 | ADDITIONAL_ROW = [0, "Additional fortune added at request time."] |
20 | | -MAX_POOL_SIZE = 1000//multiprocessing.cpu_count() |
| 18 | +MAX_POOL_SIZE = 1000 // multiprocessing.cpu_count() |
21 | 19 | MIN_POOL_SIZE = max(int(MAX_POOL_SIZE / 2), 1) |
22 | 20 |
|
23 | 21 |
|
24 | 22 | def get_num_queries(queries): |
25 | | - try: |
26 | | - query_count = int(queries) |
27 | | - except (ValueError, TypeError): |
28 | | - return 1 |
| 23 | + try: |
| 24 | + query_count = int(queries) |
| 25 | + except (ValueError, TypeError): |
| 26 | + return 1 |
29 | 27 |
|
30 | | - if query_count < 1: |
31 | | - return 1 |
32 | | - if query_count > 500: |
33 | | - return 500 |
34 | | - return query_count |
| 28 | + if query_count < 1: |
| 29 | + return 1 |
| 30 | + if query_count > 500: |
| 31 | + return 500 |
| 32 | + return query_count |
35 | 33 |
|
36 | 34 |
|
37 | 35 | connection_pool = None |
38 | 36 |
|
39 | 37 |
|
40 | | - |
41 | 38 | async def setup_database(): |
42 | | - return await asyncpg.create_pool( |
43 | | - user=os.getenv("PGUSER", "benchmarkdbuser"), |
44 | | - password=os.getenv("PGPASS", "benchmarkdbpass"), |
45 | | - database="hello_world", |
46 | | - host="tfb-database", |
47 | | - port=5432, |
48 | | - min_size=MIN_POOL_SIZE, |
49 | | - max_size=MAX_POOL_SIZE, |
50 | | - ) |
| 39 | + return await asyncpg.create_pool( |
| 40 | + user=os.getenv("PGUSER", "benchmarkdbuser"), |
| 41 | + password=os.getenv("PGPASS", "benchmarkdbpass"), |
| 42 | + database="hello_world", |
| 43 | + host="tfb-database", |
| 44 | + port=5432, |
| 45 | + min_size=MIN_POOL_SIZE, |
| 46 | + max_size=MAX_POOL_SIZE, |
| 47 | + ) |
51 | 48 |
|
52 | 49 |
|
53 | 50 | @asynccontextmanager |
54 | 51 | async def lifespan(app: Litestar): |
55 | | - # Set up the database connection pool |
56 | | - app.state.connection_pool = await setup_database() |
57 | | - yield |
58 | | - # Close the database connection pool |
59 | | - await app.state.connection_pool.close() |
60 | | - |
61 | | - |
62 | | -app = Litestar(lifespan=[lifespan], template_config=TemplateConfig( |
63 | | - directory=Path("templates"), |
64 | | - engine=JinjaTemplateEngine, |
65 | | - ),) |
| 52 | + # Set up the database connection pool |
| 53 | + app.state.connection_pool = await setup_database() |
| 54 | + yield |
| 55 | + # Close the database connection pool |
| 56 | + await app.state.connection_pool.close() |
66 | 57 |
|
67 | 58 |
|
68 | 59 | @get("/json") |
69 | | -async def json_serialization(): |
70 | | - return orjson.dumps({"message": "Hello, world!"}) |
| 60 | +async def json_serialization() -> bytes: |
| 61 | + return orjson.dumps({"message": "Hello, world!"}) |
71 | 62 |
|
72 | 63 |
|
73 | 64 | @get("/db") |
74 | | -async def single_database_query(): |
75 | | - row_id = randint(1, 10000) |
76 | | - async with app.state.connection_pool.acquire() as connection: |
77 | | - number = await connection.fetchval(READ_ROW_SQL, row_id) |
| 65 | +async def single_database_query() -> bytes: |
| 66 | + row_id = randint(1, 10000) |
| 67 | + async with app.state.connection_pool.acquire() as connection: |
| 68 | + number = await connection.fetchval(READ_ROW_SQL, row_id) |
78 | 69 |
|
79 | | - return orjson.dumps({"id": row_id, "randomNumber": number}) |
| 70 | + return orjson.dumps({"id": row_id, "randomNumber": number}) |
80 | 71 |
|
81 | 72 |
|
82 | 73 | @get("/queries") |
83 | | -async def multiple_database_queries(queries = None): |
84 | | - num_queries = get_num_queries(queries) |
85 | | - row_ids = sample(range(1, 10000), num_queries) |
86 | | - worlds = [] |
| 74 | +async def multiple_database_queries(queries: Any = None) -> bytes: |
| 75 | + num_queries = get_num_queries(queries) |
| 76 | + row_ids = sample(range(1, 10000), num_queries) |
| 77 | + worlds = [] |
87 | 78 |
|
88 | | - async with app.state.connection_pool.acquire() as connection: |
89 | | - statement = await connection.prepare(READ_ROW_SQL) |
90 | | - for row_id in row_ids: |
91 | | - number = await statement.fetchval(row_id) |
92 | | - worlds.append({"id": row_id, "randomNumber": number}) |
| 79 | + async with app.state.connection_pool.acquire() as connection: |
| 80 | + statement = await connection.prepare(READ_ROW_SQL) |
| 81 | + for row_id in row_ids: |
| 82 | + number = await statement.fetchval(row_id) |
| 83 | + worlds.append({"id": row_id, "randomNumber": number}) |
93 | 84 |
|
94 | | - return orjson.dumps(worlds) |
| 85 | + return orjson.dumps(worlds) |
95 | 86 |
|
96 | 87 |
|
97 | 88 | @get("/fortunes") |
98 | | -async def fortunes(request: Request): |
99 | | - async with app.state.connection_pool.acquire() as connection: |
100 | | - fortunes = await connection.fetch("SELECT * FROM Fortune") |
| 89 | +async def fortunes(request: Request) -> Template: |
| 90 | + async with app.state.connection_pool.acquire() as connection: |
| 91 | + fortunes = await connection.fetch("SELECT * FROM Fortune") |
101 | 92 |
|
102 | | - fortunes.append(ADDITIONAL_ROW) |
103 | | - fortunes.sort(key=lambda row: row[1]) |
104 | | - return Template("fortune.html", context={"fortunes": fortunes, "request": request}) |
| 93 | + fortunes.append(ADDITIONAL_ROW) |
| 94 | + fortunes.sort(key=lambda row: row[1]) |
| 95 | + return Template("fortune.html", context={"fortunes": fortunes, "request": request}) |
105 | 96 |
|
106 | 97 |
|
107 | 98 | @get("/updates") |
108 | | -async def database_updates(queries = None): |
109 | | - num_queries = get_num_queries(queries) |
110 | | - # To avoid deadlock |
111 | | - ids = sorted(sample(range(1, 10000 + 1), num_queries)) |
112 | | - numbers = sorted(sample(range(1, 10000), num_queries)) |
113 | | - updates = list(zip(ids, numbers)) |
| 99 | +async def database_updates(queries: Any = None) -> bytes: |
| 100 | + num_queries = get_num_queries(queries) |
| 101 | + # To avoid deadlock |
| 102 | + ids = sorted(sample(range(1, 10000 + 1), num_queries)) |
| 103 | + numbers = sorted(sample(range(1, 10000), num_queries)) |
| 104 | + updates = list(zip(ids, numbers, strict=False)) |
114 | 105 |
|
115 | | - worlds = [ |
116 | | - {"id": row_id, "randomNumber": number} for row_id, number in updates |
117 | | - ] |
| 106 | + worlds = [{"id": row_id, "randomNumber": number} for row_id, number in updates] |
118 | 107 |
|
119 | | - async with app.state.connection_pool.acquire() as connection: |
120 | | - statement = await connection.prepare(READ_ROW_SQL) |
121 | | - for row_id, _ in updates: |
122 | | - await statement.fetchval(row_id) |
123 | | - await connection.executemany(WRITE_ROW_SQL, updates) |
| 108 | + async with app.state.connection_pool.acquire() as connection: |
| 109 | + statement = await connection.prepare(READ_ROW_SQL) |
| 110 | + for row_id, _ in updates: |
| 111 | + await statement.fetchval(row_id) |
| 112 | + await connection.executemany(WRITE_ROW_SQL, updates) |
124 | 113 |
|
125 | | - return orjson.dumps(worlds) |
| 114 | + return orjson.dumps(worlds) |
126 | 115 |
|
127 | 116 |
|
128 | 117 | @get("/plaintext", media_type=MediaType.TEXT) |
129 | | -async def plaintext(): |
130 | | - return b"Hello, world!" |
| 118 | +async def plaintext() -> bytes: |
| 119 | + return b"Hello, world!" |
| 120 | + |
| 121 | + |
| 122 | +app = Litestar( |
| 123 | + lifespan=[lifespan], |
| 124 | + template_config=TemplateConfig( |
| 125 | + directory=Path("templates"), |
| 126 | + engine=JinjaTemplateEngine, |
| 127 | + ), |
| 128 | + route_handlers=[ |
| 129 | + json_serialization, |
| 130 | + single_database_query, |
| 131 | + multiple_database_queries, |
| 132 | + fortunes, |
| 133 | + database_updates, |
| 134 | + plaintext, |
| 135 | + ], |
| 136 | +) |
0 commit comments