Skip to content

Commit 252f3e2

Browse files
committed
feat: remove output_schema in favor of explicit project with casts
1 parent 0cdcf10 commit 252f3e2

File tree

6 files changed

+66
-109
lines changed

6 files changed

+66
-109
lines changed

cpp-ch/local-engine/Parser/SerializedPlanParser.cpp

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -121,72 +121,6 @@ void SerializedPlanParser::adjustOutput(const DB::QueryPlanPtr & query_plan, con
121121
aliases.emplace_back(DB::NameWithAlias(input_iter->name, *output_name));
122122
});
123123
}
124-
125-
// fixes: issue-1874, to keep the nullability as expected.
126-
const auto & output_schema = root_rel.root().output_schema();
127-
if (output_schema.types_size())
128-
{
129-
const auto & origin_header = *query_plan->getCurrentHeader();
130-
const auto & origin_columns = origin_header.getColumnsWithTypeAndName();
131-
132-
if (static_cast<size_t>(output_schema.types_size()) != origin_columns.size())
133-
{
134-
debug::dumpPlan(*query_plan, "clickhouse plan", true);
135-
debug::dumpMessage(plan, "substrait::Plan", true);
136-
throw Exception(
137-
ErrorCodes::LOGICAL_ERROR,
138-
"Missmatch result columns size. plan column size {}, subtrait plan output schema size {}, subtrait plan name size {}.",
139-
origin_columns.size(),
140-
output_schema.types_size(),
141-
root_rel.root().names_size());
142-
}
143-
144-
bool need_final_project = false;
145-
ColumnsWithTypeAndName final_columns;
146-
for (int i = 0; i < output_schema.types_size(); ++i)
147-
{
148-
const auto & origin_column = origin_columns[i];
149-
const auto & origin_type = origin_column.type;
150-
auto final_type = TypeParser::parseType(output_schema.types(i));
151-
152-
/// Intermediate aggregate data is special, no check here.
153-
if (typeid_cast<const DataTypeAggregateFunction *>(origin_column.type.get()) || origin_type->equals(*final_type))
154-
final_columns.push_back(origin_column);
155-
else
156-
{
157-
need_final_project = true;
158-
if (origin_column.column && isColumnConst(*origin_column.column))
159-
{
160-
/// For const column, we need to cast it individually. Otherwise, the const column will be converted to full column in
161-
/// ActionsDAG::makeConvertingActions.
162-
/// Note: creating fianl_column with Field of origin_column will cause Exception in some case.
163-
const DB::ContextPtr context = DB::CurrentThread::get().getQueryContext();
164-
const FunctionOverloadResolverPtr & cast_resolver = FunctionFactory::instance().get("CAST", context);
165-
const DataTypePtr string_type = std::make_shared<DataTypeString>();
166-
ColumnWithTypeAndName to_type_column = {string_type->createColumnConst(1, final_type->getName()), string_type, "__cast_const__"};
167-
FunctionBasePtr cast_function = cast_resolver->build({origin_column, to_type_column});
168-
ColumnPtr const_col = ColumnConst::create(cast_function->execute({origin_column, to_type_column}, final_type, 1, false), 1);
169-
ColumnWithTypeAndName final_column(const_col, final_type, origin_column.name);
170-
final_columns.emplace_back(std::move(final_column));
171-
}
172-
else
173-
{
174-
ColumnWithTypeAndName final_column(final_type->createColumn(), final_type, origin_column.name);
175-
final_columns.emplace_back(std::move(final_column));
176-
}
177-
}
178-
}
179-
180-
if (need_final_project)
181-
{
182-
ActionsDAG final_project
183-
= ActionsDAG::makeConvertingActions(origin_columns, final_columns, ActionsDAG::MatchColumnsMode::Position, true);
184-
QueryPlanStepPtr final_project_step
185-
= std::make_unique<ExpressionStep>(query_plan->getCurrentHeader(), std::move(final_project));
186-
final_project_step->setStepDescription("Project for output schema");
187-
query_plan->addStep(std::move(final_project_step));
188-
}
189-
}
190124
}
191125

192126
QueryPlanPtr SerializedPlanParser::parse(const substrait::Plan & plan)

docs/developers/SubstraitModifications.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ alternatives like `AdvancedExtension` could be considered.
2121
* Changed join type `JOIN_TYPE_SEMI` to `JOIN_TYPE_LEFT_SEMI` and `JOIN_TYPE_RIGHT_SEMI`([#408](https://github.com/apache/incubator-gluten/pull/408)).
2222
* Added `WindowRel`, added `column_name` and `window_type` in `WindowFunction`,
2323
changed `Unbounded` in `WindowFunction` into `Unbounded_Preceding` and `Unbounded_Following`, and added WindowType([#485](https://github.com/apache/incubator-gluten/pull/485)).
24-
* Added `output_schema` in RelRoot([#1901](https://github.com/apache/incubator-gluten/pull/1901)).
2524
* Added `ExpandRel`([#1361](https://github.com/apache/incubator-gluten/pull/1361)).
2625
* Added `GenerateRel`([#574](https://github.com/apache/incubator-gluten/pull/574)).
2726
* Added `PartitionColumn` in `LocalFiles`([#2405](https://github.com/apache/incubator-gluten/pull/2405)).

gluten-substrait/src/main/java/org/apache/gluten/substrait/plan/PlanBuilder.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import org.apache.gluten.substrait.extensions.ExtensionBuilder;
2222
import org.apache.gluten.substrait.extensions.FunctionMappingNode;
2323
import org.apache.gluten.substrait.rel.RelNode;
24-
import org.apache.gluten.substrait.type.TypeNode;
2524

2625
import com.google.common.base.Preconditions;
2726

@@ -39,21 +38,19 @@ public static PlanNode makePlan(
3938
List<FunctionMappingNode> mappingNodes,
4039
List<RelNode> relNodes,
4140
List<String> outNames,
42-
TypeNode outputSchema,
4341
AdvancedExtensionNode extension) {
44-
return new PlanNode(mappingNodes, relNodes, outNames, outputSchema, extension);
42+
return new PlanNode(mappingNodes, relNodes, outNames, extension);
4543
}
4644

4745
public static PlanNode makePlan(
4846
SubstraitContext subCtx, List<RelNode> relNodes, List<String> outNames) {
49-
return makePlan(subCtx, relNodes, outNames, null, null);
47+
return makePlan(subCtx, relNodes, outNames, null);
5048
}
5149

5250
public static PlanNode makePlan(
5351
SubstraitContext subCtx,
5452
List<RelNode> relNodes,
5553
List<String> outNames,
56-
TypeNode outputSchema,
5754
AdvancedExtensionNode extension) {
5855
Preconditions.checkNotNull(
5956
subCtx, "Cannot execute doTransform due to the SubstraitContext is null.");
@@ -64,7 +61,7 @@ public static PlanNode makePlan(
6461
ExtensionBuilder.makeFunctionMapping(entry.getKey(), entry.getValue());
6562
mappingNodes.add(mappingNode);
6663
}
67-
return makePlan(mappingNodes, relNodes, outNames, outputSchema, extension);
64+
return makePlan(mappingNodes, relNodes, outNames, extension);
6865
}
6966

7067
public static PlanNode makePlan(SubstraitContext subCtx, ArrayList<RelNode> relNodes) {

gluten-substrait/src/main/java/org/apache/gluten/substrait/plan/PlanNode.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.apache.gluten.substrait.extensions.AdvancedExtensionNode;
2020
import org.apache.gluten.substrait.extensions.FunctionMappingNode;
2121
import org.apache.gluten.substrait.rel.RelNode;
22-
import org.apache.gluten.substrait.type.TypeNode;
2322

2423
import io.substrait.proto.Plan;
2524
import io.substrait.proto.PlanRel;
@@ -33,19 +32,16 @@ public class PlanNode implements Serializable {
3332
private final List<RelNode> relNodes;
3433
private final List<String> outNames;
3534

36-
private TypeNode outputSchema = null;
3735
private AdvancedExtensionNode extension = null;
3836

3937
PlanNode(
4038
List<FunctionMappingNode> mappingNodes,
4139
List<RelNode> relNodes,
4240
List<String> outNames,
43-
TypeNode outputSchema,
4441
AdvancedExtensionNode extension) {
4542
this.mappingNodes = mappingNodes;
4643
this.relNodes = relNodes;
4744
this.outNames = outNames;
48-
this.outputSchema = outputSchema;
4945
this.extension = extension;
5046
}
5147

@@ -64,9 +60,6 @@ public Plan toProtobuf() {
6460
for (String name : outNames) {
6561
relRootBuilder.addNames(name);
6662
}
67-
if (outputSchema != null) {
68-
relRootBuilder.setOutputSchema(outputSchema.toProtobuf().getStruct());
69-
}
7063
planRelBuilder.setRoot(relRootBuilder.build());
7164

7265
planBuilder.addRelations(planRelBuilder.build());

gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,6 @@ message RelRoot {
525525
Rel input = 1;
526526
// Field names in depth-first order
527527
repeated string names = 2;
528-
Type.Struct output_schema = 3;
529528
}
530529

531530
// A relation (used internally in a plan)

gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala

Lines changed: 63 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ import org.apache.gluten.config.GlutenConfig
2121
import org.apache.gluten.expression._
2222
import org.apache.gluten.extension.columnar.transition.Convention
2323
import org.apache.gluten.metrics.{GlutenTimeMetric, MetricsUpdater}
24-
import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
2524
import org.apache.gluten.substrait.SubstraitContext
25+
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode}
2626
import org.apache.gluten.substrait.plan.{PlanBuilder, PlanNode}
27-
import org.apache.gluten.substrait.rel.{LocalFilesNode, RelNode, SplitInfo}
27+
import org.apache.gluten.substrait.rel.{LocalFilesNode, RelBuilder, RelNode, SplitInfo}
2828
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
2929
import org.apache.gluten.utils.SubstraitPlanPrinterUtil
3030

@@ -172,24 +172,53 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
172172
@transient
173173
private var wholeStageTransformerContext: Option[WholeStageTransformContext] = None
174174

175-
private var outputSchemaForPlan: Option[TypeNode] = None
175+
private var expectedOutputForPlan: Option[Seq[Attribute]] = None
176176

177-
private def inferSchemaFromAttributes(attrs: Seq[Attribute]): TypeNode = {
178-
val outputTypeNodeList = new java.util.ArrayList[TypeNode]()
179-
for (attr <- attrs) {
180-
outputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
177+
def setOutputSchemaForPlan(expectOutput: Seq[Attribute]): Unit = {
178+
if (expectedOutputForPlan.isDefined) {
179+
return
181180
}
182181

183-
TypeBuilder.makeStruct(false, outputTypeNodeList)
182+
// Fixes issue-1874: store expected output attributes for generating a ProjectRel with casts.
183+
expectedOutputForPlan = Some(expectOutput)
184184
}
185185

186-
def setOutputSchemaForPlan(expectOutput: Seq[Attribute]): Unit = {
187-
if (outputSchemaForPlan.isDefined) {
188-
return
186+
/**
187+
* Creates a ProjectRel that casts each input column to the expected output type. This is used to
188+
* enforce nullability and type constraints when the child plan's output may not match the
189+
* expected schema (e.g., in union operations). Returns the input unchanged if no casts are
190+
* needed.
191+
*/
192+
private def createOutputCastProjectRel(
193+
input: RelNode,
194+
inputAttrs: Seq[Attribute],
195+
expectedAttrs: Seq[Attribute],
196+
substraitContext: SubstraitContext): RelNode = {
197+
val castExpressions = new java.util.ArrayList[ExpressionNode]()
198+
var needsCast = false
199+
for (i <- inputAttrs.indices) {
200+
val inputAttr = inputAttrs(i)
201+
val expectedAttr = expectedAttrs(i)
202+
val fieldRef = ExpressionBuilder.makeSelection(i)
203+
// If types differ (including nullability), add a cast; otherwise pass through.
204+
if (
205+
inputAttr.dataType != expectedAttr.dataType ||
206+
inputAttr.nullable != expectedAttr.nullable
207+
) {
208+
val targetType = ConverterUtils.getTypeNode(expectedAttr.dataType, expectedAttr.nullable)
209+
castExpressions.add(ExpressionBuilder.makeCast(targetType, fieldRef, false))
210+
needsCast = true
211+
} else {
212+
castExpressions.add(fieldRef)
213+
}
214+
}
215+
// Only create a ProjectRel if casts are actually needed.
216+
if (needsCast) {
217+
// Use emitStartIndex = 0 to emit only the projected expressions (not input + expressions).
218+
RelBuilder.makeProjectRel(input, castExpressions, substraitContext, -1L, 0)
219+
} else {
220+
input
189221
}
190-
191-
// Fixes issue-1874
192-
outputSchemaForPlan = Some(inferSchemaFromAttributes(expectOutput))
193222
}
194223

195224
def substraitPlan: PlanNode = {
@@ -241,21 +270,27 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
241270
throw new IllegalStateException(s"WholeStageTransformer can't do Transform on $child")
242271
}
243272

244-
val outNames = childCtx.outputAttributes.map(ConverterUtils.genColumnNameWithExprId).asJava
245-
246-
val planNode = if (BackendsApiManager.getSettings.needOutputSchemaForPlan()) {
247-
val outputSchema =
248-
outputSchemaForPlan.getOrElse(inferSchemaFromAttributes(childCtx.outputAttributes))
273+
val (finalRoot, finalOutputAttrs) =
274+
if (BackendsApiManager.getSettings.needOutputSchemaForPlan()) {
275+
// If expected output schema differs from child's output, wrap in a ProjectRel with casts.
276+
// This fixes issue-1874 by explicitly converting types (including nullability) in the plan.
277+
expectedOutputForPlan match {
278+
case Some(expectedAttrs) =>
279+
val projectRel = createOutputCastProjectRel(
280+
childCtx.root,
281+
childCtx.outputAttributes,
282+
expectedAttrs,
283+
substraitContext)
284+
(projectRel, expectedAttrs)
285+
case None =>
286+
(childCtx.root, childCtx.outputAttributes)
287+
}
288+
} else {
289+
(childCtx.root, childCtx.outputAttributes)
290+
}
249291

250-
PlanBuilder.makePlan(
251-
substraitContext,
252-
Lists.newArrayList(childCtx.root),
253-
outNames,
254-
outputSchema,
255-
null)
256-
} else {
257-
PlanBuilder.makePlan(substraitContext, Lists.newArrayList(childCtx.root), outNames)
258-
}
292+
val outNames = finalOutputAttrs.map(ConverterUtils.genColumnNameWithExprId).asJava
293+
val planNode = PlanBuilder.makePlan(substraitContext, Lists.newArrayList(finalRoot), outNames)
259294

260295
WholeStageTransformContext(planNode, substraitContext, isCudf)
261296
}

0 commit comments

Comments
 (0)