11/*
2- * Copyright © 2021 Cask Data, Inc.
2+ * Copyright © 2021-2022 Cask Data, Inc.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
55 * use this file except in compliance with the License. You may obtain a copy of
2323import io .cdap .cdap .etl .api .connector .ConnectorContext ;
2424import io .cdap .cdap .etl .api .connector .ConnectorSpecRequest ;
2525import io .cdap .cdap .etl .api .connector .SampleRequest ;
26+ import io .cdap .cdap .etl .api .connector .SampleType ;
2627import io .cdap .plugin .common .ConfigUtil ;
2728import io .cdap .plugin .common .SourceInputFormatProvider ;
2829import io .cdap .plugin .common .db .AbstractDBConnector ;
2930import io .cdap .plugin .common .db .DBConnectorPath ;
3031import io .cdap .plugin .common .util .ExceptionUtils ;
3132import io .cdap .plugin .db .CommonSchemaReader ;
32- import io .cdap .plugin .db .ConnectionConfig ;
3333import io .cdap .plugin .db .ConnectionConfigAccessor ;
3434import io .cdap .plugin .db .SchemaReader ;
3535import io .cdap .plugin .db .batch .source .DataDrivenETLDBInputFormat ;
4444import java .sql .SQLException ;
4545import java .sql .Statement ;
4646import java .util .Map ;
47+ import java .util .UUID ;
48+ import javax .annotation .Nullable ;
4749
4850/**
4951 * An Abstract DB Specific Connector those specific DB connectors can inherits
52+ *
5053 * @param <T> the Record type that specific DB Record Reader may return while sample the data with InputFormat
5154 */
5255public abstract class AbstractDBSpecificConnector <T extends DBWritable > extends AbstractDBConnector
@@ -63,7 +66,7 @@ protected AbstractDBSpecificConnector(AbstractDBConnectorConfig config) {
6366
6467 protected abstract Class <? extends DBWritable > getDBRecordType ();
6568
66- protected SchemaReader getSchemaReader () {
69+ protected SchemaReader getSchemaReader (String sessionID ) {
6770 return new CommonSchemaReader ();
6871 }
6972
@@ -84,56 +87,98 @@ public InputFormatProvider getInputFormatProvider(ConnectorContext context, Samp
8487 ConnectionConfigAccessor connectionConfigAccessor = new ConnectionConfigAccessor ();
8588 if (config .getUser () == null && config .getPassword () == null ) {
8689 DBConfiguration .configureDB (connectionConfigAccessor .getConfiguration (), driverClass .getName (),
87- getConnectionString (path .getDatabase ()));
90+ getConnectionString (path .getDatabase ()));
8891 } else {
8992 DBConfiguration .configureDB (connectionConfigAccessor .getConfiguration (), driverClass .getName (),
90- getConnectionString (path .getDatabase ()), config .getUser (), config .getPassword ());
93+ getConnectionString (path .getDatabase ()), config .getUser (), config .getPassword ());
9194 }
92- String tableQuery = getTableQuery (path .getDatabase (), path .getSchema (), path .getTable (), request .getLimit ());
95+ String sessionID = generateSessionID ();
96+ String tableQuery = getTableQuery (path .getDatabase (), path .getSchema (), path .getTable (), request .getLimit (),
97+ request .getProperties ().get ("sampleType" ), request .getProperties ().get ("strata" ), sessionID );
9398 DataDrivenETLDBInputFormat .setInput (connectionConfigAccessor .getConfiguration (), getDBRecordType (),
94- tableQuery , null , false );
99+ tableQuery , null , false );
95100 connectionConfigAccessor .setConnectionArguments (Maps .fromProperties (config .getConnectionArgumentsProperties ()));
96101 connectionConfigAccessor .getConfiguration ().setInt (MRJobConfig .NUM_MAPS , 1 );
97102 Map <String , String > additionalArguments = config .getAdditionalArguments ();
98103 for (Map .Entry <String , String > argument : additionalArguments .entrySet ()) {
99104 connectionConfigAccessor .getConfiguration ().set (argument .getKey (), argument .getValue ());
100105 }
101106 try {
102- connectionConfigAccessor .setSchema (loadTableSchema (getConnection (path ), tableQuery ).toString ());
107+ Long timeoutMs = request .getTimeoutMs ();
108+ Integer timeoutSec = timeoutMs != null ? (int ) (timeoutMs / 1000 ) : null ;
109+ connectionConfigAccessor
110+ .setSchema (loadTableSchema (getConnection (path ), tableQuery , timeoutSec , sessionID ).toString ());
103111 } catch (SQLException e ) {
104112 throw new IOException (String .format ("Failed to get table schema due to: %s." ,
105- ExceptionUtils .getRootCauseMessage (e )), e );
113+ ExceptionUtils .getRootCauseMessage (e )), e );
106114 }
107115
108-
109116 return new SourceInputFormatProvider (DataDrivenETLDBInputFormat .class , connectionConfigAccessor .getConfiguration ());
110117 }
111118
112119 protected Connection getConnection (DBConnectorPath path ) {
113- return getConnection (getConnectionString (path .getDatabase ()) , config .getConnectionArgumentsProperties ());
120+ return getConnection (getConnectionString (path .getDatabase ()), config .getConnectionArgumentsProperties ());
114121 }
115122
116123 protected String getConnectionString (String database ) {
117124 return config .getConnectionString ();
118125 }
119126
127+ protected String getTableName (String database , String schema , String table ) {
128+ return schema == null ? String .format ("\" %s\" .\" %s\" " , database , table )
129+ : String .format ("\" %s\" .\" %s\" .\" %s\" " , database , schema , table );
130+ }
131+
120132 protected String getTableQuery (String database , String schema , String table ) {
121- return schema == null ? String . format ( "SELECT * FROM \" %s \" . \" %s \" " , database , table )
122- : String .format ("SELECT * FROM \" %s \" . \" %s \" . \" %s \" " , database , schema , table );
133+ String tableName = getTableName ( database , schema , table );
134+ return String .format ("SELECT * FROM %s" , tableName );
123135 }
124136
125137 protected String getTableQuery (String database , String schema , String table , int limit ) {
126- return schema == null ?
127- String .format ("SELECT * FROM \" %s\" .\" %s\" LIMIT %d" , database , table , limit ) :
128- String .format (
129- "SELECT * FROM \" %s\" .\" %s\" .\" %s\" LIMIT %d" , database , schema , table , limit );
138+ String tableName = getTableName (database , schema , table );
139+ return String .format ("SELECT * FROM %s LIMIT %d" , tableName , limit );
140+ }
141+
142+ protected String getTableQuery (String database , String schema , String table , int limit , String sampleType ,
143+ String strata , String sessionID ) throws IOException {
144+ if (sampleType == null ) {
145+ return getTableQuery (database , schema , table , limit );
146+ }
147+ String tableName = getTableName (database , schema , table );
148+ switch (SampleType .fromString (sampleType )) {
149+ case RANDOM :
150+ return getRandomQuery (tableName , limit );
151+ case STRATIFIED :
152+ if (strata == null ) {
153+ throw new IllegalArgumentException ("No strata column given." );
154+ }
155+ return getStratifiedQuery (tableName , limit , strata , sessionID );
156+ default :
157+ return getTableQuery (database , schema , table , limit );
158+ }
159+ }
160+
161+ // Get the query to use for randomized sampling.
162+ // By default, databases don't support randomized sampling; this method must be overridden
163+ protected String getRandomQuery (String tableName , int limit ) throws IOException {
164+ throw new IOException ("Connection does not support random sampling." );
130165 }
131166
132- protected Schema loadTableSchema (Connection connection , String query ) throws SQLException {
167+ // Get the query to use for stratified sampling.
168+ // By default, databases don't support stratified sampling; this method must be overridden
169+ protected String getStratifiedQuery (String tableName , int limit , String strata , String sessionID ) throws IOException {
170+ throw new IOException ("Connection does not support stratified sampling." );
171+ }
172+
173+ protected Schema loadTableSchema (Connection connection , String query , @ Nullable Integer timeoutSec , String sessionID )
174+ throws SQLException {
133175 Statement statement = connection .createStatement ();
134176 statement .setMaxRows (1 );
177+ if (timeoutSec != null ) {
178+ statement .setQueryTimeout (timeoutSec );
179+ }
135180 ResultSet resultSet = statement .executeQuery (query );
136- return Schema .recordOf ("outputSchema" , getSchemaReader ().getSchemaFields (resultSet ));
181+ return Schema .recordOf ("outputSchema" , getSchemaReader (sessionID ).getSchemaFields (resultSet ));
137182 }
138183
139184 protected void setConnectionProperties (Map <String , String > properties , ConnectorSpecRequest request ) {
@@ -144,7 +189,12 @@ protected void setConnectionProperties(Map<String, String> properties, Connector
144189 @ Override
145190 protected Schema getTableSchema (Connection connection , String database ,
146191 String schema , String table ) throws SQLException {
192+ String sessionID = generateSessionID ();
193+ return loadTableSchema (getConnection (), getTableQuery (database , schema , table ),
194+ null , sessionID );
195+ }
147196
148- return loadTableSchema (getConnection (), getTableQuery (database , schema , table ));
197+ protected String generateSessionID () {
198+ return UUID .randomUUID ().toString ().replace ('-' , '_' );
149199 }
150200}
0 commit comments