@@ -63,9 +63,8 @@ class TableFieldsQueryConfig:
6363 def __post_init__ (self ):
6464 if not self .table_id or not self .where_clause_template :
6565 raise ValueError (
66- "TableFieldsQueryConfig and " +
67- "TableFunctionQueryConfig must provide table_id " +
68- "and where_clause_template" )
66+ "TableFieldsQueryConfig must provide table_id and " +
67+ "where_clause_template" )
6968
7069 if not self .where_clause_fields :
7170 raise ValueError (
@@ -83,9 +82,8 @@ class TableFunctionQueryConfig:
8382 def __post_init__ (self ):
8483 if not self .table_id or not self .where_clause_template :
8584 raise ValueError (
86- "TableFieldsQueryConfig and " +
87- "TableFunctionQueryConfig must provide table_id " +
88- "and where_clause_template" )
85+ "TableFunctionQueryConfig must provide table_id and " +
86+ "where_clause_template" )
8987
9088 if not self .where_clause_value_fn :
9189 raise ValueError (
@@ -264,7 +262,7 @@ def __init__(
264262 connection_config = CloudSQLConnectionConfig(
265263 db_adapter=DatabaseTypeAdapter.POSTGRESQL,
266264 instance_connection_uri="apache-beam-testing:us-central1:itests",
267- user=postgres,
265+ user=' postgres' ,
268266 password= os.getenv("CLOUDSQL_PG_PASSWORD"))
269267 query_config=TableFieldsQueryConfig('my_table',"id = '{}'",['id']),
270268 cloudsql_handler = CloudSQLEnrichmentHandler(
@@ -319,6 +317,7 @@ def __enter__(self):
319317 url = self ._connection_config .get_db_url (), creator = connector )
320318
321319 def _execute_query (self , query : str , is_batch : bool , ** params ):
320+ connection = None
322321 try :
323322 connection = self ._engine .connect ()
324323 transaction = connection .begin ()
@@ -328,7 +327,8 @@ def _execute_query(self, query: str, is_batch: bool, **params):
328327 if is_batch :
329328 data = [row ._asdict () for row in result ]
330329 else :
331- data = result .first ()._asdict ()
330+ result_row = result .first ()
331+ data = result_row ._asdict () if result_row else {}
332332 # Explicitly commit the transaction.
333333 transaction .commit ()
334334 return data
@@ -337,8 +337,8 @@ def _execute_query(self, query: str, is_batch: bool, **params):
337337 raise RuntimeError (f"Database operation failed: { e } " )
338338 except Exception as e :
339339 raise Exception (
340- f'Could not execute the query: { query } . Please check if '
341- f'the query is properly formatted and the table exists. { e } ' )
340+ f'Could not execute the query. Please check if the query is properly '
341+ f'formatted and the table exists. { e } ' )
342342 finally :
343343 if connection :
344344 connection .close ()
@@ -454,8 +454,8 @@ def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]):
454454 req_dict [field ]
455455 for field in self ._query_config .where_clause_fields
456456 ]
457- key = ";" .join ([ "%s" ] * len ( current_values ))
458- cache_keys .extend ([ key % tuple ( current_values )] )
457+ key = ';' .join (map ( repr , current_values ))
458+ cache_keys .append ( key )
459459 except KeyError as e :
460460 raise KeyError (
461461 "Make sure the values passed in `where_clause_fields` are the "
0 commit comments