3838import io .cdap .plugin .common .ReferenceBatchSink ;
3939import io .cdap .plugin .common .ReferencePluginConfig ;
4040import io .cdap .plugin .common .batch .sink .SinkOutputFormatProvider ;
41+ import io .cdap .plugin .db .ColumnType ;
4142import io .cdap .plugin .db .CommonSchemaReader ;
4243import io .cdap .plugin .db .ConnectionConfig ;
4344import io .cdap .plugin .db .ConnectionConfigAccessor ;
5960import java .sql .ResultSetMetaData ;
6061import java .sql .SQLException ;
6162import java .sql .Statement ;
63+ import java .sql .Types ;
6264import java .util .ArrayList ;
6365import java .util .Collections ;
6466import java .util .HashSet ;
@@ -80,10 +82,9 @@ public abstract class AbstractDBSink extends ReferenceBatchSink<StructuredRecord
8082 private final DBSinkConfig dbSinkConfig ;
8183 private Class <? extends Driver > driverClass ;
8284 private DriverCleanup driverCleanup ;
83- protected int [] columnTypes ;
8485 protected List <String > columns ;
86+ protected List <ColumnType > columnTypes ;
8587 protected String dbColumns ;
86- private Schema outputSchema ;
8788
8889 public AbstractDBSink (DBSinkConfig dbSinkConfig ) {
8990 super (new ReferencePluginConfig (dbSinkConfig .referenceName ));
@@ -116,7 +117,7 @@ public void prepareRun(BatchSinkContext context) {
116117 dbSinkConfig .jdbcPluginName ,
117118 connectionString );
118119
119- outputSchema = context .getInputSchema ();
120+ Schema outputSchema = context .getInputSchema ();
120121
121122 // Load the plugin class to make sure it is available.
122123 Class <? extends Driver > driverClass = context .loadPluginClass (getJDBCPluginId ());
@@ -176,7 +177,7 @@ protected void setColumnsInfo(List<Schema.Field> fields) {
176177 public void initialize (BatchRuntimeContext context ) throws Exception {
177178 super .initialize (context );
178179 driverClass = context .loadPluginClass (getJDBCPluginId ());
179- outputSchema = Optional .ofNullable (context .getInputSchema ()).orElse (inferSchema (driverClass ));
180+ Schema outputSchema = Optional .ofNullable (context .getInputSchema ()).orElse (inferSchema (driverClass ));
180181
181182 setColumnsInfo (outputSchema .getFields ());
182183 setResultSetMetadata ();
@@ -209,23 +210,11 @@ private Schema inferSchema(Class<? extends Driver> driverClass) {
209210
210211 @ Override
211212 public void transform (StructuredRecord input , Emitter <KeyValue <DBRecord , NullWritable >> emitter ) {
212- // Create StructuredRecord that only has the columns in this.columns
213- List <Schema .Field > outputFields = new ArrayList <>();
214- for (Schema .Field field : input .getSchema ().getFields ()) {
215- Preconditions .checkArgument (columns .contains (field .getName ()), "Input field '%s' is not found in columns" ,
216- field .getName ());
217- outputFields .add (field );
218- }
219- StructuredRecord .Builder output = StructuredRecord .builder (outputSchema );
220- for (String column : columns ) {
221- output .set (column , input .get (column ));
222- }
223-
224- emitter .emit (new KeyValue <>(getDBRecord (output ), null ));
213+ emitter .emit (new KeyValue <>(getDBRecord (input ), null ));
225214 }
226215
227- protected DBRecord getDBRecord (StructuredRecord . Builder output ) {
228- return new DBRecord (output . build () , columnTypes );
216+ protected DBRecord getDBRecord (StructuredRecord output ) {
217+ return new DBRecord (output , columnTypes );
229218 }
230219
231220 protected SchemaReader getSchemaReader () {
@@ -272,12 +261,12 @@ private void setResultSetMetadata() throws Exception {
272261 }
273262 }
274263
275- columnTypes = new int [ columns .size ()];
276- for ( int i = 0 ; i < columnTypes . length ; i ++) {
277- String name = columns . get ( i );
278- Preconditions . checkArgument ( columnToType . containsKey ( name ), "Missing column '%s' in SQL table" , name );
279- columnTypes [ i ] = columnToType . get ( name );
280- }
264+ this . columnTypes = columns .stream ()
265+ . map ( name -> {
266+ Preconditions . checkArgument ( columnToType . containsKey ( name ), "Missing column '%s' in SQL table" , name );
267+ return new ColumnType ( name , columnToType . get ( name ) );
268+ })
269+ . collect ( Collectors . collectingAndThen ( Collectors . toList (), Collections :: unmodifiableList ));
281270 }
282271
283272 private void validateSchema (Class <? extends Driver > jdbcDriverClass , String tableName , Schema inputSchema ) {
@@ -324,7 +313,23 @@ private void validateFields(Schema inputSchema, ResultSet rs) throws SQLExceptio
324313 Set <String > invalidFields = new HashSet <>();
325314 for (Schema .Field field : inputSchema .getFields ()) {
326315 int columnIndex = rs .findColumn (field .getName ());
316+ boolean isColumnNullable = (ResultSetMetaData .columnNullable == rsMetaData .isNullable (columnIndex ));
317+ boolean isNotNullAssignable = !isColumnNullable && field .getSchema ().isNullable ();
318+ if (isNotNullAssignable ) {
319+ LOG .error ("Field '{}' was given as nullable but the database column is not nullable" , field .getName ());
320+ invalidFields .add (field .getName ());
321+ }
322+
327323 if (!isFieldCompatible (field , rsMetaData , columnIndex )) {
324+ String sqlTypeName = rsMetaData .getColumnTypeName (columnIndex );
325+ Schema fieldSchema = field .getSchema ().isNullable () ? field .getSchema ().getNonNullable () : field .getSchema ();
326+ Schema .Type fieldType = fieldSchema .getType ();
327+ Schema .LogicalType fieldLogicalType = fieldSchema .getLogicalType ();
328+ LOG .error ("Field '{}' was given as type '{}' but the database column is actually of type '{}'." ,
329+ field .getName (),
330+ fieldLogicalType != null ? fieldLogicalType .getToken () : fieldType ,
331+ sqlTypeName
332+ );
328333 invalidFields .add (field .getName ());
329334 }
330335 }
@@ -335,34 +340,70 @@ private void validateFields(Schema inputSchema, ResultSet rs) throws SQLExceptio
335340 }
336341
337342 /**
338- * Checks if field of the input schema is compatible with corresponding database column.
343+ * Checks if field is compatible to be written into database column of the given sql index .
339344 * @param field field of the explicit input schema.
340345 * @param metadata resultSet metadata.
341346 * @param index sql column index.
342- * @return 'true' if field is compatible, 'false' otherwise.
347+ * @return 'true' if field is compatible to be written , 'false' otherwise.
343348 */
344349 protected boolean isFieldCompatible (Schema .Field field , ResultSetMetaData metadata , int index ) throws SQLException {
345- boolean isColumnNullable = (ResultSetMetaData .columnNullable == metadata .isNullable (index ));
346- boolean isNotNullAssignable = !isColumnNullable && field .getSchema ().isNullable ();
347- if (isNotNullAssignable ) {
348- LOG .error ("Field '{}' was given as nullable but the database column is not nullable" , field .getName ());
349- return false ;
350+ Schema fieldSchema = field .getSchema ().isNullable () ? field .getSchema ().getNonNullable () : field .getSchema ();
351+ Schema .Type fieldType = fieldSchema .getType ();
352+ Schema .LogicalType fieldLogicalType = fieldSchema .getLogicalType ();
353+
354+ int sqlType = metadata .getColumnType (index );
355+
356+ // Handle logical types first
357+ if (fieldLogicalType != null ) {
358+ switch (fieldLogicalType ) {
359+ case DATE :
360+ return sqlType == Types .DATE ;
361+ case TIME_MILLIS :
362+ case TIME_MICROS :
363+ return sqlType == Types .TIME ;
364+ case TIMESTAMP_MILLIS :
365+ case TIMESTAMP_MICROS :
366+ return sqlType == Types .TIMESTAMP ;
367+ case DECIMAL :
368+ return sqlType == Types .NUMERIC
369+ || sqlType == Types .DECIMAL ;
370+ }
350371 }
351372
352- int type = metadata .getColumnType (index );
353- int precision = metadata .getPrecision (index );
354- int scale = metadata .getScale (index );
355-
356- Schema inputFieldSchema = field .getSchema ().isNullable () ? field .getSchema ().getNonNullable () : field .getSchema ();
357- Schema outputFieldSchema = DBUtils .getSchema (type , precision , scale );
358- if (!Objects .equals (inputFieldSchema .getType (), outputFieldSchema .getType ()) ||
359- !Objects .equals (inputFieldSchema .getLogicalType (), outputFieldSchema .getLogicalType ())) {
360- LOG .error ("Field '{}' was given as type '{}' but the database column is actually of type '{}'." ,
361- field .getName (), inputFieldSchema .getType (), outputFieldSchema .getType ());
362- return false ;
373+ switch (fieldType ) {
374+ case NULL :
375+ return true ;
376+ case BOOLEAN :
377+ return sqlType == Types .BOOLEAN
378+ || sqlType == Types .BIT ;
379+ case INT :
380+ return sqlType == Types .INTEGER
381+ || sqlType == Types .SMALLINT
382+ || sqlType == Types .TINYINT ;
383+ case LONG :
384+ return sqlType == Types .BIGINT ;
385+ case FLOAT :
386+ return sqlType == Types .REAL
387+ || sqlType == Types .FLOAT ;
388+ case DOUBLE :
389+ return sqlType == Types .DOUBLE ;
390+ case BYTES :
391+ return sqlType == Types .BINARY
392+ || sqlType == Types .VARBINARY
393+ || sqlType == Types .LONGVARBINARY
394+ || sqlType == Types .BLOB ;
395+ case STRING :
396+ return sqlType == Types .VARCHAR
397+ || sqlType == Types .CHAR
398+ || sqlType == Types .CLOB
399+ || sqlType == Types .LONGNVARCHAR
400+ || sqlType == Types .LONGVARCHAR
401+ || sqlType == Types .NCHAR
402+ || sqlType == Types .NCLOB
403+ || sqlType == Types .NVARCHAR ;
404+ default :
405+ return false ;
363406 }
364-
365- return true ;
366407 }
367408
368409 private void emitLineage (BatchSinkContext context , List <Schema .Field > fields ) {
0 commit comments