Skip to content

Commit 214b1bc

Browse files
sdks/python: address gemini feedback
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 948b8ba commit 214b1bc

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ class TableFieldsQueryConfig:
6363
def __post_init__(self):
6464
if not self.table_id or not self.where_clause_template:
6565
raise ValueError(
66-
"TableFieldsQueryConfig and " +
67-
"TableFunctionQueryConfig must provide table_id " +
68-
"and where_clause_template")
66+
"TableFieldsQueryConfig must provide table_id and " +
67+
"where_clause_template")
6968

7069
if not self.where_clause_fields:
7170
raise ValueError(
@@ -83,9 +82,8 @@ class TableFunctionQueryConfig:
8382
def __post_init__(self):
8483
if not self.table_id or not self.where_clause_template:
8584
raise ValueError(
86-
"TableFieldsQueryConfig and " +
87-
"TableFunctionQueryConfig must provide table_id " +
88-
"and where_clause_template")
85+
"TableFunctionQueryConfig must provide table_id and " +
86+
"where_clause_template")
8987

9088
if not self.where_clause_value_fn:
9189
raise ValueError(
@@ -264,7 +262,7 @@ def __init__(
264262
connection_config = CloudSQLConnectionConfig(
265263
db_adapter=DatabaseTypeAdapter.POSTGRESQL,
266264
instance_connection_uri="apache-beam-testing:us-central1:itests",
267-
user=postgres,
265+
user='postgres',
268266
password= os.getenv("CLOUDSQL_PG_PASSWORD"))
269267
query_config=TableFieldsQueryConfig('my_table',"id = '{}'",['id']),
270268
cloudsql_handler = CloudSQLEnrichmentHandler(
@@ -319,6 +317,7 @@ def __enter__(self):
319317
url=self._connection_config.get_db_url(), creator=connector)
320318

321319
def _execute_query(self, query: str, is_batch: bool, **params):
320+
connection = None
322321
try:
323322
connection = self._engine.connect()
324323
transaction = connection.begin()
@@ -328,7 +327,8 @@ def _execute_query(self, query: str, is_batch: bool, **params):
328327
if is_batch:
329328
data = [row._asdict() for row in result]
330329
else:
331-
data = result.first()._asdict()
330+
result_row = result.first()
331+
data = result_row._asdict() if result_row else {}
332332
# Explicitly commit the transaction.
333333
transaction.commit()
334334
return data
@@ -337,8 +337,8 @@ def _execute_query(self, query: str, is_batch: bool, **params):
337337
raise RuntimeError(f"Database operation failed: {e}")
338338
except Exception as e:
339339
raise Exception(
340-
f'Could not execute the query: {query}. Please check if '
341-
f'the query is properly formatted and the table exists. {e}')
340+
f'Could not execute the query. Please check if the query is properly '
341+
f'formatted and the table exists. {e}')
342342
finally:
343343
if connection:
344344
connection.close()
@@ -454,8 +454,8 @@ def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]):
454454
req_dict[field]
455455
for field in self._query_config.where_clause_fields
456456
]
457-
key = ";".join(["%s"] * len(current_values))
458-
cache_keys.extend([key % tuple(current_values)])
457+
key = ';'.join(map(repr, current_values))
458+
cache_keys.append(key)
459459
except KeyError as e:
460460
raise KeyError(
461461
"Make sure the values passed in `where_clause_fields` are the "

sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def start_sql_db_container(
103103
sql_db_container.start()
104104
host = sql_db_container.get_container_host_ip()
105105
port = int(sql_db_container.get_exposed_port(5432))
106-
107106
elif database_type == DatabaseTypeAdapter.MYSQL:
108107
user, password, db_id = "test", "test", "test"
109108
sql_db_container = MySqlContainer(
@@ -115,7 +114,6 @@ def start_sql_db_container(
115114
sql_db_container.start()
116115
host = sql_db_container.get_container_host_ip()
117116
port = int(sql_db_container.get_exposed_port(3306))
118-
119117
elif database_type == DatabaseTypeAdapter.SQLSERVER:
120118
user, password, db_id = "SA", "A_Str0ng_Required_Password", "tempdb"
121119
sql_db_container = SqlServerContainer(
@@ -457,7 +455,10 @@ def test_sql_enrichment_on_non_existent_table(self):
457455
with TestPipeline() as p:
458456
_ = (p | beam.Create(requests) | Enrichment(handler))
459457

460-
expect_err_msg_contains = "Could not execute the query"
458+
expect_err_msg_contains = (
459+
"Could not execute the query. Please check if the query is properly "
460+
"formatted and the table exists."
461+
)
461462
self.assertIn(expect_err_msg_contains, str(context.exception))
462463

463464
@pytest.mark.usefixtures("cache_container")
@@ -498,10 +499,12 @@ def test_sql_enrichment_with_redis(self):
498499
if not response:
499500
raise ValueError("No cache entry found for %s" % key)
500501

501-
# Mock the CloudSQL enrichment handler to avoid actual database calls.
502-
# This simulates a cache hit scenario by returning predefined data.
502+
# Mocks the CloudSQL enrichment handler to prevent actual database calls.
503+
# This ensures that a cache hit scenario does not trigger any database
504+
# interaction, raising an exception if an unexpected call occurs.
503505
actual = CloudSQLEnrichmentHandler.__call__
504-
CloudSQLEnrichmentHandler.__call__ = MagicMock(return_value=(beam.Row()))
506+
CloudSQLEnrichmentHandler.__call__ = MagicMock(
507+
side_effect=Exception("Database should not be called on a cache hit."))
505508

506509
# Run a second pipeline to verify cache is being used.
507510
with TestPipeline(is_integration_test=True) as test_pipeline:
@@ -542,7 +545,6 @@ def setUpClass(cls):
542545
@classmethod
543546
def tearDownClass(cls):
544547
super().tearDownClass()
545-
cls._db = None
546548

547549

548550
@unittest.skipUnless(

0 commit comments

Comments
 (0)