@@ -244,7 +244,6 @@ def test_xlang_jdbc_read_with_explicit_schema(self, database):
244244 else :
245245 binary_type = ('BINARY(10)' , 'VARBINARY(10)' )
246246
247- # Create a test table
248247 with self .engine .begin () as connection :
249248 connection .execute (
250249 sqlalchemy .text (
@@ -255,7 +254,6 @@ def test_xlang_jdbc_read_with_explicit_schema(self, database):
255254 "f_timestamp TIMESTAMP(3), " + "f_decimal DECIMAL(10, 2), " +
256255 "f_date DATE, " + "f_time TIME(3))" ))
257256
258- # Insert test data
259257 inserted_rows = [
260258 JdbcTestRow (
261259 i ,
@@ -271,63 +269,54 @@ def test_xlang_jdbc_read_with_explicit_schema(self, database):
271269 for i in range (ROW_COUNT )
272270 ]
273271
274- # Insert the data using SQLAlchemy
275- with self .engine .begin () as connection :
276- for row in inserted_rows :
277- connection .execute (
278- sqlalchemy .text (
279- f"INSERT INTO { table_name } "
280- f"VALUES (:f_id, :f_float, :f_char, :f_varchar, "
281- f":f_bytes, :f_varbytes, :f_timestamp, :f_decimal, "
282- f":f_date, :f_time)" ),
283- {
284- "f_id" : row .f_id ,
285- "f_float" : row .f_float ,
286- "f_char" : row .f_char ,
287- "f_varchar" : row .f_varchar ,
288- "f_bytes" : row .f_bytes ,
289- "f_varbytes" : row .f_varbytes ,
290- "f_timestamp" : row .f_timestamp .to_utc_datetime (),
291- "f_decimal" : row .f_decimal ,
292- "f_date" : row .f_date ,
293- "f_time" : row .f_time ,
294- })
272+ with TestPipeline () as p :
273+ p .not_use_test_runner_api = True
274+ _ = (
275+ p
276+ | beam .Create (inserted_rows ).with_output_types (JdbcTestRow )
277+ | 'Write to jdbc' >> WriteToJdbc (
278+ table_name = table_name ,
279+ driver_class_name = self .driver_class_name ,
280+ jdbc_url = self .jdbc_url ,
281+ username = self .username ,
282+ password = self .password ,
283+ classpath = classpath ,
284+ ))
295285
296286 # Define a custom schema with different field names
297287 CustomSchemaRow = typing .NamedTuple (
298288 "CustomSchemaRow" ,
299289 [
300- ("custom_id " , int ),
301- ("custom_float " , float ),
302- ("custom_char " , str ),
303- ("custom_varchar " , str ),
304- ("custom_bytes " , bytes ),
305- ("custom_varbytes " , bytes ),
306- ("custom_timestamp " , Timestamp ),
307- ("custom_decimal " , Decimal ),
308- ("custom_date " , datetime .date ),
309- ("custom_time " , datetime .time ),
290+ ("renamed_id " , int ),
291+ ("renamed_float " , float ),
292+ ("renamed_char " , str ),
293+ ("renamed_varchar " , str ),
294+ ("renamed_bytes " , bytes ),
295+ ("renamed_varbytes " , bytes ),
296+ ("renamed_timestamp " , Timestamp ),
297+ ("renamed_decimal " , Decimal ),
298+ ("renamed_date " , datetime .date ),
299+ ("renamed_time " , datetime .time ),
310300 ],
311301 )
312302 coders .registry .register_coder (CustomSchemaRow , coders .RowCoder )
313303
314304 # Register MillisInstant logical type to override the mapping from Timestamp
315305 LogicalType .register_logical_type (MillisInstant )
316306
317- # Expected results with custom field names
318307 expected_row = []
319308 for row in inserted_rows :
320- f_char = row .f_char + ' ' * (10 - len (row .f_char ))
321309 if database != 'postgres' :
322- # padding expected results
310+ # padding expected results for binary fields
323311 f_bytes = row .f_bytes + b'\0 ' * (10 - len (row .f_bytes ))
324312 else :
325313 f_bytes = row .f_bytes
314+
326315 expected_row .append (
327316 CustomSchemaRow (
328317 row .f_id ,
329318 row .f_float ,
330- f_char ,
319+ row . f_char + ' ' * ( 10 - len ( row . f_char )) ,
331320 row .f_varchar ,
332321 f_bytes ,
333322 row .f_bytes ,
@@ -336,6 +325,28 @@ def test_xlang_jdbc_read_with_explicit_schema(self, database):
336325 row .f_date ,
337326 row .f_time ))
338327
328+ # Custom equals function that verifies field names and handles padding
329+ def custom_equals (expected , actual ):
330+
331+ # Then, compare the field values
332+ return (
333+ expected .renamed_id == actual .renamed_id and
334+ expected .renamed_float == actual .renamed_float and
335+ expected .renamed_char == actual .renamed_char and
336+ expected .renamed_varchar == actual .renamed_varchar and (
337+ expected .renamed_bytes == actual .renamed_bytes or
338+ # Handle potential padding differences in binary fields
339+ expected .renamed_bytes .rstrip (b'\0 ' ) ==
340+ actual .renamed_bytes .rstrip (b'\0 ' )) and (
341+ expected .renamed_varbytes == actual .renamed_varbytes or
342+ # Handle potential padding differences in binary fields
343+ expected .renamed_varbytes .rstrip (b'\0 ' ) ==
344+ actual .renamed_varbytes .rstrip (b'\0 ' )) and
345+ expected .renamed_timestamp == actual .renamed_timestamp and
346+ expected .renamed_decimal == actual .renamed_decimal and
347+ expected .renamed_date == actual .renamed_date and
348+ expected .renamed_time == actual .renamed_time )
349+
339350 with TestPipeline () as p :
340351 p .not_use_test_runner_api = True
341352 result = (
@@ -349,7 +360,8 @@ def test_xlang_jdbc_read_with_explicit_schema(self, database):
349360 classpath = classpath ,
350361 schema = CustomSchemaRow ))
351362
352- assert_that (result , equal_to (expected_row ))
363+ # Use our custom equals function that verifies field names
364+ assert_that (result , equal_to (expected_row , equals_fn = custom_equals ))
353365
354366 # Creating a container with testcontainers sometimes raises ReadTimeout
355367 # error. In java there are 2 retries set by default.
0 commit comments