@@ -231,6 +231,126 @@ def test_xlang_jdbc_write_read(self, database):
231231
232232 assert_that (result , equal_to (expected_row ))
233233
234+ @parameterized .expand (['postgres' , 'mysql' ])
235+ def test_xlang_jdbc_read_with_explicit_schema (self , database ):
236+ container_init , classpath , db_string , driver = (
237+ CrossLanguageJdbcIOTest .DB_CONTAINER_CLASSPATH_STRING [database ])
238+ self ._setUpTestCase (container_init , db_string , driver )
239+ table_name = 'jdbc_schema_test'
240+
241+ if database == 'postgres' :
242+ # postgres does not have BINARY and VARBINARY type, use equivalent.
243+ binary_type = ('BYTEA' , 'BYTEA' )
244+ else :
245+ binary_type = ('BINARY(10)' , 'VARBINARY(10)' )
246+
247+ # Create a test table
248+ with self .engine .begin () as connection :
249+ connection .execute (
250+ sqlalchemy .text (
251+ "CREATE TABLE IF NOT EXISTS {}" .format (table_name ) +
252+ "(f_id INTEGER, " + "f_float DOUBLE PRECISION, " +
253+ "f_char CHAR(10), " + "f_varchar VARCHAR(10), " +
254+ f"f_bytes { binary_type [0 ]} , " + f"f_varbytes { binary_type [1 ]} , " +
255+ "f_timestamp TIMESTAMP(3), " + "f_decimal DECIMAL(10, 2), " +
256+ "f_date DATE, " + "f_time TIME(3))" ))
257+
258+ # Insert test data
259+ inserted_rows = [
260+ JdbcTestRow (
261+ i ,
262+ i + 0.1 ,
263+ f'Test{ i } ' ,
264+ f'Test{ i } ' ,
265+ f'Test{ i } ' .encode (),
266+ f'Test{ i } ' .encode (),
267+ Timestamp .of (seconds = round (time .time (), 3 )),
268+ Decimal (f'{ i - 1 } .23' ),
269+ datetime .date (1969 + i , i % 12 + 1 , i % 31 + 1 ),
270+ datetime .time (i % 24 , i % 60 , i % 60 , (i * 1000 ) % 1_000_000 ))
271+ for i in range (ROW_COUNT )
272+ ]
273+
274+ # Insert the data using SQLAlchemy
275+ with self .engine .begin () as connection :
276+ for row in inserted_rows :
277+ connection .execute (
278+ sqlalchemy .text (
279+ f"INSERT INTO { table_name } "
280+ f"VALUES (:f_id, :f_float, :f_char, :f_varchar, "
281+ f":f_bytes, :f_varbytes, :f_timestamp, :f_decimal, "
282+ f":f_date, :f_time)" ),
283+ {
284+ "f_id" : row .f_id ,
285+ "f_float" : row .f_float ,
286+ "f_char" : row .f_char ,
287+ "f_varchar" : row .f_varchar ,
288+ "f_bytes" : row .f_bytes ,
289+ "f_varbytes" : row .f_varbytes ,
290+ "f_timestamp" : row .f_timestamp .to_utc_datetime (),
291+ "f_decimal" : row .f_decimal ,
292+ "f_date" : row .f_date ,
293+ "f_time" : row .f_time ,
294+ })
295+
296+ # Define a custom schema with different field names
297+ CustomSchemaRow = typing .NamedTuple (
298+ "CustomSchemaRow" ,
299+ [
300+ ("custom_id" , int ),
301+ ("custom_float" , float ),
302+ ("custom_char" , str ),
303+ ("custom_varchar" , str ),
304+ ("custom_bytes" , bytes ),
305+ ("custom_varbytes" , bytes ),
306+ ("custom_timestamp" , Timestamp ),
307+ ("custom_decimal" , Decimal ),
308+ ("custom_date" , datetime .date ),
309+ ("custom_time" , datetime .time ),
310+ ],
311+ )
312+ coders .registry .register_coder (CustomSchemaRow , coders .RowCoder )
313+
314+ # Register MillisInstant logical type to override the mapping from Timestamp
315+ LogicalType .register_logical_type (MillisInstant )
316+
317+ # Expected results with custom field names
318+ expected_row = []
319+ for row in inserted_rows :
320+ f_char = row .f_char + ' ' * (10 - len (row .f_char ))
321+ if database != 'postgres' :
322+ # padding expected results
323+ f_bytes = row .f_bytes + b'\0 ' * (10 - len (row .f_bytes ))
324+ else :
325+ f_bytes = row .f_bytes
326+ expected_row .append (
327+ CustomSchemaRow (
328+ row .f_id ,
329+ row .f_float ,
330+ f_char ,
331+ row .f_varchar ,
332+ f_bytes ,
333+ row .f_bytes ,
334+ row .f_timestamp ,
335+ row .f_decimal ,
336+ row .f_date ,
337+ row .f_time ))
338+
339+ with TestPipeline () as p :
340+ p .not_use_test_runner_api = True
341+ result = (
342+ p
343+ | 'Read from jdbc with schema' >> ReadFromJdbc (
344+ table_name = table_name ,
345+ driver_class_name = self .driver_class_name ,
346+ jdbc_url = self .jdbc_url ,
347+ username = self .username ,
348+ password = self .password ,
349+ classpath = classpath ,
350+ schema = CustomSchemaRow ))
351+
352+ assert_that (result , equal_to (expected_row ))
353+
234354 # Creating a container with testcontainers sometimes raises ReadTimeout
235355 # error. In java there are 2 retries set by default.
236356 def start_db_container (self , retries , container_init ):
0 commit comments