diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index ac8e891c5403a..dc9b832755619 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5575,6 +5575,32 @@ }, "sqlState" : "42616" }, + "STATE_REWRITER_INVALID_CHECKPOINT" : { + "message" : [ + "The state rewrite checkpoint location '' is in an invalid state." + ], + "subClass" : { + "MISSING_KEY_ENCODER_SPEC" : { + "message" : [ + "Key state encoder spec is expected for column family '' but was not found.", + "This is likely a bug, please report it." + ] + }, + "MISSING_OPERATOR_METADATA" : { + "message" : [ + "No stateful operator metadata was found for batch .", + "Ensure that the checkpoint is for a stateful streaming query and the query ran on a Spark version that supports operator metadata (Spark 4.0+)." + ] + }, + "UNSUPPORTED_STATE_STORE_METADATA_VERSION" : { + "message" : [ + "Unsupported state store metadata version encountered.", + "Only StateStoreMetadataV1 and StateStoreMetadataV2 are supported." + ] + } + }, + "sqlState" : "55019" + }, "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : { "message" : [ "Failed to create column family with unsupported starting character and name=." diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index c34545216fdaf..6b2295da03b99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{BufferedReader, InputStreamReader} import java.nio.charset.StandardCharsets +import scala.collection.immutable.ArraySeq import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration @@ -80,12 +81,17 @@ trait OperatorStateMetadata { def version: Int def operatorInfo: OperatorInfo + + def stateStoresMetadata: Seq[StateStoreMetadata] } case class OperatorStateMetadataV1( operatorInfo: OperatorInfoV1, stateStoreInfo: Array[StateStoreMetadataV1]) extends OperatorStateMetadata { override def version: Int = 1 + + override def stateStoresMetadata: Seq[StateStoreMetadata] = + ArraySeq.unsafeWrapArray(stateStoreInfo) } case class OperatorStateMetadataV2( @@ -93,6 +99,9 @@ case class OperatorStateMetadataV2( stateStoreInfo: Array[StateStoreMetadataV2], operatorPropertiesJson: String) extends OperatorStateMetadata { override def version: Int = 2 + + override def stateStoresMetadata: Seq[StateStoreMetadata] = + ArraySeq.unsafeWrapArray(stateStoreInfo) } object OperatorStateMetadataUtils extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala index aac13d3f69f9c..3df97d3adc0e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala @@ -51,7 +51,7 @@ class StatePartitionAllColumnFamiliesWriter( hadoopConf: Configuration, partitionId: Int, targetCpLocation: String, - operatorId: Int, + operatorId: Long, storeName: String, currentBatchId: Long, colFamilyToWriterInfoMap: Map[String, StatePartitionWriterColumnFamilyInfo], @@ -153,6 +153,7 @@ class StatePartitionAllColumnFamiliesWriter( if (!stateStore.hasCommitted) { stateStore.abort() } + provider.close() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala new file mode 100644 index 0000000000000..da28a3c907f7b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.UUID + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SparkIllegalStateException, TaskContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader +import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqMetadata +import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} +import org.apache.spark.sql.execution.streaming.runtime.{StreamingCheckpointConstants, StreamingQueryCheckpointMetadata} +import org.apache.spark.sql.execution.streaming.state.{StatePartitionAllColumnFamiliesWriter, StateSchemaCompatibilityChecker} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{SerializableConfiguration, Utils} + +/** + * State Rewriter is used to rewrite the state stores for a stateful streaming query. + * It reads state from a checkpoint location, optionally applies transformation to the state, + * and then writes the state back to a (possibly different) checkpoint location for a new batch ID. + * + * Example use case is for offline state repartitioning. + * Can also be used to support other use cases. + * + * @param sparkSession The active Spark session. + * @param readBatchId The batch ID for reading state. + * @param writeBatchId The batch ID to which the (transformed) state will be written. + * @param resolvedCheckpointLocation The resolved checkpoint path where state will be written. + * @param hadoopConf Hadoop configuration for file system operations. + * @param readResolvedCheckpointLocation Optional separate checkpoint location to read state from. + * If None, reads from resolvedCheckpointLocation. + * @param transformFunc Optional transformation function applied to each operator's state + * DataFrame. If None, state is written as-is. + * @param writeCheckpointMetadata Optional checkpoint metadata for the resolvedCheckpointLocation. + * If None, will create a new one for resolvedCheckpointLocation. + * Helps us to reuse already cached checkpoint log entries, + * instead of starting from scratch. + */ +class StateRewriter( + sparkSession: SparkSession, + readBatchId: Long, + writeBatchId: Long, + resolvedCheckpointLocation: String, + hadoopConf: Configuration, + readResolvedCheckpointLocation: Option[String] = None, + transformFunc: Option[DataFrame => DataFrame] = None, + writeCheckpointMetadata: Option[StreamingQueryCheckpointMetadata] = None +) extends Logging { + require(readResolvedCheckpointLocation.isDefined || readBatchId < writeBatchId, + s"Read batch id $readBatchId must be less than write batch id $writeBatchId " + + "when reading and writing to the same checkpoint location") + + // If a different location was specified for reading state, use it. + // Else, use same location for reading and writing state. + private val checkpointLocationForRead = + readResolvedCheckpointLocation.getOrElse(resolvedCheckpointLocation) + private val stateRootLocation = new Path( + resolvedCheckpointLocation, StreamingCheckpointConstants.DIR_NAME_STATE).toString + + def run(): Unit = { + logInfo(log"Starting state rewrite for " + + log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, resolvedCheckpointLocation)}, " + + log"readCheckpointLocation=" + + log"${MDC(CHECKPOINT_LOCATION, readResolvedCheckpointLocation.getOrElse(""))}, " + + log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " + + log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}") + + val (_, timeTakenMs) = Utils.timeTakenMs { + runInternal() + } + + logInfo(log"State rewrite completed in ${MDC(DURATION, timeTakenMs)} ms for " + + log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, resolvedCheckpointLocation)}") + } + + private def runInternal(): Unit = { + try { + val stateMetadataReader = new StateMetadataPartitionReader( + resolvedCheckpointLocation, + new SerializableConfiguration(hadoopConf), + readBatchId) + + val allOperatorsMetadata = stateMetadataReader.allOperatorStateMetadata + if (allOperatorsMetadata.isEmpty) { + // Its possible that the query is stateless + // or ran on older spark version without op metadata + throw StateRewriterErrors.missingOperatorMetadataError( + resolvedCheckpointLocation, readBatchId) + } + + // Use the same conf in the offset log to create the store conf, + // to make sure the state is written with the right conf. + val (storeConf, sqlConf) = createConfsFromOffsetLog() + // SQLConf doesn't serialize properly (reader becomes null), so extract as Map + val sqlConfEntries: Map[String, String] = sqlConf.getAllConfs + + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + val hadoopConfBroadcast = + SerializableConfiguration.broadcast(sparkSession.sparkContext, hadoopConf) + + // Do rewrite for each operator + // We can potentially parallelize this, but for now, do sequentially + allOperatorsMetadata.foreach { opMetadata => + val stateStoresMetadata = opMetadata.stateStoresMetadata + assert(!stateStoresMetadata.isEmpty, + s"Operator ${opMetadata.operatorInfo.operatorName} has no state stores") + + val storeToSchemaFilesMap = getStoreToSchemaFilesMap(opMetadata) + val stateVarsIfTws = getStateVariablesIfTWS(opMetadata) + + // Rewrite each state store of the operator + stateStoresMetadata.foreach { stateStoreMetadata => + rewriteStore( + opMetadata, + stateStoreMetadata, + storeConf, + hadoopConfBroadcast, + storeToSchemaFilesMap(stateStoreMetadata.storeName), + stateVarsIfTws, + sqlConfEntries + ) + } + } + } catch { + case e: Throwable => + logError(log"State rewrite failed for " + + log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, resolvedCheckpointLocation)}, " + + log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " + + log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}", e) + throw e + } + } + + private def rewriteStore( + opMetadata: OperatorStateMetadata, + stateStoreMetadata: StateStoreMetadata, + storeConf: StateStoreConf, + hadoopConfBroadcast: Broadcast[SerializableConfiguration], + storeSchemaFiles: List[Path], + stateVarsIfTws: Map[String, TransformWithStateVariableInfo], + sqlConfEntries: Map[String, String] + ): Unit = { + // Read state + val stateDf = sparkSession.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointLocationForRead) + .option(StateSourceOptions.BATCH_ID, readBatchId) + .option(StateSourceOptions.OPERATOR_ID, opMetadata.operatorInfo.operatorId) + .option(StateSourceOptions.STORE_NAME, stateStoreMetadata.storeName) + .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") + .load() + + // Run the caller state transformation func if provided + // Otherwise, use the state as is + val updatedStateDf = transformFunc.map(func => func(stateDf)).getOrElse(stateDf) + require(updatedStateDf.schema == stateDf.schema, + s"State transformation function must return a DataFrame with the same schema " + + s"as the original state DataFrame. Original schema: ${stateDf.schema}, " + + s"Updated schema: ${updatedStateDf.schema}") + + val schemaProvider = createStoreSchemaProviderIfTWS( + opMetadata.operatorInfo.operatorName, + storeSchemaFiles + ) + val writerColFamilyInfoMap = getWriterColFamilyInfoMap( + opMetadata.operatorInfo.operatorId, + stateStoreMetadata, + storeSchemaFiles, + stateVarsIfTws + ) + + logInfo(log"Writing new state for " + + log"operator=${MDC(OP_TYPE, opMetadata.operatorInfo.operatorName)}, " + + log"stateStore=${MDC(STATE_NAME, stateStoreMetadata.storeName)}, " + + log"numColumnFamilies=${MDC(COUNT, writerColFamilyInfoMap.size)}, " + + log"numSchemaFiles=${MDC(NUM_FILES, storeSchemaFiles.size)}, " + + log"for new batch=${MDC(BATCH_ID, writeBatchId)}, " + + log"for checkpoint=${MDC(CHECKPOINT_LOCATION, resolvedCheckpointLocation)}") + + // Write state for each partition on the executor. + // Setting this as local val, + // to avoid serializing the entire Rewriter object per partition. + val targetCheckpointLocation = resolvedCheckpointLocation + val currentBatchId = writeBatchId + updatedStateDf.queryExecution.toRdd.foreachPartition { partitionIter => + // Recreate SQLConf on executor from serialized entries + val executorSqlConf = new SQLConf() + sqlConfEntries.foreach { case (k, v) => executorSqlConf.setConfString(k, v) } + + val partitionWriter = new StatePartitionAllColumnFamiliesWriter( + storeConf, + hadoopConfBroadcast.value.value, + TaskContext.get().partitionId(), + targetCheckpointLocation, + opMetadata.operatorInfo.operatorId, + stateStoreMetadata.storeName, + currentBatchId, + writerColFamilyInfoMap, + opMetadata.operatorInfo.operatorName, + schemaProvider, + executorSqlConf + ) + + partitionWriter.write(partitionIter) + } + } + + /** Create the store and sql confs from the conf written in the offset log */ + private def createConfsFromOffsetLog(): (StateStoreConf, SQLConf) = { + val offsetLog = writeCheckpointMetadata.getOrElse( + new StreamingQueryCheckpointMetadata(sparkSession, resolvedCheckpointLocation)).offsetLog + + // We want to use the same confs written in the offset log for the new batch + val offsetSeq = offsetLog.get(writeBatchId) + require(offsetSeq.isDefined, s"Offset seq must be present for the new batch $writeBatchId") + val metadata = offsetSeq.get.metadataOpt + require(metadata.isDefined, s"Metadata must be present for the new batch $writeBatchId") + + val clonedSqlConf = sparkSession.sessionState.conf.clone() + OffsetSeqMetadata.setSessionConf(metadata.get, clonedSqlConf) + (StateStoreConf(clonedSqlConf), clonedSqlConf) + } + + /** Get the map of state store name to schema files, for an operator */ + private def getStoreToSchemaFilesMap( + opMetadata: OperatorStateMetadata): Map[String, List[Path]] = { + opMetadata.stateStoresMetadata.map { storeMetadata => + val schemaFiles = storeMetadata match { + // No schema files for v1. It has a fixed/known schema file path + case _: StateStoreMetadataV1 => List.empty[Path] + case v2: StateStoreMetadataV2 => v2.stateSchemaFilePaths.map(new Path(_)) + case _ => + throw StateRewriterErrors.unsupportedStateStoreMetadataVersionError( + resolvedCheckpointLocation) + } + storeMetadata.storeName -> schemaFiles + }.toMap + } + + private def getWriterColFamilyInfoMap( + operatorId: Long, + storeMetadata: StateStoreMetadata, + schemaFiles: List[Path], + twsStateVariables: Map[String, TransformWithStateVariableInfo] = Map.empty + ): Map[String, StatePartitionWriterColumnFamilyInfo] = { + getLatestColFamilyToSchemaMap(operatorId, storeMetadata, schemaFiles) + .map { case (colFamilyName, schema) => + colFamilyName -> StatePartitionWriterColumnFamilyInfo(schema, + useMultipleValuesPerKey = twsStateVariables.get(colFamilyName) + .map(_.stateVariableType == StateVariableType.ListState).getOrElse(false)) + } + } + + private def getLatestColFamilyToSchemaMap( + operatorId: Long, + storeMetadata: StateStoreMetadata, + schemaFiles: List[Path]): Map[String, StateStoreColFamilySchema] = { + val storeId = new StateStoreId( + stateRootLocation, + operatorId, + StateStore.PARTITION_ID_TO_CHECK_SCHEMA, + storeMetadata.storeName) + // using a placeholder runId since we are not running a streaming query + val providerId = new StateStoreProviderId(storeId, queryRunId = UUID.randomUUID()) + val manager = new StateSchemaCompatibilityChecker(providerId, hadoopConf, + oldSchemaFilePaths = schemaFiles) + // Read the latest state schema from the provided path for v2 or from the dedicated path + // for v1 + manager + .readSchemaFile() + .map { schema => + schema.colFamilyName -> createKeyEncoderSpecIfAbsent(schema, storeMetadata) }.toMap + } + + private def createKeyEncoderSpecIfAbsent( + colFamilySchema: StateStoreColFamilySchema, + storeMetadata: StateStoreMetadata): StateStoreColFamilySchema = { + colFamilySchema.keyStateEncoderSpec match { + case Some(encoderSpec) => colFamilySchema + case None if storeMetadata.isInstanceOf[StateStoreMetadataV1] => + // Create the spec if missing for v1 metadata + if (storeMetadata.numColsPrefixKey > 0) { + colFamilySchema.copy(keyStateEncoderSpec = + Some(PrefixKeyScanStateEncoderSpec(colFamilySchema.keySchema, + storeMetadata.numColsPrefixKey))) + } else { + colFamilySchema.copy(keyStateEncoderSpec = + Some(NoPrefixKeyStateEncoderSpec(colFamilySchema.keySchema))) + } + case _ => + // Key encoder spec is expected in v2 metadata + throw StateRewriterErrors.missingKeyEncoderSpecError( + resolvedCheckpointLocation, colFamilySchema.colFamilyName) + } + } + + private def getStateVariablesIfTWS( + opMetadata: OperatorStateMetadata): Map[String, TransformWithStateVariableInfo] = { + if (StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES + .contains(opMetadata.operatorInfo.operatorName)) { + val operatorProperties = TransformWithStateOperatorProperties.fromJson( + opMetadata.asInstanceOf[OperatorStateMetadataV2].operatorPropertiesJson) + operatorProperties.stateVariables.map(s => s.stateName -> s).toMap + } else { + Map.empty + } + } + + // Needed only for schema evolution for TWS + private def createStoreSchemaProviderIfTWS( + opName: String, + schemaFiles: List[Path]): Option[StateSchemaProvider] = { + if (StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(opName)) { + val schemaMetadata = StateSchemaMetadata.createStateSchemaMetadata( + stateRootLocation, hadoopConf, schemaFiles.map(_.toString)) + Some(new InMemoryStateSchemaProvider(schemaMetadata)) + } else { + None + } + } +} + +/** + * Errors thrown by StateRewriter. + */ +private[state] object StateRewriterErrors { + def missingKeyEncoderSpecError( + checkpointLocation: String, + colFamilyName: String): StateRewriterInvalidCheckpointError = { + new StateRewriterMissingKeyEncoderSpecError(checkpointLocation, colFamilyName) + } + + def missingOperatorMetadataError( + checkpointLocation: String, + batchId: Long): StateRewriterInvalidCheckpointError = { + new StateRewriterMissingOperatorMetadataError(checkpointLocation, batchId) + } + + def unsupportedStateStoreMetadataVersionError( + checkpointLocation: String): StateRewriterInvalidCheckpointError = { + new StateRewriterUnsupportedStoreMetadataVersionError(checkpointLocation) + } +} + +/** + * Base class for exceptions thrown when the checkpoint location is in an invalid state + * for state rewriting. + */ +private[state] abstract class StateRewriterInvalidCheckpointError( + checkpointLocation: String, + subClass: String, + messageParameters: Map[String, String], + cause: Throwable = null) + extends SparkIllegalStateException( + errorClass = s"STATE_REWRITER_INVALID_CHECKPOINT.$subClass", + messageParameters = Map("checkpointLocation" -> checkpointLocation) ++ messageParameters, + cause = cause) + +private[state] class StateRewriterMissingKeyEncoderSpecError( + checkpointLocation: String, + colFamilyName: String) + extends StateRewriterInvalidCheckpointError( + checkpointLocation, + subClass = "MISSING_KEY_ENCODER_SPEC", + messageParameters = Map("colFamilyName" -> colFamilyName)) + +private[state] class StateRewriterMissingOperatorMetadataError( + checkpointLocation: String, + batchId: Long) + extends StateRewriterInvalidCheckpointError( + checkpointLocation, + subClass = "MISSING_OPERATOR_METADATA", + messageParameters = Map("batchId" -> batchId.toString)) + +private[state] class StateRewriterUnsupportedStoreMetadataVersionError( + checkpointLocation: String) + extends StateRewriterInvalidCheckpointError( + checkpointLocation, + subClass = "UNSUPPORTED_STATE_STORE_METADATA_VERSION", + messageParameters = Map.empty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala index 6a1f66262d5f0..d8a3bbb65af26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala @@ -666,7 +666,7 @@ object SessionWindowTestUtils { */ object StreamStreamJoinTestUtils { // All state store names from SymmetricHashJoinStateManager - private val allStoreNames: Seq[String] = + val allStoreNames: Seq[String] = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) // Column family names for keyToNumValues stores (derived from allStateStoreNames) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala index 9501e4e9e36b2..e495db499bfe6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala @@ -20,10 +20,8 @@ import java.io.File import java.sql.Timestamp import java.time.Duration -import org.apache.spark.TaskContext import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.v2.state.{CompositeKeyAggregationTestUtils, DropDuplicatesTestUtils, FlatMapGroupsWithStateTestUtils, SessionWindowTestUtils, SimpleAggregationTestUtils, StateDataSourceTestBase, StateSourceOptions, StreamStreamJoinTestUtils} +import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceTestBase, StateSourceOptions, StreamStreamJoinTestUtils} import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryCheckpointMetadata} @@ -34,7 +32,6 @@ import org.apache.spark.sql.streaming.{InputEvent, ListStateTTLProcessor, MapInp import org.apache.spark.sql.streaming.util.{StreamManualClock, TTLProcessorUtils} import org.apache.spark.sql.streaming.util.{EventTimeTimerProcessor, MultiStateVarProcessor, MultiStateVarProcessorTestUtils, TimerTestUtils} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.SerializableConfiguration /** * Test suite for StatePartitionAllColumnFamiliesWriter. @@ -51,68 +48,33 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "2") } - /** - * Helper method to create a StateSchemaProvider from column family schema map. - */ - private def createStateSchemaProvider( - columnFamilyToSchemaMap: Map[String, StatePartitionWriterColumnFamilyInfo] - ): StateSchemaProvider = { - val testSchemaProvider = new TestStateSchemaProvider() - columnFamilyToSchemaMap.foreach { case (cfName, cfInfo) => - testSchemaProvider.captureSchema( - colFamilyName = cfName, - keySchema = cfInfo.schema.keySchema, - valueSchema = cfInfo.schema.valueSchema, - keySchemaId = cfInfo.schema.keySchemaId, - valueSchemaId = cfInfo.schema.valueSchemaId - ) - } - testSchemaProvider - } - /** * Common helper method to perform round-trip test: read state bytes from source, * write to target, and verify target matches source. * * @param sourceDir Source checkpoint directory * @param targetDir Target checkpoint directory - * @param columnFamilyToSchemaMap Map of column family names to their schemas - * @param storeName Optional store name (for stream-stream join which has multiple stores) - * @param columnFamilyToSelectExprs Map of column family names to custom selectExprs - * @param columnFamilyToStateSourceOptions Map of column family names to state source options + * @param storeToColumnFamilies Optional store name to its column families + * @param storeToColumnFamilyToSelectExprs Map store name to per column family custom selectExprs + * @param storeToColumnFamilyToStateSourceOptions Map store name to per column family + * state source options */ private def performRoundTripTest( sourceDir: String, targetDir: String, - columnFamilyToSchemaMap: Map[String, StatePartitionWriterColumnFamilyInfo], - storeName: Option[String] = None, - columnFamilyToSelectExprs: Map[String, Seq[String]] = Map.empty, - columnFamilyToStateSourceOptions: Map[String, Map[String, String]] = Map.empty, + storeToColumnFamilies: Map[String, List[String]] = + Map(StateStoreId.DEFAULT_STORE_NAME -> List(StateStore.DEFAULT_COL_FAMILY_NAME)), + storeToColumnFamilyToSelectExprs: Map[String, Map[String, Seq[String]]] = Map.empty, + storeToColumnFamilyToStateSourceOptions: Map[String, Map[String, Map[String, String]]] = + Map.empty, operatorName: String): Unit = { - - val columnFamiliesToValidate: Seq[String] = if (columnFamilyToSchemaMap.size > 1) { - columnFamilyToSchemaMap.keys.toSeq - } else { - Seq(StateStore.DEFAULT_COL_FAMILY_NAME) - } - - // Step 1: Read from source using AllColumnFamiliesReader (raw bytes) - val sourceBytesReader = spark.read - .format("statestore") - .option(StateSourceOptions.PATH, sourceDir) - .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") - val sourceBytesData = (storeName match { - case Some(name) => sourceBytesReader.option(StateSourceOptions.STORE_NAME, name) - case None => sourceBytesReader - }).load() - - // Verify schema of raw bytes - val schema = sourceBytesData.schema - assert(schema.fieldNames === Array( - "partition_key", "key_bytes", "value_bytes", "column_family_name")) - - // Step 2: Write raw bytes to target checkpoint location val hadoopConf = spark.sessionState.newHadoopConf() + val sourceCpLocation = StreamingUtils.resolvedCheckpointLocation( + hadoopConf, sourceDir) + val sourceCheckpointMetadata = new StreamingQueryCheckpointMetadata( + spark, sourceCpLocation) + val readBatchId = sourceCheckpointMetadata.commitLog.getLatestBatchId().get + val targetCpLocation = StreamingUtils.resolvedCheckpointLocation( hadoopConf, targetDir) val targetCheckpointMetadata = new StreamingQueryCheckpointMetadata( @@ -120,67 +82,56 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase // increase offsetCheckpoint val lastBatch = targetCheckpointMetadata.commitLog.getLatestBatchId().get val targetOffsetSeq = targetCheckpointMetadata.offsetLog.get(lastBatch).get - val currentBatchId = lastBatch + 1 - targetCheckpointMetadata.offsetLog.add(currentBatchId, targetOffsetSeq) - - val storeConf: StateStoreConf = StateStoreConf(spark.sessionState.conf) - val serializableHadoopConf = new SerializableConfiguration(hadoopConf) - - // Create StateSchemaProvider if needed (for Avro encoding) - val stateSchemaProvider = if (storeConf.stateStoreEncodingFormat == "avro") { - Some(createStateSchemaProvider(columnFamilyToSchemaMap)) - } else { - None - } - val baseConfs: Map[String, String] = spark.sessionState.conf.getAllConfs - val putPartitionFunc: Iterator[InternalRow] => Unit = partition => { - val newConf = new SQLConf - baseConfs.foreach { case (k, v) => - newConf.setConfString(k, v) - } - val allCFWriter = new StatePartitionAllColumnFamiliesWriter( - storeConf, - serializableHadoopConf.value, - TaskContext.getPartitionId(), - targetCpLocation, - 0, - storeName.getOrElse(StateStoreId.DEFAULT_STORE_NAME), - currentBatchId, - columnFamilyToSchemaMap, - operatorName, - stateSchemaProvider, - newConf - ) - allCFWriter.write(partition) - } - sourceBytesData.queryExecution.toRdd.foreachPartition(putPartitionFunc) + val writeBatchId = lastBatch + 1 + targetCheckpointMetadata.offsetLog.add(writeBatchId, targetOffsetSeq) + + val rewriter = new StateRewriter( + spark, + readBatchId, + writeBatchId, + targetCpLocation, + hadoopConf, + readResolvedCheckpointLocation = Some(sourceCpLocation), + transformFunc = None, + writeCheckpointMetadata = Some(targetCheckpointMetadata) + ) + rewriter.run() // Commit to commitLog val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get - targetCheckpointMetadata.commitLog.add(currentBatchId, latestCommit) - val versionToCheck = currentBatchId + 1 - val storeNamePath = s"state/0/0${storeName.fold("")("/" + _)}" - assert(!checkpointFileExists(new File(targetDir, storeNamePath), versionToCheck, ".changelog")) - assert(checkpointFileExists(new File(targetDir, storeNamePath), versionToCheck, ".zip")) + targetCheckpointMetadata.commitLog.add(writeBatchId, latestCommit) + val versionToCheck = writeBatchId + 1 + + storeToColumnFamilies.foreach { case (storeName, columnFamilies) => + val storeNamePath = if (storeName == StateStoreId.DEFAULT_STORE_NAME) { + "state/0/0" + } else { + s"state/0/0/$storeName" + } + assert(!checkpointFileExists(new File(targetDir, storeNamePath), + versionToCheck, ".changelog")) + assert(checkpointFileExists(new File(targetDir, storeNamePath), versionToCheck, ".zip")) - // Step 3: Validate by reading from both source and target using normal reader" - // Default selectExprs for most column families - val defaultSelectExprs = Seq("key", "value", "partition_id") + // Validate by reading from both source and target using normal reader" + // Default selectExprs for most column families + val defaultSelectExprs = Seq("key", "value", "partition_id") - columnFamiliesToValidate + columnFamilies // filtering out "default" for TWS operator because it doesn't contain any data .filter(cfName => !(cfName == StateStore.DEFAULT_COL_FAMILY_NAME && StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName) )) .foreach { cfName => - val selectExprs = columnFamilyToSelectExprs.getOrElse(cfName, defaultSelectExprs) - val readerOptions = columnFamilyToStateSourceOptions.getOrElse(cfName, Map.empty) + val selectExprs = storeToColumnFamilyToSelectExprs.getOrElse(storeName, Map.empty) + .getOrElse(cfName, defaultSelectExprs) + val readerOptions = storeToColumnFamilyToStateSourceOptions.getOrElse(storeName, Map.empty) + .getOrElse(cfName, Map.empty) def readNormalData(dir: String): Array[Row] = { var reader = spark.read .format("statestore") .option(StateSourceOptions.PATH, dir) - .option(StateSourceOptions.STORE_NAME, storeName.orNull) + .option(StateSourceOptions.STORE_NAME, storeName) readerOptions.foreach { case (k, v) => reader = reader.option(k, v) } reader.load() .selectExpr(selectExprs: _*) @@ -192,6 +143,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase validateDataMatches(sourceNormalData, targetNormalData) } + } } /** @@ -308,15 +260,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase ) ) - // Step 2: Define schemas based on state version - val metadata = SimpleAggregationTestUtils.getSchemasWithMetadata(stateVersion) - // Perform round-trip test using common helper performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - createSingleColumnFamilySchemaMap( - metadata.keySchema, metadata.valueSchema, metadata.encoderSpec), operatorName = StatefulOperatorsUtils.STATE_STORE_SAVE_EXEC_OP_NAME ) } @@ -349,15 +296,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase ) ) - // Step 2: Define schemas based on state version for composite key - val metadata = CompositeKeyAggregationTestUtils.getSchemasWithMetadata(stateVersion) - // Perform round-trip test using common helper performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - createSingleColumnFamilySchemaMap( - metadata.keySchema, metadata.valueSchema, metadata.encoderSpec), operatorName = StatefulOperatorsUtils.STATE_STORE_SAVE_EXEC_OP_NAME ) } @@ -396,36 +338,15 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase ) // Step 2: Test all 4 state stores created by stream-stream join - // Test keyToNumValues stores (both left and right) - StreamStreamJoinTestUtils.KEY_TO_NUM_VALUES_ALL.foreach { storeName => - val metadata = StreamStreamJoinTestUtils.getKeyToNumValuesSchemasWithMetadata() - - // Perform round-trip test using common helper - performRoundTripTest( - sourceDir.getAbsolutePath, - targetDir.getAbsolutePath, - createSingleColumnFamilySchemaMap( - metadata.keySchema, metadata.valueSchema, metadata.encoderSpec), - storeName = Some(storeName), - operatorName = StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME - ) - } - - // Test keyWithIndexToValue stores (both left and right) - StreamStreamJoinTestUtils.KEY_WITH_INDEX_ALL.foreach { storeName => - val metadata = - StreamStreamJoinTestUtils.getKeyWithIndexToValueSchemasWithMetadata(stateVersion) - - // Perform round-trip test using common helper - performRoundTripTest( - sourceDir.getAbsolutePath, - targetDir.getAbsolutePath, - createSingleColumnFamilySchemaMap( - metadata.keySchema, metadata.valueSchema, metadata.encoderSpec), - storeName = Some(storeName), - operatorName = StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME - ) - } + val storeToColumnFamilies = StreamStreamJoinTestUtils.allStoreNames + .map(s => s -> List(StateStore.DEFAULT_COL_FAMILY_NAME)).toMap + // Perform round-trip test using common helper + performRoundTripTest( + sourceDir.getAbsolutePath, + targetDir.getAbsolutePath, + storeToColumnFamilies, + operatorName = StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME + ) } } } @@ -458,15 +379,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase CheckLastBatch(("a", 1, 0, false)) ) - // Step 2: Define schemas for flatMapGroupsWithState - val metadata = FlatMapGroupsWithStateTestUtils.getSchemasWithMetadata(stateVersion) - // Perform round-trip test using common helper performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - createSingleColumnFamilySchemaMap( - metadata.keySchema, metadata.valueSchema, metadata.encoderSpec), operatorName = StatefulOperatorsUtils.FLAT_MAP_GROUPS_WITH_STATE_EXEC_OP_NAME ) } @@ -474,7 +390,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase } } - /** * Helper method to build timer column family schemas and options for * RunningCountStatefulProcessorWithProcTimeTimer and EventTimeTimerProcessor @@ -564,16 +479,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase CheckAnswer(("a", 1)) ) - // Step 2: Define schemas for dropDuplicatesWithinWatermark - val metadata = - DropDuplicatesTestUtils.getDropDuplicatesWithinWatermarkSchemasWithMetadata() - // Perform round-trip test using common helper performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - createSingleColumnFamilySchemaMap( - metadata.keySchema, metadata.valueSchema, metadata.encoderSpec), operatorName = StatefulOperatorsUtils.DEDUPLICATE_WITHIN_WATERMARK_EXEC_OP_NAME ) } @@ -595,16 +504,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase CheckAnswer(("a", 1)) ) - // Step 2: Define schemas for dropDuplicates with column specified - val metadata = - DropDuplicatesTestUtils.getDropDuplicatesWithColumnSchemasWithMetadata() - // Perform round-trip test using common helper performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - createSingleColumnFamilySchemaMap( - metadata.keySchema, metadata.valueSchema, metadata.encoderSpec), operatorName = StatefulOperatorsUtils.DEDUPLICATE_EXEC_OP_NAME ) } @@ -629,16 +532,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase StopStream ) - // Step 2: Define schemas for session window aggregation - val (keySchema, valueSchema) = SessionWindowTestUtils.getSchemas() - // Session window aggregation uses prefix key scanning where sessionId is the prefix - val keyStateEncoderSpec = PrefixKeyScanStateEncoderSpec(keySchema, 1) - // Perform round-trip test using common helper performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - createSingleColumnFamilySchemaMap(keySchema, valueSchema, keyStateEncoderSpec), operatorName = StatefulOperatorsUtils.SESSION_WINDOW_STATE_STORE_SAVE_EXEC_OP_NAME ) } @@ -660,15 +557,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase assertNumStateRows(total = 6, updated = 6) ) - // Step 2: Define schemas for dropDuplicates (state version 2) - val metadata = DropDuplicatesTestUtils.getDropDuplicatesSchemasWithMetadata() - // Perform round-trip test using common helper performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - createSingleColumnFamilySchemaMap( - metadata.keySchema, metadata.valueSchema, metadata.encoderSpec), operatorName = StatefulOperatorsUtils.DEDUPLICATE_EXEC_OP_NAME ) } @@ -713,15 +605,15 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase runQuery(sourceDir.getAbsolutePath, roundsOfData = 2) runQuery(targetDir.getAbsolutePath, roundsOfData = 1) - val allColFamilyNames = StreamStreamJoinTestUtils.KEY_TO_NUM_VALUES_ALL ++ - StreamStreamJoinTestUtils.KEY_WITH_INDEX_ALL + val allColFamilyNames = StreamStreamJoinTestUtils.allStoreNames.toList performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - getJoinV3ColumnSchemaMap(), - columnFamilyToStateSourceOptions = allColFamilyNames.map { - colName => colName -> Map(StateSourceOptions.STORE_NAME -> colName) - }.toMap, + storeToColumnFamilies = Map(StateStoreId.DEFAULT_STORE_NAME -> allColFamilyNames), + storeToColumnFamilyToStateSourceOptions = + Map(StateStoreId.DEFAULT_STORE_NAME -> allColFamilyNames.map { + cfName => cfName -> Map(StateSourceOptions.STORE_NAME -> cfName) + }.toMap), operatorName = StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME ) } @@ -770,14 +662,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase runQuery(targetDir.getAbsolutePath, 1) val schemas = MultiStateVarProcessorTestUtils.getSchemasWithMetadata() - val columnFamilyToSchemaMap = schemas.map { case (cfName, metadata) => - cfName -> createColFamilyInfo( - metadata.keySchema, - metadata.valueSchema, - metadata.encoderSpec, - cfName, - metadata.useMultipleValuePerKey) - } val columnFamilyToSelectExprs = MultiStateVarProcessorTestUtils .getColumnFamilyToSelectExprs() @@ -799,9 +683,11 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - columnFamilyToSchemaMap, - columnFamilyToSelectExprs = columnFamilyToSelectExprs, - columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions, + storeToColumnFamilies = Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList), + storeToColumnFamilyToSelectExprs = + Map(StateStoreId.DEFAULT_STORE_NAME -> columnFamilyToSelectExprs), + storeToColumnFamilyToStateSourceOptions = + Map(StateStoreId.DEFAULT_STORE_NAME -> columnFamilyToStateSourceOptions), operatorName = StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME ) } @@ -842,9 +728,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - schemaMap, - columnFamilyToSelectExprs = selectExprs, - columnFamilyToStateSourceOptions = stateSourceOptions, + storeToColumnFamilies = + Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList), + storeToColumnFamilyToSelectExprs = + Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs), + storeToColumnFamilyToStateSourceOptions = + Map(StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions), operatorName = StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME ) } @@ -890,9 +779,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - schemaMap, - columnFamilyToSelectExprs = selectExprs, - columnFamilyToStateSourceOptions = sourceOptions, + storeToColumnFamilies = + Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList), + storeToColumnFamilyToSelectExprs = + Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs), + storeToColumnFamilyToStateSourceOptions = + Map(StateStoreId.DEFAULT_STORE_NAME -> sourceOptions), operatorName = StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME ) } @@ -933,14 +825,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase ) val schemas = TTLProcessorUtils.getListStateTTLSchemasWithMetadata() - val columnFamilyToSchemaMap = schemas.map { case (cfName, metadata) => - cfName -> createColFamilyInfo( - metadata.keySchema, - metadata.valueSchema, - metadata.encoderSpec, - cfName, - metadata.useMultipleValuePerKey) - } val columnFamilyToSelectExprs = Map( TTLProcessorUtils.LIST_STATE -> TTLProcessorUtils.getTTLSelectExpressions( @@ -963,9 +847,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - columnFamilyToSchemaMap, - columnFamilyToSelectExprs = columnFamilyToSelectExprs, - columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions, + storeToColumnFamilies = + Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList), + storeToColumnFamilyToSelectExprs = + Map(StateStoreId.DEFAULT_STORE_NAME -> columnFamilyToSelectExprs), + storeToColumnFamilyToStateSourceOptions = + Map(StateStoreId.DEFAULT_STORE_NAME -> columnFamilyToStateSourceOptions), operatorName = StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME ) } @@ -1006,14 +893,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase ) val schemas = TTLProcessorUtils.getMapStateTTLSchemasWithMetadata() - val columnFamilyToSchemaMap = schemas.map { case (cfName, metadata) => - cfName -> createColFamilyInfo( - metadata.keySchema, - metadata.valueSchema, - metadata.encoderSpec, - cfName, - metadata.useMultipleValuePerKey) - } val columnFamilyToSelectExprs = Map( TTLProcessorUtils.MAP_STATE -> TTLProcessorUtils.getTTLSelectExpressions( @@ -1027,9 +906,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - columnFamilyToSchemaMap, - columnFamilyToSelectExprs = columnFamilyToSelectExprs, - columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions, + storeToColumnFamilies = + Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList), + storeToColumnFamilyToSelectExprs = + Map(StateStoreId.DEFAULT_STORE_NAME -> columnFamilyToSelectExprs), + storeToColumnFamilyToStateSourceOptions = + Map(StateStoreId.DEFAULT_STORE_NAME -> columnFamilyToStateSourceOptions), operatorName = StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME ) } @@ -1071,14 +953,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase ) val schemas = TTLProcessorUtils.getValueStateTTLSchemasWithMetadata() - val columnFamilyToSchemaMap = schemas.map { case (cfName, metadata) => - cfName -> createColFamilyInfo( - metadata.keySchema, - metadata.valueSchema, - metadata.encoderSpec, - cfName, - metadata.useMultipleValuePerKey) - } val columnFamilyToStateSourceOptions = schemas.keys.map { cfName => cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName) @@ -1087,8 +961,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase performRoundTripTest( sourceDir.getAbsolutePath, targetDir.getAbsolutePath, - columnFamilyToSchemaMap, - columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions, + storeToColumnFamilies = + Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList), + storeToColumnFamilyToStateSourceOptions = + Map(StateStoreId.DEFAULT_STORE_NAME -> columnFamilyToStateSourceOptions), operatorName = StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME ) }