diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala index 1798d32d8a2c9..fefe54c7bae96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala @@ -75,4 +75,110 @@ final class StaxXmlGeneratorSuite extends SharedSparkSession { assert(df.collect().toSeq === newDf.collect().toSeq) } + // SPARK-45414: Test for string tag content misplacement issue (found in spark-xml library) + // with mixed column types + test("SPARK-45414: write mixed types with string columns between and after nested types") { + import org.apache.spark.sql.Row + + // Create a schema with mixed types: struct, array, and string columns + // This reproduces the scenario from SPARK-45414 where string content gets misplaced + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField( + "metadata", + StructType(Seq(StructField("version", StringType), StructField("timestamp", LongType))), + nullable = true), + StructField("description", StringType, nullable = true), // String between nested types + StructField("tags", ArrayType(StringType), nullable = true), + StructField("color", StringType, nullable = true), // String at the end + StructField("numbers", ArrayType(IntegerType), nullable = true))) + + val data = Seq( + Row(1, Row("v1.0", 1000L), "MyDescription", Array("tag1", "tag2"), "Red", Array(1, 2, 3)), + Row(2, Row("v2.0", 2000L), "AnotherDescription", Array("tag3"), "Blue", Array(4, 5))) + + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + + // Write to XML + val targetFile = + Files.createTempDirectory("StaxXmlGeneratorSuite").resolve("mixed-types.xml").toString + df.write.option("rowTag", "item").xml(targetFile) + + // Read back and verify the content is in correct XML tags + val readDf = spark.read.option("rowTag", "item").schema(schema).xml(targetFile) + val results = readDf.collect() + + // Verify structure is preserved + assert(results.length === 2) + + // Verify first row - ensure no data misplacement + assert(results(0).getAs[Int]("id") === 1) + val metadata1 = results(0).getAs[Row]("metadata") + assert(metadata1.getAs[String]("version") === "v1.0") + assert(metadata1.getAs[Long]("timestamp") === 1000L) + // Critical: ensure "MyDescription" is in description field, not in tags or color + assert(results(0).getAs[String]("description") === "MyDescription") + assert(results(0).getAs[Seq[String]]("tags") === Seq("tag1", "tag2")) + // Critical: ensure "Red" is in color field, not misplaced + assert(results(0).getAs[String]("color") === "Red") + assert(results(0).getAs[Seq[Int]]("numbers") === Seq(1, 2, 3)) + + // Verify second row + assert(results(1).getAs[Int]("id") === 2) + val metadata2 = results(1).getAs[Row]("metadata") + assert(metadata2.getAs[String]("version") === "v2.0") + assert(metadata2.getAs[Long]("timestamp") === 2000L) + assert(results(1).getAs[String]("description") === "AnotherDescription") + assert(results(1).getAs[Seq[String]]("tags") === Seq("tag3")) + assert(results(1).getAs[String]("color") === "Blue") + assert(results(1).getAs[Seq[Int]]("numbers") === Seq(4, 5)) + } + + // SPARK-45414: Test with attributes mixed with elements + test("SPARK-45414: write mixed types with attributes and string elements") { + import org.apache.spark.sql.Row + + // Schema with attributes (using _ prefix) and string elements + val schema = StructType( + Seq( + StructField("_id", IntegerType, nullable = false), // attribute + StructField( + "nested", + StructType( + Seq( + StructField("_attr1", StringType), // attribute + StructField("value", StringType))), + nullable = true), + StructField("description", StringType, nullable = true), // element + StructField("items", ArrayType(IntegerType), nullable = true), + StructField("name", StringType, nullable = true) // element at end + )) + + val data = Seq(Row(100, Row("attrValue", "nestedValue"), "DescText", Array(1, 2), "ItemName")) + + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + + val targetFile = + Files.createTempDirectory("StaxXmlGeneratorSuite").resolve("mixed-attrs.xml").toString + df.write.option("rowTag", "record").option("attributePrefix", "_").xml(targetFile) + + val readDf = spark.read + .option("rowTag", "record") + .option("attributePrefix", "_") + .schema(schema) + .xml(targetFile) + val results = readDf.collect() + + assert(results.length === 1) + assert(results(0).getAs[Int]("_id") === 100) + val nested = results(0).getAs[Row]("nested") + assert(nested.getAs[String]("_attr1") === "attrValue") + assert(nested.getAs[String]("value") === "nestedValue") + // Critical: ensure string elements are not misplaced + assert(results(0).getAs[String]("description") === "DescText") + assert(results(0).getAs[Seq[Int]]("items") === Seq(1, 2)) + assert(results(0).getAs[String]("name") === "ItemName") + } + }