diff --git a/spark-job/pom.xml b/spark-job/pom.xml
index c7df5bd..cd1603b 100644
--- a/spark-job/pom.xml
+++ b/spark-job/pom.xml
@@ -67,5 +67,12 @@
junit
+
+ org.mockito
+ mockito-core
+ 3.5.10
+ test
+
+
diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/DiffJob.java b/spark-job/src/main/java/org/apache/cassandra/diff/DiffJob.java
index d744bff..2632c09 100644
--- a/spark-job/src/main/java/org/apache/cassandra/diff/DiffJob.java
+++ b/spark-job/src/main/java/org/apache/cassandra/diff/DiffJob.java
@@ -122,9 +122,12 @@ public void run(JobConfiguration configuration, JavaSparkContext sc) {
ClusterProvider metadataProvider = ClusterProvider.getProvider(configuration.clusterConfig("metadata"), "metadata");
JobMetadataDb.JobLifeCycle job = null;
UUID jobId = null;
- try (Cluster metadataCluster = metadataProvider.getCluster();
- Session metadataSession = metadataCluster.connect()) {
+ Cluster metadataCluster = null;
+ Session metadataSession = null;
+ try {
+ metadataCluster = metadataProvider.getCluster();
+ metadataSession = metadataCluster.connect();
RetryStrategyProvider retryStrategyProvider = RetryStrategyProvider.create(configuration.retryOptions());
MetadataKeyspaceOptions metadataOptions = configuration.metadataOptions();
JobMetadataDb.Schema.maybeInitialize(metadataSession, metadataOptions, retryStrategyProvider);
@@ -197,18 +200,32 @@ public void run(JobConfiguration configuration, JavaSparkContext sc) {
Differ.shutdown();
JobMetadataDb.ProgressTracker.resetStatements();
}
+ if (metadataCluster != null) {
+ metadataCluster.close();
+ }
+ if (metadataSession != null) {
+ metadataSession.close();
+ }
+
}
}
- private static Params getJobParams(JobMetadataDb.JobLifeCycle job, JobConfiguration conf, List keyspaceTables) {
+ @VisibleForTesting
+ static Params getJobParams(JobMetadataDb.JobLifeCycle job, JobConfiguration conf, List keyspaceTables) {
if (conf.jobId().isPresent()) {
- return job.getJobParams(conf.jobId().get());
- } else {
- return new Params(UUID.randomUUID(),
- keyspaceTables,
- conf.buckets(),
- conf.splits());
+ final Params jobParams = job.getJobParams(conf.jobId().get());
+ if(jobParams != null) {
+ // When job_id is passed as a config property for the first time, we will not have metadata associated
+ // with job_id in metadata table. we should return jobParams from the table only when jobParams is not null
+ // Otherwise return new jobParams with provided job_id
+ return jobParams;
+ }
}
+ final UUID jobId = conf.jobId().isPresent() ? conf.jobId().get() : UUID.randomUUID();
+ return new Params(jobId,
+ keyspaceTables,
+ conf.buckets(),
+ conf.splits());
}
private static List getSplits(JobConfiguration config, TokenHelper tokenHelper) {
diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/JobMetadataDb.java b/spark-job/src/main/java/org/apache/cassandra/diff/JobMetadataDb.java
index bef173a..71a802f 100644
--- a/spark-job/src/main/java/org/apache/cassandra/diff/JobMetadataDb.java
+++ b/spark-job/src/main/java/org/apache/cassandra/diff/JobMetadataDb.java
@@ -369,7 +369,7 @@ public void initializeJob(DiffJob.Params params,
metadataKeyspace, Schema.RUNNING_JOBS),
params.jobId);
if (!rs.one().getBool("[applied]")) {
- logger.info("Aborting due to inability to mark job as running. " +
+ logger.info("Could not mark job as running. " +
"Did a previous run of job id {} fail non-gracefully?",
params.jobId);
throw new RuntimeException("Unable to mark job running, aborting");
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java
index 9082970..1bf656d 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java
@@ -20,11 +20,18 @@
package org.apache.cassandra.diff;
import java.math.BigInteger;
+import java.util.ArrayList;
import java.util.List;
+import java.util.Optional;
+import java.util.UUID;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
public class DiffJobTest
{
@@ -39,6 +46,37 @@ public void testSplitsRandom()
splitTestHelper(TokenHelper.forPartitioner("RandomPartitioner"));
}
+ @Test
+ public void testGetJobParamsWithJobIdProvidedShouldReturnNonNullConFigParams() {
+ final MockConfig mockConfig = new MockConfig();
+ final JobMetadataDb.JobLifeCycle mockJob = mock(JobMetadataDb.JobLifeCycle.class);
+ final List keyspaceTablePairs = new ArrayList<>();
+ final DiffJob.Params params = DiffJob.getJobParams(mockJob, mockConfig, keyspaceTablePairs);
+ assertNotNull(params);
+ }
+
+ @Test
+ public void testGetJobParamsDuringRetryShouldReturnPreviousParams() {
+ final MockConfig mockConfig = new MockConfig();
+ final JobMetadataDb.JobLifeCycle mockJob = mock(JobMetadataDb.JobLifeCycle.class);
+ final DiffJob.Params mockParams = mock(DiffJob.Params.class);
+ when(mockJob.getJobParams(any())).thenAnswer(invocationOnMock -> mockParams);
+ final List keyspaceTablePairs = new ArrayList<>();
+ final DiffJob.Params params = DiffJob.getJobParams(mockJob, mockConfig, keyspaceTablePairs);
+ assertEquals(params, mockParams);
+ }
+
+ @Test
+ public void testGetJobParamsWithNoJobId() {
+ final MockConfig mockConfig = mock(MockConfig.class);
+ when(mockConfig.jobId()).thenReturn(Optional.empty());
+
+ final JobMetadataDb.JobLifeCycle mockJob = mock(JobMetadataDb.JobLifeCycle.class);
+ final List keyspaceTablePairs = new ArrayList<>();
+ final DiffJob.Params params = DiffJob.getJobParams(mockJob, mockConfig, keyspaceTablePairs);
+ assertNotNull(params.jobId);
+ }
+
private void splitTestHelper(TokenHelper tokens)
{
List splits = DiffJob.calculateSplits(50, 1, tokens);
@@ -54,4 +92,21 @@ private void splitTestHelper(TokenHelper tokens)
for (int i = 0; i < splits.size(); i++)
assertEquals(i, splits.get(i).splitNumber);
}
+
+ private class MockConfig extends AbstractMockJobConfiguration {
+ @Override
+ public int splits() {
+ return 2;
+ }
+
+ @Override
+ public int buckets() {
+ return 2;
+ }
+
+ @Override
+ public Optional jobId() {
+ return Optional.of(UUID.randomUUID());
+ }
+ }
}