Skip to content

Commit 791eec6

Browse files
committed
Add ability to set schema in JdbcIO.java and jdbc.py Read.
1 parent 0f67c75 commit 791eec6

File tree

3 files changed

+57
-8
lines changed

3 files changed

+57
-8
lines changed

sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,9 @@ public abstract static class ReadRows extends PTransform<PBegin, PCollection<Row
745745
@Pure
746746
abstract boolean getDisableAutoCommit();
747747

748+
@Pure
749+
abstract @Nullable Schema getSchema();
750+
748751
abstract Builder toBuilder();
749752

750753
@AutoValue.Builder
@@ -762,6 +765,8 @@ abstract Builder setDataSourceProviderFn(
762765

763766
abstract Builder setDisableAutoCommit(boolean disableAutoCommit);
764767

768+
abstract Builder setSchema(@Nullable Schema schema);
769+
765770
abstract ReadRows build();
766771
}
767772

@@ -789,6 +794,10 @@ public ReadRows withStatementPreparator(StatementPreparator statementPreparator)
789794
return toBuilder().setStatementPreparator(statementPreparator).build();
790795
}
791796

797+
public ReadRows withSchema(Schema schema) {
798+
return toBuilder().setSchema(schema).build();
799+
}
800+
792801
/**
793802
* This method is used to set the size of the data that is going to be fetched and loaded in
794803
* memory per every database call. Please refer to: {@link java.sql.Statement#setFetchSize(int)}
@@ -830,7 +839,14 @@ public PCollection<Row> expand(PBegin input) {
830839
getDataSourceProviderFn(),
831840
"withDataSourceConfiguration() or withDataSourceProviderFn() is required");
832841

833-
Schema schema = inferBeamSchema(dataSourceProviderFn.apply(null), query.get());
842+
// Use the provided schema if it's not null, otherwise infer it
843+
Schema schema;
844+
if (getSchema() != null) {
845+
schema = getSchema();
846+
} else {
847+
schema = inferBeamSchema(dataSourceProviderFn.apply(null), query.get());
848+
}
849+
834850
PCollection<Row> rows =
835851
input.apply(
836852
JdbcIO.<Row>read()
@@ -1292,6 +1308,9 @@ public abstract static class ReadWithPartitions<T, PartitionColumnT>
12921308
@Pure
12931309
abstract boolean getUseBeamSchema();
12941310

1311+
@Pure
1312+
abstract @Nullable Schema getSchema();
1313+
12951314
@Pure
12961315
abstract @Nullable PartitionColumnT getLowerBound();
12971316

@@ -1333,6 +1352,8 @@ abstract Builder<T, PartitionColumnT> setDataSourceProviderFn(
13331352

13341353
abstract Builder<T, PartitionColumnT> setUseBeamSchema(boolean useBeamSchema);
13351354

1355+
abstract Builder setSchema(@Nullable Schema schema);
1356+
13361357
abstract Builder<T, PartitionColumnT> setFetchSize(int fetchSize);
13371358

13381359
abstract Builder<T, PartitionColumnT> setTable(String tableName);
@@ -1424,6 +1445,10 @@ public ReadWithPartitions<T, PartitionColumnT> withTable(String tableName) {
14241445
return toBuilder().setTable(tableName).build();
14251446
}
14261447

1448+
public ReadWithPartitions<T, PartitionColumnT> withSchema(Schema schema) {
1449+
return toBuilder().setSchema(schema).build();
1450+
}
1451+
14271452
private static final int EQUAL = 0;
14281453

14291454
@Override
@@ -1532,8 +1557,11 @@ public KV<Long, KV<PartitionColumnT, PartitionColumnT>> apply(
15321557
Schema schema = null;
15331558
if (getUseBeamSchema()) {
15341559
schema =
1535-
ReadRows.inferBeamSchema(
1536-
dataSourceProviderFn.apply(null), String.format("SELECT * FROM %s", getTable()));
1560+
getSchema() != null
1561+
? getSchema()
1562+
: ReadRows.inferBeamSchema(
1563+
dataSourceProviderFn.apply(null),
1564+
String.format("SELECT * FROM %s", getTable()));
15371565
rowMapper = (RowMapper<T>) SchemaUtil.BeamRowMapper.of(schema);
15381566
} else {
15391567
rowMapper = getRowMapper();

sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public Schema configurationSchema() {
8484
*/
8585
@Override
8686
public JdbcSchemaIO from(String location, Row configuration, @Nullable Schema dataSchema) {
87-
return new JdbcSchemaIO(location, configuration);
87+
return new JdbcSchemaIO(location, configuration, dataSchema);
8888
}
8989

9090
@Override
@@ -101,10 +101,12 @@ public PCollection.IsBounded isBounded() {
101101
static class JdbcSchemaIO implements SchemaIO, Serializable {
102102
protected final Row config;
103103
protected final String location;
104+
protected final @Nullable Schema dataSchema;
104105

105-
JdbcSchemaIO(String location, Row config) {
106+
JdbcSchemaIO(String location, Row config, @Nullable Schema dataSchema) {
106107
this.config = config;
107108
this.location = location;
109+
this.dataSchema = dataSchema;
108110
}
109111

110112
@Override
@@ -147,6 +149,10 @@ public PCollection<Row> expand(PBegin input) {
147149
readRows = readRows.withDisableAutoCommit(disableAutoCommit);
148150
}
149151

152+
if (dataSchema != null) {
153+
readRows = readRows.withSchema(dataSchema);
154+
}
155+
150156
return input.apply(readRows);
151157
} else {
152158

@@ -175,6 +181,11 @@ public PCollection<Row> expand(PBegin input) {
175181
readRows = readRows.withDisableAutoCommit(disableAutoCommit);
176182
}
177183

184+
// If a schema was provided, use it
185+
if (dataSchema != null) {
186+
readRows = readRows.withSchema(dataSchema);
187+
}
188+
178189
return input.apply(readRows);
179190
}
180191
}

sdks/python/apache_beam/io/jdbc.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def default_io_expansion_service(classpath=None):
114114

115115
JdbcConfigSchema = typing.NamedTuple(
116116
'JdbcConfigSchema',
117-
[('location', str), ('config', bytes)],
117+
[('location', str), ('config', bytes),
118+
('dataSchema', typing.Optional[bytes])],
118119
)
119120

120121
Config = typing.NamedTuple(
@@ -305,7 +306,7 @@ def __init__(
305306
driver_jars=None,
306307
expansion_service=None,
307308
classpath=None,
308-
):
309+
schema=None):
309310
"""
310311
Initializes a read operation from Jdbc.
311312
@@ -343,6 +344,14 @@ def __init__(
343344
driver.
344345
"""
345346
classpath = classpath or DEFAULT_JDBC_CLASSPATH
347+
348+
dataSchema = None
349+
if schema is not None:
350+
# Convert Python schema to Beam Schema proto
351+
schema_proto = typing_to_runner_api(schema).row_type.schema
352+
# Serialize the proto to bytes for transmission
353+
dataSchema = schema_proto.SerializeToString()
354+
346355
super().__init__(
347356
self.URN,
348357
NamedTupleBasedPayloadBuilder(
@@ -367,7 +376,8 @@ def __init__(
367376
max_connections=max_connections,
368377
driver_jars=driver_jars,
369378
partition_column=partition_column,
370-
partitions=partitions))),
379+
partitions=partitions)),
380+
dataSchema=dataSchema),
371381
),
372382
expansion_service or default_io_expansion_service(classpath),
373383
)

0 commit comments

Comments
 (0)