Skip to content

Commit 19286b0

Browse files
authored
Merge pull request #285 from data-integrations/feature/add-db-sampling-options
Add Resampling Capabilities for Multiple Database Connectors
2 parents 6076300 + b1503a8 commit 19286b0

File tree

16 files changed

+553
-64
lines changed

16 files changed

+553
-64
lines changed

cloudsql-mysql-plugin/src/main/java/io/cdap/plugin/cloudsql/mysql/CloudSQLMySQLConnector.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,8 @@ public StructuredRecord transform(LongWritable longWritable, DBRecord dbRecord)
7272
}
7373

7474
@Override
75-
protected String getTableQuery(String database, String schema, String table) {
76-
return String.format("SELECT * FROM `%s`.`%s`", database, table);
77-
}
78-
79-
@Override
80-
protected String getTableQuery(String database, String schema, String table, int limit) {
81-
return String.format("SELECT * FROM `%s`.`%s` LIMIT %d", database, table, limit);
75+
protected String getTableName(String database, String schema, String table) {
76+
return String.format("`%s`.`%s`", database, table);
8277
}
8378

8479
@Override

cloudsql-postgresql-plugin/src/main/java/io/cdap/plugin/cloudsql/postgres/CloudSQLPostgreSQLConnector.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,13 @@ public StructuredRecord transform(LongWritable longWritable, PostgresDBRecord po
7979
}
8080

8181
@Override
82-
protected SchemaReader getSchemaReader() {
83-
return new PostgresSchemaReader();
82+
protected SchemaReader getSchemaReader(String sessionID) {
83+
return new PostgresSchemaReader(sessionID);
8484
}
8585

8686
@Override
87-
protected String getTableQuery(String database, String schema, String table) {
88-
return String.format("SELECT * FROM \"%s\".\"%s\"", schema, table);
89-
}
90-
91-
@Override
92-
protected String getTableQuery(String database, String schema, String table, int limit) {
93-
return String.format("SELECT * FROM \"%s\".\"%s\" LIMIT %d", schema, table, limit);
87+
protected String getTableName(String database, String schema, String table) {
88+
return String.format("\"%s\".\"%s\"", schema, table);
9489
}
9590

9691
@Override

database-commons/src/main/java/io/cdap/plugin/db/ConnectionConfigAccessor.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,4 @@ public Integer getFetchSize() {
111111
public Configuration getConfiguration() {
112112
return configuration;
113113
}
114-
115114
}

database-commons/src/main/java/io/cdap/plugin/db/connector/AbstractDBSpecificConnector.java

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright © 2021 Cask Data, Inc.
2+
* Copyright © 2021-2022 Cask Data, Inc.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
55
* use this file except in compliance with the License. You may obtain a copy of
@@ -23,13 +23,13 @@
2323
import io.cdap.cdap.etl.api.connector.ConnectorContext;
2424
import io.cdap.cdap.etl.api.connector.ConnectorSpecRequest;
2525
import io.cdap.cdap.etl.api.connector.SampleRequest;
26+
import io.cdap.cdap.etl.api.connector.SampleType;
2627
import io.cdap.plugin.common.ConfigUtil;
2728
import io.cdap.plugin.common.SourceInputFormatProvider;
2829
import io.cdap.plugin.common.db.AbstractDBConnector;
2930
import io.cdap.plugin.common.db.DBConnectorPath;
3031
import io.cdap.plugin.common.util.ExceptionUtils;
3132
import io.cdap.plugin.db.CommonSchemaReader;
32-
import io.cdap.plugin.db.ConnectionConfig;
3333
import io.cdap.plugin.db.ConnectionConfigAccessor;
3434
import io.cdap.plugin.db.SchemaReader;
3535
import io.cdap.plugin.db.batch.source.DataDrivenETLDBInputFormat;
@@ -44,9 +44,12 @@
4444
import java.sql.SQLException;
4545
import java.sql.Statement;
4646
import java.util.Map;
47+
import java.util.UUID;
48+
import javax.annotation.Nullable;
4749

4850
/**
4951
* An Abstract DB Specific Connector those specific DB connectors can inherits
52+
*
5053
* @param <T> the Record type that specific DB Record Reader may return while sample the data with InputFormat
5154
*/
5255
public abstract class AbstractDBSpecificConnector<T extends DBWritable> extends AbstractDBConnector
@@ -63,7 +66,7 @@ protected AbstractDBSpecificConnector(AbstractDBConnectorConfig config) {
6366

6467
protected abstract Class<? extends DBWritable> getDBRecordType();
6568

66-
protected SchemaReader getSchemaReader() {
69+
protected SchemaReader getSchemaReader(String sessionID) {
6770
return new CommonSchemaReader();
6871
}
6972

@@ -84,56 +87,98 @@ public InputFormatProvider getInputFormatProvider(ConnectorContext context, Samp
8487
ConnectionConfigAccessor connectionConfigAccessor = new ConnectionConfigAccessor();
8588
if (config.getUser() == null && config.getPassword() == null) {
8689
DBConfiguration.configureDB(connectionConfigAccessor.getConfiguration(), driverClass.getName(),
87-
getConnectionString(path.getDatabase()));
90+
getConnectionString(path.getDatabase()));
8891
} else {
8992
DBConfiguration.configureDB(connectionConfigAccessor.getConfiguration(), driverClass.getName(),
90-
getConnectionString(path.getDatabase()), config.getUser(), config.getPassword());
93+
getConnectionString(path.getDatabase()), config.getUser(), config.getPassword());
9194
}
92-
String tableQuery = getTableQuery(path.getDatabase(), path.getSchema(), path.getTable(), request.getLimit());
95+
String sessionID = generateSessionID();
96+
String tableQuery = getTableQuery(path.getDatabase(), path.getSchema(), path.getTable(), request.getLimit(),
97+
request.getProperties().get("sampleType"), request.getProperties().get("strata"), sessionID);
9398
DataDrivenETLDBInputFormat.setInput(connectionConfigAccessor.getConfiguration(), getDBRecordType(),
94-
tableQuery, null, false);
99+
tableQuery, null, false);
95100
connectionConfigAccessor.setConnectionArguments(Maps.fromProperties(config.getConnectionArgumentsProperties()));
96101
connectionConfigAccessor.getConfiguration().setInt(MRJobConfig.NUM_MAPS, 1);
97102
Map<String, String> additionalArguments = config.getAdditionalArguments();
98103
for (Map.Entry<String, String> argument : additionalArguments.entrySet()) {
99104
connectionConfigAccessor.getConfiguration().set(argument.getKey(), argument.getValue());
100105
}
101106
try {
102-
connectionConfigAccessor.setSchema(loadTableSchema(getConnection(path), tableQuery).toString());
107+
Long timeoutMs = request.getTimeoutMs();
108+
Integer timeoutSec = timeoutMs != null ? (int) (timeoutMs / 1000) : null;
109+
connectionConfigAccessor
110+
.setSchema(loadTableSchema(getConnection(path), tableQuery, timeoutSec, sessionID).toString());
103111
} catch (SQLException e) {
104112
throw new IOException(String.format("Failed to get table schema due to: %s.",
105-
ExceptionUtils.getRootCauseMessage(e)), e);
113+
ExceptionUtils.getRootCauseMessage(e)), e);
106114
}
107115

108-
109116
return new SourceInputFormatProvider(DataDrivenETLDBInputFormat.class, connectionConfigAccessor.getConfiguration());
110117
}
111118

112119
protected Connection getConnection(DBConnectorPath path) {
113-
return getConnection(getConnectionString(path.getDatabase()) , config.getConnectionArgumentsProperties());
120+
return getConnection(getConnectionString(path.getDatabase()), config.getConnectionArgumentsProperties());
114121
}
115122

116123
protected String getConnectionString(String database) {
117124
return config.getConnectionString();
118125
}
119126

127+
protected String getTableName(String database, String schema, String table) {
128+
return schema == null ? String.format("\"%s\".\"%s\"", database, table)
129+
: String.format("\"%s\".\"%s\".\"%s\"", database, schema, table);
130+
}
131+
120132
protected String getTableQuery(String database, String schema, String table) {
121-
return schema == null ? String.format("SELECT * FROM \"%s\".\"%s\"", database, table)
122-
: String.format("SELECT * FROM \"%s\".\"%s\".\"%s\"", database, schema, table);
133+
String tableName = getTableName(database, schema, table);
134+
return String.format("SELECT * FROM %s", tableName);
123135
}
124136

125137
protected String getTableQuery(String database, String schema, String table, int limit) {
126-
return schema == null ?
127-
String.format("SELECT * FROM \"%s\".\"%s\" LIMIT %d", database, table, limit) :
128-
String.format(
129-
"SELECT * FROM \"%s\".\"%s\".\"%s\" LIMIT %d", database, schema, table, limit);
138+
String tableName = getTableName(database, schema, table);
139+
return String.format("SELECT * FROM %s LIMIT %d", tableName, limit);
140+
}
141+
142+
protected String getTableQuery(String database, String schema, String table, int limit, String sampleType,
143+
String strata, String sessionID) throws IOException {
144+
if (sampleType == null) {
145+
return getTableQuery(database, schema, table, limit);
146+
}
147+
String tableName = getTableName(database, schema, table);
148+
switch (SampleType.fromString(sampleType)) {
149+
case RANDOM:
150+
return getRandomQuery(tableName, limit);
151+
case STRATIFIED:
152+
if (strata == null) {
153+
throw new IllegalArgumentException("No strata column given.");
154+
}
155+
return getStratifiedQuery(tableName, limit, strata, sessionID);
156+
default:
157+
return getTableQuery(database, schema, table, limit);
158+
}
159+
}
160+
161+
// Get the query to use for randomized sampling.
162+
// By default, databases don't support randomized sampling; this method must be overridden
163+
protected String getRandomQuery(String tableName, int limit) throws IOException {
164+
throw new IOException("Connection does not support random sampling.");
130165
}
131166

132-
protected Schema loadTableSchema(Connection connection, String query) throws SQLException {
167+
// Get the query to use for stratified sampling.
168+
// By default, databases don't support stratified sampling; this method must be overridden
169+
protected String getStratifiedQuery(String tableName, int limit, String strata, String sessionID) throws IOException {
170+
throw new IOException("Connection does not support stratified sampling.");
171+
}
172+
173+
protected Schema loadTableSchema(Connection connection, String query, @Nullable Integer timeoutSec, String sessionID)
174+
throws SQLException {
133175
Statement statement = connection.createStatement();
134176
statement.setMaxRows(1);
177+
if (timeoutSec != null) {
178+
statement.setQueryTimeout(timeoutSec);
179+
}
135180
ResultSet resultSet = statement.executeQuery(query);
136-
return Schema.recordOf("outputSchema", getSchemaReader().getSchemaFields(resultSet));
181+
return Schema.recordOf("outputSchema", getSchemaReader(sessionID).getSchemaFields(resultSet));
137182
}
138183

139184
protected void setConnectionProperties(Map<String, String> properties, ConnectorSpecRequest request) {
@@ -144,7 +189,12 @@ protected void setConnectionProperties(Map<String, String> properties, Connector
144189
@Override
145190
protected Schema getTableSchema(Connection connection, String database,
146191
String schema, String table) throws SQLException {
192+
String sessionID = generateSessionID();
193+
return loadTableSchema(getConnection(), getTableQuery(database, schema, table),
194+
null, sessionID);
195+
}
147196

148-
return loadTableSchema(getConnection(), getTableQuery(database, schema, table));
197+
protected String generateSessionID() {
198+
return UUID.randomUUID().toString().replace('-', '_');
149199
}
150200
}

mssql-plugin/src/main/java/io/cdap/plugin/mssql/SqlServerConnector.java

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import io.cdap.cdap.etl.api.connector.ConnectorSpec;
2929
import io.cdap.cdap.etl.api.connector.ConnectorSpecRequest;
3030
import io.cdap.cdap.etl.api.connector.PluginSpec;
31+
import io.cdap.cdap.etl.api.connector.SampleType;
3132
import io.cdap.plugin.common.Constants;
3233
import io.cdap.plugin.common.ReferenceNames;
3334
import io.cdap.plugin.common.db.DBConnectorPath;
@@ -37,6 +38,7 @@
3738
import org.apache.hadoop.io.LongWritable;
3839
import org.apache.hadoop.mapreduce.lib.db.DBWritable;
3940

41+
import java.io.IOException;
4042
import java.sql.SQLException;
4143
import java.util.HashMap;
4244
import java.util.Map;
@@ -75,7 +77,9 @@ protected void setConnectorSpec(ConnectorSpecRequest request, DBConnectorPath pa
7577
setConnectionProperties(sinkProperties, request);
7678
builder
7779
.addRelatedPlugin(new PluginSpec(SqlServerConstants.PLUGIN_NAME, BatchSource.PLUGIN_TYPE, sourceProperties))
78-
.addRelatedPlugin(new PluginSpec(SqlServerConstants.PLUGIN_NAME, BatchSink.PLUGIN_TYPE, sinkProperties));
80+
.addRelatedPlugin(new PluginSpec(SqlServerConstants.PLUGIN_NAME, BatchSink.PLUGIN_TYPE, sinkProperties))
81+
.addSupportedSampleType(SampleType.RANDOM)
82+
.addSupportedSampleType(SampleType.STRATIFIED);
7983

8084
String database = path.getDatabase();
8185
if (database != null) {
@@ -101,19 +105,48 @@ protected void setConnectorSpec(ConnectorSpecRequest request, DBConnectorPath pa
101105
}
102106

103107
@Override
104-
protected SchemaReader getSchemaReader() {
105-
return new SqlServerSourceSchemaReader();
108+
protected SchemaReader getSchemaReader(String sessionID) {
109+
return new SqlServerSourceSchemaReader(sessionID);
106110
}
107111

108112
@Override
109113
public StructuredRecord transform(LongWritable longWritable, SqlServerSourceDBRecord record) {
110114
return record.getRecord();
111115
}
112116

117+
@Override
118+
protected String getTableName(String database, String schema, String table) {
119+
return String.format("\"%s\".\"%s\".\"%s\"", database, schema, table);
120+
}
121+
113122
@Override
114123
protected String getTableQuery(String database, String schema, String table, int limit) {
115-
return String.format(
116-
"SELECT TOP(%d) * FROM \"%s\".\"%s\".\"%s\"", limit, database, schema, table);
124+
String tableName = getTableName(database, schema, table);
125+
return String.format("SELECT TOP %d * FROM %s", limit, tableName);
126+
}
127+
128+
@Override
129+
protected String getRandomQuery(String tableName, int limit) {
130+
// This query doesn't guarantee exactly "limit" number of rows
131+
return String.format("SELECT * FROM %s " +
132+
"WHERE (ABS(CAST((BINARY_CHECKSUM(*) * RAND()) as int)) %% 100) " +
133+
"< %d / (SELECT COUNT(*) FROM %s)",
134+
tableName, limit * 100, tableName);
135+
}
136+
137+
@Override
138+
protected String getStratifiedQuery(String tableName, int limit, String strata, String sessionID) {
139+
return String.format("WITH t_%s AS (\n" +
140+
" SELECT *,\n" +
141+
" ROW_NUMBER() OVER (ORDER BY %s, RAND()) AS sqn_%s,\n" +
142+
" COUNT(*) OVER () AS c_%s\n" +
143+
" FROM %s\n" +
144+
" )\n" +
145+
"SELECT TOP %d * FROM t_%s\n" +
146+
"WHERE sqn_%s %% CAST(0.5 * ((c_%s / %d + 1) + ABS(c_%s / %d - 1)) AS bigint) = 1\n" +
147+
"ORDER BY %s",
148+
sessionID, strata, sessionID, sessionID, tableName, limit, sessionID, sessionID, sessionID,
149+
limit, sessionID, limit, strata);
117150
}
118151

119152
@Override

mssql-plugin/src/main/java/io/cdap/plugin/mssql/SqlServerSourceSchemaReader.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ public class SqlServerSourceSchemaReader extends CommonSchemaReader {
3434
public static final int SQL_VARIANT = -156;
3535
public static final String DATETIME_TYPE_PREFIX = "datetime";
3636

37+
private final String sessionID;
38+
39+
public SqlServerSourceSchemaReader() {
40+
this(null);
41+
}
42+
43+
public SqlServerSourceSchemaReader(String sessionID) {
44+
super();
45+
this.sessionID = sessionID;
46+
}
47+
3748
@Override
3849
public Schema getSchema(ResultSetMetaData metadata, int index) throws SQLException {
3950
int columnSqlType = metadata.getColumnType(index);
@@ -72,4 +83,13 @@ public static boolean shouldConvertToDatetime(ResultSetMetaData metadata, int in
7283
public static boolean shouldConvertToDatetime(String typeName) {
7384
return typeName.startsWith(DATETIME_TYPE_PREFIX);
7485
}
86+
87+
@Override
88+
public boolean shouldIgnoreColumn(ResultSetMetaData metadata, int index) throws SQLException {
89+
if (sessionID == null) {
90+
return false;
91+
}
92+
return metadata.getColumnName(index).equals("c_" + sessionID) ||
93+
metadata.getColumnName(index).equals("sqn_" + sessionID);
94+
}
7595
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright © 2022 Cask Data, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
5+
* use this file except in compliance with the License. You may obtain a copy of
6+
* the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
* License for the specific language governing permissions and limitations under
14+
* the License.
15+
*/
16+
17+
package io.cdap.plugin.mssql;
18+
19+
import org.junit.Assert;
20+
import org.junit.Rule;
21+
import org.junit.Test;
22+
import org.junit.rules.ExpectedException;
23+
24+
/**
25+
* Unit tests for {@link SqlServerConnector}
26+
*/
27+
public class SqlServerConnectorUnitTest {
28+
@Rule
29+
public ExpectedException expectedEx = ExpectedException.none();
30+
31+
private static final SqlServerConnector CONNECTOR = new SqlServerConnector(null);
32+
33+
/**
34+
* Unit tests for getTableQuery()
35+
*/
36+
@Test
37+
public void getTableQueryTest() {
38+
String tableName = "\"db\".\"schema\".\"table\"";
39+
40+
// default query
41+
Assert.assertEquals(String.format("SELECT TOP %d * FROM %s", 100, tableName),
42+
CONNECTOR.getTableQuery("db", "schema", "table",
43+
100));
44+
45+
// random query
46+
Assert.assertEquals(String.format("SELECT * FROM %s " +
47+
"WHERE (ABS(CAST((BINARY_CHECKSUM(*) * RAND()) as int)) %% 100) " +
48+
"< %d / (SELECT COUNT(*) FROM %s)",
49+
tableName, 100, tableName),
50+
CONNECTOR.getRandomQuery(tableName, 100));
51+
}
52+
}

0 commit comments

Comments
 (0)