diff --git a/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java b/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java index 7a20b30..8d74de8 100644 --- a/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java +++ b/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java @@ -87,6 +87,13 @@ default SpecificTokens specificTokens() { MetadataKeyspaceOptions metadataOptions(); + /** + * Sampling probability ranges from 0-1 which decides how many partitions are to be diffed using probabilistic diff + * default value is 1 which means all the partitions are diffed + * @return partitionSamplingProbability + */ + double partitionSamplingProbability(); + /** * Contains the options that specify the retry strategy for retrieving data at the application level. * Note that it is different than cassandra java driver's {@link com.datastax.driver.core.policies.RetryPolicy}, diff --git a/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java b/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java index 359466a..7d60403 100644 --- a/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java +++ b/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java @@ -48,6 +48,7 @@ public class YamlJobConfiguration implements JobConfiguration { public String specific_tokens = null; public String disallowed_tokens = null; public RetryOptions retry_options; + public double partition_sampling_probability = 1; public static YamlJobConfiguration load(InputStream inputStream) { Yaml yaml = new Yaml(new CustomClassLoaderConstructor(YamlJobConfiguration.class, @@ -103,6 +104,11 @@ public MetadataKeyspaceOptions metadataOptions() { return metadata_options; } + @Override + public double partitionSamplingProbability() { + return partition_sampling_probability; + } + public RetryOptions retryOptions() { return retry_options; } @@ -130,6 +136,7 @@ public String toString() { ", keyspace_tables=" + keyspace_tables + ", buckets=" + buckets + ", rate_limit=" + rate_limit + + ", partition_sampling_probability=" + partition_sampling_probability + ", job_id='" + job_id + '\'' + ", token_scan_fetch_size=" + token_scan_fetch_size + ", partition_read_fetch_size=" + partition_read_fetch_size + diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java b/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java index cf1c9a5..7415d1a 100644 --- a/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java +++ b/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java @@ -27,10 +27,12 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.UUID; import java.util.concurrent.Callable; import java.util.function.BiConsumer; import java.util.function.Function; +import java.util.function.Predicate; import java.util.stream.Collectors; import com.google.common.annotations.VisibleForTesting; @@ -63,6 +65,7 @@ public class Differ implements Serializable private final double reverseReadProbability; private final SpecificTokens specificTokens; private final RetryStrategyProvider retryStrategyProvider; + private final double partitionSamplingProbability; private static DiffCluster srcDiffCluster; private static DiffCluster targetDiffCluster; @@ -103,6 +106,7 @@ public Differ(JobConfiguration config, this.reverseReadProbability = config.reverseReadProbability(); this.specificTokens = config.specificTokens(); this.retryStrategyProvider = retryStrategyProvider; + this.partitionSamplingProbability = config.partitionSamplingProbability(); synchronized (Differ.class) { /* @@ -225,12 +229,31 @@ public RangeStats diffTable(final DiffContext context, mismatchReporter, journal, COMPARISON_EXECUTOR); - - final RangeStats tableStats = rangeComparator.compare(sourceKeys, targetKeys, partitionTaskProvider); + final Predicate partitionSamplingFunction = shouldIncludePartition(jobId, partitionSamplingProbability); + final RangeStats tableStats = rangeComparator.compare(sourceKeys, targetKeys, partitionTaskProvider, partitionSamplingFunction); logger.debug("Table [{}] stats - ({})", context.table.getTable(), tableStats); return tableStats; } + // Returns a function which decides if we should include a partition for diffing + // Uses probability for sampling. + @VisibleForTesting + static Predicate shouldIncludePartition(final UUID jobId, final double partitionSamplingProbability) { + if (partitionSamplingProbability > 1 || partitionSamplingProbability <= 0) { + final String message = "Invalid partition sampling property " + + partitionSamplingProbability + + ", it should be between 0 and 1"; + logger.error(message); + throw new IllegalArgumentException(message); + } + if (partitionSamplingProbability == 1) { + return partitionKey -> true; + } else { + final Random random = new Random(jobId.hashCode()); + return partitionKey -> random.nextDouble() <= partitionSamplingProbability; + } + } + private Iterator fetchRows(DiffContext context, PartitionKey key, boolean shouldReverse, DiffCluster.Type type) { Callable> rows = () -> type == DiffCluster.Type.SOURCE ? context.source.getPartition(context.table, key, shouldReverse) diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java b/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java index 5d6710e..280fbd5 100644 --- a/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java +++ b/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java @@ -27,6 +27,7 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Predicate; import com.google.common.base.Verify; import org.slf4j.Logger; @@ -57,6 +58,22 @@ public RangeComparator(DiffContext context, public RangeStats compare(Iterator sourceKeys, Iterator targetKeys, Function partitionTaskProvider) { + return compare(sourceKeys,targetKeys,partitionTaskProvider, partitionKey -> true); + } + + /** + * Compares partitions in src and target clusters. + * + * @param sourceKeys partition keys in the source cluster + * @param targetKeys partition keys in the target cluster + * @param partitionTaskProvider comparision task + * @param partitionSampler samples partitions based on the probability for probabilistic diff + * @return stats about the diff + */ + public RangeStats compare(Iterator sourceKeys, + Iterator targetKeys, + Function partitionTaskProvider, + Predicate partitionSampler) { final RangeStats rangeStats = RangeStats.newStats(); // We can catch this condition earlier, but it doesn't hurt to also check here @@ -115,11 +132,16 @@ public RangeStats compare(Iterator sourceKeys, BigInteger token = sourceKey.getTokenAsBigInteger(); try { - PartitionComparator comparisonTask = partitionTaskProvider.apply(sourceKey); - comparisonExecutor.submit(comparisonTask, - onSuccess(rangeStats, partitionCount, token, highestTokenSeen, mismatchReporter, journal), - onError(rangeStats, token, errorReporter), - phaser); + // Use probabilisticPartitionSampler for sampling partitions, skip partition + // if the sampler returns false otherwise run diff on that partition + if (partitionSampler.test(sourceKey)) { + PartitionComparator comparisonTask = partitionTaskProvider.apply(sourceKey); + comparisonExecutor.submit(comparisonTask, + onSuccess(rangeStats, partitionCount, token, highestTokenSeen, mismatchReporter, journal), + onError(rangeStats, token, errorReporter), + phaser); + } + } catch (Throwable t) { // Handle errors thrown when creating the comparison task. This should trap timeouts and // unavailables occurring when performing the initial query to read the full partition. diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java index 1bf656d..49c1f11 100644 --- a/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java +++ b/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java @@ -108,5 +108,10 @@ public int buckets() { public Optional jobId() { return Optional.of(UUID.randomUUID()); } + + @Override + public double partitionSamplingProbability() { + return 1; + } } } diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java index e588575..ea59006 100644 --- a/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java +++ b/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java @@ -21,16 +21,65 @@ import java.math.BigInteger; import java.util.Map; +import java.util.UUID; import java.util.function.Function; +import java.util.function.Predicate; import com.google.common.base.VerifyException; import com.google.common.collect.Lists; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; public class DifferTest { + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testIncludeAllPartitions() { + final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0); + final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e"); + assertTrue(Differ.shouldIncludePartition(uuid, 1).test(testKey)); + } + + @Test + public void shouldIncludePartitionWithProbabilityInvalidProbability() { + final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0); + final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e"); + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Invalid partition sampling property -1.0, it should be between 0 and 1"); + Differ.shouldIncludePartition(uuid, -1).test(testKey); + } + + @Test + public void shouldIncludePartitionWithProbabilityHalf() { + final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0); + int count = 0; + final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e"); + final Predicate partitionSampler = Differ.shouldIncludePartition(uuid, 0.5); + for (int i = 0; i < 20; i++) { + if (partitionSampler.test(testKey)) { + count++; + } + } + assertTrue(count <= 15); + assertTrue(count >= 5); + } + + @Test + public void shouldIncludePartitionShouldGenerateSameSequenceForGivenJobId() { + final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e"); + final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0); + final Predicate partitionSampler1 = Differ.shouldIncludePartition(uuid, 0.5); + final Predicate partitionSampler2 = Differ.shouldIncludePartition(uuid, 0.5); + for (int i = 0; i < 10; i++) { + assertEquals(partitionSampler2.test(testKey), partitionSampler1.test(testKey)); + } + } @Test(expected = VerifyException.class) public void rejectNullStartOfRange() { diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java index fd2926b..e09f68f 100644 --- a/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java +++ b/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java @@ -56,6 +56,38 @@ public class RangeComparatorTest { private ComparisonExecutor executor = ComparisonExecutor.newExecutor(1, new MetricRegistry()); private RetryStrategyProvider mockRetryStrategyFactory = RetryStrategyProvider.create(null); // create a NoRetry provider + @Test + public void probabilisticDiffIncludeAllPartitions() { + RangeComparator comparator = comparator(context(0L, 100L)); + RangeStats stats = comparator.compare(keys(0, 1, 2, 3, 4, 5, 6), keys(0,1, 2, 3, 4, 5, 7), this::alwaysMatch); + assertFalse(stats.isEmpty()); + assertEquals(1, stats.getOnlyInSource()); + assertEquals(1, stats.getOnlyInTarget()); + assertEquals(6, stats.getMatchedPartitions()); + assertReported(6, MismatchType.ONLY_IN_SOURCE, mismatches); + assertReported(7, MismatchType.ONLY_IN_TARGET, mismatches); + assertNothingReported(errors, journal); + assertCompared(0, 1, 2, 3, 4, 5); + } + + @Test + public void probabilisticDiffProbabilityHalf() { + RangeComparator comparator = comparator(context(0L, 100L)); + RangeStats stats = comparator.compare(keys(0, 1, 2, 3, 4, 5, 6), + keys(0, 1, 2, 3, 4, 5, 7), + this::alwaysMatch, + key -> key.getTokenAsBigInteger().intValue() % 2 == 0); + assertFalse(stats.isEmpty()); + assertEquals(1, stats.getOnlyInSource()); + assertEquals(1, stats.getOnlyInTarget()); + assertEquals(3, stats.getMatchedPartitions()); + assertReported(6, MismatchType.ONLY_IN_SOURCE, mismatches); + assertReported(7, MismatchType.ONLY_IN_TARGET, mismatches); + assertNothingReported(errors, journal); + assertCompared(0, 2, 4); + } + + @Test public void emptyRange() { RangeComparator comparator = comparator(context(100L, 100L)); diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java index 17dc67c..b94d22c 100644 --- a/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java +++ b/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java @@ -29,6 +29,11 @@ private class MockConfig extends AbstractMockJobConfiguration { public List disallowedKeyspaces() { return disallowedKeyspaces; } + + @Override + public double partitionSamplingProbability() { + return 1; + } } @Test