3333import org .apache .beam .sdk .schemas .transforms .SchemaTransformProvider ;
3434import org .apache .beam .sdk .schemas .transforms .TypedSchemaTransformProvider ;
3535import org .apache .beam .sdk .values .PCollectionRowTuple ;
36+ import org .apache .beam .sdk .values .Row ;
3637import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .Strings ;
3738import org .checkerframework .checker .initialization .qual .Initialized ;
3839import org .checkerframework .checker .nullness .qual .NonNull ;
@@ -213,21 +214,52 @@ protected JdbcIO.DataSourceConfiguration dataSourceConfiguration() {
213214
214215 @ Override
215216 public PCollectionRowTuple expand (PCollectionRowTuple input ) {
216- String query = config .getReadQuery ();
217- if (query == null ) {
218- query = String .format ("SELECT * FROM %s" , config .getLocation ());
217+ config .validate ();
218+ // If we define a partition column, we follow a different route.
219+ @ Nullable String partitionColumn = config .getPartitionColumn ();
220+ @ Nullable String location = config .getLocation ();
221+ if (partitionColumn != null ) {
222+ JdbcIO .ReadWithPartitions <Row , ?> readRowsWithParitions =
223+ JdbcIO .<Row >readWithPartitions ()
224+ .withDataSourceConfiguration (dataSourceConfiguration ())
225+ .withTable (location )
226+ .withPartitionColumn (partitionColumn )
227+ .withRowOutput ();
228+
229+ @ Nullable Integer partitions = config .getNumPartitions ();
230+ if (partitions != null ) {
231+ readRowsWithParitions = readRowsWithParitions .withNumPartitions (partitions );
232+ }
233+
234+ @ Nullable Integer fetchSize = config .getFetchSize ();
235+ if (fetchSize != null && fetchSize > 0 ) {
236+ readRowsWithParitions = readRowsWithParitions .withFetchSize (fetchSize );
237+ }
238+
239+ @ Nullable Boolean disableAutoCommit = config .getDisableAutoCommit ();
240+ if (disableAutoCommit != null ) {
241+ readRowsWithParitions = readRowsWithParitions .withDisableAutoCommit (disableAutoCommit );
242+ }
243+ return PCollectionRowTuple .of ("output" , input .getPipeline ().apply (readRowsWithParitions ));
244+ }
245+ @ Nullable String readQuery = config .getReadQuery ();
246+ if (readQuery == null ) {
247+ readQuery = String .format ("SELECT * FROM %s" , location );
219248 }
220249 JdbcIO .ReadRows readRows =
221- JdbcIO .readRows ().withDataSourceConfiguration (dataSourceConfiguration ()).withQuery (query );
222- Integer fetchSize = config .getFetchSize ();
250+ JdbcIO .readRows ()
251+ .withDataSourceConfiguration (dataSourceConfiguration ())
252+ .withQuery (readQuery );
253+
254+ @ Nullable Integer fetchSize = config .getFetchSize ();
223255 if (fetchSize != null && fetchSize > 0 ) {
224256 readRows = readRows .withFetchSize (fetchSize );
225257 }
226- Boolean outputParallelization = config .getOutputParallelization ();
258+ @ Nullable Boolean outputParallelization = config .getOutputParallelization ();
227259 if (outputParallelization != null ) {
228260 readRows = readRows .withOutputParallelization (outputParallelization );
229261 }
230- Boolean disableAutoCommit = config .getDisableAutoCommit ();
262+ @ Nullable Boolean disableAutoCommit = config .getDisableAutoCommit ();
231263 if (disableAutoCommit != null ) {
232264 readRows = readRows .withDisableAutoCommit (disableAutoCommit );
233265 }
@@ -294,6 +326,14 @@ public abstract static class JdbcReadSchemaTransformConfiguration implements Ser
294326 @ Nullable
295327 public abstract String getLocation ();
296328
329+ @ SchemaFieldDescription ("Name of a column of numeric type that will be used for partitioning." )
330+ @ Nullable
331+ public abstract String getPartitionColumn ();
332+
333+ @ SchemaFieldDescription ("The number of partitions" )
334+ @ Nullable
335+ public abstract Integer getNumPartitions ();
336+
297337 @ SchemaFieldDescription (
298338 "Whether to reshuffle the resulting PCollection so results are distributed to all workers." )
299339 @ Nullable
@@ -340,13 +380,20 @@ public void validate(String jdbcType) throws IllegalArgumentException {
340380
341381 boolean readQueryPresent = (getReadQuery () != null && !"" .equals (getReadQuery ()));
342382 boolean locationPresent = (getLocation () != null && !"" .equals (getLocation ()));
383+ boolean partitionColumnPresent =
384+ (getPartitionColumn () != null && !"" .equals (getPartitionColumn ()));
343385
386+ // If you specify a readQuery, it is to be used instead of a table.
344387 if (readQueryPresent && locationPresent ) {
345388 throw new IllegalArgumentException ("Query and Table are mutually exclusive configurations" );
346389 }
347390 if (!readQueryPresent && !locationPresent ) {
348391 throw new IllegalArgumentException ("Either Query or Table must be specified." );
349392 }
393+ // Reading with partitions only supports table argument.
394+ if (partitionColumnPresent && !locationPresent ) {
395+ throw new IllegalArgumentException ("Table must be specified to read with partitions." );
396+ }
350397 }
351398
352399 public static Builder builder () {
@@ -368,6 +415,10 @@ public abstract static class Builder {
368415
369416 public abstract Builder setLocation (String value );
370417
418+ public abstract Builder setPartitionColumn (String value );
419+
420+ public abstract Builder setNumPartitions (Integer value );
421+
371422 public abstract Builder setReadQuery (String value );
372423
373424 public abstract Builder setConnectionProperties (String value );
0 commit comments