Skip to content

Commit f60859d

Browse files
committed
Fix xlang test.
1 parent 1bb076a commit f60859d

File tree

1 file changed

+50
-38
lines changed

1 file changed

+50
-38
lines changed

sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,6 @@ def test_xlang_jdbc_read_with_explicit_schema(self, database):
244244
else:
245245
binary_type = ('BINARY(10)', 'VARBINARY(10)')
246246

247-
# Create a test table
248247
with self.engine.begin() as connection:
249248
connection.execute(
250249
sqlalchemy.text(
@@ -255,7 +254,6 @@ def test_xlang_jdbc_read_with_explicit_schema(self, database):
255254
"f_timestamp TIMESTAMP(3), " + "f_decimal DECIMAL(10, 2), " +
256255
"f_date DATE, " + "f_time TIME(3))"))
257256

258-
# Insert test data
259257
inserted_rows = [
260258
JdbcTestRow(
261259
i,
@@ -271,63 +269,54 @@ def test_xlang_jdbc_read_with_explicit_schema(self, database):
271269
for i in range(ROW_COUNT)
272270
]
273271

274-
# Insert the data using SQLAlchemy
275-
with self.engine.begin() as connection:
276-
for row in inserted_rows:
277-
connection.execute(
278-
sqlalchemy.text(
279-
f"INSERT INTO {table_name} "
280-
f"VALUES (:f_id, :f_float, :f_char, :f_varchar, "
281-
f":f_bytes, :f_varbytes, :f_timestamp, :f_decimal, "
282-
f":f_date, :f_time)"),
283-
{
284-
"f_id": row.f_id,
285-
"f_float": row.f_float,
286-
"f_char": row.f_char,
287-
"f_varchar": row.f_varchar,
288-
"f_bytes": row.f_bytes,
289-
"f_varbytes": row.f_varbytes,
290-
"f_timestamp": row.f_timestamp.to_utc_datetime(),
291-
"f_decimal": row.f_decimal,
292-
"f_date": row.f_date,
293-
"f_time": row.f_time,
294-
})
272+
with TestPipeline() as p:
273+
p.not_use_test_runner_api = True
274+
_ = (
275+
p
276+
| beam.Create(inserted_rows).with_output_types(JdbcTestRow)
277+
| 'Write to jdbc' >> WriteToJdbc(
278+
table_name=table_name,
279+
driver_class_name=self.driver_class_name,
280+
jdbc_url=self.jdbc_url,
281+
username=self.username,
282+
password=self.password,
283+
classpath=classpath,
284+
))
295285

296286
# Define a custom schema with different field names
297287
CustomSchemaRow = typing.NamedTuple(
298288
"CustomSchemaRow",
299289
[
300-
("custom_id", int),
301-
("custom_float", float),
302-
("custom_char", str),
303-
("custom_varchar", str),
304-
("custom_bytes", bytes),
305-
("custom_varbytes", bytes),
306-
("custom_timestamp", Timestamp),
307-
("custom_decimal", Decimal),
308-
("custom_date", datetime.date),
309-
("custom_time", datetime.time),
290+
("renamed_id", int),
291+
("renamed_float", float),
292+
("renamed_char", str),
293+
("renamed_varchar", str),
294+
("renamed_bytes", bytes),
295+
("renamed_varbytes", bytes),
296+
("renamed_timestamp", Timestamp),
297+
("renamed_decimal", Decimal),
298+
("renamed_date", datetime.date),
299+
("renamed_time", datetime.time),
310300
],
311301
)
312302
coders.registry.register_coder(CustomSchemaRow, coders.RowCoder)
313303

314304
# Register MillisInstant logical type to override the mapping from Timestamp
315305
LogicalType.register_logical_type(MillisInstant)
316306

317-
# Expected results with custom field names
318307
expected_row = []
319308
for row in inserted_rows:
320-
f_char = row.f_char + ' ' * (10 - len(row.f_char))
321309
if database != 'postgres':
322-
# padding expected results
310+
# padding expected results for binary fields
323311
f_bytes = row.f_bytes + b'\0' * (10 - len(row.f_bytes))
324312
else:
325313
f_bytes = row.f_bytes
314+
326315
expected_row.append(
327316
CustomSchemaRow(
328317
row.f_id,
329318
row.f_float,
330-
f_char,
319+
row.f_char + ' ' * (10 - len(row.f_char)),
331320
row.f_varchar,
332321
f_bytes,
333322
row.f_bytes,
@@ -336,6 +325,28 @@ def test_xlang_jdbc_read_with_explicit_schema(self, database):
336325
row.f_date,
337326
row.f_time))
338327

328+
# Custom equals function that verifies field names and handles padding
329+
def custom_equals(expected, actual):
330+
331+
# Then, compare the field values
332+
return (
333+
expected.renamed_id == actual.renamed_id and
334+
expected.renamed_float == actual.renamed_float and
335+
expected.renamed_char == actual.renamed_char and
336+
expected.renamed_varchar == actual.renamed_varchar and (
337+
expected.renamed_bytes == actual.renamed_bytes or
338+
# Handle potential padding differences in binary fields
339+
expected.renamed_bytes.rstrip(b'\0') ==
340+
actual.renamed_bytes.rstrip(b'\0')) and (
341+
expected.renamed_varbytes == actual.renamed_varbytes or
342+
# Handle potential padding differences in binary fields
343+
expected.renamed_varbytes.rstrip(b'\0') ==
344+
actual.renamed_varbytes.rstrip(b'\0')) and
345+
expected.renamed_timestamp == actual.renamed_timestamp and
346+
expected.renamed_decimal == actual.renamed_decimal and
347+
expected.renamed_date == actual.renamed_date and
348+
expected.renamed_time == actual.renamed_time)
349+
339350
with TestPipeline() as p:
340351
p.not_use_test_runner_api = True
341352
result = (
@@ -349,7 +360,8 @@ def test_xlang_jdbc_read_with_explicit_schema(self, database):
349360
classpath=classpath,
350361
schema=CustomSchemaRow))
351362

352-
assert_that(result, equal_to(expected_row))
363+
# Use our custom equals function that verifies field names
364+
assert_that(result, equal_to(expected_row, equals_fn=custom_equals))
353365

354366
# Creating a container with testcontainers sometimes raises ReadTimeout
355367
# error. In java there are 2 retries set by default.

0 commit comments

Comments
 (0)