@@ -21,10 +21,10 @@ import org.apache.gluten.config.GlutenConfig
2121import org .apache .gluten .expression ._
2222import org .apache .gluten .extension .columnar .transition .Convention
2323import org .apache .gluten .metrics .{GlutenTimeMetric , MetricsUpdater }
24- import org .apache .gluten .substrait .`type` .{TypeBuilder , TypeNode }
2524import org .apache .gluten .substrait .SubstraitContext
25+ import org .apache .gluten .substrait .expression .{ExpressionBuilder , ExpressionNode }
2626import org .apache .gluten .substrait .plan .{PlanBuilder , PlanNode }
27- import org .apache .gluten .substrait .rel .{LocalFilesNode , RelNode , SplitInfo }
27+ import org .apache .gluten .substrait .rel .{LocalFilesNode , RelBuilder , RelNode , SplitInfo }
2828import org .apache .gluten .substrait .rel .LocalFilesNode .ReadFileFormat
2929import org .apache .gluten .utils .SubstraitPlanPrinterUtil
3030
@@ -172,24 +172,53 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
172172 @ transient
173173 private var wholeStageTransformerContext : Option [WholeStageTransformContext ] = None
174174
175- private var outputSchemaForPlan : Option [TypeNode ] = None
175+ private var expectedOutputForPlan : Option [Seq [ Attribute ] ] = None
176176
177- private def inferSchemaFromAttributes (attrs : Seq [Attribute ]): TypeNode = {
178- val outputTypeNodeList = new java.util.ArrayList [TypeNode ]()
179- for (attr <- attrs) {
180- outputTypeNodeList.add(ConverterUtils .getTypeNode(attr.dataType, attr.nullable))
177+ def setOutputSchemaForPlan (expectOutput : Seq [Attribute ]): Unit = {
178+ if (expectedOutputForPlan.isDefined) {
179+ return
181180 }
182181
183- TypeBuilder .makeStruct(false , outputTypeNodeList)
182+ // Fixes issue-1874: store expected output attributes for generating a ProjectRel with casts.
183+ expectedOutputForPlan = Some (expectOutput)
184184 }
185185
186- def setOutputSchemaForPlan (expectOutput : Seq [Attribute ]): Unit = {
187- if (outputSchemaForPlan.isDefined) {
188- return
186+ /**
187+ * Creates a ProjectRel that casts each input column to the expected output type. This is used to
188+ * enforce nullability and type constraints when the child plan's output may not match the
189+ * expected schema (e.g., in union operations). Returns the input unchanged if no casts are
190+ * needed.
191+ */
192+ private def createOutputCastProjectRel (
193+ input : RelNode ,
194+ inputAttrs : Seq [Attribute ],
195+ expectedAttrs : Seq [Attribute ],
196+ substraitContext : SubstraitContext ): RelNode = {
197+ val castExpressions = new java.util.ArrayList [ExpressionNode ]()
198+ var needsCast = false
199+ for (i <- inputAttrs.indices) {
200+ val inputAttr = inputAttrs(i)
201+ val expectedAttr = expectedAttrs(i)
202+ val fieldRef = ExpressionBuilder .makeSelection(i)
203+ // If types differ (including nullability), add a cast; otherwise pass through.
204+ if (
205+ inputAttr.dataType != expectedAttr.dataType ||
206+ inputAttr.nullable != expectedAttr.nullable
207+ ) {
208+ val targetType = ConverterUtils .getTypeNode(expectedAttr.dataType, expectedAttr.nullable)
209+ castExpressions.add(ExpressionBuilder .makeCast(targetType, fieldRef, false ))
210+ needsCast = true
211+ } else {
212+ castExpressions.add(fieldRef)
213+ }
214+ }
215+ // Only create a ProjectRel if casts are actually needed.
216+ if (needsCast) {
217+ // Use emitStartIndex = 0 to emit only the projected expressions (not input + expressions).
218+ RelBuilder .makeProjectRel(input, castExpressions, substraitContext, - 1L , 0 )
219+ } else {
220+ input
189221 }
190-
191- // Fixes issue-1874
192- outputSchemaForPlan = Some (inferSchemaFromAttributes(expectOutput))
193222 }
194223
195224 def substraitPlan : PlanNode = {
@@ -241,21 +270,27 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
241270 throw new IllegalStateException (s " WholeStageTransformer can't do Transform on $child" )
242271 }
243272
244- val outNames = childCtx.outputAttributes.map(ConverterUtils .genColumnNameWithExprId).asJava
245-
246- val planNode = if (BackendsApiManager .getSettings.needOutputSchemaForPlan()) {
247- val outputSchema =
248- outputSchemaForPlan.getOrElse(inferSchemaFromAttributes(childCtx.outputAttributes))
273+ val (finalRoot, finalOutputAttrs) =
274+ if (BackendsApiManager .getSettings.needOutputSchemaForPlan()) {
275+ // If expected output schema differs from child's output, wrap in a ProjectRel with casts.
276+ // This fixes issue-1874 by explicitly converting types (including nullability) in the plan.
277+ expectedOutputForPlan match {
278+ case Some (expectedAttrs) =>
279+ val projectRel = createOutputCastProjectRel(
280+ childCtx.root,
281+ childCtx.outputAttributes,
282+ expectedAttrs,
283+ substraitContext)
284+ (projectRel, expectedAttrs)
285+ case None =>
286+ (childCtx.root, childCtx.outputAttributes)
287+ }
288+ } else {
289+ (childCtx.root, childCtx.outputAttributes)
290+ }
249291
250- PlanBuilder .makePlan(
251- substraitContext,
252- Lists .newArrayList(childCtx.root),
253- outNames,
254- outputSchema,
255- null )
256- } else {
257- PlanBuilder .makePlan(substraitContext, Lists .newArrayList(childCtx.root), outNames)
258- }
292+ val outNames = finalOutputAttrs.map(ConverterUtils .genColumnNameWithExprId).asJava
293+ val planNode = PlanBuilder .makePlan(substraitContext, Lists .newArrayList(finalRoot), outNames)
259294
260295 WholeStageTransformContext (planNode, substraitContext, isCudf)
261296 }
0 commit comments