diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 897bb8339197c..6a143edfbc448 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpoint import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} import org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors import org.apache.spark.sql.execution.streaming.utils.StreamingUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.streaming.TimeMode import org.apache.spark.sql.types.StructType @@ -66,7 +67,10 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging properties: util.Map[String, String]): Table = { val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf, StateSourceOptions.apply(session, hadoopConf, properties)) - val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) + // Build the sql conf for the batch we are reading using confs in the offsetlog + val batchSqlConf = + buildSqlConfForBatch(sourceOptions.resolvedCpLocation, sourceOptions.batchId) + val stateConf = StateStoreConf(batchSqlConf) // We only support RocksDB because the repartition work that this option // is built for only supports RocksDB if (sourceOptions.internalOnlyReadAllColumnFamilies @@ -86,7 +90,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging NoPrefixKeyStateEncoderSpec(keySchema) } - new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, + new StateTable(session, schema, sourceOptions, stateConf, + batchSqlConf.getConf(SQLConf.STATEFUL_SHUFFLE_PARTITIONS_INTERNAL).get, keyStateEncoderSpec, stateStoreReaderInfo.transformWithStateVariableInfoOpt, stateStoreReaderInfo.stateStoreColFamilySchemaOpt, stateStoreReaderInfo.stateSchemaProviderOpt, @@ -147,7 +152,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging sourceOptions.operatorId) } - private def buildStateStoreConf(checkpointLocation: String, batchId: Long): StateStoreConf = { + private def buildSqlConfForBatch( + checkpointLocation: String, + batchId: Long): SQLConf = { val offsetLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).offsetLog offsetLog.get(batchId) match { case Some(value) => @@ -157,7 +164,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val clonedSqlConf = session.sessionState.conf.clone() OffsetSeqMetadata.setSessionConf(metadata, clonedSqlConf) - StateStoreConf(clonedSqlConf) + clonedSqlConf case _ => throw StateDataSourceErrors.offsetLogUnavailable(batchId, checkpointLocation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala index 1d7e7f709a6e4..c3056767ee4b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala @@ -42,6 +42,7 @@ class StateScanBuilder( schema: StructType, sourceOptions: StateSourceOptions, stateStoreConf: StateStoreConf, + batchNumPartitions: Int, keyStateEncoderSpec: KeyStateEncoderSpec, stateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], @@ -49,7 +50,8 @@ class StateScanBuilder( joinColFamilyOpt: Option[String], allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]) extends ScanBuilder { override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf, - keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, + batchNumPartitions, keyStateEncoderSpec, + stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt, allColumnFamiliesReaderInfo) } @@ -65,6 +67,7 @@ class StateScan( schema: StructType, sourceOptions: StateSourceOptions, stateStoreConf: StateStoreConf, + batchNumPartitions: Int, keyStateEncoderSpec: KeyStateEncoderSpec, stateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], @@ -85,7 +88,11 @@ class StateScan( val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new PathFilter() { override def accept(path: Path): Boolean = { fs.getFileStatus(path).isDirectory && - Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 + Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 && + // Since we now support state repartitioning, it is possible that a future batch has + // increased the number of partitions, hence increased the number of partition directories. + // So we only want partition dirs for the number of partitions in this batch. + path.getName.toInt < batchNumPartitions } }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index e945b803d45bc..43f2f28c6b954 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -41,6 +41,7 @@ class StateTable( override val schema: StructType, sourceOptions: StateSourceOptions, stateConf: StateStoreConf, + batchNumPartitions: Int, keyStateEncoderSpec: KeyStateEncoderSpec, stateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], @@ -86,7 +87,8 @@ class StateTable( override def capabilities(): util.Set[TableCapability] = CAPABILITY override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = - new StateScanBuilder(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, + new StateScanBuilder(session, schema, sourceOptions, stateConf, + batchNumPartitions, keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt, allColumnFamiliesReaderInfo) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala index 02c8e85986d0a..19f75eb385f50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala @@ -17,15 +17,19 @@ package org.apache.spark.sql.execution.streaming.state +import org.apache.hadoop.fs.Path + import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys._ -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader import org.apache.spark.sql.execution.streaming.checkpointing.{CommitMetadata, OffsetMap, OffsetSeq, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataBase} -import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata +import org.apache.spark.sql.execution.streaming.runtime.{StreamingCheckpointConstants, StreamingQueryCheckpointMetadata} import org.apache.spark.sql.execution.streaming.utils.StreamingUtils +import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * Runs repartitioning for the state stores used by a streaming query. @@ -78,9 +82,22 @@ class OfflineStateRepartitionRunner( val newBatchId = createNewBatchIfNeeded(lastBatchId, lastCommittedBatchId) - // todo(SPARK-54365): Do the repartitioning here, in subsequent PR - - // todo(SPARK-54365): update operator metadata in subsequent PR. + val stateRepartitionFunc = (stateDf: DataFrame) => { + // Repartition the state by the partition key + stateDf.repartition(numPartitions, col("partition_key")) + } + val rewriter = new StateRewriter( + sparkSession, + readBatchId = lastCommittedBatchId, + writeBatchId = newBatchId, + resolvedCpLocation, + hadoopConf, + transformFunc = Some(stateRepartitionFunc), + writeCheckpointMetadata = Some(checkpointMetadata) + ) + rewriter.run() + + updateNumPartitionsInOperatorMetadata(newBatchId, readBatchId = lastCommittedBatchId) // Commit the repartition batch commitBatch(newBatchId, lastCommittedBatchId) @@ -229,6 +246,49 @@ class OfflineStateRepartitionRunner( newBatchId } + private def updateNumPartitionsInOperatorMetadata( + newBatchId: Long, + readBatchId: Long): Unit = { + val stateMetadataReader = new StateMetadataPartitionReader( + resolvedCpLocation, + new SerializableConfiguration(hadoopConf), + readBatchId) + + val allOperatorsMetadata = stateMetadataReader.allOperatorStateMetadata + assert(allOperatorsMetadata.nonEmpty, "Operator metadata shouldn't be empty") + + val stateRootLocation = new Path( + resolvedCpLocation, StreamingCheckpointConstants.DIR_NAME_STATE).toString + + allOperatorsMetadata.foreach { opMetadata => + opMetadata match { + // We would only update shuffle partitions for v2 op metadata since it is versioned. + // For v1, we wouldn't update it since there is only one metadata file. + case v2: OperatorStateMetadataV2 => + // update for each state store + val updatedStoreInfo = v2.stateStoreInfo.map { stateStore => + stateStore.copy(numPartitions = numPartitions) + } + val updatedMetadata = v2.copy(stateStoreInfo = updatedStoreInfo) + // write the updated metadata + val metadataWriter = OperatorStateMetadataWriter.createWriter( + new Path(stateRootLocation, updatedMetadata.operatorInfo.operatorId.toString), + hadoopConf, + updatedMetadata.version, + Some(newBatchId)) + metadataWriter.write(updatedMetadata) + + logInfo(log"Updated operator metadata for " + + log"operator=${MDC(OP_TYPE, updatedMetadata.operatorInfo.operatorName)}, " + + log"numStateStores=${MDC(COUNT, updatedMetadata.stateStoreInfo.length)}") + case v => + logInfo(log"Skipping operator metadata update for " + + log"operator=${MDC(OP_TYPE, v.operatorInfo.operatorName)}, " + + log"since metadata version(${MDC(FILE_VERSION, v.version)}) is not versioned") + } + } + } + private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit = { val latestCommit = checkpointMetadata.commitLog.get(lastCommittedBatchId).get diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala index f67d80679d511..75e459b265118 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala @@ -201,6 +201,14 @@ class RocksDBFileManager( } } + private def createDfsRootDirIfNotExist(): Unit = { + if (!rootDirChecked) { + val rootDir = new Path(dfsRootDir) + if (!fm.exists(rootDir)) fm.mkdirs(rootDir) + rootDirChecked = true + } + } + def getChangeLogWriter( version: Long, useColumnFamilies: Boolean = false, @@ -209,11 +217,7 @@ class RocksDBFileManager( ): StateStoreChangelogWriter = { try { val changelogFile = dfsChangelogFile(version, checkpointUniqueId) - if (!rootDirChecked) { - val rootDir = new Path(dfsRootDir) - if (!fm.exists(rootDir)) fm.mkdirs(rootDir) - rootDirChecked = true - } + createDfsRootDirIfNotExist() val enableStateStoreCheckpointIds = checkpointUniqueId.isDefined val changelogVersion = getChangelogWriterVersion( @@ -332,11 +336,7 @@ class RocksDBFileManager( // CheckpointFileManager.createAtomic API which doesn't auto-initialize parent directories. // Moreover, once we disable to track the number of keys, in which the numKeys is -1, we // still need to create the initial dfs root directory anyway. - if (!rootDirChecked) { - val path = new Path(dfsRootDir) - if (!fm.exists(path)) fm.mkdirs(path) - rootDirChecked = true - } + createDfsRootDirIfNotExist() } zipToDfsFile(localOtherFiles :+ metadataFile, dfsBatchZipFile(version, checkpointUniqueId), verifyNonEmptyFilesInZip) @@ -372,6 +372,7 @@ class RocksDBFileManager( val metadata = if (version == 0) { if (localDir.exists) Utils.deleteRecursively(localDir) Utils.createDirectory(localDir) + createDfsRootDirIfNotExist() // Since we cleared the local dir, we should also clear the local file mapping rocksDBFileMapping.clear() RocksDBCheckpointMetadata(Seq.empty, 0) @@ -404,11 +405,7 @@ class RocksDBFileManager( // Return if there is a snapshot file at the corresponding version // and optionally with checkpointunique id, e.g. version.zip or version_uniqueId.zip def existsSnapshotFile(version: Long, checkpointUniqueId: Option[String] = None): Boolean = { - if (!rootDirChecked) { - val path = new Path(dfsRootDir) - if (!fm.exists(path)) fm.mkdirs(path) - rootDirChecked = true - } + createDfsRootDirIfNotExist() fm.exists(dfsBatchZipFile(version, checkpointUniqueId)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala index 860e7a1ab2e45..1ab581d79437f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala @@ -17,18 +17,26 @@ package org.apache.spark.sql.execution.streaming.state +import scala.util.Try + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata} import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryCheckpointMetadata} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ +import org.apache.spark.util.SerializableConfiguration /** * Test for offline state repartitioning. This tests that repartition behaves as expected * for different scenarios. */ -class OfflineStateRepartitionSuite extends StreamTest { +class OfflineStateRepartitionSuite extends StreamTest + with AlsoTestWithRocksDBFeatures { import testImplicits._ import OfflineStateRepartitionUtils._ + import OfflineStateRepartitionTestUtils._ test("Fail if empty checkpoint directory") { withTempDir { dir => @@ -101,7 +109,8 @@ class OfflineStateRepartitionSuite extends StreamTest { test("Repartition: success, failure, retry") { withTempDir { dir => val originalPartitions = 3 - val batchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath) + val input = MemoryStream[Int] + val batchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input) val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, dir.getAbsolutePath) // Shouldn't be seen as a repartition batch assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog, dir.getAbsolutePath)) @@ -124,8 +133,9 @@ class OfflineStateRepartitionSuite extends StreamTest { val newPartitions = originalPartitions + 1 spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, newPartitions) val repartitionBatchId = batchId + 1 + val hadoopConf = spark.sessionState.newHadoopConf() verifyRepartitionBatch( - repartitionBatchId, checkpointMetadata, dir.getAbsolutePath, newPartitions) + repartitionBatchId, checkpointMetadata, hadoopConf, dir.getAbsolutePath, newPartitions) // Now delete the repartition commit to simulate a failed repartition attempt. // This will delete all the commits after the batchId. @@ -150,7 +160,17 @@ class OfflineStateRepartitionSuite extends StreamTest { // Retrying with the same numPartitions should work spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, newPartitions) verifyRepartitionBatch( - repartitionBatchId, checkpointMetadata, dir.getAbsolutePath, newPartitions) + repartitionBatchId, checkpointMetadata, hadoopConf, dir.getAbsolutePath, newPartitions) + + // Repartition with way more partitions, to verify that empty partitions are properly created + val morePartitions = newPartitions * 3 + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, morePartitions) + verifyRepartitionBatch( + repartitionBatchId + 1, checkpointMetadata, hadoopConf, + dir.getAbsolutePath, morePartitions) + + // Restart the query to make sure it can start after repartitioning + runSimpleStreamQuery(morePartitions, dir.getAbsolutePath, input) } } @@ -188,6 +208,7 @@ class OfflineStateRepartitionSuite extends StreamTest { verifyRepartitionBatch( lastBatchId + 1, checkpointMetadata, + spark.sessionState.newHadoopConf(), dir.getAbsolutePath, originalPartitions + 1, // Repartition should be based on the first batch, since we skipped the others @@ -197,18 +218,21 @@ class OfflineStateRepartitionSuite extends StreamTest { test("Consecutive repartition") { withTempDir { dir => - val originalPartitions = 3 - val batchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath) + val originalPartitions = 5 + val input = MemoryStream[Int] + val batchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input) val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, dir.getAbsolutePath) + val hadoopConf = spark.sessionState.newHadoopConf() // decrease - spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, originalPartitions - 1) + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, originalPartitions - 3) verifyRepartitionBatch( batchId + 1, checkpointMetadata, + hadoopConf, dir.getAbsolutePath, - originalPartitions - 1 + originalPartitions - 3 ) // increase @@ -216,9 +240,13 @@ class OfflineStateRepartitionSuite extends StreamTest { verifyRepartitionBatch( batchId + 2, checkpointMetadata, + hadoopConf, dir.getAbsolutePath, originalPartitions + 1 ) + + // Restart the query to make sure it can start after repartitioning + runSimpleStreamQuery(originalPartitions + 1, dir.getAbsolutePath, input) } } @@ -231,31 +259,56 @@ class OfflineStateRepartitionSuite extends StreamTest { SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) var committedBatchId: Long = -1 - testStream(input.toDF().groupBy().count(), outputMode = OutputMode.Update)( - StartStream(checkpointLocation = checkpointLocation, additionalConfs = conf), - AddData(input, 1, 2, 3), - ProcessAllAvailable(), - Execute { query => - committedBatchId = Option(query.lastProgress).map(_.batchId).getOrElse(-1) - } - ) + // Set the confs before starting the stream + withSQLConf(conf.toSeq: _*) { + testStream(input.toDF().groupBy("value").count(), outputMode = OutputMode.Update)( + StartStream(checkpointLocation = checkpointLocation), + AddData(input, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + committedBatchId = Option(query.lastProgress).map(_.batchId).getOrElse(-1) + } + ) + } assert(committedBatchId >= 0, "No batch was committed in the streaming query") committedBatchId } +} - private def verifyRepartitionBatch( +object OfflineStateRepartitionTestUtils { + import OfflineStateRepartitionUtils._ + + def verifyRepartitionBatch( batchId: Long, checkpointMetadata: StreamingQueryCheckpointMetadata, + hadoopConf: Configuration, checkpointLocation: String, expectedShufflePartitions: Int, baseBatchId: Option[Long] = None): Unit = { // Should be seen as a repartition batch assert(isRepartitionBatch(batchId, checkpointMetadata.offsetLog, checkpointLocation)) + // When failed batches are skipped, then repartition can be based + // on an older batch and not batchId - 1. + val previousBatchId = baseBatchId.getOrElse(batchId - 1) + + verifyOffsetAndCommitLog( + batchId, previousBatchId, expectedShufflePartitions, checkpointMetadata) + verifyPartitionDirs(checkpointLocation, expectedShufflePartitions) + verifyOperatorMetadata( + batchId, previousBatchId, checkpointLocation, expectedShufflePartitions, hadoopConf) + } + + private def verifyOffsetAndCommitLog( + repartitionBatchId: Long, + previousBatchId: Long, + expectedShufflePartitions: Int, + checkpointMetadata: StreamingQueryCheckpointMetadata): Unit = { // Verify the repartition batch val lastBatchId = checkpointMetadata.offsetLog.getLatestBatchId().get - assert(lastBatchId == batchId) + assert(lastBatchId == repartitionBatchId, + "The latest batch in offset log should be the repartition batch") val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get val lastBatchShufflePartitions = getShufflePartitions(lastBatch.metadataOpt.get).get @@ -263,18 +316,16 @@ class OfflineStateRepartitionSuite extends StreamTest { // Verify the commit log val lastCommitId = checkpointMetadata.commitLog.getLatestBatchId().get - assert(lastCommitId == batchId) + assert(lastCommitId == repartitionBatchId, + "The latest batch in commit log should be the repartition batch") // verify that the offset seq is the same between repartition batch and // the batch the repartition is based on except for the shuffle partitions. - // When failed batches are skipped, then repartition can be based - // on an older batch and not batchId - 1. - val previousBatchId = baseBatchId.getOrElse(batchId - 1) val previousBatch = checkpointMetadata.offsetLog.get(previousBatchId).get // Verify offsets are identical assert(lastBatch.offsets == previousBatch.offsets, - s"Offsets should be identical between batch $previousBatchId and $batchId") + s"Offsets should be identical between batch $previousBatchId and $repartitionBatchId") // Verify metadata is the same except for shuffle partitions config (lastBatch.metadataOpt, previousBatch.metadataOpt) match { @@ -298,7 +349,109 @@ class OfflineStateRepartitionSuite extends StreamTest { getShufflePartitions(lastMetadata).get != getShufflePartitions(previousMetadata).get, "Shuffle partitions should be different between batches") case _ => - fail("Both batches should have metadata") + assert(false, "Both batches should have metadata") + } + } + + // verify number of partition dirs in state dir + private def verifyPartitionDirs( + checkpointLocation: String, + expectedShufflePartitions: Int): Unit = { + val stateDir = new java.io.File(checkpointLocation, "state") + + def numDirs(file: java.io.File): Int = { + file.listFiles() + .filter(d => d.isDirectory && Try(d.getName.toInt).isSuccess) + .length + } + + val numOperators = numDirs(stateDir) + for (op <- 0 until numOperators) { + val partitionsDir = new java.io.File(stateDir, s"$op") + val numPartitions = numDirs(partitionsDir) + // Doing <= in case of reduced number of partitions + assert(expectedShufflePartitions <= numPartitions, + s"Expected atleast $expectedShufflePartitions partition dirs for operator $op," + + s" but found $numPartitions") + } + } + + private def verifyOperatorMetadata( + repartitionBatchId: Long, + baseBatchId: Long, + checkpointLocation: String, + expectedShufflePartitions: Int, + hadoopConf: Configuration): Unit = { + val serializableConf = new SerializableConfiguration(hadoopConf) + + // Read operator metadata for both batches + val baseMetadataReader = new StateMetadataPartitionReader( + checkpointLocation, serializableConf, baseBatchId) + val repartitionMetadataReader = new StateMetadataPartitionReader( + checkpointLocation, serializableConf, repartitionBatchId) + + val baseOperatorsMetadata = baseMetadataReader.allOperatorStateMetadata + val repartitionOperatorsMetadata = repartitionMetadataReader.allOperatorStateMetadata + + assert(baseOperatorsMetadata.nonEmpty, "Base batch should have operator metadata") + assert(repartitionOperatorsMetadata.nonEmpty, "Repartition batch should have operator metadata") + assert(baseOperatorsMetadata.length == repartitionOperatorsMetadata.length, + "Both batches should have the same number of operators") + + // Verify each operator's metadata + baseOperatorsMetadata.zip(repartitionOperatorsMetadata).foreach { + case (baseOp, repartitionOp) => + // Verify both are of the same type + assert(baseOp.getClass == repartitionOp.getClass, + s"Metadata types should match: base=${baseOp.getClass.getSimpleName}, " + + s"repartition=${repartitionOp.getClass.getSimpleName}") + + (baseOp, repartitionOp) match { + case (baseV2: OperatorStateMetadataV2, repartitionV2: OperatorStateMetadataV2) => + // Verify operator info is the same + assert(baseV2.operatorInfo == repartitionV2.operatorInfo, + s"Operator info should match: base=${baseV2.operatorInfo}, " + + s"repartition=${repartitionV2.operatorInfo}") + + // Verify operator properties JSON is the same + assert(baseV2.operatorPropertiesJson == repartitionV2.operatorPropertiesJson, + "Operator properties JSON should match") + + // Verify state store info (except numPartitions) + assert(baseV2.stateStoreInfo.length == repartitionV2.stateStoreInfo.length, + "Should have same number of state stores") + + baseV2.stateStoreInfo.zip(repartitionV2.stateStoreInfo).foreach { + case (baseStore, repartitionStore) => + assert(baseStore.storeName == repartitionStore.storeName, + s"Store name should match: ${baseStore.storeName} " + + s"vs ${repartitionStore.storeName}") + assert(baseStore.numColsPrefixKey == repartitionStore.numColsPrefixKey, + "numColsPrefixKey should match") + // Schema file paths should be the same (they reference the same schema files) + assert(baseStore.stateSchemaFilePaths == repartitionStore.stateSchemaFilePaths, + "State schema file paths should match") + assert(baseStore.numPartitions != repartitionStore.numPartitions, + "numPartitions shouldn't be the same") + // Verify numPartitions is updated to expectedShufflePartitions + assert(repartitionStore.numPartitions == expectedShufflePartitions, + s"Repartition batch numPartitions should be $expectedShufflePartitions, " + + s"but found ${repartitionStore.numPartitions}") + } + + case (baseV1: OperatorStateMetadataV1, repartitionV1: OperatorStateMetadataV1) => + // For v1, since we didn't update it, then it should be the same. + // Can't use == directly because Array uses reference equality + assert(baseV1.operatorInfo == repartitionV1.operatorInfo, + "V1 operator info should be the same") + assert(baseV1.stateStoreInfo.sameElements(repartitionV1.stateStoreInfo), + "V1 state store info should be the same") + + case _ => + assert(false, + s"Unexpected metadata types: base=${baseOp.getClass.getSimpleName}, " + + s"repartition=${repartitionOp.getClass.getSimpleName}") + } } } }