| 
13 | 13 | See the License for the specific language governing permissions and  | 
14 | 14 | limitations under the License.  | 
15 | 15 | """  | 
 | 16 | + | 
16 | 17 | import asyncio  | 
17 | 18 | import os  | 
18 |  | -from typing import AsyncGenerator  | 
19 |  | -import uuid  | 
 | 19 | +from typing import Tuple  | 
20 | 20 | 
 
  | 
21 | 21 | import asyncpg  | 
22 |  | -import pytest  | 
23 | 22 | import sqlalchemy  | 
24 |  | -from sqlalchemy.ext.asyncio import AsyncEngine  | 
25 |  | -from sqlalchemy.ext.asyncio import create_async_engine  | 
 | 23 | +import sqlalchemy.ext.asyncio  | 
26 | 24 | 
 
  | 
27 | 25 | from google.cloud.sql.connector import Connector  | 
28 | 26 | 
 
  | 
29 |  | -table_name = f"books_{uuid.uuid4().hex}"  | 
30 | 27 | 
 
  | 
 | 28 | +async def create_sqlalchemy_engine(  | 
 | 29 | +    instance_connection_name: str,  | 
 | 30 | +    user: str,  | 
 | 31 | +    password: str,  | 
 | 32 | +    db: str,  | 
 | 33 | +    refresh_strategy: str = "background",  | 
 | 34 | +) -> Tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, Connector]:  | 
 | 35 | +    """Creates a connection pool for a Cloud SQL instance and returns the pool  | 
 | 36 | +    and the connector. Callers are responsible for closing the pool and the  | 
 | 37 | +    connector.  | 
 | 38 | +
  | 
 | 39 | +    A sample invocation looks like:  | 
 | 40 | +
  | 
 | 41 | +        engine, connector = await create_sqlalchemy_engine(  | 
 | 42 | +            inst_conn_name,  | 
 | 43 | +            user,  | 
 | 44 | +            password,  | 
 | 45 | +            db,  | 
 | 46 | +        )  | 
 | 47 | +        async with engine.connect() as conn:  | 
 | 48 | +            time = (await conn.execute(sqlalchemy.text("SELECT NOW()"))).fetchone()  | 
 | 49 | +            curr_time = time[0]  | 
 | 50 | +            # do something with query result  | 
 | 51 | +            await connector.close_async()  | 
 | 52 | +
  | 
 | 53 | +    Args:  | 
 | 54 | +        instance_connection_name (str):  | 
 | 55 | +            The instance connection name specifies the instance relative to the  | 
 | 56 | +            project and region. For example: "my-project:my-region:my-instance"  | 
 | 57 | +        user (str):  | 
 | 58 | +            The database user name, e.g., postgres  | 
 | 59 | +        password (str):  | 
 | 60 | +            The database user's password, e.g., secret-password  | 
 | 61 | +        db (str):  | 
 | 62 | +            The name of the database, e.g., mydb  | 
 | 63 | +        refresh_strategy (Optional[str]):  | 
 | 64 | +            Refresh strategy for the Cloud SQL Connector. Can be one of "lazy"  | 
 | 65 | +            or "background". For serverless environments use "lazy" to avoid  | 
 | 66 | +            errors resulting from CPU being throttled.  | 
 | 67 | +    """  | 
 | 68 | +    loop = asyncio.get_running_loop()  | 
 | 69 | +    connector = Connector(loop=loop, refresh_strategy=refresh_strategy)  | 
31 | 70 | 
 
  | 
32 |  | -# The Cloud SQL Python Connector can be used along with SQLAlchemy using the  | 
33 |  | -# 'async_creator' argument to 'create_async_engine'  | 
34 |  | -async def init_connection_pool() -> AsyncEngine:  | 
35 | 71 |     async def getconn() -> asyncpg.Connection:  | 
36 |  | -        loop = asyncio.get_running_loop()  | 
37 |  | -        # initialize Connector object for connections to Cloud SQL  | 
38 |  | -        async with Connector(loop=loop) as connector:  | 
39 |  | -            conn: asyncpg.Connection = await connector.connect_async(  | 
40 |  | -                os.environ["POSTGRES_CONNECTION_NAME"],  | 
41 |  | -                "asyncpg",  | 
42 |  | -                user=os.environ["POSTGRES_USER"],  | 
43 |  | -                password=os.environ["POSTGRES_PASS"],  | 
44 |  | -                db=os.environ["POSTGRES_DB"],  | 
45 |  | -            )  | 
46 |  | -            return conn  | 
 | 72 | +        conn: asyncpg.Connection = await connector.connect_async(  | 
 | 73 | +            instance_connection_name,  | 
 | 74 | +            "asyncpg",  | 
 | 75 | +            user=user,  | 
 | 76 | +            password=password,  | 
 | 77 | +            db=db,  | 
 | 78 | +            ip_type="public",  # can also be "private" or "psc"  | 
 | 79 | +        )  | 
 | 80 | +        return conn  | 
47 | 81 | 
 
  | 
48 | 82 |     # create SQLAlchemy connection pool  | 
49 |  | -    pool = create_async_engine(  | 
 | 83 | +    engine = sqlalchemy.ext.asyncio.create_async_engine(  | 
50 | 84 |         "postgresql+asyncpg://",  | 
51 | 85 |         async_creator=getconn,  | 
52 | 86 |         execution_options={"isolation_level": "AUTOCOMMIT"},  | 
53 | 87 |     )  | 
54 |  | -    return pool  | 
 | 88 | +    return engine, connector  | 
55 | 89 | 
 
  | 
56 | 90 | 
 
  | 
57 |  | -@pytest.fixture(name="pool")  | 
58 |  | -async def setup() -> AsyncGenerator:  | 
59 |  | -    pool = await init_connection_pool()  | 
60 |  | -    async with pool.connect() as conn:  | 
61 |  | -        await conn.execute(  | 
62 |  | -            sqlalchemy.text(  | 
63 |  | -                f"CREATE TABLE IF NOT EXISTS {table_name}"  | 
64 |  | -                " ( id CHAR(20) NOT NULL, title TEXT NOT NULL );"  | 
65 |  | -            )  | 
66 |  | -        )  | 
 | 91 | +async def test_connection_with_asyncpg() -> None:  | 
 | 92 | +    """Basic test to get time from database."""  | 
 | 93 | +    inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"]  | 
 | 94 | +    user = os.environ["POSTGRES_USER"]  | 
 | 95 | +    password = os.environ["POSTGRES_PASS"]  | 
 | 96 | +    db = os.environ["POSTGRES_DB"]  | 
67 | 97 | 
 
  | 
68 |  | -    yield pool  | 
 | 98 | +    pool, connector = await create_sqlalchemy_engine(inst_conn_name, user, password, db)  | 
69 | 99 | 
 
  | 
70 | 100 |     async with pool.connect() as conn:  | 
71 |  | -        await conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS {table_name}"))  | 
72 |  | -    # dispose of asyncpg connection pool  | 
73 |  | -    await pool.dispose()  | 
 | 101 | +        res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone()  | 
 | 102 | +        assert res[0] == 1  | 
74 | 103 | 
 
  | 
 | 104 | +    await connector.close_async()  | 
75 | 105 | 
 
  | 
76 |  | -@pytest.mark.asyncio  | 
77 |  | -async def test_connection_with_asyncpg(pool: AsyncEngine) -> None:  | 
78 |  | -    insert_stmt = sqlalchemy.text(  | 
79 |  | -        f"INSERT INTO {table_name} (id, title) VALUES (:id, :title)",  | 
 | 106 | + | 
 | 107 | +async def test_lazy_connection_with_asyncpg() -> None:  | 
 | 108 | +    """Basic test to get time from database."""  | 
 | 109 | +    inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"]  | 
 | 110 | +    user = os.environ["POSTGRES_USER"]  | 
 | 111 | +    password = os.environ["POSTGRES_PASS"]  | 
 | 112 | +    db = os.environ["POSTGRES_DB"]  | 
 | 113 | + | 
 | 114 | +    pool, connector = await create_sqlalchemy_engine(  | 
 | 115 | +        inst_conn_name, user, password, db, "lazy"  | 
80 | 116 |     )  | 
81 |  | -    async with pool.connect() as conn:  | 
82 |  | -        await conn.execute(insert_stmt, parameters={"id": "book1", "title": "Book One"})  | 
83 |  | -        await conn.execute(insert_stmt, parameters={"id": "book2", "title": "Book Two"})  | 
84 | 117 | 
 
  | 
85 |  | -        select_stmt = sqlalchemy.text(f"SELECT title FROM {table_name} ORDER BY ID;")  | 
86 |  | -        rows = (await conn.execute(select_stmt)).fetchall()  | 
87 |  | -        titles = [row[0] for row in rows]  | 
 | 118 | +    async with pool.connect() as conn:  | 
 | 119 | +        res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone()  | 
 | 120 | +        assert res[0] == 1  | 
88 | 121 | 
 
  | 
89 |  | -    assert titles == ["Book One", "Book Two"]  | 
 | 122 | +    await connector.close_async()  | 
0 commit comments