Skip to content

Commit d4f6346

Browse files
authored
feat(offline): support UNION ALL/DISTINCT (#3653)
1 parent 2eb802c commit d4f6346

File tree

8 files changed

+175
-13
lines changed

8 files changed

+175
-13
lines changed

hybridse/include/vm/physical_op.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,7 +1423,7 @@ class PhysicalRequestJoinNode : public PhysicalBinaryNode {
14231423
class PhysicalSetOperationNode : public PhysicalOpNode {
14241424
public:
14251425
PhysicalSetOperationNode(node::SetOperationType type, absl::Span<PhysicalOpNode *const> inputs, bool distinct)
1426-
: PhysicalOpNode(kPhysicalOpSetOperation, false), op_type_(type), distinct_(distinct) {
1426+
: PhysicalOpNode(kPhysicalOpSetOperation, false), set_type_(type), distinct_(distinct) {
14271427
for (auto n : inputs) {
14281428
AddProducer(n);
14291429
}
@@ -1435,7 +1435,7 @@ class PhysicalSetOperationNode : public PhysicalOpNode {
14351435
}
14361436
}
14371437

1438-
if (group_optimized && op_type_ == node::SetOperationType::UNION) {
1438+
if (group_optimized && set_type_ == node::SetOperationType::UNION) {
14391439
output_type_ = kSchemaTypeGroup;
14401440
} else {
14411441
output_type_ = kSchemaTypeTable;
@@ -1452,7 +1452,7 @@ class PhysicalSetOperationNode : public PhysicalOpNode {
14521452

14531453
absl::StatusOr<ColProducerTraceInfo> TraceColID(absl::string_view col_name) const override;
14541454

1455-
node::SetOperationType op_type_;
1455+
node::SetOperationType set_type_;
14561456
const bool distinct_ = false;
14571457
static PhysicalSetOperationNode *CastFrom(PhysicalOpNode *node);
14581458
};

hybridse/src/passes/physical/group_and_sort_optimized.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ bool GroupAndSortOptimized::KeysOptimizedImpl(const SchemasContext* root_schemas
729729
}
730730
vm::PhysicalSetOperationNode* opt_set = nullptr;
731731
if (!plan_ctx_
732-
->CreateOp<vm::PhysicalSetOperationNode>(&opt_set, set_op->op_type_, opt_inputs, set_op->distinct_)
732+
->CreateOp<vm::PhysicalSetOperationNode>(&opt_set, set_op->set_type_, opt_inputs, set_op->distinct_)
733733
.isOK()) {
734734
return false;
735735
}

hybridse/src/vm/physical_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,13 +1234,13 @@ Status PhysicalSetOperationNode::InitSchema(PhysicalPlanContext* ctx) {
12341234
Status PhysicalSetOperationNode::WithNewChildren(node::NodeManager* nm, const std::vector<PhysicalOpNode*>& children,
12351235
PhysicalOpNode** out) {
12361236
absl::Span<PhysicalOpNode* const> sp = absl::MakeSpan(children);
1237-
*out = nm->RegisterNode(new PhysicalSetOperationNode(op_type_, sp, distinct_));
1237+
*out = nm->RegisterNode(new PhysicalSetOperationNode(set_type_, sp, distinct_));
12381238
return Status::OK();
12391239
}
12401240

12411241
void PhysicalSetOperationNode::Print(std::ostream& output, const std::string& tab) const {
12421242
PhysicalOpNode::Print(output, tab);
1243-
output << "(" << node::SetOperatorName(op_type_, distinct_) << ")\n";
1243+
output << "(" << node::SetOperatorName(set_type_, distinct_) << ")\n";
12441244
PrintChildren(output, tab);
12451245
}
12461246

hybridse/src/vm/runner_builder.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,8 +523,13 @@ ClusterTask RunnerBuilder::Build(PhysicalOpNode* node, Status& status) {
523523
}
524524
case kPhysicalOpSetOperation: {
525525
auto set = dynamic_cast<vm::PhysicalSetOperationNode*>(node);
526+
if (set->distinct_) {
527+
status.msg = "online un-implemented: UNION DISTINCT";
528+
status.code = common::kExecutionPlanError;
529+
return fail;
530+
}
526531
auto set_runner =
527-
CreateRunner<SetOperationRunner>(id_++, node->schemas_ctx(), set->op_type_, set->distinct_);
532+
CreateRunner<SetOperationRunner>(id_++, node->schemas_ctx(), set->set_type_, set->distinct_);
528533
std::vector<ClusterTask> tasks;
529534
for (auto n : node->GetProducers()) {
530535
auto task = Build(n, status);

hybridse/src/vm/transform.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,6 @@ Status BatchModeTransformer::TransformSetOperation(const node::SetOperationPlanN
340340
PhysicalSetOperationNode** out) {
341341
CHECK_TRUE(node != nullptr && out != nullptr, kPlanError, "Input node or output node is null");
342342

343-
CHECK_TRUE(!node->distinct(), common::kPhysicalPlanError, "un-implemented: UNION DISTINCT");
344-
345343
std::vector<PhysicalOpNode*> inputs;
346344
const SchemasContext* expect_sc = nullptr;
347345
for (auto n : node->inputs()) {

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkPlanner.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ import com._4paradigm.hybridse.vm.{CoreAPI, Engine, PhysicalConstProjectNode, Ph
2323
PhysicalDataProviderNode, PhysicalFilterNode, PhysicalGroupAggrerationNode, PhysicalGroupNode, PhysicalJoinNode,
2424
PhysicalLimitNode, PhysicalLoadDataNode, PhysicalOpNode, PhysicalOpType, PhysicalProjectNode, PhysicalRenameNode,
2525
PhysicalSelectIntoNode, PhysicalSimpleProjectNode, PhysicalSortNode, PhysicalTableProjectNode,
26-
PhysicalWindowAggrerationNode, ProjectType}
26+
PhysicalWindowAggrerationNode, ProjectType, PhysicalSetOperationNode}
2727
import com._4paradigm.openmldb.batch.api.OpenmldbSession
2828
import com._4paradigm.openmldb.batch.nodes.{ConstProjectPlan, CreateTablePlan, DataProviderPlan, FilterPlan,
2929
GroupByAggregationPlan, GroupByPlan, JoinPlan, LimitPlan, LoadDataPlan, RenamePlan, RowProjectPlan, SelectIntoPlan,
30-
SimpleProjectPlan, SortByPlan, WindowAggPlan}
30+
SimpleProjectPlan, SortByPlan, WindowAggPlan, SetOperationPlan}
3131
import com._4paradigm.openmldb.batch.utils.{DataTypeUtil, ExternalUdfUtil, GraphvizUtil, HybridseUtil, NodeIndexInfo,
3232
NodeIndexType}
3333
import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor
@@ -271,6 +271,8 @@ class SparkPlanner(session: SparkSession, config: OpenmldbBatchConfig, sparkAppN
271271
SelectIntoPlan.gen(ctx, PhysicalSelectIntoNode.CastFrom(root), children.head)
272272
case PhysicalOpType.kPhysicalCreateTable =>
273273
CreateTablePlan.gen(ctx, PhysicalCreateTableNode.CastFrom(root))
274+
case PhysicalOpType.kPhysicalOpSetOperation =>
275+
SetOperationPlan.gen(ctx, PhysicalSetOperationNode.CastFrom(root), children)
274276
case _ =>
275277
throw new UnsupportedHybridSeException(s"Plan type $opType not supported")
276278
}
@@ -399,5 +401,3 @@ class SparkPlanner(session: SparkSession, config: OpenmldbBatchConfig, sparkAppN
399401
}
400402
}
401403
}
402-
403-
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright 2021 4Paradigm
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com._4paradigm.openmldb.batch.nodes
18+
19+
import com._4paradigm.openmldb.batch.PlanContext
20+
import com._4paradigm.hybridse.vm.PhysicalSetOperationNode
21+
import com._4paradigm.hybridse.node.SetOperationType
22+
import com._4paradigm.openmldb.batch.SparkInstance
23+
import org.slf4j.LoggerFactory
24+
import com._4paradigm.hybridse.sdk.HybridSeException
25+
26+
// UNION [ ALL | DISTINCT ] : YES
27+
// EXCEPT : NO
28+
// INTERSECT : NO
29+
object SetOperationPlan {
30+
private val logger = LoggerFactory.getLogger(this.getClass)
31+
32+
def gen(
33+
ctx: PlanContext,
34+
node: PhysicalSetOperationNode,
35+
inputs: Array[SparkInstance]
36+
): SparkInstance = {
37+
val setType = node.getSet_type_()
38+
if (setType != SetOperationType.UNION) {
39+
throw new HybridSeException(s"Set Operation type $setType not supported")
40+
}
41+
42+
if (inputs.size < 2) {
43+
throw new HybridSeException(s"Set Operation requires input size >= 2")
44+
}
45+
46+
val unionAll = inputs
47+
.map(inst => inst.getDf())
48+
.reduceLeft({ (acc, df) =>
49+
{
50+
acc.union(df)
51+
}
52+
})
53+
54+
val outputDf = if (node.getDistinct_()) {
55+
unionAll.distinct()
56+
} else {
57+
unionAll
58+
}
59+
60+
SparkInstance.createConsideringIndex(ctx, node.GetNodeId(), outputDf)
61+
}
62+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright 2021 4Paradigm
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com._4paradigm.openmldb.batch
18+
19+
import com._4paradigm.openmldb.batch.api.OpenmldbSession
20+
import org.apache.spark.sql.Row
21+
import org.apache.spark.sql.types.{
22+
IntegerType,
23+
StringType,
24+
StructField,
25+
StructType
26+
}
27+
import com._4paradigm.openmldb.batch.utils.SparkUtil
28+
29+
class TestSetOperation extends SparkTestSuite {
30+
31+
test("Test UNION ALL") {
32+
val spark = getSparkSession
33+
val sess = new OpenmldbSession(spark)
34+
35+
val schema = StructType(
36+
List(StructField("id", IntegerType), StructField("user", StringType))
37+
)
38+
val data1 = Seq(Row(1, "tom"), Row(2, "amy"))
39+
val df1 = spark.createDataFrame(spark.sparkContext.makeRDD(data1), schema)
40+
val data2 = Seq(Row(1, "tom"))
41+
val df2 = spark.createDataFrame(spark.sparkContext.makeRDD(data2), schema)
42+
43+
sess.registerTable("t1", df1)
44+
sess.registerTable("t2", df2)
45+
df1.createOrReplaceTempView("t1")
46+
df2.createOrReplaceTempView("t2")
47+
48+
val sqlText = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
49+
50+
val outputDf = sess.sql(sqlText)
51+
outputDf.show()
52+
val sparksqlOutputDf = sess.sparksql(sqlText)
53+
sparksqlOutputDf.show()
54+
assert(outputDf.getSparkDf().count() == 3)
55+
assert(
56+
SparkUtil.approximateDfEqual(
57+
outputDf.getSparkDf(),
58+
sparksqlOutputDf,
59+
true
60+
)
61+
)
62+
}
63+
64+
test("Test UNION DISTINCT") {
65+
val spark = getSparkSession
66+
val sess = new OpenmldbSession(spark)
67+
68+
val schema = StructType(
69+
List(StructField("id", IntegerType), StructField("user", StringType))
70+
)
71+
val data1 = Seq(Row(1, "tom"), Row(2, "amy"))
72+
val df1 = spark.createDataFrame(spark.sparkContext.makeRDD(data1), schema)
73+
val data2 = Seq(Row(1, "tom"))
74+
val df2 = spark.createDataFrame(spark.sparkContext.makeRDD(data2), schema)
75+
76+
sess.registerTable("t1", df1)
77+
sess.registerTable("t2", df2)
78+
df1.createOrReplaceTempView("t1")
79+
df2.createOrReplaceTempView("t2")
80+
81+
val sqlText = "SELECT * FROM t1 UNION DISTINCT SELECT * FROM t2"
82+
83+
val outputDf = sess.sql(sqlText)
84+
outputDf.show()
85+
val sparksqlOutputDf = sess.sparksql(sqlText)
86+
sparksqlOutputDf.show()
87+
assert(outputDf.getSparkDf().count() == 2)
88+
assert(
89+
SparkUtil.approximateDfEqual(
90+
outputDf.getSparkDf(),
91+
sparksqlOutputDf,
92+
true
93+
)
94+
)
95+
}
96+
97+
}

0 commit comments

Comments
 (0)