Skip to content

Commit b02113e

Browse files
authored
Merge pull request #18 from Ankitp1342/feature/batch-size-one-fix
Fixed bug with batch-size one & some logging improvements
2 parents 92aca37 + a483829 commit b02113e

File tree

5 files changed

+39
-48
lines changed

5 files changed

+39
-48
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
<groupId>com.datastax.spark.example</groupId>
55
<artifactId>migrate</artifactId>
6-
<version>0.11</version>
6+
<version>0.12</version>
77
<packaging>jar</packaging>
88

99
<properties>

src/main/java/datastax/astra/migrate/AbstractJobSession.java

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ public abstract class AbstractJobSession {
3939
protected Integer batchSize = 1;
4040
protected Integer printStatsAfter = 100000;
4141

42-
protected Boolean writeTimeStampFilter = false;
42+
protected Boolean isPreserveTTLWritetime = Boolean.FALSE;
43+
protected Boolean writeTimeStampFilter = Boolean.FALSE;
4344
protected Long minWriteTimeStampFilter = 0l;
4445
protected Long maxWriteTimeStampFilter = Long.MAX_VALUE;
4546

@@ -51,11 +52,9 @@ public abstract class AbstractJobSession {
5152
protected String sourceKeyspaceTable;
5253
protected String astraKeyspaceTable;
5354

54-
5555
protected Boolean hasRandomPartitioner;
5656

5757
protected AbstractJobSession(CqlSession sourceSession, CqlSession astraSession, SparkConf sparkConf) {
58-
5958
this.sourceSession = sourceSession;
6059
this.astraSession = astraSession;
6160

@@ -72,22 +71,43 @@ protected AbstractJobSession(CqlSession sourceSession, CqlSession astraSession,
7271
sourceKeyspaceTable = sparkConf.get("spark.migrate.source.keyspaceTable");
7372
astraKeyspaceTable = sparkConf.get("spark.migrate.astra.keyspaceTable");
7473

74+
isPreserveTTLWritetime = Boolean.parseBoolean(sparkConf.get("spark.migrate.preserveTTLWriteTime", "false"));
75+
if (isPreserveTTLWritetime) {
76+
String ttlColsStr = sparkConf.get("spark.migrate.source.ttl.cols");
77+
if (null != ttlColsStr && ttlColsStr.trim().length() > 0) {
78+
for (String ttlCol : ttlColsStr.split(",")) {
79+
ttlCols.add(Integer.parseInt(ttlCol));
80+
}
81+
}
82+
}
83+
7584
writeTimeStampFilter = Boolean
7685
.parseBoolean(sparkConf.get("spark.migrate.source.writeTimeStampFilter", "false"));
77-
minWriteTimeStampFilter = new Long(
78-
sparkConf.get("spark.migrate.source.minWriteTimeStampFilter", "0"));
79-
maxWriteTimeStampFilter = new Long(
80-
sparkConf.get("spark.migrate.source.maxWriteTimeStampFilter", "" + Long.MAX_VALUE));
8186
// batchsize set to 1 if there is a writeFilter
8287
if (writeTimeStampFilter) {
8388
batchSize = 1;
89+
String writeTimestampColsStr = sparkConf.get("spark.migrate.source.writeTimeStampFilter.cols");
90+
if (null != writeTimestampColsStr && writeTimestampColsStr.trim().length() > 0) {
91+
for (String writeTimeStampCol : writeTimestampColsStr.split(",")) {
92+
writeTimeStampCols.add(Integer.parseInt(writeTimeStampCol));
93+
}
94+
}
8495
}
96+
97+
minWriteTimeStampFilter = new Long(
98+
sparkConf.get("spark.migrate.source.minWriteTimeStampFilter", "0"));
99+
maxWriteTimeStampFilter = new Long(
100+
sparkConf.get("spark.migrate.source.maxWriteTimeStampFilter", "" + Long.MAX_VALUE));
101+
85102
logger.info(" DEFAULT -- Write Batch Size: " + batchSize);
86103
logger.info(" DEFAULT -- Source Keyspace Table: " + sourceKeyspaceTable);
87104
logger.info(" DEFAULT -- Astra Keyspace Table: " + astraKeyspaceTable);
88105
logger.info(" DEFAULT -- ReadRateLimit: " + readLimiter.getRate());
89106
logger.info(" DEFAULT -- WriteRateLimit: " + writeLimiter.getRate());
90107
logger.info(" DEFAULT -- WriteTimestampFilter: " + writeTimeStampFilter);
108+
logger.info(" DEFAULT -- WriteTimestampFilterCols: " + writeTimeStampCols);
109+
logger.info(" DEFAULT -- isPreserveTTLWritetime: " + isPreserveTTLWritetime);
110+
logger.info(" DEFAULT -- TTLCols: " + ttlCols);
91111

92112
hasRandomPartitioner = Boolean.parseBoolean(sparkConf.get("spark.migrate.source.hasRandomPartitioner", "false"));
93113

@@ -96,20 +116,6 @@ protected AbstractJobSession(CqlSession sourceSession, CqlSession astraSession,
96116
counterDeltaMaxIndex = Integer
97117
.parseInt(sparkConf.get("spark.migrate.source.counterTable.update.max.counter.index", "0"));
98118

99-
String writeTimestampColsStr = sparkConf.get("spark.migrate.source.writeTimeStampFilter.cols");
100-
if (null != writeTimestampColsStr && writeTimestampColsStr.trim().length() > 0) {
101-
for (String writeTimeStampCol : writeTimestampColsStr.split(",")) {
102-
writeTimeStampCols.add(Integer.parseInt(writeTimeStampCol));
103-
}
104-
}
105-
106-
String ttlColsStr = sparkConf.get("spark.migrate.source.ttl.cols");
107-
if (null != ttlColsStr && ttlColsStr.trim().length() > 0) {
108-
for (String ttlCol : ttlColsStr.split(",")) {
109-
ttlCols.add(Integer.parseInt(ttlCol));
110-
}
111-
}
112-
113119
String partionKey = sparkConf.get("spark.migrate.query.cols.partitionKey");
114120
String idCols = sparkConf.get("spark.migrate.query.cols.id");
115121
idColTypes = getTypes(sparkConf.get("spark.migrate.query.cols.id.types"));
@@ -128,15 +134,13 @@ protected AbstractJobSession(CqlSession sourceSession, CqlSession astraSession,
128134
}
129135

130136
sourceSelectCondition = sparkConf.get("spark.migrate.query.cols.select.condition", "");
131-
132137
sourceSelectStatement = sourceSession.prepare(
133138
"select " + selectCols + " from " + sourceKeyspaceTable + " where token(" + partionKey.trim()
134139
+ ") >= ? and token(" + partionKey.trim() + ") <= ? " + sourceSelectCondition + " ALLOW FILTERING");
135140

136141
astraSelectStatement = astraSession.prepare(
137142
"select " + selectCols + " from " + astraKeyspaceTable
138143
+ " where " + idBinds);
139-
140144
}
141145

142146
public List<MigrateDataType> getTypes(String types) {
@@ -146,7 +150,6 @@ public List<MigrateDataType> getTypes(String types) {
146150
}
147151

148152
return dataTypes;
149-
150153
}
151154

152155
public int getLargestTTL(Row sourceRow) {
@@ -177,7 +180,6 @@ public BoundStatement selectFromAstra(PreparedStatement selectStatement, Row sou
177180
}
178181

179182
public Object getData(MigrateDataType dataType, int index, Row sourceRow) {
180-
181183
if (dataType.typeClass == Map.class) {
182184
return sourceRow.getMap(index, dataType.subTypes.get(0), dataType.subTypes.get(1));
183185
} else if (dataType.typeClass == List.class) {

src/main/java/datastax/astra/migrate/CopyJobSession.java

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ public class CopyJobSession extends AbstractJobSession {
2424

2525
protected List<MigrateDataType> insertColTypes = new ArrayList<MigrateDataType>();
2626
protected List<Integer> updateSelectMapping = new ArrayList<Integer>();
27-
protected Boolean isPreserveTTLWritetime = Boolean.FALSE;
2827

2928
public static CopyJobSession getInstance(CqlSession sourceSession, CqlSession astraSession, SparkConf sparkConf) {
3029
if (copyJobSession == null) {
@@ -53,26 +52,22 @@ protected CopyJobSession(CqlSession sourceSession, CqlSession astraSession, Spar
5352
}
5453
count++;
5554
}
56-
isPreserveTTLWritetime = Boolean.parseBoolean(sparkConf.get("spark.migrate.preserveTTLWriteTime", "false"));
5755

5856
if (isCounterTable) {
5957
String updateSelectMappingStr = sparkConf.get("spark.migrate.source.counterTable.update.select.index", "0");
6058
for (String updateSelectIndex : updateSelectMappingStr.split(",")) {
6159
updateSelectMapping.add(Integer.parseInt(updateSelectIndex));
6260
}
6361

64-
6562
String counterTableUpdate = sparkConf.get("spark.migrate.source.counterTable.update.cql");
6663
astraInsertStatement = astraSession.prepare(counterTableUpdate);
6764
} else {
68-
6965
if (isPreserveTTLWritetime) {
7066
astraInsertStatement = astraSession.prepare("insert into " + astraKeyspaceTable + " (" + insertCols + ") VALUES (" + insertBinds + ") using TTL ? and TIMESTAMP ?");
7167
} else {
7268
astraInsertStatement = astraSession.prepare("insert into " + astraKeyspaceTable + " (" + insertCols + ") VALUES (" + insertBinds + ")");
7369
}
7470
}
75-
7671
}
7772

7873
public void getDataAndInsert(BigInteger min, BigInteger max) {
@@ -81,9 +76,7 @@ public void getDataAndInsert(BigInteger min, BigInteger max) {
8176
for (int retryCount = 1; retryCount <= maxAttempts; retryCount++) {
8277

8378
try {
84-
8579
ResultSet resultSet = sourceSession.execute(sourceSelectStatement.bind(hasRandomPartitioner? min : min.longValueExact(), hasRandomPartitioner? max : max.longValueExact()));
86-
8780
Collection<CompletionStage<AsyncResultSet>> writeResults = new ArrayList<CompletionStage<AsyncResultSet>>();
8881

8982
// cannot do batching if the writeFilter is greater than 0 or
@@ -99,7 +92,6 @@ public void getDataAndInsert(BigInteger min, BigInteger max) {
9992
|| sourceWriteTimeStamp > maxWriteTimeStampFilter) {
10093
continue;
10194
}
102-
10395
}
10496

10597
writeLimiter.acquire(1);
@@ -118,41 +110,33 @@ public void getDataAndInsert(BigInteger min, BigInteger max) {
118110
CompletionStage<AsyncResultSet> astraWriteResultSet = astraSession
119111
.executeAsync(bindInsert(astraInsertStatement, sourceRow, astraRow));
120112
writeResults.add(astraWriteResultSet);
121-
122113
} else {
123114
CompletionStage<AsyncResultSet> astraWriteResultSet = astraSession
124115
.executeAsync(bindInsert(astraInsertStatement, sourceRow));
125116
writeResults.add(astraWriteResultSet);
126117
}
127-
128118
if (writeResults.size() > 1000) {
129119
iterateAndClearWriteResults(writeResults, 1);
130120
}
131121
}
132122

133123
// clear the write resultset in-case it didnt mod at 1000 above
134124
iterateAndClearWriteResults(writeResults, 1);
135-
136125
} else {
137-
//
138126
BatchStatement batchStatement = BatchStatement.newInstance(BatchType.UNLOGGED);
139127
for (Row row : resultSet) {
140128
readLimiter.acquire(1);
141129
writeLimiter.acquire(1);
142130
if (readCounter.incrementAndGet() % 1000 == 0) {
143131
logger.info("TreadID: " + Thread.currentThread().getId() + " Read Record Count: " + readCounter.get());
144132
}
145-
146133
batchStatement = batchStatement.add(bindInsert(astraInsertStatement, row));
147134

148-
149135
// if batch threshold is met, send the writes and clear the batch
150136
if (batchStatement.size() >= batchSize) {
151-
152137
CompletionStage<AsyncResultSet> writeResultSet = astraSession.executeAsync(batchStatement);
153138
writeResults.add(writeResultSet);
154139
batchStatement = BatchStatement.newInstance(BatchType.UNLOGGED);
155-
156140
}
157141

158142
if (writeResults.size() * batchSize > 1000) {
@@ -163,7 +147,6 @@ public void getDataAndInsert(BigInteger min, BigInteger max) {
163147
// clear the write resultset in-case it didnt mod at 1000 above
164148
iterateAndClearWriteResults(writeResults, batchSize);
165149

166-
167150
// if there are any pending writes because the batchSize threshold was not met, then write and clear them
168151
if (batchStatement.size() > 0) {
169152
CompletionStage<AsyncResultSet> writeResultSet = astraSession.executeAsync(batchStatement);
@@ -187,7 +170,6 @@ public void getDataAndInsert(BigInteger min, BigInteger max) {
187170

188171
}
189172

190-
191173
private void iterateAndClearWriteResults(Collection<CompletionStage<AsyncResultSet>> writeResults, int incrementBy) throws Exception{
192174
for (CompletionStage<AsyncResultSet> writeResult : writeResults) {
193175
//wait for the writes to complete for the batch. The Retry policy, if defined, should retry the write on timeouts.
@@ -199,8 +181,6 @@ private void iterateAndClearWriteResults(Collection<CompletionStage<AsyncResultS
199181
writeResults.clear();
200182
}
201183

202-
203-
204184
public BoundStatement bindInsert(PreparedStatement insertStatement, Row sourceRow) {
205185
return bindInsert(insertStatement, sourceRow, null);
206186
}

src/main/scala/datastax/astra/migrate/DiffData.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata
55
import com.datastax.spark.connector._
66
import com.datastax.spark.connector.cql.CassandraConnector
77
import datastax.astra.migrate.Migrate.{astraPassword, astraReadConsistencyLevel, astraScbPath, astraUsername, sc, sourceHost, sourcePassword, sourceReadConsistencyLevel, sourceUsername}
8+
import org.apache.log4j.Logger
89
import org.apache.spark.sql.{SaveMode, SparkSession}
910
import org.apache.spark.sql.hive._
1011
import org.apache.spark.sql.cassandra._
@@ -17,6 +18,8 @@ import java.math.BigInteger
1718

1819
object DiffData extends App {
1920

21+
val logger = Logger.getLogger(this.getClass.getName)
22+
2023
val spark = SparkSession.builder
2124
.appName("Datastax Data Validation")
2225
.getOrCreate()
@@ -46,7 +49,7 @@ object DiffData extends App {
4649
val splitSize = sc.getConf.get("spark.migrate.splitSize","10000")
4750

4851

49-
println("Started Data Validation App")
52+
logger.info("Started Data Validation App")
5053

5154
val isBeta = sc.getConf.get("spark.migrate.beta","false")
5255
val isCassandraToCassandra = sc.getConf.get("spark.migrate.ctoc", "false")
@@ -90,6 +93,8 @@ object DiffData extends App {
9093
private def diffTable(sourceConnection: CassandraConnector, astraConnection: CassandraConnector, minPartition:BigInteger, maxPartition:BigInteger) = {
9194
val partitions = SplitPartitions.getRandomSubPartitions(BigInteger.valueOf(Long.parseLong(splitSize)), minPartition, maxPartition)
9295
val parts = sc.parallelize(partitions.toSeq,partitions.size);
96+
97+
logger.info("Spark parallelize created : " + parts.count() + " parts!");
9398
parts.foreach(part => {
9499
sourceConnection.withSessionDo(sourceSession =>
95100
astraConnection.withSessionDo(astraSession =>

src/main/scala/datastax/astra/migrate/Migrate.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import com.datastax.oss.driver.api.core.{CqlIdentifier, CqlSession}
44
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata
55
import com.datastax.spark.connector._
66
import com.datastax.spark.connector.cql.CassandraConnector
7+
import org.apache.log4j.Logger
78
import org.apache.spark.sql.{SaveMode, SparkSession}
89
import org.apache.spark.sql.hive._
910
import org.apache.spark.sql.cassandra._
@@ -18,6 +19,8 @@ import collection.JavaConversions._
1819
// http://www.russellspitzer.com/2016/02/16/Multiple-Clusters-SparkSql-Cassandra/
1920

2021
object Migrate extends App {
22+
val logger = Logger.getLogger(this.getClass.getName)
23+
2124
val spark = SparkSession.builder
2225
.appName("Datastax Data Migration")
2326
.getOrCreate()
@@ -46,7 +49,7 @@ object Migrate extends App {
4649
val astraReadConsistencyLevel = sc.getConf.get("spark.cassandra.astra.read.consistency.level","LOCAL_QUORUM")
4750

4851

49-
println("Started Migration App")
52+
logger.info("Started Migration App")
5053

5154
val isBeta = sc.getConf.get("spark.migrate.beta","false")
5255

@@ -87,6 +90,7 @@ object Migrate extends App {
8790

8891
val partitions = SplitPartitions.getRandomSubPartitions(BigInteger.valueOf(Long.parseLong(splitSize)), minPartition, maxPartition)
8992
val parts = sc.parallelize(partitions.toSeq,partitions.size);
93+
logger.info("Spark parallelize created : " + parts.count() + " parts!");
9094
parts.foreach(part => {
9195
sourceConnection.withSessionDo(sourceSession => astraConnection.withSessionDo(astraSession=> CopyJobSession.getInstance(sourceSession,astraSession, sc.getConf).getDataAndInsert(part.getMin, part.getMax)))
9296
})

0 commit comments

Comments
 (0)