From 29d321d70940e104fe51d5e76333ce0453c1e745 Mon Sep 17 00:00:00 2001 From: Vasily Bondarenko Date: Wed, 31 Jul 2024 14:18:05 +0100 Subject: [PATCH 1/4] Added truncateMode write configuration To allow the overwrite save mode to keep collection options SPARK-384 Original PR: #123 - removed recreate mode due to fragility --------- Co-authored-by: Ross Lawley --- .../spark/sql/connector/RoundTripTest.java | 16 +- .../connector/write/TruncateModesTest.java | 153 ++++++++++++++++++ .../sql/connector/config/WriteConfig.java | 78 +++++++++ .../sql/connector/write/MongoBatchWrite.java | 3 +- .../sql/connector/config/MongoConfigTest.java | 15 ++ 5 files changed, 260 insertions(+), 5 deletions(-) create mode 100644 src/integrationTest/java/com/mongodb/spark/sql/connector/write/TruncateModesTest.java diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java index 284ad28a..e8ec2859 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java @@ -25,6 +25,8 @@ import com.mongodb.spark.sql.connector.beans.ComplexBean; import com.mongodb.spark.sql.connector.beans.DateTimeBean; import com.mongodb.spark.sql.connector.beans.PrimitiveBean; +import com.mongodb.spark.sql.connector.config.WriteConfig; +import com.mongodb.spark.sql.connector.config.WriteConfig.TruncateMode; import com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorTestCase; import java.sql.Date; import java.sql.Timestamp; @@ -41,6 +43,8 @@ import org.apache.spark.sql.Encoders; import org.apache.spark.sql.SparkSession; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; public class RoundTripTest extends MongoSparkConnectorTestCase { @@ -68,8 +72,9 @@ void testPrimitiveBean() { assertIterableEquals(dataSetOriginal, dataSetMongo); } - @Test - void testBoxedBean() { + @ParameterizedTest + @EnumSource(TruncateMode.class) + void testBoxedBean(final TruncateMode mode) { // Given List dataSetOriginal = singletonList(new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, true)); @@ -79,7 +84,12 @@ void testBoxedBean() { Encoder encoder = Encoders.bean(BoxedBean.class); Dataset dataset = spark.createDataset(dataSetOriginal, encoder); - dataset.write().format("mongodb").mode("Overwrite").save(); + dataset + .write() + .format("mongodb") + .mode("Overwrite") + .option(WriteConfig.TRUNCATE_MODE_CONFIG, mode.name()) + .save(); // Then List dataSetMongo = spark diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/write/TruncateModesTest.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/write/TruncateModesTest.java new file mode 100644 index 00000000..a5ba976a --- /dev/null +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/write/TruncateModesTest.java @@ -0,0 +1,153 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package com.mongodb.spark.sql.connector.write; + +import static com.mongodb.spark.sql.connector.config.WriteConfig.TRUNCATE_MODE_CONFIG; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertIterableEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.Collation; +import com.mongodb.client.model.CollationStrength; +import com.mongodb.client.model.CreateCollectionOptions; +import com.mongodb.client.model.IndexOptions; +import com.mongodb.spark.sql.connector.beans.BoxedBean; +import com.mongodb.spark.sql.connector.config.WriteConfig; +import com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorTestCase; +import java.util.ArrayList; +import java.util.List; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.SparkSession; +import org.bson.Document; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TruncateModesTest extends MongoSparkConnectorTestCase { + + public static final String INT_FIELD_INDEX = "intFieldIndex"; + public static final String ID_INDEX = "_id_"; + + @BeforeEach + void setup() { + MongoDatabase database = getDatabase(); + getCollection().drop(); + CreateCollectionOptions createCollectionOptions = new CreateCollectionOptions() + .collation(Collation.builder() + .locale("en") + .collationStrength(CollationStrength.SECONDARY) + .build()); + database.createCollection(getCollectionName(), createCollectionOptions); + MongoCollection collection = database.getCollection(getCollectionName()); + collection.insertOne(new Document().append("intField", null)); + collection.createIndex( + new Document().append("intField", 1), new IndexOptions().name(INT_FIELD_INDEX)); + } + + @Test + void testCollectionDroppedOnOverwrite() { + // Given + List dataSetOriginal = singletonList(getBoxedBean()); + + // when + SparkSession spark = getOrCreateSparkSession(); + Encoder encoder = Encoders.bean(BoxedBean.class); + Dataset dataset = spark.createDataset(dataSetOriginal, encoder); + dataset + .write() + .format("mongodb") + .mode("Overwrite") + .option(TRUNCATE_MODE_CONFIG, WriteConfig.TruncateMode.DROP.toString()) + .save(); + + // Then + List dataSetMongo = spark + .read() + .format("mongodb") + .schema(encoder.schema()) + .load() + .as(encoder) + .collectAsList(); + assertIterableEquals(dataSetOriginal, dataSetMongo); + + List indexes = + getCollection().listIndexes().map(it -> it.getString("name")).into(new ArrayList<>()); + assertEquals(indexes, singletonList(ID_INDEX)); + Document options = getCollectionOptions(); + assertTrue(options.isEmpty()); + } + + @Test + void testOptionKeepingOverwrites() { + // Given + List dataSetOriginal = singletonList(getBoxedBean()); + + // when + SparkSession spark = getOrCreateSparkSession(); + Encoder encoder = Encoders.bean(BoxedBean.class); + Dataset dataset = spark.createDataset(dataSetOriginal, encoder); + dataset + .write() + .format("mongodb") + .mode("Overwrite") + .option(TRUNCATE_MODE_CONFIG, WriteConfig.TruncateMode.TRUNCATE.toString()) + .save(); + + // Then + List dataSetMongo = spark + .read() + .format("mongodb") + .schema(encoder.schema()) + .load() + .as(encoder) + .collectAsList(); + assertIterableEquals(dataSetOriginal, dataSetMongo); + + List indexes = + getCollection().listIndexes().map(it -> it.getString("name")).into(new ArrayList<>()); + assertEquals(indexes, asList(ID_INDEX, INT_FIELD_INDEX)); + + Document options = getCollectionOptions(); + assertTrue(options.containsKey("collation")); + assertEquals("en", options.get("collation", new Document()).get("locale", "NA"), "en"); + } + + private @NotNull BoxedBean getBoxedBean() { + return new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, true); + } + + private Document getCollectionOptions() { + Document getCollectionMeta = new Document() + .append("listCollections", 1) + .append("filter", new Document().append("name", getCollectionName())); + + Document foundMeta = getDatabase().runCommand(getCollectionMeta); + Document cursor = foundMeta.get("cursor", Document.class); + List firstBatch = cursor.getList("firstBatch", Document.class); + if (firstBatch.isEmpty()) { + return getCollectionMeta; + } + + return firstBatch.get(0).get("options", Document.class); + } +} diff --git a/src/main/java/com/mongodb/spark/sql/connector/config/WriteConfig.java b/src/main/java/com/mongodb/spark/sql/connector/config/WriteConfig.java index e5d0082b..6e92d8f5 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/config/WriteConfig.java +++ b/src/main/java/com/mongodb/spark/sql/connector/config/WriteConfig.java @@ -22,11 +22,13 @@ import com.mongodb.MongoNamespace; import com.mongodb.WriteConcern; +import com.mongodb.client.MongoCollection; import com.mongodb.spark.sql.connector.exceptions.ConfigException; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import org.bson.Document; import org.jetbrains.annotations.ApiStatus; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -116,6 +118,59 @@ public String toString() { } } + /** + * Determines how to truncate a collection when using {@link org.apache.spark.sql.SaveMode#Overwrite} + * + * @since 10.6 + */ + public enum TruncateMode { + /** + * Drops the collection + */ + DROP("drop") { + @Override + public void truncate(final WriteConfig writeConfig) { + writeConfig.doWithCollection(MongoCollection::drop); + } + }, + /** + * Deletes all entries in the collection preserving indexes, collection options and any sharding configuration + *

Warning: This operation is currently much more expensive than doing a simple drop operation.

+ */ + TRUNCATE("truncate") { + @Override + public void truncate(final WriteConfig writeConfig) { + writeConfig.doWithCollection(collection -> collection.deleteMany(new Document())); + } + }; + + private final String value; + + TruncateMode(final String value) { + this.value = value; + } + + static TruncateMode fromString(final String truncateMode) { + for (TruncateMode truncateModeType : TruncateMode.values()) { + if (truncateMode.equalsIgnoreCase(truncateModeType.value)) { + return truncateModeType; + } + } + throw new ConfigException(format("'%s' is not a valid Truncate Mode", truncateMode)); + } + + /** + * The truncation implementation for each different truncation type + * @param writeConfig the write config + */ + public abstract void truncate(WriteConfig writeConfig); + + @Override + public String toString() { + return value; + } + } + /** * The maximum batch size for the batch in the bulk operation. * @@ -243,6 +298,21 @@ public String toString() { private static final boolean IGNORE_NULL_VALUES_DEFAULT = false; + /** + * Truncate Mode + * + *

Configuration: {@value} + * + *

Default: {@code Drop} + * + *

Determines how to truncate a collection when using {@link org.apache.spark.sql.SaveMode#Overwrite} + * + * @since 10.6 + */ + public static final String TRUNCATE_MODE_CONFIG = "truncateMode"; + + private static final String TRUNCATE_MODE_DEFAULT = TruncateMode.DROP.value; + private final WriteConcern writeConcern; private final OperationType operationType; @@ -319,6 +389,14 @@ public boolean ignoreNullValues() { return getBoolean(IGNORE_NULL_VALUES_CONFIG, IGNORE_NULL_VALUES_DEFAULT); } + /** + * @return the truncate mode for use when overwriting collections + * @since 10.6 + */ + public TruncateMode truncateMode() { + return TruncateMode.fromString(getOrDefault(TRUNCATE_MODE_CONFIG, TRUNCATE_MODE_DEFAULT)); + } + @Override CollectionsConfig parseAndValidateCollectionsConfig() { CollectionsConfig collectionsConfig = super.parseAndValidateCollectionsConfig(); diff --git a/src/main/java/com/mongodb/spark/sql/connector/write/MongoBatchWrite.java b/src/main/java/com/mongodb/spark/sql/connector/write/MongoBatchWrite.java index 09063f0d..19b964dd 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/write/MongoBatchWrite.java +++ b/src/main/java/com/mongodb/spark/sql/connector/write/MongoBatchWrite.java @@ -19,7 +19,6 @@ import static java.lang.String.format; -import com.mongodb.client.MongoCollection; import com.mongodb.spark.sql.connector.config.WriteConfig; import com.mongodb.spark.sql.connector.exceptions.DataException; import java.util.Arrays; @@ -62,7 +61,7 @@ final class MongoBatchWrite implements BatchWrite { @Override public DataWriterFactory createBatchWriterFactory(final PhysicalWriteInfo physicalWriteInfo) { if (truncate) { - writeConfig.doWithCollection(MongoCollection::drop); + writeConfig.truncateMode().truncate(writeConfig); } return new MongoDataWriterFactory(info.schema(), writeConfig); } diff --git a/src/test/java/com/mongodb/spark/sql/connector/config/MongoConfigTest.java b/src/test/java/com/mongodb/spark/sql/connector/config/MongoConfigTest.java index 22e8ed20..0d5df138 100644 --- a/src/test/java/com/mongodb/spark/sql/connector/config/MongoConfigTest.java +++ b/src/test/java/com/mongodb/spark/sql/connector/config/MongoConfigTest.java @@ -159,6 +159,21 @@ void testWriteConfigConvertJson() { WriteConfig.ConvertJson.OBJECT_OR_ARRAY_ONLY); } + @Test + void testWriteConfigTruncateMode() { + WriteConfig writeConfig = MongoConfig.createConfig(CONFIG_MAP).toWriteConfig(); + assertEquals(writeConfig.truncateMode(), WriteConfig.TruncateMode.DROP); + assertEquals( + writeConfig.withOption("TruncateMode", "truncate").truncateMode(), + WriteConfig.TruncateMode.TRUNCATE); + assertEquals( + writeConfig.withOption("TruncateMode", "Drop").truncateMode(), + WriteConfig.TruncateMode.DROP); + assertThrows( + ConfigException.class, + () -> writeConfig.withOption("TruncateMode", "RECREATE").truncateMode()); + } + @Test void testMongoConfigOptionsParsing() { MongoConfig mongoConfig = MongoConfig.readConfig(OPTIONS_CONFIG_MAP); From 8aacd6c13d51f45d546ff91a31aa18b0ef8ca572 Mon Sep 17 00:00:00 2001 From: Vasily Bondarenko Date: Wed, 14 Aug 2024 12:31:39 +0100 Subject: [PATCH 2/4] Added MongoTable support for deletes - added support for delete - minor refactoring for filter transformations - minor fix for catalog to load table with resolved schema instead of empty one SPARK-414 Original PR: #124 --------- Co-authored-by: Ross Lawley --- .../spark/sql/connector/RoundTripTest.java | 31 +++ .../mongodb/MongoSparkConnectorHelper.java | 3 + .../sql/connector/ExpressionConverter.java | 253 ++++++++++++++++++ .../spark/sql/connector/MongoCatalog.java | 5 +- .../spark/sql/connector/MongoTable.java | 31 ++- .../sql/connector/read/MongoScanBuilder.java | 197 +------------- 6 files changed, 326 insertions(+), 194 deletions(-) create mode 100644 src/main/java/com/mongodb/spark/sql/connector/ExpressionConverter.java diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java index e8ec2859..68a24694 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java @@ -17,8 +17,10 @@ package com.mongodb.spark.sql.connector; +import static com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorHelper.CATALOG; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertIterableEquals; import com.mongodb.spark.sql.connector.beans.BoxedBean; @@ -41,6 +43,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -172,4 +175,32 @@ void testComplexBean() { .collectAsList(); assertIterableEquals(dataSetOriginal, dataSetMongo); } + + @Test + void testCatalogAccessAndDelete() { + List dataSetOriginal = asList( + new BoxedBean((byte) 1, (short) 2, 0, 4L, 5.0f, 6.0, true), + new BoxedBean((byte) 1, (short) 2, 1, 4L, 5.0f, 6.0, true), + new BoxedBean((byte) 1, (short) 2, 2, 4L, 5.0f, 6.0, true), + new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, false), + new BoxedBean((byte) 1, (short) 2, 4, 4L, 5.0f, 6.0, false), + new BoxedBean((byte) 1, (short) 2, 5, 4L, 5.0f, 6.0, false)); + + SparkSession spark = getOrCreateSparkSession(); + Encoder encoder = Encoders.bean(BoxedBean.class); + spark + .createDataset(dataSetOriginal, encoder) + .write() + .format("mongodb") + .mode("Overwrite") + .save(); + + String tableName = CATALOG + "." + HELPER.getDatabaseName() + "." + HELPER.getCollectionName(); + List rows = spark.sql("select * from " + tableName).collectAsList(); + assertEquals(6, rows.size()); + + spark.sql("delete from " + tableName + " where booleanField = false and intField > 3"); + rows = spark.sql("select * from " + tableName).collectAsList(); + assertEquals(4, rows.size()); + } } diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorHelper.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorHelper.java index d3ceddb6..fec63931 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorHelper.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorHelper.java @@ -28,6 +28,7 @@ import com.mongodb.client.model.UpdateOptions; import com.mongodb.client.model.Updates; import com.mongodb.connection.ClusterType; +import com.mongodb.spark.sql.connector.MongoCatalog; import com.mongodb.spark.sql.connector.config.MongoConfig; import java.io.File; import java.io.IOException; @@ -62,6 +63,7 @@ public class MongoSparkConnectorHelper "{_id: '%s', pk: '%s', dups: '%s', i: %d, s: '%s'}"; private static final String COMPLEX_SAMPLE_DATA_TEMPLATE = "{_id: '%s', nested: {pk: '%s', dups: '%s', i: %d}, s: '%s'}"; + public static final String CATALOG = "mongo_catalog"; private static final Logger LOGGER = LoggerFactory.getLogger(MongoSparkConnectorHelper.class); @@ -146,6 +148,7 @@ public SparkConf getSparkConf() { .set("spark.sql.streaming.checkpointLocation", getTempDirectory()) .set("spark.sql.streaming.forceDeleteTempCheckpointLocation", "true") .set("spark.app.id", "MongoSparkConnector") + .set("spark.sql.catalog." + CATALOG, MongoCatalog.class.getCanonicalName()) .set( MongoConfig.PREFIX + MongoConfig.CONNECTION_STRING_CONFIG, getConnectionString().getConnectionString()) diff --git a/src/main/java/com/mongodb/spark/sql/connector/ExpressionConverter.java b/src/main/java/com/mongodb/spark/sql/connector/ExpressionConverter.java new file mode 100644 index 00000000..79f32bc8 --- /dev/null +++ b/src/main/java/com/mongodb/spark/sql/connector/ExpressionConverter.java @@ -0,0 +1,253 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.mongodb.spark.sql.connector; + +import static com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter.createObjectToBsonValue; +import static java.lang.String.format; + +import com.mongodb.client.model.Filters; +import com.mongodb.spark.sql.connector.assertions.Assertions; +import com.mongodb.spark.sql.connector.config.WriteConfig; +import com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualNullSafe; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; +import org.apache.spark.sql.sources.StringContains; +import org.apache.spark.sql.sources.StringEndsWith; +import org.apache.spark.sql.sources.StringStartsWith; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.bson.BsonValue; +import org.bson.conversions.Bson; +import org.jetbrains.annotations.Nullable; +import org.jetbrains.annotations.VisibleForTesting; + +/** + * Utility class to convert {@link Filter} expressions into MongoDB aggregation pipelines + * + * @since 10.6 + */ +public final class ExpressionConverter { + private final StructType schema; + + /** + * Construct a new instance + * @param schema the schema for the data + */ + public ExpressionConverter(final StructType schema) { + this.schema = schema; + } + + /** + * Processes {@link Filter} into aggregation pipelines if possible + * @param filter the filter to translate + * @return the {@link FilterAndPipelineStage} representing the Filter and pipeline stage if conversion is possible + */ + public FilterAndPipelineStage processFilter(final Filter filter) { + Assertions.ensureArgument(() -> filter != null, () -> "Invalid argument filter cannot be null"); + if (filter instanceof And) { + And andFilter = (And) filter; + FilterAndPipelineStage eitherLeft = processFilter(andFilter.left()); + FilterAndPipelineStage eitherRight = processFilter(andFilter.right()); + if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) { + return new FilterAndPipelineStage( + filter, Filters.and(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage())); + } + } else if (filter instanceof EqualNullSafe) { + EqualNullSafe equalNullSafe = (EqualNullSafe) filter; + String fieldName = unquoteFieldName(equalNullSafe.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, equalNullSafe.value()) + .map(bsonValue -> Filters.eq(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof EqualTo) { + EqualTo equalTo = (EqualTo) filter; + String fieldName = unquoteFieldName(equalTo.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, equalTo.value()) + .map(bsonValue -> Filters.eq(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof GreaterThan) { + GreaterThan greaterThan = (GreaterThan) filter; + String fieldName = unquoteFieldName(greaterThan.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, greaterThan.value()) + .map(bsonValue -> Filters.gt(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof GreaterThanOrEqual) { + GreaterThanOrEqual greaterThanOrEqual = (GreaterThanOrEqual) filter; + String fieldName = unquoteFieldName(greaterThanOrEqual.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, greaterThanOrEqual.value()) + .map(bsonValue -> Filters.gte(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof In) { + In inFilter = (In) filter; + String fieldName = unquoteFieldName(inFilter.attribute()); + List values = Arrays.stream(inFilter.values()) + .map(v -> getBsonValue(fieldName, v)) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toList()); + + // Ensure all values were matched otherwise leave to Spark to filter. + Bson pipelineStage = null; + if (values.size() == inFilter.values().length) { + pipelineStage = Filters.in(fieldName, values); + } + return new FilterAndPipelineStage(filter, pipelineStage); + } else if (filter instanceof IsNull) { + IsNull isNullFilter = (IsNull) filter; + String fieldName = unquoteFieldName(isNullFilter.attribute()); + return new FilterAndPipelineStage(filter, Filters.eq(fieldName, null)); + } else if (filter instanceof IsNotNull) { + IsNotNull isNotNullFilter = (IsNotNull) filter; + String fieldName = unquoteFieldName(isNotNullFilter.attribute()); + return new FilterAndPipelineStage(filter, Filters.ne(fieldName, null)); + } else if (filter instanceof LessThan) { + LessThan lessThan = (LessThan) filter; + String fieldName = unquoteFieldName(lessThan.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, lessThan.value()) + .map(bsonValue -> Filters.lt(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof LessThanOrEqual) { + LessThanOrEqual lessThanOrEqual = (LessThanOrEqual) filter; + String fieldName = unquoteFieldName(lessThanOrEqual.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, lessThanOrEqual.value()) + .map(bsonValue -> Filters.lte(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof Not) { + Not notFilter = (Not) filter; + FilterAndPipelineStage notChild = processFilter(notFilter.child()); + if (notChild.hasPipelineStage()) { + return new FilterAndPipelineStage(filter, Filters.not(notChild.pipelineStage)); + } + } else if (filter instanceof Or) { + Or or = (Or) filter; + FilterAndPipelineStage eitherLeft = processFilter(or.left()); + FilterAndPipelineStage eitherRight = processFilter(or.right()); + if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) { + return new FilterAndPipelineStage( + filter, Filters.or(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage())); + } + } else if (filter instanceof StringContains) { + StringContains stringContains = (StringContains) filter; + String fieldName = unquoteFieldName(stringContains.attribute()); + return new FilterAndPipelineStage( + filter, Filters.regex(fieldName, format(".*%s.*", stringContains.value()))); + } else if (filter instanceof StringEndsWith) { + StringEndsWith stringEndsWith = (StringEndsWith) filter; + String fieldName = unquoteFieldName(stringEndsWith.attribute()); + return new FilterAndPipelineStage( + filter, Filters.regex(fieldName, format(".*%s$", stringEndsWith.value()))); + } else if (filter instanceof StringStartsWith) { + StringStartsWith stringStartsWith = (StringStartsWith) filter; + String fieldName = unquoteFieldName(stringStartsWith.attribute()); + return new FilterAndPipelineStage( + filter, Filters.regex(fieldName, format("^%s.*", stringStartsWith.value()))); + } + return new FilterAndPipelineStage(filter, null); + } + + @VisibleForTesting + static String unquoteFieldName(final String fieldName) { + // Spark automatically escapes hyphenated names using backticks + if (fieldName.contains("`")) { + return new Column(fieldName).toString(); + } + return fieldName; + } + + private Optional getBsonValue(final String fieldName, final Object value) { + try { + StructType localSchema = schema; + DataType localDataType = localSchema; + + for (String localFieldName : fieldName.split("\\.")) { + StructField localField = localSchema.apply(localFieldName); + localDataType = localField.dataType(); + if (localField.dataType() instanceof StructType) { + localSchema = (StructType) localField.dataType(); + } + } + RowToBsonDocumentConverter.ObjectToBsonValue objectToBsonValue = + createObjectToBsonValue(localDataType, WriteConfig.ConvertJson.FALSE, false); + return Optional.of(objectToBsonValue.apply(value)); + } catch (Exception e) { + // ignore + return Optional.empty(); + } + } + + /** FilterAndPipelineStage - contains an optional pipeline stage for the filter. */ + public static final class FilterAndPipelineStage { + + private final Filter filter; + private final Bson pipelineStage; + + private FilterAndPipelineStage(final Filter filter, @Nullable final Bson pipelineStage) { + this.filter = filter; + this.pipelineStage = pipelineStage; + } + + /** + * @return the filter + */ + public Filter getFilter() { + return filter; + } + + /** + * @return the equivalent pipeline for the filter or {@code null} if translation for the filter wasn't possible + */ + public Bson getPipelineStage() { + return pipelineStage; + } + + /** + * @return true if the {@link Filter} could be converted into a pipeline stage + */ + public boolean hasPipelineStage() { + return pipelineStage != null; + } + } +} diff --git a/src/main/java/com/mongodb/spark/sql/connector/MongoCatalog.java b/src/main/java/com/mongodb/spark/sql/connector/MongoCatalog.java index 22feed44..9a454781 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/MongoCatalog.java +++ b/src/main/java/com/mongodb/spark/sql/connector/MongoCatalog.java @@ -28,6 +28,7 @@ import com.mongodb.spark.sql.connector.config.ReadConfig; import com.mongodb.spark.sql.connector.config.WriteConfig; import com.mongodb.spark.sql.connector.exceptions.MongoSparkException; +import com.mongodb.spark.sql.connector.schema.InferSchema; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -239,7 +240,9 @@ public Table loadTable(final Identifier identifier) throws NoSuchTableException properties.put( MongoConfig.READ_PREFIX + MongoConfig.DATABASE_NAME_CONFIG, identifier.namespace()[0]); properties.put(MongoConfig.READ_PREFIX + MongoConfig.COLLECTION_NAME_CONFIG, identifier.name()); - return new MongoTable(MongoConfig.readConfig(properties)); + return new MongoTable( + InferSchema.inferSchema(new CaseInsensitiveStringMap(properties)), + MongoConfig.readConfig(properties)); } /** diff --git a/src/main/java/com/mongodb/spark/sql/connector/MongoTable.java b/src/main/java/com/mongodb/spark/sql/connector/MongoTable.java index 0dc7039d..4481b434 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/MongoTable.java +++ b/src/main/java/com/mongodb/spark/sql/connector/MongoTable.java @@ -19,7 +19,12 @@ import static java.util.Arrays.asList; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.Filters; import com.mongodb.spark.connector.Versions; +import com.mongodb.spark.sql.connector.ExpressionConverter.FilterAndPipelineStage; import com.mongodb.spark.sql.connector.config.MongoConfig; import com.mongodb.spark.sql.connector.config.ReadConfig; import com.mongodb.spark.sql.connector.config.WriteConfig; @@ -27,9 +32,12 @@ import com.mongodb.spark.sql.connector.write.MongoWriteBuilder; import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; +import org.apache.spark.sql.connector.catalog.SupportsDelete; import org.apache.spark.sql.connector.catalog.SupportsRead; import org.apache.spark.sql.connector.catalog.SupportsWrite; import org.apache.spark.sql.connector.catalog.Table; @@ -38,13 +46,16 @@ import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.write.LogicalWriteInfo; import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.bson.Document; +import org.bson.conversions.Bson; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** Represents a MongoDB Collection. */ -final class MongoTable implements Table, SupportsWrite, SupportsRead { +final class MongoTable implements Table, SupportsWrite, SupportsRead, SupportsDelete { private static final Logger LOGGER = LoggerFactory.getLogger(MongoTable.class); private static final Set TABLE_CAPABILITY_SET = new HashSet<>(asList( TableCapability.BATCH_WRITE, @@ -179,4 +190,22 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(partitioning); return result; } + + @Override + public void deleteWhere(final Filter[] filters) { + ExpressionConverter converter = new ExpressionConverter(schema); + + List stages = Arrays.stream(filters) + .map(converter::processFilter) + .filter(FilterAndPipelineStage::hasPipelineStage) + .map(FilterAndPipelineStage::getPipelineStage) + .collect(Collectors.toList()); + Bson query = Filters.and(stages); + WriteConfig writeConfig = mongoConfig.toWriteConfig(); + + MongoClient mongoClient = writeConfig.getMongoClient(); + MongoDatabase database = mongoClient.getDatabase(writeConfig.getDatabaseName()); + MongoCollection collection = database.getCollection(writeConfig.getCollectionName()); + collection.deleteMany(query); + } } diff --git a/src/main/java/com/mongodb/spark/sql/connector/read/MongoScanBuilder.java b/src/main/java/com/mongodb/spark/sql/connector/read/MongoScanBuilder.java index f8c5c643..31e77be3 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/read/MongoScanBuilder.java +++ b/src/main/java/com/mongodb/spark/sql/connector/read/MongoScanBuilder.java @@ -17,23 +17,19 @@ package com.mongodb.spark.sql.connector.read; -import static com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter.createObjectToBsonValue; -import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; import com.mongodb.client.model.Aggregates; import com.mongodb.client.model.Filters; -import com.mongodb.spark.sql.connector.assertions.Assertions; +import com.mongodb.spark.sql.connector.ExpressionConverter; +import com.mongodb.spark.sql.connector.ExpressionConverter.FilterAndPipelineStage; import com.mongodb.spark.sql.connector.config.MongoConfig; import com.mongodb.spark.sql.connector.config.ReadConfig; -import com.mongodb.spark.sql.connector.config.WriteConfig; -import com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Locale; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import org.apache.spark.sql.Column; @@ -42,30 +38,11 @@ import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.read.SupportsPushDownFilters; import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; -import org.apache.spark.sql.sources.And; -import org.apache.spark.sql.sources.EqualNullSafe; -import org.apache.spark.sql.sources.EqualTo; import org.apache.spark.sql.sources.Filter; -import org.apache.spark.sql.sources.GreaterThan; -import org.apache.spark.sql.sources.GreaterThanOrEqual; -import org.apache.spark.sql.sources.In; -import org.apache.spark.sql.sources.IsNotNull; -import org.apache.spark.sql.sources.IsNull; -import org.apache.spark.sql.sources.LessThan; -import org.apache.spark.sql.sources.LessThanOrEqual; -import org.apache.spark.sql.sources.Not; -import org.apache.spark.sql.sources.Or; -import org.apache.spark.sql.sources.StringContains; -import org.apache.spark.sql.sources.StringEndsWith; -import org.apache.spark.sql.sources.StringStartsWith; -import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.bson.BsonDocument; -import org.bson.BsonValue; -import org.bson.conversions.Bson; import org.jetbrains.annotations.ApiStatus; -import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.VisibleForTesting; /** A builder for a {@link MongoScan}. */ @@ -121,8 +98,10 @@ public Scan build() { */ @Override public Filter[] pushFilters(final Filter[] filters) { + ExpressionConverter converter = new ExpressionConverter(schema); + List processed = - Arrays.stream(filters).map(this::processFilter).collect(Collectors.toList()); + Arrays.stream(filters).map(converter::processFilter).collect(Collectors.toList()); List withPipelines = processed.stream() .filter(FilterAndPipelineStage::hasPipelineStage) @@ -166,127 +145,6 @@ private String getColumnName(final StructField field) { return field.name(); } - /** - * Processes the Filter and if possible creates the equivalent aggregation pipeline stage. - * - * @param filter the filter to be applied - * @return the FilterAndPipelineStage which contains a pipeline stage if the filter is convertible - * into an aggregation pipeline. - */ - private FilterAndPipelineStage processFilter(final Filter filter) { - Assertions.ensureArgument(() -> filter != null, () -> "Invalid argument filter cannot be null"); - if (filter instanceof And) { - And andFilter = (And) filter; - FilterAndPipelineStage eitherLeft = processFilter(andFilter.left()); - FilterAndPipelineStage eitherRight = processFilter(andFilter.right()); - if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) { - return new FilterAndPipelineStage( - filter, Filters.and(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage())); - } - } else if (filter instanceof EqualNullSafe) { - EqualNullSafe equalNullSafe = (EqualNullSafe) filter; - String fieldName = unquoteFieldName(equalNullSafe.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, equalNullSafe.value()) - .map(bsonValue -> Filters.eq(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof EqualTo) { - EqualTo equalTo = (EqualTo) filter; - String fieldName = unquoteFieldName(equalTo.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, equalTo.value()) - .map(bsonValue -> Filters.eq(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof GreaterThan) { - GreaterThan greaterThan = (GreaterThan) filter; - String fieldName = unquoteFieldName(greaterThan.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, greaterThan.value()) - .map(bsonValue -> Filters.gt(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof GreaterThanOrEqual) { - GreaterThanOrEqual greaterThanOrEqual = (GreaterThanOrEqual) filter; - String fieldName = unquoteFieldName(greaterThanOrEqual.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, greaterThanOrEqual.value()) - .map(bsonValue -> Filters.gte(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof In) { - In inFilter = (In) filter; - String fieldName = unquoteFieldName(inFilter.attribute()); - List values = Arrays.stream(inFilter.values()) - .map(v -> getBsonValue(fieldName, v)) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(Collectors.toList()); - - // Ensure all values were matched otherwise leave to Spark to filter. - Bson pipelineStage = null; - if (values.size() == inFilter.values().length) { - pipelineStage = Filters.in(fieldName, values); - } - return new FilterAndPipelineStage(filter, pipelineStage); - } else if (filter instanceof IsNull) { - IsNull isNullFilter = (IsNull) filter; - String fieldName = unquoteFieldName(isNullFilter.attribute()); - return new FilterAndPipelineStage(filter, Filters.eq(fieldName, null)); - } else if (filter instanceof IsNotNull) { - IsNotNull isNotNullFilter = (IsNotNull) filter; - String fieldName = unquoteFieldName(isNotNullFilter.attribute()); - return new FilterAndPipelineStage(filter, Filters.ne(fieldName, null)); - } else if (filter instanceof LessThan) { - LessThan lessThan = (LessThan) filter; - String fieldName = unquoteFieldName(lessThan.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, lessThan.value()) - .map(bsonValue -> Filters.lt(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof LessThanOrEqual) { - LessThanOrEqual lessThanOrEqual = (LessThanOrEqual) filter; - String fieldName = unquoteFieldName(lessThanOrEqual.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, lessThanOrEqual.value()) - .map(bsonValue -> Filters.lte(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof Not) { - Not notFilter = (Not) filter; - FilterAndPipelineStage notChild = processFilter(notFilter.child()); - if (notChild.hasPipelineStage()) { - return new FilterAndPipelineStage(filter, Filters.not(notChild.pipelineStage)); - } - } else if (filter instanceof Or) { - Or or = (Or) filter; - FilterAndPipelineStage eitherLeft = processFilter(or.left()); - FilterAndPipelineStage eitherRight = processFilter(or.right()); - if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) { - return new FilterAndPipelineStage( - filter, Filters.or(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage())); - } - } else if (filter instanceof StringContains) { - StringContains stringContains = (StringContains) filter; - String fieldName = unquoteFieldName(stringContains.attribute()); - return new FilterAndPipelineStage( - filter, Filters.regex(fieldName, format(".*%s.*", stringContains.value()))); - } else if (filter instanceof StringEndsWith) { - StringEndsWith stringEndsWith = (StringEndsWith) filter; - String fieldName = unquoteFieldName(stringEndsWith.attribute()); - return new FilterAndPipelineStage( - filter, Filters.regex(fieldName, format(".*%s$", stringEndsWith.value()))); - } else if (filter instanceof StringStartsWith) { - StringStartsWith stringStartsWith = (StringStartsWith) filter; - String fieldName = unquoteFieldName(stringStartsWith.attribute()); - return new FilterAndPipelineStage( - filter, Filters.regex(fieldName, format("^%s.*", stringStartsWith.value()))); - } - return new FilterAndPipelineStage(filter, null); - } - @VisibleForTesting static String unquoteFieldName(final String fieldName) { // Spark automatically escapes hyphenated names using backticks @@ -295,49 +153,4 @@ static String unquoteFieldName(final String fieldName) { } return fieldName; } - - private Optional getBsonValue(final String fieldName, final Object value) { - try { - StructType localSchema = schema; - DataType localDataType = localSchema; - - for (String localFieldName : fieldName.split("\\.")) { - StructField localField = localSchema.apply(localFieldName); - localDataType = localField.dataType(); - if (localField.dataType() instanceof StructType) { - localSchema = (StructType) localField.dataType(); - } - } - RowToBsonDocumentConverter.ObjectToBsonValue objectToBsonValue = - createObjectToBsonValue(localDataType, WriteConfig.ConvertJson.FALSE, false); - return Optional.of(objectToBsonValue.apply(value)); - } catch (Exception e) { - // ignore - return Optional.empty(); - } - } - - /** FilterAndPipelineStage - contains an optional pipeline stage for the filter. */ - private static final class FilterAndPipelineStage { - - private final Filter filter; - private final Bson pipelineStage; - - private FilterAndPipelineStage(final Filter filter, @Nullable final Bson pipelineStage) { - this.filter = filter; - this.pipelineStage = pipelineStage; - } - - public Filter getFilter() { - return filter; - } - - public Bson getPipelineStage() { - return pipelineStage; - } - - boolean hasPipelineStage() { - return pipelineStage != null; - } - } } From c97fdcc4bad5d155bbdb6747c750923fdefe365c Mon Sep 17 00:00:00 2001 From: Vasily Bondarenko Date: Thu, 15 Aug 2024 18:24:40 +0100 Subject: [PATCH 3/4] Add support for spark.sql.datetime.java8API.enabled Motivation: when using connector with thrift server, any queries on data containing date/time types are crashing. It's caused by thrift server enabling `spark.sql.datetime.java8API.enabled` flag for all of its sessions. Supported JSR-310 types which are already supported by Spark SQL: - LocalDate - LocalDateTime - Instant Other types are causing Encoder to fail, so they won't work even if taken into account. As `spark.sql.datetime.java8API.enabled` is disabled by default, it shouldn't cause compatibility issues. Adds support for the Spark 3.4 TimeTimestampNTZType SPARK-453 Original PR: #125 --------- Co-authored-by: Ross Lawley --- .../spark/sql/connector/RoundTripTest.java | 21 ++++++++---- .../sql/connector/beans/DateTimeBean.java | 34 +++++++++++++++---- .../schema/BsonDocumentToRowConverter.java | 25 ++++++++++++-- .../sql/connector/schema/ConverterHelper.java | 25 ++++++++++++++ .../schema/RowToBsonDocumentConverter.java | 33 +++++++++++++++--- 5 files changed, 119 insertions(+), 19 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java index 68a24694..2c474d86 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java @@ -18,10 +18,13 @@ package com.mongodb.spark.sql.connector; import static com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorHelper.CATALOG; +import static com.mongodb.spark.sql.connector.schema.ConverterHelper.TIMESTAMP_NTZ_TYPE; +import static java.time.ZoneOffset.UTC; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertIterableEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; import com.mongodb.spark.sql.connector.beans.BoxedBean; import com.mongodb.spark.sql.connector.beans.ComplexBean; @@ -34,7 +37,7 @@ import java.sql.Timestamp; import java.time.Instant; import java.time.LocalDate; -import java.time.ZoneOffset; +import java.time.LocalDateTime; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -48,6 +51,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.ValueSource; public class RoundTripTest extends MongoSparkConnectorTestCase { @@ -105,24 +109,29 @@ void testBoxedBean(final TruncateMode mode) { assertIterableEquals(dataSetOriginal, dataSetMongo); } - @Test - void testDateTimeBean() { + @ParameterizedTest() + @ValueSource(strings = {"true", "false"}) + void testDateTimeBean(final String java8DateTimeAPI) { + assumeTrue(TIMESTAMP_NTZ_TYPE != null); TimeZone original = TimeZone.getDefault(); try { - TimeZone.setDefault(TimeZone.getTimeZone(ZoneOffset.UTC)); + TimeZone.setDefault(TimeZone.getTimeZone(UTC)); // Given long oneHour = TimeUnit.MILLISECONDS.convert(1, TimeUnit.HOURS); long oneDay = oneHour * 24; + Instant epoch = Instant.EPOCH; List dataSetOriginal = singletonList(new DateTimeBean( new Date(oneDay * 365), new Timestamp(oneDay + oneHour), LocalDate.of(2000, 1, 1), - Instant.EPOCH)); + epoch, + LocalDateTime.ofInstant(epoch, UTC))); // when - SparkSession spark = getOrCreateSparkSession(); + SparkSession spark = getOrCreateSparkSession( + getSparkConf().set("spark.sql.datetime.java8API.enabled", java8DateTimeAPI)); Encoder encoder = Encoders.bean(DateTimeBean.class); Dataset dataset = spark.createDataset(dataSetOriginal, encoder); diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/beans/DateTimeBean.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/beans/DateTimeBean.java index 5b4ff473..09538022 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/beans/DateTimeBean.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/beans/DateTimeBean.java @@ -21,6 +21,7 @@ import java.sql.Timestamp; import java.time.Instant; import java.time.LocalDate; +import java.time.LocalDateTime; import java.util.Objects; public class DateTimeBean implements Serializable { @@ -28,6 +29,7 @@ public class DateTimeBean implements Serializable { private java.sql.Timestamp sqlTimestamp; private java.time.LocalDate localDate; private java.time.Instant instant; + private java.time.LocalDateTime localDateTime; public DateTimeBean() {} @@ -35,10 +37,12 @@ public DateTimeBean( final Date sqlDate, final Timestamp sqlTimestamp, final LocalDate localDate, - final Instant instant) { + final Instant instant, + final LocalDateTime localDateTime) { this.sqlDate = sqlDate; this.sqlTimestamp = sqlTimestamp; this.localDate = localDate; + this.localDateTime = localDateTime; this.instant = instant; } @@ -66,6 +70,14 @@ public void setLocalDate(final LocalDate localDate) { this.localDate = localDate; } + public LocalDateTime getLocalDateTime() { + return localDateTime; + } + + public void setLocalDateTime(final LocalDateTime localDateTime) { + this.localDateTime = localDateTime; + } + public Instant getInstant() { return instant; } @@ -86,20 +98,28 @@ public boolean equals(final Object o) { return Objects.equals(sqlDate, that.sqlDate) && Objects.equals(sqlTimestamp, that.sqlTimestamp) && Objects.equals(localDate, that.localDate) + && Objects.equals(localDateTime, that.localDateTime) && Objects.equals(instant, that.instant); } @Override public int hashCode() { - return Objects.hash(sqlDate, sqlTimestamp, localDate, instant); + return Objects.hash(sqlDate, sqlTimestamp, localDate, localDateTime, instant); } @Override public String toString() { - return "DateTimeBean{" + "sqlDate=" - + sqlDate + ", sqlTimestamp=" - + sqlTimestamp + ", localDate=" - + localDate + ", instant=" - + instant + '}'; + return "DateTimeBean{" + + "sqlDate=" + + sqlDate + + ", sqlTimestamp=" + + sqlTimestamp + + ", localDate=" + + localDate + + ", localDateTime=" + + localDateTime + + ", instant=" + + instant + + '}'; } } diff --git a/src/main/java/com/mongodb/spark/sql/connector/schema/BsonDocumentToRowConverter.java b/src/main/java/com/mongodb/spark/sql/connector/schema/BsonDocumentToRowConverter.java index a97b720a..7d0f9bbe 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/schema/BsonDocumentToRowConverter.java +++ b/src/main/java/com/mongodb/spark/sql/connector/schema/BsonDocumentToRowConverter.java @@ -19,6 +19,7 @@ import static com.mongodb.spark.sql.connector.schema.ConverterHelper.BSON_VALUE_CODEC; import static com.mongodb.spark.sql.connector.schema.ConverterHelper.getJsonWriterSettings; +import static com.mongodb.spark.sql.connector.schema.ConverterHelper.isTimestampNTZ; import static com.mongodb.spark.sql.connector.schema.ConverterHelper.toJson; import static java.lang.String.format; import static java.util.concurrent.TimeUnit.SECONDS; @@ -44,6 +45,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; +import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.BinaryType; import org.apache.spark.sql.types.BooleanType; @@ -76,6 +78,8 @@ import org.bson.types.Decimal128; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * The helper for conversion of BsonDocuments to GenericRowWithSchema instances. @@ -89,6 +93,7 @@ @NotNull public final class BsonDocumentToRowConverter implements Serializable { private static final long serialVersionUID = 1L; + private static final Logger LOGGER = LoggerFactory.getLogger(BsonDocumentToRowConverter.class); private final Function rowToInternalRowFunction; private final StructType schema; private final boolean outputExtendedJson; @@ -96,6 +101,7 @@ public final class BsonDocumentToRowConverter implements Serializable { private final boolean dropMalformed; private final String columnNameOfCorruptRecord; private final boolean schemaContainsCorruptRecordColumn; + private final boolean dataTimeJava8APIEnabled; private boolean corruptedRecord; @@ -114,6 +120,7 @@ public BsonDocumentToRowConverter(final StructType originalSchema, final ReadCon this.columnNameOfCorruptRecord = readConfig.getColumnNameOfCorruptRecord(); this.schemaContainsCorruptRecordColumn = !columnNameOfCorruptRecord.isEmpty() && Arrays.asList(schema.fieldNames()).contains(columnNameOfCorruptRecord); + this.dataTimeJava8APIEnabled = SQLConf.get().datetimeJava8ApiEnabled(); } /** @return the schema for the converter */ @@ -165,6 +172,7 @@ GenericRowWithSchema toRow(final BsonDocument bsonDocument) { @VisibleForTesting Object convertBsonValue( final String fieldName, final DataType dataType, final BsonValue bsonValue) { + LOGGER.info("converting bson to value: {} {} {}", fieldName, dataType, bsonValue); try { if (bsonValue.isNull()) { return null; @@ -179,9 +187,22 @@ Object convertBsonValue( } else if (dataType instanceof BooleanType) { return convertToBoolean(fieldName, dataType, bsonValue); } else if (dataType instanceof DateType) { - return convertToDate(fieldName, dataType, bsonValue); + Date date = convertToDate(fieldName, dataType, bsonValue); + if (dataTimeJava8APIEnabled) { + return date.toLocalDate(); + } else { + return date; + } } else if (dataType instanceof TimestampType) { - return convertToTimestamp(fieldName, dataType, bsonValue); + Timestamp timestamp = convertToTimestamp(fieldName, dataType, bsonValue); + if (dataTimeJava8APIEnabled) { + return timestamp.toInstant(); + } else { + return timestamp; + } + } else if (isTimestampNTZ(dataType)) { + Timestamp timestamp = convertToTimestamp(fieldName, dataType, bsonValue); + return timestamp.toLocalDateTime(); } else if (dataType instanceof FloatType) { return convertToFloat(fieldName, dataType, bsonValue); } else if (dataType instanceof IntegerType) { diff --git a/src/main/java/com/mongodb/spark/sql/connector/schema/ConverterHelper.java b/src/main/java/com/mongodb/spark/sql/connector/schema/ConverterHelper.java index 385986a3..43bbaf1b 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/schema/ConverterHelper.java +++ b/src/main/java/com/mongodb/spark/sql/connector/schema/ConverterHelper.java @@ -22,6 +22,8 @@ import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; import java.util.Base64; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; import org.bson.BsonDocument; import org.bson.BsonValue; import org.bson.codecs.BsonValueCodec; @@ -39,6 +41,29 @@ static JsonWriterSettings getJsonWriterSettings(final boolean outputExtendedJson return outputExtendedJson ? EXTENDED_JSON_WRITER_SETTINGS : RELAXED_JSON_WRITER_SETTINGS; } + /** + * The {{TimestampNTZType}} if available or null + * + *

Only available in Spark 3.4+ + *

TODO: SPARK-450 remove code for Spark 4.0 + */ + public static final DataType TIMESTAMP_NTZ_TYPE; + + static { + DataType timestampNTZType; + try { + timestampNTZType = + (DataType) DataTypes.class.getDeclaredField("TimestampNTZType").get(DataType.class); + } catch (IllegalAccessException | NoSuchFieldException e) { + timestampNTZType = null; + } + TIMESTAMP_NTZ_TYPE = timestampNTZType; + } + + static boolean isTimestampNTZ(final DataType dataType) { + return TIMESTAMP_NTZ_TYPE != null && TIMESTAMP_NTZ_TYPE.acceptsType(dataType); + } + private static final JsonWriterSettings RELAXED_JSON_WRITER_SETTINGS = JsonWriterSettings.builder() .outputMode(JsonMode.RELAXED) diff --git a/src/main/java/com/mongodb/spark/sql/connector/schema/RowToBsonDocumentConverter.java b/src/main/java/com/mongodb/spark/sql/connector/schema/RowToBsonDocumentConverter.java index d1378515..2c3e03a1 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/schema/RowToBsonDocumentConverter.java +++ b/src/main/java/com/mongodb/spark/sql/connector/schema/RowToBsonDocumentConverter.java @@ -17,6 +17,7 @@ package com.mongodb.spark.sql.connector.schema; +import static com.mongodb.spark.sql.connector.schema.ConverterHelper.isTimestampNTZ; import static java.lang.String.format; import static java.util.Arrays.asList; @@ -26,6 +27,10 @@ import com.mongodb.spark.sql.connector.interop.JavaScala; import java.io.Serializable; import java.math.BigDecimal; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; import java.util.Arrays; import java.util.Date; import java.util.List; @@ -137,9 +142,11 @@ public static ObjectToBsonValue createObjectToBsonValue( try { return cachedObjectToBsonValue.apply(data); } catch (Exception e) { - throw new DataException(format( - "Cannot cast %s into a BsonValue. %s has no matching BsonValue. Error: %s", - data, dataType, e.getMessage())); + throw new DataException( + format( + "Cannot cast %s into a BsonValue. %s has no matching BsonValue. Error: %s", + data, dataType, e.getMessage()), + e); } }; } @@ -177,12 +184,30 @@ private static ObjectToBsonValue objectToBsonValue( } else if (DataTypes.StringType.acceptsType(dataType)) { return (data) -> processString((String) data, convertJson); } else if (DataTypes.DateType.acceptsType(dataType) - || DataTypes.TimestampType.acceptsType(dataType)) { + || DataTypes.TimestampType.acceptsType(dataType) + || isTimestampNTZ(dataType)) { return (data) -> { if (data instanceof Date) { // Covers java.util.Date, java.sql.Date, java.sql.Timestamp return new BsonDateTime(((Date) data).getTime()); } + if (data instanceof Instant) { + return new BsonDateTime(((Instant) data).toEpochMilli()); + } + if (data instanceof LocalDateTime) { + LocalDateTime dateTime = (LocalDateTime) data; + return new BsonDateTime(Timestamp.valueOf(dateTime).getTime()); + } + if (data instanceof LocalDate) { + long epochSeconds = ((LocalDate) data).toEpochDay() * 24L * 3600L; + return new BsonDateTime(epochSeconds * 1000L); + } + + /* + NOTE 1: ZonedDateTime, OffsetDateTime, OffsetTime are not explicitly supported by Spark and cause the Encoder resolver to fail + due to cyclic dependency in the ZoneOffset. Subject for review after it changes (if ever). + NOTE 2: LocalTime type is not represented neither in Bson nor in Spark + */ throw new MongoSparkException( "Unsupported date type: " + data.getClass().getSimpleName()); }; From c4b594125dd1420d07743c6910619f768bd7b089 Mon Sep 17 00:00:00 2001 From: Tomas Sedlak Date: Tue, 15 Jul 2025 17:49:25 +0200 Subject: [PATCH 4/4] Streaming added full document before change support SPARK-449 Original PR: #140 --------- Co-authored-by: Ross Lawley --- .../mongodb/MongoSparkConnectorTestCase.java | 4 ++ .../read/AbstractMongoStreamTest.java | 62 +++++++++++++++++++ .../sql/connector/config/ReadConfig.java | 43 +++++++++++++ .../read/MongoContinuousPartitionReader.java | 1 + .../read/MongoMicroBatchPartitionReader.java | 1 + .../sql/connector/config/MongoConfigTest.java | 24 +++++++ 6 files changed, 135 insertions(+) diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorTestCase.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorTestCase.java index 2150742c..35dea4bd 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorTestCase.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorTestCase.java @@ -101,6 +101,10 @@ public boolean isAtLeastFiveDotZero() { return getMaxWireVersion() >= 12; } + public boolean isAtLeastSixDotZero() { + return getMaxWireVersion() >= 17; + } + public boolean isAtLeastSevenDotZero() { return getMaxWireVersion() >= 21; } diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/read/AbstractMongoStreamTest.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/read/AbstractMongoStreamTest.java index fde7a44c..6ca8bf37 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/read/AbstractMongoStreamTest.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/read/AbstractMongoStreamTest.java @@ -39,6 +39,8 @@ import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.ChangeStreamPreAndPostImagesOptions; +import com.mongodb.client.model.CreateCollectionOptions; import com.mongodb.client.model.Filters; import com.mongodb.client.model.InsertManyOptions; import com.mongodb.client.model.Updates; @@ -54,6 +56,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.concurrent.TimeoutException; import java.util.function.BiConsumer; @@ -304,6 +307,60 @@ void testStreamWithPublishFullDocumentOnly(final String collectionsConfigModeStr msg))); } + @Test + void testStreamFullDocumentBeforeChange() { + assumeTrue(supportsChangeStreams()); + assumeTrue(isAtLeastSixDotZero()); + + CollectionsConfig.Type collectionsConfigType = CollectionsConfig.Type.SINGLE; + testIdentifier = computeTestIdentifier("FullDocBeforeChange", collectionsConfigType); + + testStreamingQuery( + createMongoConfig(collectionsConfigType) + .withOption( + ReadConfig.READ_PREFIX + + ReadConfig.STREAM_LOOKUP_FULL_DOCUMENT_BEFORE_CHANGE_CONFIG, + "required"), + DOCUMENT_BEFORE_CHANGE_SCHEMA, + withSourceDb( + "Create the collection", + (msg, db) -> db.createCollection( + collectionName(), + new CreateCollectionOptions() + .changeStreamPreAndPostImagesOptions( + new ChangeStreamPreAndPostImagesOptions(true)))), + withSource("inserting 0-25", (msg, coll) -> coll.insertMany(createDocuments(0, 25))), + withMemorySink("Expected to see 25 documents", (msg, ds) -> { + List rows = ds.collectAsList(); + assertEquals(25, rows.size(), msg); + assertTrue( + rows.stream() + .map(r -> r.getString(r.fieldIndex("fullDocumentBeforeChange"))) + .allMatch(Objects::isNull), + msg); + }), + withSource( + "Updating all", + (msg, coll) -> + coll.updateMany(new BsonDocument(), Updates.set("a", new BsonString("a")))), + withMemorySink( + "Expecting to see 50 documents and the last 25 have fullDocumentBeforeChange", + (msg, ds) -> { + List rows = ds.collectAsList(); + assertEquals(50, rows.size()); + assertTrue( + rows.subList(0, 24).stream() + .map(r -> r.getString(r.fieldIndex("fullDocumentBeforeChange"))) + .allMatch(Objects::isNull), + msg); + assertTrue( + rows.subList(25, 50).stream() + .map(r -> r.getString(r.fieldIndex("fullDocumentBeforeChange"))) + .noneMatch(Objects::isNull), + msg); + })); + } + @ParameterizedTest @ValueSource(strings = {"SINGLE", "MULTIPLE", "ALL"}) void testStreamPublishFullDocumentOnlyHandlesCollectionDrop( @@ -707,6 +764,11 @@ void testReadsWithParseMode() { createStructField("clusterTime", DataTypes.StringType, false), createStructField("fullDocument", DataTypes.StringType, true))); + private static final StructType DOCUMENT_BEFORE_CHANGE_SCHEMA = createStructType(asList( + createStructField("operationType", DataTypes.StringType, false), + createStructField("clusterTime", DataTypes.StringType, false), + createStructField("fullDocumentBeforeChange", DataTypes.StringType, true))); + @SafeVarargs private final void testStreamingQuery( final MongoConfig mongoConfig, diff --git a/src/main/java/com/mongodb/spark/sql/connector/config/ReadConfig.java b/src/main/java/com/mongodb/spark/sql/connector/config/ReadConfig.java index 2c7d6a01..079536b7 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/config/ReadConfig.java +++ b/src/main/java/com/mongodb/spark/sql/connector/config/ReadConfig.java @@ -23,6 +23,7 @@ import static java.util.Collections.unmodifiableList; import com.mongodb.client.model.changestream.FullDocument; +import com.mongodb.client.model.changestream.FullDocumentBeforeChange; import com.mongodb.spark.sql.connector.exceptions.ConfigException; import com.mongodb.spark.sql.connector.read.partitioner.Partitioner; import java.util.HashMap; @@ -288,6 +289,35 @@ static ParseMode fromString(final String userParseMode) { private static final String STREAM_LOOKUP_FULL_DOCUMENT_DEFAULT = FullDocument.DEFAULT.getValue(); + /** + * Streaming full document before change configuration. + * + *

Determines what to return as the pre-image of the document during replace, update, or delete operations + * when using a MongoDB Change Stream. + * + *

Only applies if the MongoDB server is configured to capture pre-images. + * See: + * Change streams lookup full document before change for further details. + * + *

Possible values: + *

    + *
  • "default" - Uses the server's default behavior for the fullDocumentBeforeChange field.
  • + *
  • "off" - Do not include the pre-image of the document in the change stream event.
  • + *
  • "whenAvailable" - Include the pre-image of the modified document if available; otherwise, omit it.
  • + *
  • "required" - Include the pre-image, and raise an error if it is not available.
  • + *
+ * + *

Configuration: {@value} + * + *

Default: "default" - the server's default behavior for the fullDocumentBeforeChange field. + * @since 10.6 + */ + public static final String STREAM_LOOKUP_FULL_DOCUMENT_BEFORE_CHANGE_CONFIG = + "change.stream.lookup.full.document.before.change"; + + private static final String STREAM_LOOKUP_FULL_DOCUMENT_BEFORE_CHANGE_DEFAULT = + FullDocumentBeforeChange.DEFAULT.getValue(); + enum StreamingStartupMode { LATEST, TIMESTAMP; @@ -492,6 +522,19 @@ public FullDocument getStreamFullDocument() { } } + /** @return the stream full document before change configuration or 'default' if not set. + * @since 10.6 + */ + public FullDocumentBeforeChange getStreamFullDocumentBeforeChange() { + try { + return FullDocumentBeforeChange.fromString(getOrDefault( + STREAM_LOOKUP_FULL_DOCUMENT_BEFORE_CHANGE_CONFIG, + STREAM_LOOKUP_FULL_DOCUMENT_BEFORE_CHANGE_DEFAULT)); + } catch (IllegalArgumentException e) { + throw new ConfigException(e); + } + } + /** @return true if should drop any malformed rows */ public boolean dropMalformed() { return parseMode == ParseMode.DROPMALFORMED; diff --git a/src/main/java/com/mongodb/spark/sql/connector/read/MongoContinuousPartitionReader.java b/src/main/java/com/mongodb/spark/sql/connector/read/MongoContinuousPartitionReader.java index 0ac677ae..9381fc1c 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/read/MongoContinuousPartitionReader.java +++ b/src/main/java/com/mongodb/spark/sql/connector/read/MongoContinuousPartitionReader.java @@ -196,6 +196,7 @@ private MongoChangeStreamCursor getCursor() { } changeStreamIterable .fullDocument(readConfig.getStreamFullDocument()) + .fullDocumentBeforeChange(readConfig.getStreamFullDocumentBeforeChange()) .comment(readConfig.getComment()); changeStreamIterable = lastOffset.applyToChangeStreamIterable(changeStreamIterable); diff --git a/src/main/java/com/mongodb/spark/sql/connector/read/MongoMicroBatchPartitionReader.java b/src/main/java/com/mongodb/spark/sql/connector/read/MongoMicroBatchPartitionReader.java index e7f7467d..1681fc95 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/read/MongoMicroBatchPartitionReader.java +++ b/src/main/java/com/mongodb/spark/sql/connector/read/MongoMicroBatchPartitionReader.java @@ -185,6 +185,7 @@ private MongoChangeStreamCursor getCursor() { } changeStreamIterable .fullDocument(readConfig.getStreamFullDocument()) + .fullDocumentBeforeChange(readConfig.getStreamFullDocumentBeforeChange()) .comment(readConfig.getComment()); if (partition.getStartOffsetTimestamp().getTime() >= 0) { changeStreamIterable.startAtOperationTime(partition.getStartOffsetTimestamp()); diff --git a/src/test/java/com/mongodb/spark/sql/connector/config/MongoConfigTest.java b/src/test/java/com/mongodb/spark/sql/connector/config/MongoConfigTest.java index 0d5df138..a912261e 100644 --- a/src/test/java/com/mongodb/spark/sql/connector/config/MongoConfigTest.java +++ b/src/test/java/com/mongodb/spark/sql/connector/config/MongoConfigTest.java @@ -28,6 +28,7 @@ import com.mongodb.WriteConcern; import com.mongodb.client.model.changestream.FullDocument; +import com.mongodb.client.model.changestream.FullDocumentBeforeChange; import com.mongodb.spark.sql.connector.exceptions.ConfigException; import java.util.HashMap; import java.util.Map; @@ -324,6 +325,29 @@ void testReadConfigStreamFullDocument() { assertEquals(readConfig.getStreamFullDocument(), FullDocument.UPDATE_LOOKUP); } + @Test + void testReadConfigStreamFullDocumentBeforeChange() { + ReadConfig readConfig = MongoConfig.readConfig(CONFIG_MAP); + assertEquals(readConfig.getStreamFullDocumentBeforeChange(), FullDocumentBeforeChange.DEFAULT); + + readConfig = + readConfig.withOption(ReadConfig.STREAM_LOOKUP_FULL_DOCUMENT_BEFORE_CHANGE_CONFIG, "off"); + assertEquals(readConfig.getStreamFullDocumentBeforeChange(), FullDocumentBeforeChange.OFF); + + readConfig = readConfig.withOption( + ReadConfig.STREAM_LOOKUP_FULL_DOCUMENT_BEFORE_CHANGE_CONFIG, "whenAvailable"); + assertEquals( + readConfig.getStreamFullDocumentBeforeChange(), FullDocumentBeforeChange.WHEN_AVAILABLE); + + readConfig = readConfig.withOption( + ReadConfig.STREAM_LOOKUP_FULL_DOCUMENT_BEFORE_CHANGE_CONFIG, "required"); + assertEquals(readConfig.getStreamFullDocumentBeforeChange(), FullDocumentBeforeChange.REQUIRED); + + readConfig = readConfig.withOption( + ReadConfig.STREAM_LOOKUP_FULL_DOCUMENT_BEFORE_CHANGE_CONFIG, "INVALID"); + assertThrows(ConfigException.class, readConfig::getStreamFullDocumentBeforeChange); + } + @Test void testReadConfigSchemaHints() { ReadConfig readConfig =