diff --git a/examples/java/src/main/java/org/apache/beam/examples/cookbook/BigQueryStreamingTornadoes.java b/examples/java/src/main/java/org/apache/beam/examples/cookbook/BigQueryStreamingTornadoes.java new file mode 100644 index 000000000000..395da115e0ca --- /dev/null +++ b/examples/java/src/main/java/org/apache/beam/examples/cookbook/BigQueryStreamingTornadoes.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.examples.cookbook; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryDynamicReadDescriptor; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.PeriodicImpulse; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An example that reads periodically the public samples of weather data from BigQuery, counts the + * number of tornadoes that occur in each month, and writes the results to BigQuery. + * + *

Concepts: Reading/writing BigQuery; counting a PCollection; user-defined PTransforms + * + *

Note: Before running this example, you must create a BigQuery dataset to contain your output + * table. + * + *

To execute this pipeline locally, specify the BigQuery table for the output with the form: + * + *

{@code
+ * --output=YOUR_PROJECT_ID:DATASET_ID.TABLE_ID
+ * }
+ * + *

To change the runner, specify: + * + *

{@code
+ * --runner=YOUR_SELECTED_RUNNER
+ * }
+ * + * See examples/java/README.md for instructions about how to configure different runners. + * + *

The BigQuery input table defaults to {@code apache-beam-testing.samples.weather_stations} and + * can be overridden with {@code --input}. + */ +public class BigQueryStreamingTornadoes { + private static final Logger LOG = LoggerFactory.getLogger(BigQueryStreamingTornadoes.class); + + // Default to using a 1000 row subset of the public weather station table publicdata:samples.gsod. + private static final String WEATHER_SAMPLES_TABLE = + "apache-beam-testing.samples.weather_stations"; + + /** + * Examines each row in the input table. If a tornado was recorded in that sample, the month in + * which it occurred is output. + */ + static class ExtractTornadoesFn extends DoFn { + @ProcessElement + public void processElement(ProcessContext c) { + TableRow row = c.element(); + if (Boolean.TRUE.equals(row.get("tornado"))) { + c.output(Integer.parseInt((String) row.get("month"))); + } + } + } + + /** + * Prepares the data for writing to BigQuery by building a TableRow object containing an integer + * representation of month and the number of tornadoes that occurred in each month. + */ + static class FormatCountsFn extends DoFn, TableRow> { + @ProcessElement + public void processElement(ProcessContext c) { + TableRow row = + new TableRow() + .set("ts", c.timestamp().toString()) + .set("month", c.element().getKey()) + .set("tornado_count", c.element().getValue()); + c.output(row); + } + } + + /** + * Takes rows from a table and generates a table of counts. + * + *

The input schema is described by https://developers.google.com/bigquery/docs/dataset-gsod . + * The output contains the total number of tornadoes found in each month in the following schema: + * + *

+ */ + static class CountTornadoes extends PTransform, PCollection> { + @Override + public PCollection expand(PCollection rows) { + + // row... => month... + PCollection tornadoes = rows.apply(ParDo.of(new ExtractTornadoesFn())); + + // month... => ... + PCollection> tornadoCounts = tornadoes.apply(Count.perElement()); + + // ... => row... + PCollection results = tornadoCounts.apply(ParDo.of(new FormatCountsFn())); + + return results; + } + } + + /** + * Options supported by {@link BigQueryStreamingTornadoes}. + * + *

Inherits standard configuration options. + */ + public interface Options extends PipelineOptions { + @Description("Table to read from, specified as :.") + @Default.String(WEATHER_SAMPLES_TABLE) + String getInput(); + + void setInput(String value); + + @Description("Write method to use to write to BigQuery") + @Default.Enum("DEFAULT") + BigQueryIO.Write.Method getWriteMethod(); + + void setWriteMethod(BigQueryIO.Write.Method value); + + @Description( + "BigQuery table to write to, specified as " + + ":.. The dataset must already exist.") + @Validation.Required + String getOutput(); + + void setOutput(String value); + } + + public static void applyBigQueryStreamingTornadoes(Pipeline p, Options options) { + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("ts").setType("STRING")); + fields.add(new TableFieldSchema().setName("month").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("tornado_count").setType("INTEGER")); + TableSchema schema = new TableSchema().setFields(fields); + + PCollection descriptors = + p.apply("Impulse", PeriodicImpulse.create().withInterval(Duration.standardSeconds(60))) + .apply( + "Create query", + MapElements.into(TypeDescriptor.of(BigQueryDynamicReadDescriptor.class)) + .via( + (Instant t) -> + BigQueryDynamicReadDescriptor.table( + WEATHER_SAMPLES_TABLE, null, null))); + + PCollection readDynamically = + descriptors.apply("Read dynamically", BigQueryIO.readDynamicallyTableRows()); + readDynamically + .apply(Window.into(FixedWindows.of(Duration.standardMinutes(1)))) + .apply(new CountTornadoes()) + .apply( + BigQueryIO.writeTableRows() + .to(options.getOutput()) + .withSchema(schema) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_NEVER) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_APPEND) + .withMethod(options.getWriteMethod())); + } + + public static void runBigQueryTornadoes(Options options) { + LOG.info("Running BigQuery Tornadoes with options " + options.toString()); + Pipeline p = Pipeline.create(options); + applyBigQueryStreamingTornadoes(p, options); + p.run().waitUntilFinish(); + } + + public static void main(String[] args) { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + runBigQueryTornadoes(options); + } +} diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle index b5b27003b944..f535ccd859cd 100644 --- a/sdks/java/io/google-cloud-platform/build.gradle +++ b/sdks/java/io/google-cloud-platform/build.gradle @@ -202,6 +202,8 @@ task integrationTest(type: Test, dependsOn: processTestResources) { exclude '**/BigQueryIOStorageQueryIT.class' exclude '**/BigQueryIOStorageReadIT.class' exclude '**/BigQueryIOStorageWriteIT.class' + exclude '**/BigQueryIODynamicQueryIT.class' + exclude '**/BigQueryIODynamicReadIT.class' exclude '**/BigQueryToTableIT.class' maxParallelForks 4 @@ -271,6 +273,7 @@ task bigQueryEarlyRolloutIntegrationTest(type: Test, dependsOn: processTestResou include '**/BigQueryToTableIT.class' include '**/BigQueryIOJsonIT.class' include '**/BigQueryIOStorageReadTableRowIT.class' + include '**/BigQueryIODynamicReadTableRowIT.class' // storage write api include '**/StorageApiDirectWriteProtosIT.class' include '**/StorageApiSinkFailedRowsIT.class' diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryDynamicReadDescriptor.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryDynamicReadDescriptor.java new file mode 100644 index 000000000000..b6da635ea1ec --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryDynamicReadDescriptor.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; + +import com.google.auto.value.AutoValue; +import java.io.Serializable; +import java.util.List; +import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaCreate; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldNumber; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.dataflow.qual.Pure; + +/** Represents a BigQuery source description used for dynamic read. */ +@DefaultSchema(AutoValueSchema.class) +@AutoValue +public abstract class BigQueryDynamicReadDescriptor implements Serializable { + @SchemaFieldName("query") + @SchemaFieldNumber("0") + @Pure + abstract @Nullable String getQuery(); + + @SchemaFieldName("table") + @SchemaFieldNumber("1") + @Pure + abstract @Nullable String getTable(); + + @SchemaFieldName("flattenResults") + @SchemaFieldNumber("2") + @Pure + abstract @Nullable Boolean getFlattenResults(); + + @SchemaFieldName("legacySql") + @SchemaFieldNumber("3") + @Pure + abstract @Nullable Boolean getUseLegacySql(); + + @SchemaFieldName("selectedFields") + @SchemaFieldNumber("4") + @Pure + abstract @Nullable List getSelectedFields(); + + @SchemaFieldName("rowRestriction") + @SchemaFieldNumber("5") + @Pure + abstract @Nullable String getRowRestriction(); + + @SchemaCreate + public static BigQueryDynamicReadDescriptor create( + @Nullable String query, + @Nullable String table, + @Nullable Boolean flattenResults, + @Nullable Boolean useLegacySql, + @Nullable List selectedFields, + @Nullable String rowRestriction) { + checkArgument((query != null || table != null), "Either query or table has to be specified."); + checkArgument( + !(query != null && table != null), "Either query or table has to be specified not both."); + checkArgument( + !(table != null && (flattenResults != null || useLegacySql != null)), + "Specifies a table with a result flattening preference or legacySql, which only applies to queries"); + checkArgument( + !(query != null && (selectedFields != null || rowRestriction != null)), + "Selected fields and row restriction are only applicable for table reads"); + checkArgument( + !(query != null && (flattenResults == null || useLegacySql == null)), + "If query is used, flattenResults and legacySql have to be set as well."); + + return new AutoValue_BigQueryDynamicReadDescriptor( + query, table, flattenResults, useLegacySql, selectedFields, rowRestriction); + } + + public static BigQueryDynamicReadDescriptor query( + String query, Boolean flattenResults, Boolean useLegacySql) { + return create(query, null, flattenResults, useLegacySql, null, null); + } + + public static BigQueryDynamicReadDescriptor table( + String table, @Nullable List selectedFields, @Nullable String rowRestriction) { + return create(null, table, null, null, selectedFields, rowRestriction); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index f986e802f1ca..4e3509d331d8 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -76,6 +76,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; import org.apache.beam.sdk.extensions.avro.io.AvroSource; @@ -118,11 +119,13 @@ import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Redistribute; import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.errorhandling.BadRecord; import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; @@ -669,6 +672,33 @@ public static TypedRead readTableRowsWithSchema() { BigQueryUtils.tableRowToBeamRow(), BigQueryUtils.tableRowFromBeamRow()); } + /** @deprecated this method may have breaking changes introduced, use with caution */ + @Deprecated + public static DynamicRead readDynamicallyTableRows() { + return new AutoValue_BigQueryIO_DynamicRead.Builder() + .setBigQueryServices(new BigQueryServicesImpl()) + .setParseFn(new TableRowParser()) + .setFormat(DataFormat.AVRO) + .setOutputCoder(TableRowJsonCoder.of()) + .setProjectionPushdownApplied(false) + .setBadRecordErrorHandler(new DefaultErrorHandler<>()) + .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER) + .build(); + } + /** @deprecated this method may have breaking changes introduced, use with caution */ + @Deprecated + public static DynamicRead readDynamically( + SerializableFunction parseFn, Coder outputCoder) { + return new AutoValue_BigQueryIO_DynamicRead.Builder() + .setBigQueryServices(new BigQueryServicesImpl()) + .setParseFn(parseFn) + .setFormat(DataFormat.AVRO) + .setOutputCoder(outputCoder) + .setProjectionPushdownApplied(false) + .setBadRecordErrorHandler(new DefaultErrorHandler<>()) + .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER) + .build(); + } private static class TableSchemaFunction implements Serializable, Function<@Nullable String, @Nullable TableSchema> { @@ -804,6 +834,208 @@ public TableRow apply(SchemaAndRecord schemaAndRecord) { return BigQueryAvroUtils.convertGenericRecordToTableRow(schemaAndRecord.getRecord()); } } + /** @deprecated this class may have breaking changes introduced, use with caution */ + @Deprecated + @AutoValue + public abstract static class DynamicRead + extends PTransform, PCollection> { + + abstract BigQueryServices getBigQueryServices(); + + abstract DataFormat getFormat(); + + abstract @Nullable SerializableFunction getParseFn(); + + abstract @Nullable Coder getOutputCoder(); + + abstract boolean getProjectionPushdownApplied(); + + abstract BadRecordRouter getBadRecordRouter(); + + abstract ErrorHandler getBadRecordErrorHandler(); + + abstract @Nullable String getQueryLocation(); + + abstract @Nullable String getQueryTempDataset(); + + abstract @Nullable String getQueryTempProject(); + + abstract @Nullable String getKmsKey(); + + abstract DynamicRead.Builder toBuilder(); + + public DynamicRead withQueryLocation(String location) { + return toBuilder().setQueryLocation(location).build(); + } + + public DynamicRead withQueryTempProject(String tempProject) { + return toBuilder().setQueryTempProject(tempProject).build(); + } + + public DynamicRead withQueryTempDataset(String tempDataset) { + return toBuilder().setQueryTempDataset(tempDataset).build(); + } + + public DynamicRead withKmsKey(String kmsKey) { + return toBuilder().setKmsKey(kmsKey).build(); + } + + public DynamicRead withFormat(DataFormat format) { + return toBuilder().setFormat(format).build(); + } + + public DynamicRead withBadRecordErrorHandler( + ErrorHandler badRecordErrorHandler) { + return toBuilder() + .setBadRecordRouter(RECORDING_ROUTER) + .setBadRecordErrorHandler(badRecordErrorHandler) + .build(); + } + + @VisibleForTesting + public DynamicRead withTestServices(BigQueryServices testServices) { + return toBuilder().setBigQueryServices(testServices).build(); + } + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setFormat(DataFormat format); + + abstract Builder setBigQueryServices(BigQueryServices bigQueryServices); + + abstract Builder setParseFn(SerializableFunction parseFn); + + abstract Builder setOutputCoder(Coder coder); + + abstract Builder setProjectionPushdownApplied(boolean projectionPushdownApplied); + + abstract Builder setBadRecordErrorHandler( + ErrorHandler badRecordErrorHandler); + + abstract Builder setBadRecordRouter(BadRecordRouter badRecordRouter); + + abstract DynamicRead build(); + + abstract Builder setKmsKey(String kmsKey); + + abstract Builder setQueryLocation(String queryLocation); + + abstract Builder setQueryTempDataset(String queryTempDataset); + + abstract Builder setQueryTempProject(String queryTempProject); + } + + DynamicRead() {} + + class CreateBoundedSourceForTable + extends DoFn, BigQueryStorageStreamSource> { + + @ProcessElement + public void processElement( + OutputReceiver> receiver, + @Element KV kv, + PipelineOptions options) + throws Exception { + + BigQueryDynamicReadDescriptor descriptor = kv.getValue(); + if (descriptor.getTable() != null) { + BigQueryStorageTableSource output = + BigQueryStorageTableSource.create( + StaticValueProvider.of(BigQueryHelpers.parseTableSpec(descriptor.getTable())), + getFormat(), + descriptor.getSelectedFields() != null + ? StaticValueProvider.of(descriptor.getSelectedFields()) + : null, + descriptor.getRowRestriction() != null + ? StaticValueProvider.of(descriptor.getRowRestriction()) + : null, + getParseFn(), + getOutputCoder(), + getBigQueryServices(), + getProjectionPushdownApplied()); + // 1mb --> 1 shard; 1gb --> 32 shards; 1tb --> 1000 shards, 1pb --> 32k + // shards + long desiredChunkSize = getDesiredChunkSize(options, output); + List> split = output.split(desiredChunkSize, options); + split.stream().forEach(source -> receiver.output(source)); + } else { + // run query + BigQueryStorageQuerySource querySource = + BigQueryStorageQuerySource.create( + kv.getKey(), + StaticValueProvider.of(descriptor.getQuery()), + descriptor.getFlattenResults(), + descriptor.getUseLegacySql(), + TypedRead.QueryPriority.INTERACTIVE, + getQueryLocation(), + getQueryTempDataset(), + getQueryTempProject(), + getKmsKey(), + getFormat(), + getParseFn(), + getOutputCoder(), + getBigQueryServices()); + Table queryResultTable = querySource.getTargetTable(options.as(BigQueryOptions.class)); + + BigQueryStorageTableSource output = + BigQueryStorageTableSource.create( + StaticValueProvider.of(queryResultTable.getTableReference()), + getFormat(), + null, + null, + getParseFn(), + getOutputCoder(), + getBigQueryServices(), + false); + // 1mb --> 1 shard; 1gb --> 32 shards; 1tb --> 1000 shards, 1pb --> 32k + // shards + long desiredChunkSize = getDesiredChunkSize(options, output); + List> split = output.split(desiredChunkSize, options); + split.stream().forEach(source -> receiver.output(source)); + } + } + + private long getDesiredChunkSize( + PipelineOptions options, BigQueryStorageTableSource output) throws Exception { + return Math.max(1 << 20, (long) (1000 * Math.sqrt(output.getEstimatedSizeBytes(options)))); + } + } + + @Override + public PCollection expand(PCollection input) { + TupleTag rowTag = new TupleTag<>(); + PCollection> addJobId = + input + .apply( + "Add job id", + WithKeys.of( + new SimpleFunction() { + @Override + public String apply(BigQueryDynamicReadDescriptor input) { + return BigQueryHelpers.randomUUIDString(); + } + })) + .apply("Checkpoint", Redistribute.byKey()); + + PCollectionTuple resultTuple = + addJobId + .apply("Create streams", ParDo.of(new CreateBoundedSourceForTable())) + .setCoder( + SerializableCoder.of(new TypeDescriptor>() {})) + .apply("Redistribute", Redistribute.arbitrarily()) + .apply( + "Read Streams with storage read api", + ParDo.of( + new TypedRead.ReadTableSource( + rowTag, getParseFn(), getBadRecordRouter())) + .withOutputTags(rowTag, TupleTagList.of(BAD_RECORD_TAG))); + getBadRecordErrorHandler() + .addErrorCollection( + resultTuple.get(BAD_RECORD_TAG).setCoder(BadRecord.getCoder(input.getPipeline()))); + return resultTuple.get(rowTag).setCoder(getOutputCoder()); + } + } /** Implementation of {@link BigQueryIO#read()}. */ public static class Read extends PTransform> { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java index 5dbebc7fb79d..124a708eed6b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java @@ -52,6 +52,7 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Objects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; import org.checkerframework.checker.nullness.qual.RequiresNonNull; @@ -79,6 +80,26 @@ public static BigQueryStorageStreamSource create( bqServices); } + @Override + public boolean equals(@Nullable Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + BigQueryStorageStreamSource other = (BigQueryStorageStreamSource) obj; + return readSession.equals(other.readSession) + && readStream.equals(other.readStream) + && jsonTableSchema.equals(other.jsonTableSchema) + && outputCoder.equals(other.outputCoder); + } + + @Override + public int hashCode() { + return Objects.hashCode(readSession, readStream, jsonTableSchema, outputCoder); + } + /** * Creates a new source with the same properties as this one, except with a different {@link * ReadStream}. diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicQueryIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicQueryIT.java new file mode 100644 index 000000000000..7ea512bec355 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicQueryIT.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions.BIGQUERY_EARLY_ROLLOUT_REGION; + +import com.google.api.services.bigquery.model.TableRow; +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TableRowParser; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.ErrorSinkTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Integration tests for {@link BigQueryIO#read(SerializableFunction)} using {@link + * Method#DIRECT_READ} to read query results. This test runs a simple "SELECT *" query over a + * pre-defined table and asserts that the number of records read is equal to the expected count. + */ +@RunWith(JUnit4.class) +public class BigQueryIODynamicQueryIT { + + private static final Map EXPECTED_NUM_RECORDS = + ImmutableMap.of( + "empty", 0L, + "1M", 10592L, + "1G", 11110839L, + "1T", 11110839000L); + + private static final String DATASET_ID = + TestPipeline.testingPipelineOptions() + .as(TestBigQueryOptions.class) + .getBigQueryLocation() + .equals(BIGQUERY_EARLY_ROLLOUT_REGION) + ? "big_query_storage_day0" + : "big_query_storage"; + private static final String TABLE_PREFIX = "storage_read_"; + + private BigQueryIOQueryOptions options; + + /** Customized {@link TestPipelineOptions} for BigQueryIOStorageQuery pipelines. */ + public interface BigQueryIOQueryOptions extends TestPipelineOptions, ExperimentalOptions { + @Description("The table to be queried") + @Validation.Required + String getInputTable(); + + void setInputTable(String table); + + @Description("The expected number of records") + @Validation.Required + long getNumRecords(); + + void setNumRecords(long numRecords); + } + + private void setUpTestEnvironment(String tableSize) { + PipelineOptionsFactory.register(BigQueryIOQueryOptions.class); + options = TestPipeline.testingPipelineOptions().as(BigQueryIOQueryOptions.class); + options.setNumRecords(EXPECTED_NUM_RECORDS.get(tableSize)); + String project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + options.setInputTable(project + '.' + DATASET_ID + '.' + TABLE_PREFIX + tableSize); + } + + private void runBigQueryIODynamicQueryPipeline() { + Pipeline p = Pipeline.create(options); + PCollection count = + p.apply( + Create.of( + BigQueryDynamicReadDescriptor.create( + "SELECT * FROM `" + options.getInputTable() + "`", + null, + false, + false, + null, + null))) + .apply( + "DynamicRead", + BigQueryIO.readDynamically(TableRowParser.INSTANCE, TableRowJsonCoder.of())) + .apply("Count", Count.globally()); + + PAssert.thatSingleton(count).isEqualTo(options.getNumRecords()); + p.run().waitUntilFinish(); + } + + @Test + public void testBigQueryDynamicQuery1G() throws Exception { + setUpTestEnvironment("1G"); + runBigQueryIODynamicQueryPipeline(); + } + + static class FailingTableRowParser implements SerializableFunction { + + public static final BigQueryIOStorageReadIT.FailingTableRowParser INSTANCE = + new BigQueryIOStorageReadIT.FailingTableRowParser(); + + private int parseCount = 0; + + @Override + public TableRow apply(SchemaAndRecord schemaAndRecord) { + parseCount++; + if (parseCount % 50 == 0) { + throw new RuntimeException("ExpectedException"); + } + return TableRowParser.INSTANCE.apply(schemaAndRecord); + } + } + + @Test + public void testBigQueryDynamicQueryWithErrorHandling1M() throws Exception { + setUpTestEnvironment("1M"); + Pipeline p = Pipeline.create(options); + ErrorHandler> errorHandler = + p.registerBadRecordErrorHandler(new ErrorSinkTransform()); + PCollection count = + p.apply( + Create.of( + BigQueryDynamicReadDescriptor.create( + "SELECT * FROM `" + options.getInputTable() + "`", + null, + false, + false, + null, + null))) + .apply( + "DynamicRead", + BigQueryIO.readDynamically(FailingTableRowParser.INSTANCE, TableRowJsonCoder.of()) + .withBadRecordErrorHandler(errorHandler)) + .apply("Count", Count.globally()); + errorHandler.close(); + + // When 1/50 elements fail sequentially, this is the expected success count + PAssert.thatSingleton(count).isEqualTo(10381L); + // this is the total elements, less the successful elements + PAssert.thatSingleton(errorHandler.getOutput()).isEqualTo(10592L - 10381L); + p.run().waitUntilFinish(); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicReadIT.java new file mode 100644 index 000000000000..742a390c8bd1 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicReadIT.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions.BIGQUERY_EARLY_ROLLOUT_REGION; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.bigquery.storage.v1.DataFormat; +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TableRowParser; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.ErrorSinkTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Integration tests for {@link BigQueryIO#readDynamically(SerializableFunction, + * org.apache.beam.sdk.coders.Coder)} using {@link Method#DIRECT_READ}. This test reads from a + * pre-defined table and asserts that the number of records read is equal to the expected count. + */ +@RunWith(JUnit4.class) +public class BigQueryIODynamicReadIT { + + private static final Map EXPECTED_NUM_RECORDS = + ImmutableMap.of( + "empty", 0L, + "1M", 10592L, + "1G", 11110839L, + "1T", 11110839000L, + "multi_field", 11110839L); + + private static final String DATASET_ID = + TestPipeline.testingPipelineOptions() + .as(TestBigQueryOptions.class) + .getBigQueryLocation() + .equals(BIGQUERY_EARLY_ROLLOUT_REGION) + ? "big_query_storage_day0" + : "big_query_storage"; + private static final String TABLE_PREFIX = "storage_read_"; + + private BigQueryIODynamicReadOptions options; + + /** Customized {@link TestPipelineOptions} for BigQueryIOStorageRead pipelines. */ + public interface BigQueryIODynamicReadOptions extends TestPipelineOptions, ExperimentalOptions { + @Description("The table to be read") + @Validation.Required + String getInputTable(); + + void setInputTable(String table); + + @Description("The expected number of records") + @Validation.Required + long getNumRecords(); + + void setNumRecords(long numRecords); + + @Description("The data format to use") + @Validation.Required + DataFormat getDataFormat(); + + void setDataFormat(DataFormat dataFormat); + } + + private void setUpTestEnvironment(String tableSize, DataFormat format) { + PipelineOptionsFactory.register(BigQueryIODynamicReadOptions.class); + options = TestPipeline.testingPipelineOptions().as(BigQueryIODynamicReadOptions.class); + options.setNumRecords(EXPECTED_NUM_RECORDS.get(tableSize)); + options.setDataFormat(format); + String project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + options.setInputTable(project + ":" + DATASET_ID + "." + TABLE_PREFIX + tableSize); + } + + private void runBigQueryIODynamicReadPipeline() { + Pipeline p = Pipeline.create(options); + PCollection count = + p.apply( + Create.of( + BigQueryDynamicReadDescriptor.create( + null, options.getInputTable(), null, null, null, null))) + .apply( + "Read", + BigQueryIO.readDynamically(TableRowParser.INSTANCE, TableRowJsonCoder.of()) + .withFormat(options.getDataFormat())) + .apply("Count", Count.globally()); + PAssert.thatSingleton(count).isEqualTo(options.getNumRecords()); + p.run().waitUntilFinish(); + } + + static class FailingTableRowParser implements SerializableFunction { + + public static final FailingTableRowParser INSTANCE = new FailingTableRowParser(); + + private int parseCount = 0; + + @Override + public TableRow apply(SchemaAndRecord schemaAndRecord) { + parseCount++; + if (parseCount % 50 == 0) { + throw new RuntimeException("ExpectedException"); + } + return TableRowParser.INSTANCE.apply(schemaAndRecord); + } + } + + private void runBigQueryIODynamicReadPipelineErrorHandling() throws Exception { + Pipeline p = Pipeline.create(options); + ErrorHandler> errorHandler = + p.registerBadRecordErrorHandler(new ErrorSinkTransform()); + PCollection count = + p.apply( + Create.of( + BigQueryDynamicReadDescriptor.create( + null, options.getInputTable(), null, null, null, null))) + .apply( + "Read", + BigQueryIO.readDynamically(TableRowParser.INSTANCE, TableRowJsonCoder.of()) + .withFormat(options.getDataFormat()) + .withBadRecordErrorHandler(errorHandler)) + .apply("Count", Count.globally()); + + errorHandler.close(); + + // When 1/50 elements fail sequentially, this is the expected success count + PAssert.thatSingleton(count).isEqualTo(10381L); + // this is the total elements, less the successful elements + PAssert.thatSingleton(errorHandler.getOutput()).isEqualTo(10592L - 10381L); + p.run().waitUntilFinish(); + } + + @Test + public void testBigQueryDynamicRead1GAvro() throws Exception { + setUpTestEnvironment("1G", DataFormat.AVRO); + runBigQueryIODynamicReadPipeline(); + } + + @Test + public void testBigQueryDynamicRead1GArrow() throws Exception { + setUpTestEnvironment("1G", DataFormat.ARROW); + runBigQueryIODynamicReadPipeline(); + } + + @Test + public void testBigQueryDynamicRead1MErrorHandlingAvro() throws Exception { + setUpTestEnvironment("1M", DataFormat.AVRO); + runBigQueryIODynamicReadPipelineErrorHandling(); + } + + @Test + public void testBigQueryDynamicRead1MErrorHandlingArrow() throws Exception { + setUpTestEnvironment("1M", DataFormat.ARROW); + runBigQueryIODynamicReadPipelineErrorHandling(); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicReadTableRowIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicReadTableRowIT.java new file mode 100644 index 000000000000..4fecb18ce507 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicReadTableRowIT.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions.BIGQUERY_EARLY_ROLLOUT_REGION; + +import com.google.api.services.bigquery.model.TableRow; +import java.util.HashSet; +import java.util.Set; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TableRowParser; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.transforms.join.CoGbkResult; +import org.apache.beam.sdk.transforms.join.CoGroupByKey; +import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TupleTag; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Integration tests for {@link BigQueryIO#readTableRows()} using {@link Method#DIRECT_READ} in + * combination with {@link TableRowParser} to generate output in {@link TableRow} form. + */ +@RunWith(JUnit4.class) +public class BigQueryIODynamicReadTableRowIT { + + private static final String DATASET_ID = + TestPipeline.testingPipelineOptions() + .as(TestBigQueryOptions.class) + .getBigQueryLocation() + .equals(BIGQUERY_EARLY_ROLLOUT_REGION) + ? "big_query_import_export_day0" + : "big_query_import_export"; + private static final String TABLE_PREFIX = "parallel_read_table_row_"; + + private BigQueryIODynamicReadTableRowOptions options; + + /** Private pipeline options for the test. */ + public interface BigQueryIODynamicReadTableRowOptions + extends TestPipelineOptions, ExperimentalOptions { + @Description("The table to be read") + @Validation.Required + String getInputTable(); + + void setInputTable(String table); + } + + private static class TableRowToKVPairFn extends SimpleFunction> { + @Override + public KV apply(TableRow input) { + Integer rowId = Integer.parseInt((String) input.get("id")); + return KV.of(rowId, BigQueryHelpers.toJsonString(input)); + } + } + + private void setUpTestEnvironment(String tableName) { + PipelineOptionsFactory.register(BigQueryIODynamicReadTableRowOptions.class); + options = TestPipeline.testingPipelineOptions().as(BigQueryIODynamicReadTableRowOptions.class); + String project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + options.setInputTable(project + ":" + DATASET_ID + "." + TABLE_PREFIX + tableName); + options.setTempLocation( + FileSystems.matchNewDirectory(options.getTempRoot(), "temp-it").toString()); + } + + private static void runPipeline(BigQueryIODynamicReadTableRowOptions pipelineOptions) { + Pipeline pipeline = Pipeline.create(pipelineOptions); + + PCollection> jsonTableRowsFromExport = + pipeline + .apply( + Create.of( + BigQueryDynamicReadDescriptor.create( + null, pipelineOptions.getInputTable(), null, null, null, null))) + .apply("DynamicRead", BigQueryIO.readDynamicallyTableRows()) + .apply("MapExportedRows", MapElements.via(new TableRowToKVPairFn())); + + PCollection> jsonTableRowsFromDirectRead = + pipeline + .apply( + "DirectReadTable", + BigQueryIO.readTableRows() + .from(pipelineOptions.getInputTable()) + .withMethod(Method.DIRECT_READ)) + .apply("MapDirectReadRows", MapElements.via(new TableRowToKVPairFn())); + + final TupleTag exportTag = new TupleTag<>(); + final TupleTag directReadTag = new TupleTag<>(); + + PCollection>> unmatchedRows = + KeyedPCollectionTuple.of(exportTag, jsonTableRowsFromExport) + .and(directReadTag, jsonTableRowsFromDirectRead) + .apply(CoGroupByKey.create()) + .apply( + ParDo.of( + new DoFn, KV>>() { + @ProcessElement + public void processElement(ProcessContext c) { + KV element = c.element(); + + // Add all the exported rows for the key to a collection. + Set uniqueRows = new HashSet<>(); + for (String row : element.getValue().getAll(exportTag)) { + uniqueRows.add(row); + } + + // Compute the disjunctive union of the rows in the direct read collection. + for (String row : element.getValue().getAll(directReadTag)) { + if (uniqueRows.contains(row)) { + uniqueRows.remove(row); + } else { + uniqueRows.add(row); + } + } + + // Emit any rows in the result set. + if (!uniqueRows.isEmpty()) { + c.output(KV.of(element.getKey(), uniqueRows)); + } + } + })); + + PAssert.that(unmatchedRows).empty(); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testBigQueryDynamicReadTableRow100() { + setUpTestEnvironment("100"); + runPipeline(options); + } + + @Test + public void testBigQueryDynamicReadTableRow1k() { + setUpTestEnvironment("1K"); + runPipeline(options); + } + + @Test + public void testBigQueryDynamicReadTableRow10k() { + setUpTestEnvironment("10K"); + runPipeline(options); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicReadTest.java new file mode 100644 index 000000000000..9fd777b477b4 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIODynamicReadTest.java @@ -0,0 +1,786 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import com.google.api.services.bigquery.model.JobStatistics; +import com.google.api.services.bigquery.model.JobStatistics2; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.bigquery.storage.v1.ArrowRecordBatch; +import com.google.cloud.bigquery.storage.v1.ArrowSchema; +import com.google.cloud.bigquery.storage.v1.AvroRows; +import com.google.cloud.bigquery.storage.v1.AvroSchema; +import com.google.cloud.bigquery.storage.v1.CreateReadSessionRequest; +import com.google.cloud.bigquery.storage.v1.DataFormat; +import com.google.cloud.bigquery.storage.v1.ReadRowsRequest; +import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; +import com.google.cloud.bigquery.storage.v1.ReadSession; +import com.google.cloud.bigquery.storage.v1.ReadStream; +import com.google.cloud.bigquery.storage.v1.StreamStats; +import com.google.cloud.bigquery.storage.v1.StreamStats.Progress; +import com.google.protobuf.ByteString; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.util.Text; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.Encoder; +import org.apache.avro.io.EncoderFactory; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.extensions.protobuf.ByteStringCoder; +import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient; +import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; +import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices.FakeBigQueryServerStream; +import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; +import org.apache.beam.sdk.io.gcp.testing.FakeJobService; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.junit.runners.model.Statement; + +/** + * Tests for {@link BigQueryIO#readDynamically(SerializableFunction, Coder)} limited to direct read. + */ +@RunWith(JUnit4.class) +public class BigQueryIODynamicReadTest { + + private static final EncoderFactory ENCODER_FACTORY = EncoderFactory.get(); + private static final String AVRO_SCHEMA_STRING = + "{\"namespace\": \"example.avro\",\n" + + " \"type\": \"record\",\n" + + " \"name\": \"RowRecord\",\n" + + " \"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"},\n" + + " {\"name\": \"number\", \"type\": \"long\"}\n" + + " ]\n" + + "}"; + private static final Schema AVRO_SCHEMA = new Schema.Parser().parse(AVRO_SCHEMA_STRING); + private static final String TRIMMED_AVRO_SCHEMA_STRING = + "{\"namespace\": \"example.avro\",\n" + + "\"type\": \"record\",\n" + + "\"name\": \"RowRecord\",\n" + + "\"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"}\n" + + " ]\n" + + "}"; + private static final Schema TRIMMED_AVRO_SCHEMA = + new Schema.Parser().parse(TRIMMED_AVRO_SCHEMA_STRING); + private static final TableSchema TABLE_SCHEMA = + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING").setMode("REQUIRED"), + new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED"))); + private static final org.apache.arrow.vector.types.pojo.Schema ARROW_SCHEMA = + new org.apache.arrow.vector.types.pojo.Schema( + asList( + field("name", new ArrowType.Utf8()), field("number", new ArrowType.Int(64, true)))); + private final transient TemporaryFolder testFolder = new TemporaryFolder(); + private final FakeDatasetService fakeDatasetService = new FakeDatasetService(); + @Rule public transient ExpectedException thrown = ExpectedException.none(); + private transient GcpOptions options; + private transient TestPipeline p; + + @Rule + public final transient TestRule folderThenPipeline = + new TestRule() { + @Override + public Statement apply(Statement base, Description description) { + // We need to set up the temporary folder, and then set up the TestPipeline based on the + // chosen folder. Unfortunately, since rule evaluation order is unspecified and unrelated + // to field order, and is separate from construction, that requires manually creating this + // TestRule. + Statement withPipeline = + new Statement() { + @Override + public void evaluate() throws Throwable { + options = TestPipeline.testingPipelineOptions().as(GcpOptions.class); + options.as(BigQueryOptions.class).setProject("project-id"); + if (description.getAnnotations().stream() + .anyMatch(a -> a.annotationType().equals(ProjectOverride.class))) { + options.as(BigQueryOptions.class).setBigQueryProject("bigquery-project-id"); + } + options + .as(BigQueryOptions.class) + .setTempLocation(testFolder.getRoot().getAbsolutePath()); + p = TestPipeline.fromOptions(options); + p.apply(base, description).evaluate(); + } + }; + return testFolder.apply(withPipeline, description); + } + }; + + private BufferAllocator allocator; + + private static ByteString serializeArrowSchema( + org.apache.arrow.vector.types.pojo.Schema arrowSchema) { + ByteArrayOutputStream byteOutputStream = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize( + new WriteChannel(Channels.newChannel(byteOutputStream)), arrowSchema); + } catch (IOException ex) { + throw new RuntimeException("Failed to serialize arrow schema.", ex); + } + return ByteString.copyFrom(byteOutputStream.toByteArray()); + } + + private static ReadRowsResponse createResponse( + Schema schema, + Collection genericRecords, + double progressAtResponseStart, + double progressAtResponseEnd) + throws Exception { + GenericDatumWriter writer = new GenericDatumWriter<>(schema); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + Encoder binaryEncoder = ENCODER_FACTORY.binaryEncoder(outputStream, null); + for (GenericRecord genericRecord : genericRecords) { + writer.write(genericRecord, binaryEncoder); + } + + binaryEncoder.flush(); + + return ReadRowsResponse.newBuilder() + .setAvroRows( + AvroRows.newBuilder() + .setSerializedBinaryRows(ByteString.copyFrom(outputStream.toByteArray())) + .setRowCount(genericRecords.size())) + .setRowCount(genericRecords.size()) + .setStats( + StreamStats.newBuilder() + .setProgress( + Progress.newBuilder() + .setAtResponseStart(progressAtResponseStart) + .setAtResponseEnd(progressAtResponseEnd))) + .build(); + } + + private static GenericRecord createRecord(String name, Schema schema) { + GenericRecord genericRecord = new GenericData.Record(schema); + genericRecord.put("name", name); + return genericRecord; + } + + private static GenericRecord createRecord(String name, long number, Schema schema) { + GenericRecord genericRecord = new GenericData.Record(schema); + genericRecord.put("name", name); + genericRecord.put("number", number); + return genericRecord; + } + + private static org.apache.arrow.vector.types.pojo.Field field( + String name, + boolean nullable, + ArrowType type, + org.apache.arrow.vector.types.pojo.Field... children) { + return new org.apache.arrow.vector.types.pojo.Field( + name, + new org.apache.arrow.vector.types.pojo.FieldType(nullable, type, null, null), + asList(children)); + } + + static org.apache.arrow.vector.types.pojo.Field field( + String name, ArrowType type, org.apache.arrow.vector.types.pojo.Field... children) { + return field(name, false, type, children); + } + + @Before + public void setUp() throws Exception { + FakeDatasetService.setUp(); + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @After + public void teardown() { + allocator.close(); + } + + @Test + public void testCreateWithQuery() { + String query = "SELECT * FROM dataset.table"; + Boolean flattenResults = true; + Boolean legacySql = false; + + BigQueryDynamicReadDescriptor descriptor = + BigQueryDynamicReadDescriptor.create(query, null, flattenResults, legacySql, null, null); + + assertNotNull(descriptor); + } + + @Test + public void testCreateWithTable() { + String table = "dataset.table"; + + BigQueryDynamicReadDescriptor descriptor = + BigQueryDynamicReadDescriptor.create(null, table, null, null, null, null); + + assertNotNull(descriptor); + } + + @Test + public void testCreateWithTableAndSelectedFieldsAndRowRestriction() { + String table = "dataset.table"; + List selectedFields = Arrays.asList("field1", "field2"); + String rowRestriction = "field1 > 10"; + + BigQueryDynamicReadDescriptor descriptor = + BigQueryDynamicReadDescriptor.create( + null, table, null, null, selectedFields, rowRestriction); + + assertNotNull(descriptor); + } + + @Test + public void testCreateWithNullQueryAndTableShouldThrowException() { + assertThrows( + IllegalArgumentException.class, + () -> BigQueryDynamicReadDescriptor.create(null, null, null, null, null, null)); + } + + @Test + public void testCreateWithBothQueryAndTableShouldThrowException() { + String query = "SELECT * FROM dataset.table"; + String table = "dataset.table"; + assertThrows( + IllegalArgumentException.class, + () -> BigQueryDynamicReadDescriptor.create(query, table, null, null, null, null)); + } + + @Test + public void testCreateWithTableAndFlattenResultsShouldThrowException() { + String table = "dataset.table"; + Boolean flattenResults = true; + assertThrows( + IllegalArgumentException.class, + () -> BigQueryDynamicReadDescriptor.create(null, table, flattenResults, null, null, null)); + } + + @Test + public void testCreateWithTableAndLegacySqlShouldThrowException() { + String table = "dataset.table"; + Boolean legacySql = true; + assertThrows( + IllegalArgumentException.class, + () -> BigQueryDynamicReadDescriptor.create(null, table, null, legacySql, null, null)); + } + + @Test + public void testCreateWithQueryAndSelectedFieldsShouldThrowException() { + String query = "SELECT * FROM dataset.table"; + Boolean flattenResults = true; + Boolean legacySql = false; + List selectedFields = Arrays.asList("field1", "field2"); + + assertThrows( + IllegalArgumentException.class, + () -> + BigQueryDynamicReadDescriptor.create( + query, null, flattenResults, legacySql, selectedFields, null)); + } + + @Test + public void testCreateWithQueryAndRowRestrictionShouldThrowException() { + String query = "SELECT * FROM dataset.table"; + Boolean flattenResults = true; + Boolean legacySql = false; + String rowRestriction = "field1 > 10"; + + assertThrows( + IllegalArgumentException.class, + () -> + BigQueryDynamicReadDescriptor.create( + query, null, flattenResults, legacySql, null, rowRestriction)); + } + + @Test + public void testCreateWithQueryAndNullFlattenResultsShouldThrowException() { + String query = "SELECT * FROM dataset.table"; + Boolean legacySql = false; + + assertThrows( + IllegalArgumentException.class, + () -> BigQueryDynamicReadDescriptor.create(query, null, null, legacySql, null, null)); + } + + @Test + public void testCreateWithQueryAndNullLegacySqlShouldThrowException() { + String query = "SELECT * FROM dataset.table"; + Boolean flattenResults = true; + + assertThrows( + IllegalArgumentException.class, + () -> BigQueryDynamicReadDescriptor.create(query, null, flattenResults, null, null, null)); + } + + @Test + public void testCoderInference() { + // Lambdas erase too much type information -- use an anonymous class here. + SerializableFunction> parseFn = + new SerializableFunction>() { + @Override + public KV apply(SchemaAndRecord input) { + return null; + } + }; + + assertEquals( + KvCoder.of(ByteStringCoder.of(), ProtoCoder.of(ReadSession.class)), + BigQueryIO.read(parseFn).inferCoder(CoderRegistry.createDefault())); + } + + private ReadRowsResponse createResponseArrow( + org.apache.arrow.vector.types.pojo.Schema arrowSchema, + List name, + List number, + double progressAtResponseStart, + double progressAtResponseEnd) { + ArrowRecordBatch serializedRecord; + try (VectorSchemaRoot schemaRoot = VectorSchemaRoot.create(arrowSchema, allocator)) { + schemaRoot.allocateNew(); + schemaRoot.setRowCount(name.size()); + VarCharVector strVector = (VarCharVector) schemaRoot.getFieldVectors().get(0); + BigIntVector bigIntVector = (BigIntVector) schemaRoot.getFieldVectors().get(1); + for (int i = 0; i < name.size(); i++) { + bigIntVector.set(i, number.get(i)); + strVector.set(i, new Text(name.get(i))); + } + + VectorUnloader unLoader = new VectorUnloader(schemaRoot); + try (org.apache.arrow.vector.ipc.message.ArrowRecordBatch records = + unLoader.getRecordBatch()) { + try (ByteArrayOutputStream os = new ByteArrayOutputStream()) { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(os)), records); + serializedRecord = + ArrowRecordBatch.newBuilder() + .setRowCount(records.getLength()) + .setSerializedRecordBatch(ByteString.copyFrom(os.toByteArray())) + .build(); + } catch (IOException e) { + throw new RuntimeException("Error writing to byte array output stream", e); + } + } + } + + return ReadRowsResponse.newBuilder() + .setArrowRecordBatch(serializedRecord) + .setRowCount(name.size()) + .setStats( + StreamStats.newBuilder() + .setProgress( + Progress.newBuilder() + .setAtResponseStart(progressAtResponseStart) + .setAtResponseEnd(progressAtResponseEnd))) + .build(); + } + + private static final class ParseKeyValue + implements SerializableFunction> { + + @Override + public KV apply(SchemaAndRecord input) { + return KV.of( + input.getRecord().get("name").toString(), (Long) input.getRecord().get("number")); + } + } + + @Test + public void testReadFromBigQueryIO() throws Exception { + fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); + Table table = new Table().setTableReference(tableRef).setNumBytes(10L).setSchema(TABLE_SCHEMA); + fakeDatasetService.createTable(table); + + CreateReadSessionRequest expectedCreateReadSessionRequest = + CreateReadSessionRequest.newBuilder() + .setParent("projects/project-id") + .setReadSession( + ReadSession.newBuilder() + .setTable("projects/foo.com:project/datasets/dataset/tables/table") + .setDataFormat(DataFormat.AVRO) + .setReadOptions(ReadSession.TableReadOptions.newBuilder())) + .setMaxStreamCount(10) + .build(); + + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSessionName") + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .addStreams(ReadStream.newBuilder().setName("streamName")) + .setDataFormat(DataFormat.AVRO) + .build(); + + ReadRowsRequest expectedReadRowsRequest = + ReadRowsRequest.newBuilder().setReadStream("streamName").build(); + + List records = + Lists.newArrayList( + createRecord("A", 1, AVRO_SCHEMA), + createRecord("B", 2, AVRO_SCHEMA), + createRecord("C", 3, AVRO_SCHEMA), + createRecord("D", 4, AVRO_SCHEMA)); + + List readRowsResponses = + Lists.newArrayList( + createResponse(AVRO_SCHEMA, records.subList(0, 2), 0.0, 0.50), + createResponse(AVRO_SCHEMA, records.subList(2, 4), 0.5, 0.75)); + + StorageClient fakeStorageClient = mock(StorageClient.class, withSettings().serializable()); + when(fakeStorageClient.createReadSession(expectedCreateReadSessionRequest)) + .thenReturn(readSession); + when(fakeStorageClient.readRows(expectedReadRowsRequest, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponses)); + + PCollection> output = + p.apply( + Create.of( + BigQueryDynamicReadDescriptor.create( + null, "foo.com:project:dataset.table", null, null, null, null))) + .apply( + BigQueryIO.readDynamically( + new ParseKeyValue(), KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of())) + .withFormat(DataFormat.AVRO) + .withTestServices( + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withStorageClient(fakeStorageClient))); + + PAssert.that(output) + .containsInAnyOrder( + ImmutableList.of(KV.of("A", 1L), KV.of("B", 2L), KV.of("C", 3L), KV.of("D", 4L))); + + p.run(); + } + + @Test + public void testReadFromBigQueryIOWithTrimmedSchema() throws Exception { + fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); + Table table = new Table().setTableReference(tableRef).setNumBytes(10L).setSchema(TABLE_SCHEMA); + fakeDatasetService.createTable(table); + + CreateReadSessionRequest expectedCreateReadSessionRequest = + CreateReadSessionRequest.newBuilder() + .setParent("projects/project-id") + .setReadSession( + ReadSession.newBuilder() + .setTable("projects/foo.com:project/datasets/dataset/tables/table") + .setReadOptions( + ReadSession.TableReadOptions.newBuilder().addSelectedFields("name")) + .setDataFormat(DataFormat.AVRO)) + .setMaxStreamCount(10) + .build(); + + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSessionName") + .setAvroSchema(AvroSchema.newBuilder().setSchema(TRIMMED_AVRO_SCHEMA_STRING)) + .addStreams(ReadStream.newBuilder().setName("streamName")) + .setDataFormat(DataFormat.AVRO) + .build(); + + ReadRowsRequest expectedReadRowsRequest = + ReadRowsRequest.newBuilder().setReadStream("streamName").build(); + + List records = + Lists.newArrayList( + createRecord("A", TRIMMED_AVRO_SCHEMA), + createRecord("B", TRIMMED_AVRO_SCHEMA), + createRecord("C", TRIMMED_AVRO_SCHEMA), + createRecord("D", TRIMMED_AVRO_SCHEMA)); + + List readRowsResponses = + Lists.newArrayList( + createResponse(TRIMMED_AVRO_SCHEMA, records.subList(0, 2), 0.0, 0.50), + createResponse(TRIMMED_AVRO_SCHEMA, records.subList(2, 4), 0.5, 0.75)); + + StorageClient fakeStorageClient = mock(StorageClient.class, withSettings().serializable()); + when(fakeStorageClient.createReadSession(expectedCreateReadSessionRequest)) + .thenReturn(readSession); + when(fakeStorageClient.readRows(expectedReadRowsRequest, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponses)); + + PCollection output = + p.apply( + Create.of( + BigQueryDynamicReadDescriptor.create( + null, + "foo.com:project:dataset.table", + null, + null, + Lists.newArrayList("name"), + null))) + .apply( + BigQueryIO.readDynamicallyTableRows() + .withFormat(DataFormat.AVRO) + .withTestServices( + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withStorageClient(fakeStorageClient))); + + PAssert.that(output) + .containsInAnyOrder( + ImmutableList.of( + new TableRow().set("name", "A"), + new TableRow().set("name", "B"), + new TableRow().set("name", "C"), + new TableRow().set("name", "D"))); + + p.run(); + } + + @Test + public void testReadFromBigQueryIOArrow() throws Exception { + fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); + Table table = new Table().setTableReference(tableRef).setNumBytes(10L).setSchema(TABLE_SCHEMA); + fakeDatasetService.createTable(table); + + CreateReadSessionRequest expectedCreateReadSessionRequest = + CreateReadSessionRequest.newBuilder() + .setParent("projects/project-id") + .setReadSession( + ReadSession.newBuilder() + .setTable("projects/foo.com:project/datasets/dataset/tables/table") + .setDataFormat(DataFormat.ARROW) + .setReadOptions(ReadSession.TableReadOptions.newBuilder())) + .setMaxStreamCount(10) + .build(); + + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSessionName") + .setArrowSchema( + ArrowSchema.newBuilder() + .setSerializedSchema(serializeArrowSchema(ARROW_SCHEMA)) + .build()) + .addStreams(ReadStream.newBuilder().setName("streamName")) + .setDataFormat(DataFormat.ARROW) + .build(); + + ReadRowsRequest expectedReadRowsRequest = + ReadRowsRequest.newBuilder().setReadStream("streamName").build(); + + List names = Arrays.asList("A", "B", "C", "D"); + List values = Arrays.asList(1L, 2L, 3L, 4L); + List readRowsResponses = + Lists.newArrayList( + createResponseArrow(ARROW_SCHEMA, names.subList(0, 2), values.subList(0, 2), 0.0, 0.50), + createResponseArrow( + ARROW_SCHEMA, names.subList(2, 4), values.subList(2, 4), 0.5, 0.75)); + + StorageClient fakeStorageClient = mock(StorageClient.class, withSettings().serializable()); + when(fakeStorageClient.createReadSession(expectedCreateReadSessionRequest)) + .thenReturn(readSession); + when(fakeStorageClient.readRows(expectedReadRowsRequest, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponses)); + + PCollection> output = + p.apply( + Create.of( + BigQueryDynamicReadDescriptor.create( + null, "foo.com:project:dataset.table", null, null, null, null))) + .apply( + BigQueryIO.readDynamically( + new ParseKeyValue(), KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of())) + .withFormat(DataFormat.ARROW) + .withTestServices( + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withStorageClient(fakeStorageClient))); + + PAssert.that(output) + .containsInAnyOrder( + ImmutableList.of(KV.of("A", 1L), KV.of("B", 2L), KV.of("C", 3L), KV.of("D", 4L))); + + p.run(); + } + + private FakeJobService fakeJobService = new FakeJobService(); + + public PCollection> configureDynamicRead( + Pipeline p, + SerializableFunction> parseFn, + ErrorHandler> errorHandler) + throws Exception { + TableReference sourceTableRef = BigQueryHelpers.parseTableSpec("project:dataset.table"); + + fakeDatasetService.createDataset( + sourceTableRef.getProjectId(), + sourceTableRef.getDatasetId(), + "asia-northeast1", + "Fake plastic tree^H^H^H^Htables", + null); + + fakeDatasetService.createTable( + new Table().setTableReference(sourceTableRef).setLocation("asia-northeast1")); + + Table queryResultTable = new Table().setSchema(TABLE_SCHEMA).setNumBytes(0L); + + String encodedQuery = FakeBigQueryServices.encodeQueryResult(queryResultTable); + + fakeJobService.expectDryRunQuery( + options.getProject(), + encodedQuery, + new JobStatistics() + .setQuery( + new JobStatistics2() + .setTotalBytesProcessed(1024L * 1024L) + .setReferencedTables(ImmutableList.of(sourceTableRef)))); + + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSessionName") + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .addStreams(ReadStream.newBuilder().setName("streamName")) + .setDataFormat(DataFormat.AVRO) + .build(); + + ReadRowsRequest expectedReadRowsRequest = + ReadRowsRequest.newBuilder().setReadStream("streamName").build(); + + List records = + Lists.newArrayList( + createRecord("A", 1, AVRO_SCHEMA), + createRecord("B", 2, AVRO_SCHEMA), + createRecord("C", 3, AVRO_SCHEMA), + createRecord("D", 4, AVRO_SCHEMA)); + + List readRowsResponses = + Lists.newArrayList( + createResponse(AVRO_SCHEMA, records.subList(0, 2), 0.0, 0.500), + createResponse(AVRO_SCHEMA, records.subList(2, 4), 0.5, 0.875)); + + // + // Note that since the temporary table name is generated by the pipeline, we can't match the + // expected create read session request exactly. For now, match against any appropriately typed + // proto object. + // + + StorageClient fakeStorageClient = mock(StorageClient.class, withSettings().serializable()); + when(fakeStorageClient.createReadSession(any())).thenReturn(readSession); + when(fakeStorageClient.readRows(expectedReadRowsRequest, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponses)); + + BigQueryIO.DynamicRead> t = + BigQueryIO.readDynamically(parseFn, KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of())) + .withTestServices( + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withJobService(fakeJobService) + .withStorageClient(fakeStorageClient)); + if (errorHandler != null) { + t = t.withBadRecordErrorHandler(errorHandler); + } + return p.apply( + Create.of( + BigQueryDynamicReadDescriptor.create(encodedQuery, null, false, false, null, null))) + .apply("read", t); + } + + @Test + public void testReadQueryFromBigQueryIO() throws Exception { + PCollection> output = configureDynamicRead(p, new ParseKeyValue(), null); + + PAssert.that(output) + .containsInAnyOrder( + ImmutableList.of(KV.of("A", 1L), KV.of("B", 2L), KV.of("C", 3L), KV.of("D", 4L))); + + p.run(); + } + + private static final class FailingParseKeyValue + implements SerializableFunction> { + @Override + public KV apply(SchemaAndRecord input) { + if (input.getRecord().get("name").toString().equals("B")) { + throw new RuntimeException("ExpectedException"); + } + return KV.of( + input.getRecord().get("name").toString(), (Long) input.getRecord().get("number")); + } + } + + @Test + public void testReadFromBigQueryWithExceptionHandling() throws Exception { + ErrorHandler> errorHandler = + p.registerBadRecordErrorHandler(new ErrorHandlingTestUtils.ErrorSinkTransform()); + PCollection> output = + configureDynamicRead(p, new FailingParseKeyValue(), errorHandler); + + errorHandler.close(); + + PAssert.that(output) + .containsInAnyOrder(ImmutableList.of(KV.of("A", 1L), KV.of("C", 3L), KV.of("D", 4L))); + + PAssert.thatSingleton(errorHandler.getOutput()).isEqualTo(1L); + + p.run(); + } +}