|
16 | 16 |
|
17 | 17 | import os
|
18 | 18 | from contextlib import asynccontextmanager
|
| 19 | +from typing import Literal |
19 | 20 |
|
20 | 21 | import asyncpg
|
21 | 22 | import geojson
|
|
28 | 29 |
|
29 | 30 |
|
30 | 31 | @asynccontextmanager
|
31 |
| -async def get_connection(): |
| 32 | +async def get_connection(database: Literal["oqapidb", "ohsomedb"] = "oqapidb"): |
32 | 33 | # DNS in libpq connection URI format
|
33 |
| - dns = "postgres://{user}:{password}@{host}:{port}/{database}".format( |
34 |
| - host=get_config_value("postgres_host"), |
35 |
| - port=get_config_value("postgres_port"), |
36 |
| - database=get_config_value("postgres_db"), |
37 |
| - user=get_config_value("postgres_user"), |
38 |
| - password=get_config_value("postgres_password"), |
39 |
| - ) |
| 34 | + match database: |
| 35 | + case "oqapidb": |
| 36 | + dns = "postgres://{user}:{password}@{host}:{port}/{database}".format( |
| 37 | + host=get_config_value("postgres_host"), |
| 38 | + port=get_config_value("postgres_port"), |
| 39 | + database=get_config_value("postgres_db"), |
| 40 | + user=get_config_value("postgres_user"), |
| 41 | + password=get_config_value("postgres_password"), |
| 42 | + ) |
| 43 | + case "ohsomedb": |
| 44 | + dns = "postgres://{user}:{password}@{host}:{port}/{database}".format( |
| 45 | + host=get_config_value("ohsomedb_host"), |
| 46 | + port=get_config_value("ohsomedb_port"), |
| 47 | + database=get_config_value("ohsomedb_db"), |
| 48 | + user=get_config_value("ohsomedb_user"), |
| 49 | + password=get_config_value("ohsomedb_password"), |
| 50 | + ) |
| 51 | + case _: |
| 52 | + raise ValueError() |
40 | 53 | conn = await asyncpg.connect(dns)
|
41 | 54 | try:
|
42 | 55 | yield conn
|
43 | 56 | finally:
|
44 | 57 | await conn.close()
|
45 | 58 |
|
46 | 59 |
|
47 |
| -async def fetch(query: str, *args) -> list: |
48 |
| - async with get_connection() as conn: |
| 60 | +async def fetch( |
| 61 | + query: str, |
| 62 | + *args, |
| 63 | + database: Literal["oqapidb", "ohsomedb"] = "oqapidb", |
| 64 | +) -> list: |
| 65 | + async with get_connection(database) as conn: |
49 | 66 | return await conn.fetch(query, *args)
|
50 | 67 |
|
51 | 68 |
|
|
0 commit comments