1616
1717import  asyncio 
1818import  os 
19- from  typing  import  Any 
19+ from  typing  import  Any ,  Union 
2020
2121import  asyncpg 
2222import  sqlalchemy 
2323import  sqlalchemy .ext .asyncio 
2424
2525from  google .cloud .sql .connector  import  Connector 
26+ from  google .cloud .sql .connector  import  DefaultResolver 
27+ from  google .cloud .sql .connector  import  DnsResolver 
2628
2729
2830async  def  create_sqlalchemy_engine (
@@ -31,6 +33,7 @@ async def create_sqlalchemy_engine(
3133    password : str ,
3234    db : str ,
3335    refresh_strategy : str  =  "background" ,
36+     resolver : Union [type [DefaultResolver ], type [DnsResolver ]] =  DefaultResolver ,
3437) ->  tuple [sqlalchemy .ext .asyncio .engine .AsyncEngine , Connector ]:
3538    """Creates a connection pool for a Cloud SQL instance and returns the pool 
3639    and the connector. Callers are responsible for closing the pool and the 
@@ -64,9 +67,16 @@ async def create_sqlalchemy_engine(
6467            Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" 
6568            or "background". For serverless environments use "lazy" to avoid 
6669            errors resulting from CPU being throttled. 
70+         resolver (Optional[google.cloud.sql.connector.DefaultResolver]): 
71+             Resolver class for resolving instance connection name. Use 
72+             google.cloud.sql.connector.DnsResolver when resolving DNS domain 
73+             names or google.cloud.sql.connector.DefaultResolver for regular 
74+             instance connection names ("my-project:my-region:my-instance"). 
6775    """ 
6876    loop  =  asyncio .get_running_loop ()
69-     connector  =  Connector (loop = loop , refresh_strategy = refresh_strategy )
77+     connector  =  Connector (
78+         loop = loop , refresh_strategy = refresh_strategy , resolver = resolver 
79+     )
7080
7181    async  def  getconn () ->  asyncpg .Connection :
7282        conn : asyncpg .Connection  =  await  connector .connect_async (
@@ -183,6 +193,24 @@ async def test_lazy_sqlalchemy_connection_with_asyncpg() -> None:
183193    await  connector .close_async ()
184194
185195
196+ async  def  test_custom_SAN_with_dns_sqlalchemy_connection_with_asyncpg () ->  None :
197+     """Basic test to get time from database.""" 
198+     inst_conn_name  =  os .environ ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME" ]
199+     user  =  os .environ ["POSTGRES_USER" ]
200+     password  =  os .environ ["POSTGRES_CUSTOMER_CAS_PASS" ]
201+     db  =  os .environ ["POSTGRES_DB" ]
202+ 
203+     pool , connector  =  await  create_sqlalchemy_engine (
204+         inst_conn_name , user , password , db , resolver = DnsResolver 
205+     )
206+ 
207+     async  with  pool .connect () as  conn :
208+         res  =  (await  conn .execute (sqlalchemy .text ("SELECT 1" ))).fetchone ()
209+         assert  res [0 ] ==  1 
210+ 
211+     await  connector .close_async ()
212+ 
213+ 
186214async  def  test_connection_with_asyncpg () ->  None :
187215    """Basic test to get time from database.""" 
188216    inst_conn_name  =  os .environ ["POSTGRES_CONNECTION_NAME" ]
0 commit comments