@@ -23,6 +23,10 @@ import org.apache.hudi.HoodieSparkUtils
2323import org .apache .hudi .common .testutils .HoodieTestUtils
2424import org .apache .hudi .common .util .StringUtils
2525
26+ import org .apache .hadoop .fs .{FileSystem , Path => HadoopPath }
27+ import org .apache .parquet .hadoop .ParquetFileReader
28+ import org .apache .parquet .hadoop .util .HadoopInputFile
29+ import org .apache .parquet .schema .{GroupType , MessageType , Type }
2630import org .apache .spark .sql .hudi .common .HoodieSparkSqlTestBase
2731
2832
@@ -171,4 +175,166 @@ class TestVariantDataType extends HoodieSparkSqlTestBase {
171175
172176 spark.sql(s " drop table $tableName" )
173177 }
178+
179+ test(" Test Shredded Variant Write and Read + Validate Parquet Schema after Write" ) {
180+ assume(HoodieSparkUtils .gteqSpark4_0, " Variant type requires Spark 4.0 or higher" )
181+
182+ // Test 1: Shredding enabled with forced schema → parquet should have typed_value
183+ withRecordType()(withTempDir { tmp =>
184+ val tableName = generateTableName
185+ spark.sql(
186+ s """
187+ |create table $tableName (
188+ | id int,
189+ | name string,
190+ | v variant,
191+ | ts long
192+ |) using hudi
193+ | location ' ${tmp.getCanonicalPath}'
194+ | tblproperties (
195+ | primaryKey = 'id',
196+ | type = 'cow',
197+ | preCombineField = 'ts'
198+ | )
199+ """ .stripMargin)
200+
201+ spark.sql(" set hoodie.parquet.variant.write.shredding.enabled = true" )
202+ spark.sql(" set hoodie.parquet.variant.allow.reading.shredded = true" )
203+ spark.sql(" set hoodie.parquet.variant.force.shredding.schema.for.test = a int, b string" )
204+
205+ spark.sql(
206+ s """
207+ |insert into $tableName values
208+ | (1, 'row1', parse_json('{"a": 1, "b": "hello"}'), 1000)
209+ """ .stripMargin)
210+ checkAnswer(s " select id, name, cast(v as string), ts from $tableName order by id " )(
211+ Seq (1 , " row1" , " {\" a\" :1,\" b\" :\" hello\" }" , 1000 )
212+ )
213+
214+ // Verify parquet schema has shredded structure with typed_value
215+ val parquetFiles = listDataParquetFiles(tmp.getCanonicalPath)
216+ assert(parquetFiles.nonEmpty, " Should have at least one data parquet file" )
217+
218+ parquetFiles.foreach { filePath =>
219+ val schema = readParquetSchema(filePath)
220+ val variantGroup = getFieldAsGroup(schema, " v" )
221+ assert(groupContainsField(variantGroup, " typed_value" ),
222+ s " Shredded variant should have typed_value field. Schema: \n $variantGroup" )
223+ val valueField = variantGroup.getType(variantGroup.getFieldIndex(" value" ))
224+ assert(valueField.getRepetition == Type .Repetition .OPTIONAL ,
225+ " Shredded variant value field should be OPTIONAL" )
226+ val metadataField = variantGroup.getType(variantGroup.getFieldIndex(" metadata" ))
227+ assert(metadataField.getRepetition == Type .Repetition .REQUIRED ,
228+ " Shredded variant metadata field should be REQUIRED" )
229+ }
230+ })
231+ }
232+
233+ test(" Test Unshredded Variant Write and Read + Validate Parquet Schema after Write" ) {
234+ assume(HoodieSparkUtils .gteqSpark4_0, " Variant type requires Spark 4.0 or higher" )
235+ // Shredding disabled parquet should NOT have typed_value
236+ withRecordType()(withTempDir { tmp =>
237+ val tableName = generateTableName
238+ spark.sql(
239+ s """
240+ |create table $tableName (
241+ | id int,
242+ | name string,
243+ | v variant,
244+ | ts long
245+ |) using hudi
246+ | location ' ${tmp.getCanonicalPath}'
247+ | tblproperties (
248+ | primaryKey = 'id',
249+ | type = 'cow',
250+ | preCombineField = 'ts'
251+ | )
252+ """ .stripMargin)
253+
254+ spark.sql(s " set hoodie.parquet.variant.write.shredding.enabled = false " )
255+
256+ spark.sql(
257+ s """
258+ |insert into $tableName values
259+ | (1, 'row1', parse_json('{"a": 1, "b": "hello"}'), 1000)
260+ """ .stripMargin)
261+
262+ checkAnswer(s " select id, name, cast(v as string), ts from $tableName order by id " )(
263+ Seq (1 , " row1" , " {\" a\" :1,\" b\" :\" hello\" }" , 1000 )
264+ )
265+
266+ // Verify parquet schema does NOT have typed_value
267+ val parquetFiles = listDataParquetFiles(tmp.getCanonicalPath)
268+ assert(parquetFiles.nonEmpty, " Should have at least one data parquet file" )
269+
270+ parquetFiles.foreach { filePath =>
271+ val schema = readParquetSchema(filePath)
272+ val variantGroup = getFieldAsGroup(schema, " v" )
273+ assert(! groupContainsField(variantGroup, " typed_value" ),
274+ s " Non-shredded variant should NOT have typed_value field. Schema: \n $variantGroup" )
275+ val valueField = variantGroup.getType(variantGroup.getFieldIndex(" value" ))
276+ assert(valueField.getRepetition == Type .Repetition .REQUIRED ,
277+ " Non-shredded variant value field should be REQUIRED" )
278+ }
279+
280+ // Verify data can still be read back for the non-shredded case
281+ checkAnswer(s " select id, name, cast(v as string), ts from $tableName order by id " )(
282+ Seq (1 , " row1" , " {\" a\" :1,\" b\" :\" hello\" }" , 1000 )
283+ )
284+ })
285+ }
286+
287+ /**
288+ * Lists data parquet files in the table directory, excluding Hudi metadata files.
289+ */
290+ private def listDataParquetFiles (tablePath : String ): Seq [String ] = {
291+ val conf = spark.sparkContext.hadoopConfiguration
292+ val fs = FileSystem .get(new HadoopPath (tablePath).toUri, conf)
293+ val iter = fs.listFiles(new HadoopPath (tablePath), true )
294+ val files = scala.collection.mutable.ArrayBuffer [String ]()
295+ while (iter.hasNext) {
296+ val file = iter.next()
297+ val path = file.getPath.toString
298+ if (path.endsWith(" .parquet" ) && ! path.contains(" .hoodie" )) {
299+ files += path
300+ }
301+ }
302+ files.toSeq
303+ }
304+
305+ /**
306+ * Reads the Parquet schema (MessageType) from a parquet file.
307+ */
308+ private def readParquetSchema (filePath : String ): MessageType = {
309+ val conf = spark.sparkContext.hadoopConfiguration
310+ val inputFile = HadoopInputFile .fromPath(new HadoopPath (filePath), conf)
311+ val reader = ParquetFileReader .open(inputFile)
312+ try {
313+ reader.getFooter.getFileMetaData.getSchema
314+ } finally {
315+ reader.close()
316+ }
317+ }
318+
319+ /**
320+ * Gets a named field from a GroupType (MessageType) and returns it as a GroupType.
321+ * Uses getFieldIndex(String) + getType(int) to avoid Scala overload resolution issues.
322+ */
323+ private def getFieldAsGroup (parent : GroupType , fieldName : String ): GroupType = {
324+ val idx : Int = parent.getFieldIndex(fieldName)
325+ parent.getType(idx).asGroupType()
326+ }
327+
328+ /**
329+ * Checks whether a GroupType contains a field with the given name.
330+ * Uses try/catch on getFieldIndex to avoid Scala-Java collection converter dependencies.
331+ */
332+ private def groupContainsField (group : GroupType , fieldName : String ): Boolean = {
333+ try {
334+ group.getFieldIndex(fieldName)
335+ true
336+ } catch {
337+ case _ : Exception => false
338+ }
339+ }
174340}
0 commit comments