Skip to content

Commit cf173b4

Browse files
authored
Merge pull request #1084 from data-integrations/feature/add-sampling-options
Add sampling options to BigQuery connector
2 parents 0c7dfcf + 9ee9ddb commit cf173b4

File tree

2 files changed

+262
-21
lines changed

2 files changed

+262
-21
lines changed

src/main/java/io/cdap/plugin/gcp/bigquery/connector/BigQueryConnector.java

Lines changed: 99 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import com.google.api.gax.paging.Page;
2020
import com.google.auth.oauth2.GoogleCredentials;
21+
import com.google.cloud.RetryOption;
2122
import com.google.cloud.bigquery.BigQuery;
2223
import com.google.cloud.bigquery.BigQueryException;
2324
import com.google.cloud.bigquery.Dataset;
@@ -49,6 +50,7 @@
4950
import io.cdap.cdap.etl.api.connector.DirectConnector;
5051
import io.cdap.cdap.etl.api.connector.PluginSpec;
5152
import io.cdap.cdap.etl.api.connector.SampleRequest;
53+
import io.cdap.cdap.etl.api.connector.SampleType;
5254
import io.cdap.cdap.etl.api.engine.sql.BatchSQLEngine;
5355
import io.cdap.cdap.etl.api.validation.ValidationException;
5456
import io.cdap.plugin.common.ConfigUtil;
@@ -62,6 +64,7 @@
6264
import io.cdap.plugin.gcp.bigquery.util.BigQueryDataParser;
6365
import io.cdap.plugin.gcp.bigquery.util.BigQueryUtil;
6466
import io.cdap.plugin.gcp.common.GCPUtils;
67+
import org.threeten.bp.Duration;
6568

6669
import java.io.IOException;
6770
import java.util.HashMap;
@@ -95,8 +98,13 @@ public List<StructuredRecord> sample(ConnectorContext context, SampleRequest sam
9598
throw new IllegalArgumentException("Path should contain both dataset and table name.");
9699
}
97100
String dataset = path.getDataset();
98-
return getTableData(getBigQuery(config.getProject()), config.getDatasetProject(), dataset, table,
99-
sampleRequest.getLimit());
101+
String query = getTableQuery(String.format("`%s.%s.%s`", config.getDatasetProject(), dataset, table),
102+
sampleRequest.getLimit(),
103+
SampleType.fromString(sampleRequest.getProperties().get("sampleType")),
104+
sampleRequest.getProperties().get("strata"),
105+
UUID.randomUUID().toString().replace("-", "_"));
106+
String id = UUID.randomUUID().toString();
107+
return getQueryResult(waitForJob(getBigQuery(config.getProject()), query, sampleRequest.getTimeoutMs(), id), id);
100108
}
101109

102110
@Override
@@ -117,7 +125,7 @@ public void test(ConnectorContext context) throws ValidationException {
117125
GCPUtils.loadServiceAccountCredentials(config.getServiceAccount(), config.isServiceAccountFilePath());
118126
} catch (Exception e) {
119127
failureCollector.addFailure(String.format("Service account key provided is not valid: %s", e.getMessage()),
120-
"Please provide a valid service account key.");
128+
"Please provide a valid service account key.");
121129
}
122130
}
123131
// if either project or credentials cannot be loaded , no need to continue
@@ -130,7 +138,7 @@ public void test(ConnectorContext context) throws ValidationException {
130138
bigQuery.listDatasets(BigQuery.DatasetListOption.pageSize(1));
131139
} catch (Exception e) {
132140
failureCollector.addFailure(String.format("Could not connect to BigQuery: %s", e.getMessage()),
133-
"Please specify correct connection properties.");
141+
"Please specify correct connection properties.");
134142
}
135143
}
136144

@@ -143,15 +151,15 @@ public BrowseDetail browse(ConnectorContext context, BrowseRequest browseRequest
143151
if (dataset == null) {
144152
// browse project to list all datasets
145153
return config.rootDataset == null ?
146-
listDatasets(getBigQuery(config.getDatasetProject()), browseRequest.getLimit()) :
147-
BrowseDetail.builder().setTotalCount(1).addEntity(
148-
BrowseEntity.builder(config.rootDataset, "/" + config.rootDataset, ENTITY_TYPE_DATASET)
149-
.canBrowse(true).build()).build();
154+
listDatasets(getBigQuery(config.getDatasetProject()), browseRequest.getLimit()) :
155+
BrowseDetail.builder().setTotalCount(1).addEntity(
156+
BrowseEntity.builder(config.rootDataset, "/" + config.rootDataset, ENTITY_TYPE_DATASET)
157+
.canBrowse(true).build()).build();
150158
}
151159
String table = path.getTable();
152160
if (table == null) {
153161
return listTables(getBigQuery(config.getProject()), config.getDatasetProject(), dataset,
154-
browseRequest.getLimit());
162+
browseRequest.getLimit());
155163
}
156164
return getTableDetail(getBigQuery(config.getProject()), config.getDatasetProject(), dataset, table);
157165
}
@@ -202,7 +210,7 @@ private BrowseDetail listTables(BigQuery bigQuery, String datasetProject, String
202210

203211
private BrowseDetail listDatasets(BigQuery bigQuery, Integer limit) {
204212
Page<Dataset> datasetPage = config.showHiddenDatasets() ?
205-
bigQuery.listDatasets(BigQuery.DatasetListOption.all()) : bigQuery.listDatasets();
213+
bigQuery.listDatasets(BigQuery.DatasetListOption.all()) : bigQuery.listDatasets();
206214
int countLimit = limit == null || limit <= 0 ? Integer.MAX_VALUE : limit;
207215
int count = 0;
208216
BrowseDetail.Builder browseDetailBuilder = BrowseDetail.builder();
@@ -233,31 +241,99 @@ private BigQuery getBigQuery(String project) throws IOException {
233241
return GCPUtils.getBigQuery(project, credentials);
234242
}
235243

236-
private List<StructuredRecord> getTableData(BigQuery bigQuery, String datasetProject, String dataset, String table,
237-
int limit)
238-
throws IOException {
239-
String query =
240-
String.format("SELECT * FROM `%s.%s.%s` LIMIT %d", datasetProject, dataset, table, limit);
244+
/**
245+
* Get the SQL query used to sample the table
246+
* @param tableName name of the table
247+
* @param limit limit on rows returned
248+
* @param sampleType sampling method
249+
* @param strata strata column (if given)
250+
* @param sessionID UUID
251+
* @return String
252+
* @throws IllegalArgumentException if no strata column is given for a stratified query
253+
*/
254+
protected String getTableQuery(String tableName, int limit, SampleType sampleType, @Nullable String strata,
255+
String sessionID) {
256+
switch (sampleType) {
257+
case RANDOM:
258+
return String.format("WITH table AS (\n" +
259+
" SELECT *, RAND() AS r_%s\n" +
260+
" FROM %s\n" +
261+
" WHERE RAND() < 2*%d/(SELECT COUNT(*) FROM %s)\n" +
262+
")\n" +
263+
"SELECT * EXCEPT (r_%s)\n" +
264+
"FROM table\n" +
265+
"ORDER BY r_%s\n" +
266+
"LIMIT %d",
267+
sessionID, tableName, limit, tableName, sessionID, sessionID, limit);
268+
case STRATIFIED:
269+
if (strata == null) {
270+
throw new IllegalArgumentException("No strata column given.");
271+
}
272+
return String.format("SELECT * EXCEPT (`sqn_%s`, `c_%s`)\n" +
273+
"FROM (\n" +
274+
"SELECT *, row_number() OVER (ORDER BY %s, RAND()) AS sqn_%s,\n" +
275+
"COUNT(*) OVER () as c_%s,\n" +
276+
"FROM %s\n" +
277+
") %s\n" +
278+
"WHERE MOD(sqn_%s, CAST(c_%s / %d AS INT64)) = 1\n" +
279+
"ORDER BY %s\n" +
280+
"LIMIT %d",
281+
sessionID, sessionID, strata, sessionID, sessionID, tableName, tableName, sessionID,
282+
sessionID, limit, strata, limit);
283+
default:
284+
return String.format("SELECT * FROM %s LIMIT %d", tableName, limit);
285+
}
286+
}
287+
288+
/**
289+
* Wait for job to complete or time out (if timeout is given)
290+
* @param bigQuery BigQuery client
291+
* @param query SQL query
292+
* @param timeoutMs timeout (if given)
293+
* @param id job ID
294+
* @return Job
295+
* @throws IOException if the job is interrupted
296+
*/
297+
private Job waitForJob(BigQuery bigQuery, String query, @Nullable Long timeoutMs, String id) throws IOException {
298+
299+
// set up job
241300
QueryJobConfiguration queryConfig = QueryJobConfiguration.newBuilder(query).build();
242-
String id = UUID.randomUUID().toString();
243301
JobId jobId = JobId.of(id);
244302
Job queryJob = bigQuery.create(JobInfo.newBuilder(queryConfig).setJobId(jobId).build());
245-
// Wait for the job to finish
303+
304+
// wait for job
246305
try {
247-
queryJob = queryJob.waitFor();
306+
if (timeoutMs == null) {
307+
return queryJob.waitFor();
308+
} else {
309+
return queryJob.waitFor(RetryOption.totalTimeout(Duration.ofMillis(timeoutMs)));
310+
}
248311
} catch (InterruptedException e) {
249312
throw new IOException(String.format("Query job %s interrupted.", id), e);
250313
}
314+
}
251315

252-
// check for errors
316+
/**
317+
* Retrieve the results of a SQL query
318+
* @param queryJob query job after completion or timeout
319+
* @param id job ID
320+
* @return List of structured records
321+
* @throws IOException if query encounters an error or times out
322+
*/
323+
protected List<StructuredRecord> getQueryResult(@Nullable Job queryJob, String id) throws IOException {
324+
325+
// Check for errors
253326
if (queryJob == null) {
254327
throw new IOException(String.format("Job %s no longer exists.", id));
328+
} else if (!queryJob.isDone()) {
329+
queryJob.cancel();
330+
throw new IOException(String.format("Job %s timed out.", id));
255331
} else if (queryJob.getStatus().getError() != null) {
256332
throw new IOException(String.format("Failed to query table : %s", queryJob.getStatus().getError().toString()));
257333
}
258334

259335
// Get the results
260-
TableResult result = null;
336+
TableResult result;
261337
try {
262338
result = queryJob.getQueryResults();
263339
} catch (InterruptedException e) {
@@ -266,7 +342,6 @@ private List<StructuredRecord> getTableData(BigQuery bigQuery, String datasetPro
266342
return BigQueryDataParser.parse(result);
267343
}
268344

269-
270345
@Override
271346
public ConnectorSpec generateSpec(ConnectorContext context,
272347
ConnectorSpecRequest connectorSpecRequest) throws IOException {
@@ -297,6 +372,9 @@ public ConnectorSpec generateSpec(ConnectorContext context,
297372
.addRelatedPlugin(new PluginSpec(BigQuerySink.NAME, BatchSink.PLUGIN_TYPE, properties))
298373
.addRelatedPlugin(new PluginSpec(BigQueryMultiSink.NAME, BatchSink.PLUGIN_TYPE, properties))
299374
.addRelatedPlugin(new PluginSpec(BigQuerySQLEngine.NAME, BatchSQLEngine.PLUGIN_TYPE, properties))
375+
.addSupportedSampleType(SampleType.RANDOM)
376+
.addSupportedSampleType(SampleType.STRATIFIED)
300377
.build();
301378
}
302379
}
380+
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Copyright © 2022 Cask Data, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
5+
* use this file except in compliance with the License. You may obtain a copy of
6+
* the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
* License for the specific language governing permissions and limitations under
14+
* the License.
15+
*/
16+
17+
package io.cdap.plugin.gcp.bigquery.connector;
18+
19+
import com.google.cloud.bigquery.BigQueryError;
20+
import com.google.cloud.bigquery.EmptyTableResult;
21+
import com.google.cloud.bigquery.Job;
22+
import com.google.cloud.bigquery.JobStatus;
23+
import com.google.cloud.bigquery.TableResult;
24+
import io.cdap.cdap.api.data.format.StructuredRecord;
25+
import io.cdap.cdap.etl.api.connector.SampleType;
26+
import io.cdap.plugin.gcp.bigquery.util.BigQueryDataParser;
27+
import org.junit.Assert;
28+
import org.junit.Before;
29+
import org.junit.Rule;
30+
import org.junit.Test;
31+
import org.junit.rules.ExpectedException;
32+
import org.junit.runner.RunWith;
33+
import org.mockito.junit.MockitoJUnitRunner;
34+
35+
import static org.powermock.api.mockito.PowerMockito.doAnswer;
36+
import static org.powermock.api.mockito.PowerMockito.doThrow;
37+
import static org.powermock.api.mockito.PowerMockito.mock;
38+
39+
import java.io.IOException;
40+
import java.util.List;
41+
import java.util.UUID;
42+
43+
@RunWith(MockitoJUnitRunner.class)
44+
public class BigQueryConnectorUnitTest {
45+
46+
@Rule
47+
public ExpectedException expectedEx = ExpectedException.none();
48+
49+
private static final BigQueryConnector CONNECTOR = new BigQueryConnector(null);
50+
private String tableName;
51+
private String sessionID;
52+
private int limit;
53+
private String strata;
54+
private Job queryJob;
55+
56+
@Before
57+
public void setUp() {
58+
tableName = "`project.dataset.table`";
59+
sessionID = UUID.randomUUID().toString().replace('-', '_');
60+
limit = 100;
61+
strata = "strata";
62+
}
63+
64+
/**
65+
* Unit tests for getTableQuery()
66+
*/
67+
@Test
68+
public void getTableQueryTest() {
69+
// random query
70+
Assert.assertEquals(String.format("WITH table AS (\n" +
71+
" SELECT *, RAND() AS r_%s\n" +
72+
" FROM %s\n" +
73+
" WHERE RAND() < 2*%d/(SELECT COUNT(*) FROM %s)\n" +
74+
")\n" +
75+
"SELECT * EXCEPT (r_%s)\n" +
76+
"FROM table\n" +
77+
"ORDER BY r_%s\n" +
78+
"LIMIT %d",
79+
sessionID, tableName, limit, tableName, sessionID, sessionID, limit),
80+
CONNECTOR.getTableQuery(tableName, limit, SampleType.RANDOM, null, sessionID));
81+
82+
// stratified query
83+
Assert.assertEquals(String.format("SELECT * EXCEPT (`sqn_%s`, `c_%s`)\n" +
84+
"FROM (\n" +
85+
"SELECT *, row_number() OVER (ORDER BY %s, RAND()) AS sqn_%s,\n" +
86+
"COUNT(*) OVER () as c_%s,\n" +
87+
"FROM %s\n" +
88+
") %s\n" +
89+
"WHERE MOD(sqn_%s, CAST(c_%s / %d AS INT64)) = 1\n" +
90+
"ORDER BY %s\n" +
91+
"LIMIT %d",
92+
sessionID, sessionID, strata, sessionID, sessionID, tableName, tableName,
93+
sessionID, sessionID, limit, strata, limit),
94+
CONNECTOR.getTableQuery(tableName, limit, SampleType.STRATIFIED, strata, sessionID));
95+
96+
// default query
97+
Assert.assertEquals(String.format("SELECT * FROM %s LIMIT %d", tableName, limit),
98+
CONNECTOR.getTableQuery(tableName, limit, SampleType.DEFAULT, null, sessionID));
99+
}
100+
101+
/**
102+
* Test for {@link IllegalArgumentException} from getTableQuery when attempting stratified query with null strata
103+
* @throws IllegalArgumentException expected
104+
*/
105+
@Test
106+
public void getTableQueryNullStrataTest() throws IllegalArgumentException {
107+
expectedEx.expect(IllegalArgumentException.class);
108+
CONNECTOR.getTableQuery(tableName, limit, SampleType.STRATIFIED, null, sessionID);
109+
}
110+
111+
/**
112+
* Test for {@link IOException} from getQueryResult() when attempting on null query job
113+
* @throws IOException expected
114+
*/
115+
@Test
116+
public void getQueryResultNullJobTest() throws IOException {
117+
expectedEx.expect(IOException.class);
118+
CONNECTOR.getQueryResult(null, sessionID);
119+
}
120+
121+
/**
122+
* Test for {@link IOException} from getQueryResult() if job timed out
123+
* @throws IOException expected
124+
*/
125+
@Test
126+
public void getQueryResultTimedOutTest() throws IOException {
127+
expectedEx.expect(IOException.class);
128+
queryJob = mock(Job.class);
129+
doAnswer(invocation -> false).when(queryJob).isDone();
130+
CONNECTOR.getQueryResult(queryJob, sessionID);
131+
}
132+
133+
/**
134+
* Test for {@link IOException} from getQueryResult() if query has error
135+
* @throws IOException expected
136+
*/
137+
@Test
138+
public void getQueryResultErrorTest() throws IOException {
139+
expectedEx.expect(IOException.class);
140+
queryJob = mock(Job.class);
141+
doAnswer(invocation -> true).when(queryJob).isDone();
142+
JobStatus status = mock(JobStatus.class);
143+
doAnswer(invocation -> mock(BigQueryError.class)).when(status).getError();
144+
doAnswer(invocation -> status).when(queryJob).getStatus();
145+
CONNECTOR.getQueryResult(queryJob, sessionID);
146+
}
147+
148+
/**
149+
* Test for {@link IOException} from getQueryResult() if query is interrupted
150+
* @throws InterruptedException to get IOException
151+
* @throws IOException expected
152+
*/
153+
@Test
154+
public void getQueryResultInterruptedTest() throws InterruptedException, IOException {
155+
expectedEx.expect(IOException.class);
156+
queryJob = mock(Job.class);
157+
doAnswer(invocation -> true).when(queryJob).isDone();
158+
doAnswer(invocation -> mock(JobStatus.class)).when(queryJob).getStatus();
159+
doThrow(mock(InterruptedException.class)).when(queryJob).getQueryResults();
160+
CONNECTOR.getQueryResult(queryJob, sessionID);
161+
}
162+
163+
}

0 commit comments

Comments
 (0)