Skip to content

Commit 1472db1

Browse files
Writing field id when writing iceberg's data file (#14328)
Fixes nvbugs-5894334. ### Description When generating iceberg data file, we should write field into it so that manifest metrics could be correctly produced. ### Checklists - [x] This PR has added documentation for new or modified features or behaviors. - [x] This PR has added new tests or modified existing tests to cover new code paths. (Please explain in the PR description how the new code paths are tested, such as names of the new/existing tests that cover them.) - [ ] Performance testing has been performed and its results are added in the PR description. Or, an issue has been filed with a link in the PR description. --------- Signed-off-by: Ray Liu <liurenjie2008@gmail.com>
1 parent 7a49f8d commit 1472db1

File tree

2 files changed

+82
-3
lines changed

2 files changed

+82
-3
lines changed

iceberg/common/src/main/scala/org/apache/iceberg/spark/source/GpuSparkWrite.scala

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2025, NVIDIA CORPORATION.
2+
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -25,11 +25,12 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableSeq
2525
import com.nvidia.spark.rapids.SpillPriorities.ACTIVE_ON_DECK_PRIORITY
2626
import com.nvidia.spark.rapids.fileio.iceberg.IcebergFileIO
2727
import com.nvidia.spark.rapids.iceberg.GpuIcebergSpecPartitioner
28+
import com.nvidia.spark.rapids.shims.parquet.ParquetFieldIdShims
2829
import org.apache.hadoop.mapreduce.Job
2930
import org.apache.hadoop.shaded.org.apache.commons.lang3.reflect.{FieldUtils, MethodUtils}
3031
import org.apache.iceberg._
3132
import org.apache.iceberg.io._
32-
import org.apache.iceberg.spark.{Spark3Util, SparkSchemaUtil}
33+
import org.apache.iceberg.spark.{GpuTypeToSparkType, Spark3Util, SparkSchemaUtil}
3334
import org.apache.iceberg.spark.functions.{GpuFieldTransform, GpuTransform}
3435
import org.apache.iceberg.spark.source.GpuWriteContext.positionDeleteSparkType
3536
import org.apache.iceberg.spark.source.SparkWrite.TaskCommit
@@ -42,8 +43,10 @@ import org.apache.spark.sql.connector.distributions.Distribution
4243
import org.apache.spark.sql.connector.expressions.SortOrder
4344
import org.apache.spark.sql.connector.write.{DataWriter, _}
4445
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
46+
import org.apache.spark.sql.execution.SparkPlan
4547
import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
4648
import org.apache.spark.sql.rapids.GpuWriteJobStatsTracker
49+
import org.apache.spark.sql.rapids.shims.SparkSessionUtils
4750
import org.apache.spark.sql.types.StructType
4851
import org.apache.spark.sql.vectorized.ColumnarBatch
4952
import org.apache.spark.util.SerializableConfiguration
@@ -104,7 +107,12 @@ class GpuSparkWrite(cpu: SparkWrite) extends GpuWrite with RequiresDistributionA
104107
val outputSpecId = FieldUtils.readField(cpu, "outputSpecId", true).asInstanceOf[Int]
105108
val targetFileSize = FieldUtils.readField(cpu, "targetFileSize", true).asInstanceOf[Long]
106109
val writeSchema = FieldUtils.readField(cpu, "writeSchema", true).asInstanceOf[Schema]
107-
val dsSchema = FieldUtils.readField(cpu, "dsSchema", true).asInstanceOf[StructType]
110+
// Convert writeSchema to Spark StructType with Iceberg field IDs (PARQUET:field_id).
111+
// The CPU path uses Iceberg's own Parquet writer which natively embeds field IDs, but
112+
// the GPU path uses Spark's Parquet infrastructure which requires field IDs in the
113+
// StructType metadata. Without them, Iceberg's ParquetMetrics cannot extract file-level
114+
// statistics, causing StrictMetricsEvaluator to fail during overwrite validation.
115+
val dsSchema = GpuTypeToSparkType.toSparkType(writeSchema)
108116
val useFanout = FieldUtils.readField(cpu, "useFanoutWriter", true).asInstanceOf[Boolean]
109117
val writeProps = FieldUtils.readField(cpu, "writeProperties", true)
110118
.asInstanceOf[java.util.Map[String, String]]
@@ -115,6 +123,7 @@ class GpuSparkWrite(cpu: SparkWrite) extends GpuWrite with RequiresDistributionA
115123
}
116124

117125
val hadoopConf = sparkContext.hadoopConfiguration
126+
118127
val job = {
119128
val tmpJob = Job.getInstance(hadoopConf)
120129
tmpJob.setOutputKeyClass(classOf[Void])
@@ -180,6 +189,16 @@ object GpuSparkWrite {
180189
partitionSpec: PartitionSpec,
181190
meta: SparkPlanMeta[_]): Unit = {
182191

192+
// Iceberg requires Parquet field IDs for correct file-level metrics. Without them,
193+
// StrictMetricsEvaluator fails during overwrite validation.
194+
val spark = SparkSessionUtils.sessionFromPlan(meta.wrapped.asInstanceOf[SparkPlan])
195+
val hadoopConf = spark.sparkContext.hadoopConfiguration
196+
val sqlConf = spark.sessionState.conf
197+
if (!ParquetFieldIdShims.getParquetIdWriteEnabled(hadoopConf, sqlConf)) {
198+
meta.willNotWorkOnGpu("Iceberg requires Parquet field IDs to be written for correct " +
199+
"file-level metrics. Set spark.sql.parquet.fieldId.write.enabled=true")
200+
}
201+
183202
// Check file format support
184203
if (dataFormat.exists(!_.equals(FileFormat.PARQUET))) {
185204
meta.willNotWorkOnGpu(s"GpuSparkWrite only supports Parquet, but got: ${dataFormat.get}")
@@ -292,6 +311,7 @@ object GpuSparkWrite {
292311
def convert(cpuWrite: Write): GpuSparkWrite = {
293312
new GpuSparkWrite(cpuWrite.asInstanceOf[SparkWrite])
294313
}
314+
295315
}
296316

297317
class GpuWriterFactory(val tableBroadcast: Broadcast[Table],

integration_tests/src/main/python/iceberg/iceberg_overwrite_static_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import Callable, Any
1515

1616
import pytest
17+
from pyspark.sql import functions as F
1718

1819
from asserts import assert_equal_with_local_sort, assert_gpu_fallback_collect
1920
from conftest import is_iceberg_remote_catalog
@@ -404,3 +405,61 @@ def overwrite_data(spark, table_name):
404405
cpu_data = with_cpu_session(lambda spark: spark.table(cpu_table_name).collect())
405406
gpu_data = with_cpu_session(lambda spark: spark.table(gpu_table_name).collect())
406407
assert_equal_with_local_sort(cpu_data, gpu_data)
408+
409+
410+
@iceberg
411+
@ignore_order(local=True)
412+
@allow_non_gpu('ShuffleExchangeExec')
413+
@pytest.mark.skipif(is_iceberg_remote_catalog(), reason="Skip for remote catalog to reduce test time")
414+
def test_insert_overwrite_static_df_api_truncate_string(spark_tmp_table_factory):
415+
"""Test static overwrite via DataFrame writeTo().overwrite() API with truncate(5, string_col)
416+
partitioning. Verifies GPU writes produce Parquet files with correct Iceberg field IDs
417+
so that file-level statistics are available for overwrite validation.
418+
"""
419+
truncate_width = 5
420+
str_length = truncate_width - 2
421+
prefix = "T" * str_length
422+
partition_col_sql = f"truncate({truncate_width}, _c6)"
423+
partition_filter = f"_c6 >= '{prefix}10' AND _c6 < '{prefix}20'"
424+
425+
table_prop = {"format-version": "2",
426+
"write.format.default": "parquet"}
427+
428+
conf = copy_and_update(iceberg_static_overwrite_conf, {
429+
"spark.sql.adaptive.enabled": "true",
430+
"spark.sql.adaptive.coalescePartitions.enabled": "true",
431+
})
432+
433+
# Use standard iceberg schema but override _c6 with a constrained string gen
434+
# to produce predictable truncate partitions for the range filter
435+
gen_list = list(zip(iceberg_base_table_cols, iceberg_gens_list))
436+
gen_list[6] = ('_c6', StringGen(pattern=f'{prefix}[1-9][0-9][A-Z]{{3}}'))
437+
438+
base_table_name = get_full_table_name(spark_tmp_table_factory)
439+
cpu_table_name = f"{base_table_name}_cpu"
440+
gpu_table_name = f"{base_table_name}_gpu"
441+
442+
def create_table_with_ctas(spark, table_name):
443+
df = gen_df(spark, gen_list, seed=INITIAL_INSERT_SEED)
444+
view_name = spark_tmp_table_factory.get()
445+
df.createOrReplaceTempView(view_name)
446+
props_sql = ", ".join([f"'{k}' = '{v}'" for k, v in table_prop.items()])
447+
spark.sql(f"CREATE TABLE {table_name} USING ICEBERG "
448+
f"PARTITIONED BY ({partition_col_sql}) "
449+
f"TBLPROPERTIES ({props_sql}) "
450+
f"AS SELECT * FROM {view_name}")
451+
452+
with_cpu_session(lambda spark: create_table_with_ctas(spark, cpu_table_name), conf=conf)
453+
with_gpu_session(lambda spark: create_table_with_ctas(spark, gpu_table_name), conf=conf)
454+
455+
def overwrite_data(spark, table_name):
456+
df = gen_df(spark, gen_list, seed=INITIAL_INSERT_SEED + 1)
457+
filtered_df = df.filter(partition_filter)
458+
filtered_df.writeTo(table_name).overwrite(F.expr(partition_filter))
459+
460+
with_cpu_session(lambda spark: overwrite_data(spark, cpu_table_name), conf=conf)
461+
with_gpu_session(lambda spark: overwrite_data(spark, gpu_table_name), conf=conf)
462+
463+
cpu_data = with_cpu_session(lambda spark: spark.table(cpu_table_name).collect())
464+
gpu_data = with_cpu_session(lambda spark: spark.table(gpu_table_name).collect())
465+
assert_equal_with_local_sort(cpu_data, gpu_data)

0 commit comments

Comments
 (0)