Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpoint
import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors
import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.streaming.TimeMode
import org.apache.spark.sql.types.StructType
Expand All @@ -66,7 +67,10 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
properties: util.Map[String, String]): Table = {
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
StateSourceOptions.apply(session, hadoopConf, properties))
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
// Build the sql conf for the batch we are reading using confs in the offsetlog
val batchSqlConf =
buildSqlConfForBatch(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
val stateConf = StateStoreConf(batchSqlConf)
// We only support RocksDB because the repartition work that this option
// is built for only supports RocksDB
if (sourceOptions.internalOnlyReadAllColumnFamilies
Expand All @@ -86,7 +90,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
NoPrefixKeyStateEncoderSpec(keySchema)
}

new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
new StateTable(session, schema, sourceOptions, stateConf,
batchSqlConf.getConf(SQLConf.STATEFUL_SHUFFLE_PARTITIONS_INTERNAL).get, keyStateEncoderSpec,
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
stateStoreReaderInfo.stateSchemaProviderOpt,
Expand Down Expand Up @@ -147,7 +152,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
sourceOptions.operatorId)
}

private def buildStateStoreConf(checkpointLocation: String, batchId: Long): StateStoreConf = {
private def buildSqlConfForBatch(
checkpointLocation: String,
batchId: Long): SQLConf = {
val offsetLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).offsetLog
offsetLog.get(batchId) match {
case Some(value) =>
Expand All @@ -157,7 +164,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging

val clonedSqlConf = session.sessionState.conf.clone()
OffsetSeqMetadata.setSessionConf(metadata, clonedSqlConf)
StateStoreConf(clonedSqlConf)
clonedSqlConf

case _ =>
throw StateDataSourceErrors.offsetLogUnavailable(batchId, checkpointLocation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ class StateScanBuilder(
schema: StructType,
sourceOptions: StateSourceOptions,
stateStoreConf: StateStoreConf,
batchNumPartitions: Int,
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String],
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf,
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
batchNumPartitions, keyStateEncoderSpec,
stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
joinColFamilyOpt, allColumnFamiliesReaderInfo)
}

Expand All @@ -65,6 +67,7 @@ class StateScan(
schema: StructType,
sourceOptions: StateSourceOptions,
stateStoreConf: StateStoreConf,
batchNumPartitions: Int,
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
Expand All @@ -85,7 +88,11 @@ class StateScan(
val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new PathFilter() {
override def accept(path: Path): Boolean = {
fs.getFileStatus(path).isDirectory &&
Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0
Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 &&
// Since we now support state repartitioning, it is possible that a future batch has
// increased the number of partitions, hence increased the number of partition directories.
// So we only want partition dirs for the number of partitions in this batch.
path.getName.toInt < batchNumPartitions
}
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class StateTable(
override val schema: StructType,
sourceOptions: StateSourceOptions,
stateConf: StateStoreConf,
batchNumPartitions: Int,
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
Expand Down Expand Up @@ -86,7 +87,8 @@ class StateTable(
override def capabilities(): util.Set[TableCapability] = CAPABILITY

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
new StateScanBuilder(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
new StateScanBuilder(session, schema, sourceOptions, stateConf,
batchNumPartitions, keyStateEncoderSpec,
stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
joinColFamilyOpt, allColumnFamiliesReaderInfo)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@

package org.apache.spark.sql.execution.streaming.state

import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.streaming.checkpointing.{CommitMetadata, OffsetMap, OffsetSeq, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataBase}
import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata
import org.apache.spark.sql.execution.streaming.runtime.{StreamingCheckpointConstants, StreamingQueryCheckpointMetadata}
import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils
import org.apache.spark.util.{SerializableConfiguration, Utils}

/**
* Runs repartitioning for the state stores used by a streaming query.
Expand Down Expand Up @@ -78,9 +82,22 @@ class OfflineStateRepartitionRunner(

val newBatchId = createNewBatchIfNeeded(lastBatchId, lastCommittedBatchId)

// todo(SPARK-54365): Do the repartitioning here, in subsequent PR

// todo(SPARK-54365): update operator metadata in subsequent PR.
val stateRepartitionFunc = (stateDf: DataFrame) => {
// Repartition the state by the partition key
stateDf.repartition(numPartitions, col("partition_key"))
}
val rewriter = new StateRewriter(
sparkSession,
readBatchId = lastCommittedBatchId,
writeBatchId = newBatchId,
resolvedCpLocation,
hadoopConf,
transformFunc = Some(stateRepartitionFunc),
writeCheckpointMetadata = Some(checkpointMetadata)
)
rewriter.run()

updateNumPartitionsInOperatorMetadata(newBatchId, readBatchId = lastCommittedBatchId)

// Commit the repartition batch
commitBatch(newBatchId, lastCommittedBatchId)
Expand Down Expand Up @@ -229,6 +246,49 @@ class OfflineStateRepartitionRunner(
newBatchId
}

private def updateNumPartitionsInOperatorMetadata(
newBatchId: Long,
readBatchId: Long): Unit = {
val stateMetadataReader = new StateMetadataPartitionReader(
resolvedCpLocation,
new SerializableConfiguration(hadoopConf),
readBatchId)

val allOperatorsMetadata = stateMetadataReader.allOperatorStateMetadata
assert(allOperatorsMetadata.nonEmpty, "Operator metadata shouldn't be empty")

val stateRootLocation = new Path(
resolvedCpLocation, StreamingCheckpointConstants.DIR_NAME_STATE).toString

allOperatorsMetadata.foreach { opMetadata =>
opMetadata match {
// We would only update shuffle partitions for v2 op metadata since it is versioned.
// For v1, we wouldn't update it since there is only one metadata file.
case v2: OperatorStateMetadataV2 =>
// update for each state store
val updatedStoreInfo = v2.stateStoreInfo.map { stateStore =>
stateStore.copy(numPartitions = numPartitions)
}
val updatedMetadata = v2.copy(stateStoreInfo = updatedStoreInfo)
// write the updated metadata
val metadataWriter = OperatorStateMetadataWriter.createWriter(
new Path(stateRootLocation, updatedMetadata.operatorInfo.operatorId.toString),
hadoopConf,
updatedMetadata.version,
Some(newBatchId))
metadataWriter.write(updatedMetadata)

logInfo(log"Updated operator metadata for " +
log"operator=${MDC(OP_TYPE, updatedMetadata.operatorInfo.operatorName)}, " +
log"numStateStores=${MDC(COUNT, updatedMetadata.stateStoreInfo.length)}")
case v =>
logInfo(log"Skipping operator metadata update for " +
log"operator=${MDC(OP_TYPE, v.operatorInfo.operatorName)}, " +
log"since metadata version(${MDC(FILE_VERSION, v.version)}) is not versioned")
}
}
}

private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit = {
val latestCommit = checkpointMetadata.commitLog.get(lastCommittedBatchId).get

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ class RocksDBFileManager(
}
}

private def createDfsRootDirIfNotExist(): Unit = {
if (!rootDirChecked) {
val rootDir = new Path(dfsRootDir)
if (!fm.exists(rootDir)) fm.mkdirs(rootDir)
rootDirChecked = true
}
}

def getChangeLogWriter(
version: Long,
useColumnFamilies: Boolean = false,
Expand All @@ -209,11 +217,7 @@ class RocksDBFileManager(
): StateStoreChangelogWriter = {
try {
val changelogFile = dfsChangelogFile(version, checkpointUniqueId)
if (!rootDirChecked) {
val rootDir = new Path(dfsRootDir)
if (!fm.exists(rootDir)) fm.mkdirs(rootDir)
rootDirChecked = true
}
createDfsRootDirIfNotExist()

val enableStateStoreCheckpointIds = checkpointUniqueId.isDefined
val changelogVersion = getChangelogWriterVersion(
Expand Down Expand Up @@ -332,11 +336,7 @@ class RocksDBFileManager(
// CheckpointFileManager.createAtomic API which doesn't auto-initialize parent directories.
// Moreover, once we disable to track the number of keys, in which the numKeys is -1, we
// still need to create the initial dfs root directory anyway.
if (!rootDirChecked) {
val path = new Path(dfsRootDir)
if (!fm.exists(path)) fm.mkdirs(path)
rootDirChecked = true
}
createDfsRootDirIfNotExist()
}
zipToDfsFile(localOtherFiles :+ metadataFile,
dfsBatchZipFile(version, checkpointUniqueId), verifyNonEmptyFilesInZip)
Expand Down Expand Up @@ -372,6 +372,7 @@ class RocksDBFileManager(
val metadata = if (version == 0) {
if (localDir.exists) Utils.deleteRecursively(localDir)
Utils.createDirectory(localDir)
createDfsRootDirIfNotExist()
// Since we cleared the local dir, we should also clear the local file mapping
rocksDBFileMapping.clear()
RocksDBCheckpointMetadata(Seq.empty, 0)
Expand Down Expand Up @@ -404,11 +405,7 @@ class RocksDBFileManager(
// Return if there is a snapshot file at the corresponding version
// and optionally with checkpointunique id, e.g. version.zip or version_uniqueId.zip
def existsSnapshotFile(version: Long, checkpointUniqueId: Option[String] = None): Boolean = {
if (!rootDirChecked) {
val path = new Path(dfsRootDir)
if (!fm.exists(path)) fm.mkdirs(path)
rootDirChecked = true
}
createDfsRootDirIfNotExist()
fm.exists(dfsBatchZipFile(version, checkpointUniqueId))
}

Expand Down
Loading