Skip to content

Commit 19716c5

Browse files
committed
feat: remove output_schema in favor of explicit project with casts
1 parent 2ad67a5 commit 19716c5

File tree

5 files changed

+58
-108
lines changed

5 files changed

+58
-108
lines changed

cpp-ch/local-engine/Parser/SerializedPlanParser.cpp

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -121,72 +121,6 @@ void SerializedPlanParser::adjustOutput(const DB::QueryPlanPtr & query_plan, con
121121
aliases.emplace_back(DB::NameWithAlias(input_iter->name, *output_name));
122122
});
123123
}
124-
125-
// fixes: issue-1874, to keep the nullability as expected.
126-
const auto & output_schema = root_rel.root().output_schema();
127-
if (output_schema.types_size())
128-
{
129-
const auto & origin_header = *query_plan->getCurrentHeader();
130-
const auto & origin_columns = origin_header.getColumnsWithTypeAndName();
131-
132-
if (static_cast<size_t>(output_schema.types_size()) != origin_columns.size())
133-
{
134-
debug::dumpPlan(*query_plan, "clickhouse plan", true);
135-
debug::dumpMessage(plan, "substrait::Plan", true);
136-
throw Exception(
137-
ErrorCodes::LOGICAL_ERROR,
138-
"Missmatch result columns size. plan column size {}, subtrait plan output schema size {}, subtrait plan name size {}.",
139-
origin_columns.size(),
140-
output_schema.types_size(),
141-
root_rel.root().names_size());
142-
}
143-
144-
bool need_final_project = false;
145-
ColumnsWithTypeAndName final_columns;
146-
for (int i = 0; i < output_schema.types_size(); ++i)
147-
{
148-
const auto & origin_column = origin_columns[i];
149-
const auto & origin_type = origin_column.type;
150-
auto final_type = TypeParser::parseType(output_schema.types(i));
151-
152-
/// Intermediate aggregate data is special, no check here.
153-
if (typeid_cast<const DataTypeAggregateFunction *>(origin_column.type.get()) || origin_type->equals(*final_type))
154-
final_columns.push_back(origin_column);
155-
else
156-
{
157-
need_final_project = true;
158-
if (origin_column.column && isColumnConst(*origin_column.column))
159-
{
160-
/// For const column, we need to cast it individually. Otherwise, the const column will be converted to full column in
161-
/// ActionsDAG::makeConvertingActions.
162-
/// Note: creating fianl_column with Field of origin_column will cause Exception in some case.
163-
const DB::ContextPtr context = DB::CurrentThread::get().getQueryContext();
164-
const FunctionOverloadResolverPtr & cast_resolver = FunctionFactory::instance().get("CAST", context);
165-
const DataTypePtr string_type = std::make_shared<DataTypeString>();
166-
ColumnWithTypeAndName to_type_column = {string_type->createColumnConst(1, final_type->getName()), string_type, "__cast_const__"};
167-
FunctionBasePtr cast_function = cast_resolver->build({origin_column, to_type_column});
168-
ColumnPtr const_col = ColumnConst::create(cast_function->execute({origin_column, to_type_column}, final_type, 1, false), 1);
169-
ColumnWithTypeAndName final_column(const_col, final_type, origin_column.name);
170-
final_columns.emplace_back(std::move(final_column));
171-
}
172-
else
173-
{
174-
ColumnWithTypeAndName final_column(final_type->createColumn(), final_type, origin_column.name);
175-
final_columns.emplace_back(std::move(final_column));
176-
}
177-
}
178-
}
179-
180-
if (need_final_project)
181-
{
182-
ActionsDAG final_project
183-
= ActionsDAG::makeConvertingActions(origin_columns, final_columns, ActionsDAG::MatchColumnsMode::Position, true);
184-
QueryPlanStepPtr final_project_step
185-
= std::make_unique<ExpressionStep>(query_plan->getCurrentHeader(), std::move(final_project));
186-
final_project_step->setStepDescription("Project for output schema");
187-
query_plan->addStep(std::move(final_project_step));
188-
}
189-
}
190124
}
191125

192126
QueryPlanPtr SerializedPlanParser::parse(const substrait::Plan & plan)

gluten-substrait/src/main/java/org/apache/gluten/substrait/plan/PlanBuilder.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import org.apache.gluten.substrait.extensions.ExtensionBuilder;
2222
import org.apache.gluten.substrait.extensions.FunctionMappingNode;
2323
import org.apache.gluten.substrait.rel.RelNode;
24-
import org.apache.gluten.substrait.type.TypeNode;
2524

2625
import com.google.common.base.Preconditions;
2726

@@ -39,21 +38,19 @@ public static PlanNode makePlan(
3938
List<FunctionMappingNode> mappingNodes,
4039
List<RelNode> relNodes,
4140
List<String> outNames,
42-
TypeNode outputSchema,
4341
AdvancedExtensionNode extension) {
44-
return new PlanNode(mappingNodes, relNodes, outNames, outputSchema, extension);
42+
return new PlanNode(mappingNodes, relNodes, outNames, extension);
4543
}
4644

4745
public static PlanNode makePlan(
4846
SubstraitContext subCtx, List<RelNode> relNodes, List<String> outNames) {
49-
return makePlan(subCtx, relNodes, outNames, null, null);
47+
return makePlan(subCtx, relNodes, outNames, null);
5048
}
5149

5250
public static PlanNode makePlan(
5351
SubstraitContext subCtx,
5452
List<RelNode> relNodes,
5553
List<String> outNames,
56-
TypeNode outputSchema,
5754
AdvancedExtensionNode extension) {
5855
Preconditions.checkNotNull(
5956
subCtx, "Cannot execute doTransform due to the SubstraitContext is null.");
@@ -64,7 +61,7 @@ public static PlanNode makePlan(
6461
ExtensionBuilder.makeFunctionMapping(entry.getKey(), entry.getValue());
6562
mappingNodes.add(mappingNode);
6663
}
67-
return makePlan(mappingNodes, relNodes, outNames, outputSchema, extension);
64+
return makePlan(mappingNodes, relNodes, outNames, extension);
6865
}
6966

7067
public static PlanNode makePlan(SubstraitContext subCtx, ArrayList<RelNode> relNodes) {

gluten-substrait/src/main/java/org/apache/gluten/substrait/plan/PlanNode.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.apache.gluten.substrait.extensions.AdvancedExtensionNode;
2020
import org.apache.gluten.substrait.extensions.FunctionMappingNode;
2121
import org.apache.gluten.substrait.rel.RelNode;
22-
import org.apache.gluten.substrait.type.TypeNode;
2322

2423
import io.substrait.proto.Plan;
2524
import io.substrait.proto.PlanRel;
@@ -33,19 +32,16 @@ public class PlanNode implements Serializable {
3332
private final List<RelNode> relNodes;
3433
private final List<String> outNames;
3534

36-
private TypeNode outputSchema = null;
3735
private AdvancedExtensionNode extension = null;
3836

3937
PlanNode(
4038
List<FunctionMappingNode> mappingNodes,
4139
List<RelNode> relNodes,
4240
List<String> outNames,
43-
TypeNode outputSchema,
4441
AdvancedExtensionNode extension) {
4542
this.mappingNodes = mappingNodes;
4643
this.relNodes = relNodes;
4744
this.outNames = outNames;
48-
this.outputSchema = outputSchema;
4945
this.extension = extension;
5046
}
5147

@@ -64,9 +60,6 @@ public Plan toProtobuf() {
6460
for (String name : outNames) {
6561
relRootBuilder.addNames(name);
6662
}
67-
if (outputSchema != null) {
68-
relRootBuilder.setOutputSchema(outputSchema.toProtobuf().getStruct());
69-
}
7063
planRelBuilder.setRoot(relRootBuilder.build());
7164

7265
planBuilder.addRelations(planRelBuilder.build());

gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,6 @@ message RelRoot {
525525
Rel input = 1;
526526
// Field names in depth-first order
527527
repeated string names = 2;
528-
Type.Struct output_schema = 3;
529528
}
530529

531530
// A relation (used internally in a plan)

gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ import org.apache.gluten.config.GlutenConfig
2121
import org.apache.gluten.expression._
2222
import org.apache.gluten.extension.columnar.transition.Convention
2323
import org.apache.gluten.metrics.{GlutenTimeMetric, MetricsUpdater}
24-
import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
2524
import org.apache.gluten.substrait.SubstraitContext
25+
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode}
2626
import org.apache.gluten.substrait.plan.{PlanBuilder, PlanNode}
27-
import org.apache.gluten.substrait.rel.{LocalFilesNode, RelNode, SplitInfo}
27+
import org.apache.gluten.substrait.rel.{LocalFilesNode, RelBuilder, RelNode, SplitInfo}
2828
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
2929
import org.apache.gluten.utils.SubstraitPlanPrinterUtil
3030

@@ -172,24 +172,45 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
172172
@transient
173173
private var wholeStageTransformerContext: Option[WholeStageTransformContext] = None
174174

175-
private var outputSchemaForPlan: Option[TypeNode] = None
175+
private var expectedOutputForPlan: Option[Seq[Attribute]] = None
176176

177-
private def inferSchemaFromAttributes(attrs: Seq[Attribute]): TypeNode = {
178-
val outputTypeNodeList = new java.util.ArrayList[TypeNode]()
179-
for (attr <- attrs) {
180-
outputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
177+
def setOutputSchemaForPlan(expectOutput: Seq[Attribute]): Unit = {
178+
if (expectedOutputForPlan.isDefined) {
179+
return
181180
}
182181

183-
TypeBuilder.makeStruct(false, outputTypeNodeList)
182+
// Fixes issue-1874: store expected output attributes for generating a ProjectRel with casts.
183+
expectedOutputForPlan = Some(expectOutput)
184184
}
185185

186-
def setOutputSchemaForPlan(expectOutput: Seq[Attribute]): Unit = {
187-
if (outputSchemaForPlan.isDefined) {
188-
return
186+
/**
187+
* Creates a ProjectRel that casts each input column to the expected output type. This is used to
188+
* enforce nullability and type constraints when the child plan's output may not match the
189+
* expected schema (e.g., in union operations).
190+
*/
191+
private def createOutputCastProjectRel(
192+
input: RelNode,
193+
inputAttrs: Seq[Attribute],
194+
expectedAttrs: Seq[Attribute],
195+
substraitContext: SubstraitContext): RelNode = {
196+
val castExpressions = new java.util.ArrayList[ExpressionNode]()
197+
for (i <- inputAttrs.indices) {
198+
val inputAttr = inputAttrs(i)
199+
val expectedAttr = expectedAttrs(i)
200+
val fieldRef = ExpressionBuilder.makeSelection(i)
201+
// If types differ (including nullability), add a cast; otherwise pass through.
202+
if (
203+
inputAttr.dataType != expectedAttr.dataType ||
204+
inputAttr.nullable != expectedAttr.nullable
205+
) {
206+
val targetType = ConverterUtils.getTypeNode(expectedAttr.dataType, expectedAttr.nullable)
207+
castExpressions.add(ExpressionBuilder.makeCast(targetType, fieldRef, false))
208+
} else {
209+
castExpressions.add(fieldRef)
210+
}
189211
}
190-
191-
// Fixes issue-1874
192-
outputSchemaForPlan = Some(inferSchemaFromAttributes(expectOutput))
212+
// Use emitStartIndex = 0 to emit only the projected expressions (not input + expressions).
213+
RelBuilder.makeProjectRel(input, castExpressions, substraitContext, -1L, 0)
193214
}
194215

195216
def substraitPlan: PlanNode = {
@@ -241,21 +262,27 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
241262
throw new IllegalStateException(s"WholeStageTransformer can't do Transform on $child")
242263
}
243264

244-
val outNames = childCtx.outputAttributes.map(ConverterUtils.genColumnNameWithExprId).asJava
245-
246-
val planNode = if (BackendsApiManager.getSettings.needOutputSchemaForPlan()) {
247-
val outputSchema =
248-
outputSchemaForPlan.getOrElse(inferSchemaFromAttributes(childCtx.outputAttributes))
265+
val (finalRoot, finalOutputAttrs) =
266+
if (BackendsApiManager.getSettings.needOutputSchemaForPlan()) {
267+
// If expected output schema differs from child's output, wrap in a ProjectRel with casts.
268+
// This fixes issue-1874 by explicitly converting types (including nullability) in the plan.
269+
expectedOutputForPlan match {
270+
case Some(expectedAttrs) =>
271+
val projectRel = createOutputCastProjectRel(
272+
childCtx.root,
273+
childCtx.outputAttributes,
274+
expectedAttrs,
275+
substraitContext)
276+
(projectRel, expectedAttrs)
277+
case None =>
278+
(childCtx.root, childCtx.outputAttributes)
279+
}
280+
} else {
281+
(childCtx.root, childCtx.outputAttributes)
282+
}
249283

250-
PlanBuilder.makePlan(
251-
substraitContext,
252-
Lists.newArrayList(childCtx.root),
253-
outNames,
254-
outputSchema,
255-
null)
256-
} else {
257-
PlanBuilder.makePlan(substraitContext, Lists.newArrayList(childCtx.root), outNames)
258-
}
284+
val outNames = finalOutputAttrs.map(ConverterUtils.genColumnNameWithExprId).asJava
285+
val planNode = PlanBuilder.makePlan(substraitContext, Lists.newArrayList(finalRoot), outNames)
259286

260287
WholeStageTransformContext(planNode, substraitContext, isCudf)
261288
}

0 commit comments

Comments
 (0)