Skip to content

Commit 47bc3bc

Browse files
committed
parquet_writer
1 parent f249f97 commit 47bc3bc

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ import java.util.Locale
2424
import scala.jdk.CollectionConverters._
2525

2626
import org.apache.spark.SparkException
27+
import org.apache.spark.sql.{SaveMode, SparkSession}
2728
import org.apache.spark.sql.comet.{CometNativeExec, CometNativeWriteExec}
2829
import org.apache.spark.sql.execution.command.DataWritingCommandExec
2930
import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, WriteFilesExec}
3031
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
3132
import org.apache.spark.sql.internal.SQLConf
33+
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
3234

3335
import org.apache.comet.{CometConf, ConfigEntry, DataTypeSupport}
3436
import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -61,6 +63,10 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec
6163
return Unsupported(Some("Bucketed writes are not supported"))
6264
}
6365

66+
if (SQLConf.get.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) {
67+
return Unsupported(Some("Dynamic partition overwrite is not supported"))
68+
}
69+
6470
if (cmd.partitionColumns.nonEmpty || cmd.staticPartitions.nonEmpty) {
6571
return Incompatible(Some("Partitioned writes are highly experimental"))
6672
}
@@ -158,6 +164,14 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec
158164
val cmd = op.cmd.asInstanceOf[InsertIntoHadoopFsRelationCommand]
159165
val outputPath = cmd.outputPath.toString
160166

167+
// TODO : support dynamic partition overwrite
168+
if (cmd.mode == SaveMode.Overwrite) {
169+
val fs = cmd.outputPath.getFileSystem(SparkSession.active.sparkContext.hadoopConfiguration)
170+
if (fs.exists(cmd.outputPath)) {
171+
fs.delete(cmd.outputPath, true)
172+
}
173+
}
174+
161175
// Get the child plan from the WriteFilesExec or use the child directly
162176
val childPlan = op.child match {
163177
case writeFiles: WriteFilesExec =>
@@ -168,8 +182,6 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec
168182
other
169183
}
170184

171-
val isDynamicOverWriteMode = cmd.partitionColumns.nonEmpty
172-
173185
// Create FileCommitProtocol for atomic writes
174186
val jobId = java.util.UUID.randomUUID().toString
175187
val committer =

spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.{CometTestBase, DataFrame}
2727
import org.apache.spark.sql.comet.{CometNativeScanExec, CometNativeWriteExec}
2828
import org.apache.spark.sql.execution.QueryExecution
2929
import org.apache.spark.sql.execution.command.DataWritingCommandExec
30-
import org.apache.spark.sql.functions.col
30+
import org.apache.spark.sql.functions.{col, lit}
3131
import org.apache.spark.sql.internal.SQLConf
3232

3333
import org.apache.comet.CometConf
@@ -273,4 +273,52 @@ class CometParquetWriterSuite extends CometTestBase {
273273
}
274274
}
275275
}
276+
277+
test("partitioned write - data correctness per partition") {
278+
withTempPath { dir =>
279+
val outputPath = new File(dir, "output").getAbsolutePath
280+
281+
withTempPath { inputDir =>
282+
val inputPath = createTestData(inputDir)
283+
284+
withSQLConf(
285+
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
286+
CometConf.getOperatorAllowIncompatConfigKey(
287+
classOf[DataWritingCommandExec]) -> "true") {
288+
289+
val inputDf = spark.read.parquet(inputPath).filter(col("c1") <= lit(10))
290+
val partCols = inputDf.columns.take(2)
291+
val col1 = partCols(0)
292+
val col2 = partCols(1)
293+
294+
inputDf.write.partitionBy(partCols: _*).parquet(outputPath)
295+
296+
// unique combinations
297+
val combinations = inputDf
298+
.select(partCols.head, partCols.last)
299+
.distinct()
300+
.collect()
301+
.map(r => (r.getBoolean(0), r.getByte(1)))
302+
303+
combinations.foreach { tuple =>
304+
val val1 = tuple._1
305+
val val2 = tuple._2
306+
307+
val partitionPath = s"$outputPath/${partCols.head}=$val1/${partCols.last}=$val2"
308+
309+
val actualDf = spark.read.parquet(partitionPath)
310+
val expectedDf = inputDf
311+
.filter(col(col1) === val1)
312+
.filter(col(col2) === val2)
313+
.drop(col1, col2)
314+
315+
checkAnswer(actualDf, expectedDf)
316+
}
317+
318+
// Verify total count as well
319+
checkAnswer(spark.read.parquet(outputPath), inputDf)
320+
}
321+
}
322+
}
323+
}
276324
}

0 commit comments

Comments
 (0)