Skip to content

Commit 7a20948

Browse files
jagillfacebook-github-bot
authored andcommitted
feat: Add SpatialJoinNode to presto_protocol
To send SpatialJoinNodes to Velox, we need to serialize and deserialize them via presto_protocol. This change requires facebookincubator/velox#14339 for spatial joins to not cause an error. After this PR and the above lands, Spatial Joins should be enabled implemented as Nested Loop Joins. Not efficient, but it should be correct.
1 parent 3425c7e commit 7a20948

File tree

7 files changed

+180
-3
lines changed

7 files changed

+180
-3
lines changed

presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,17 @@ core::JoinType toJoinType(protocol::JoinType type) {
11861186

11871187
VELOX_UNSUPPORTED("Unknown join type");
11881188
}
1189+
1190+
core::JoinType toJoinType(protocol::SpatialJoinType type) {
1191+
switch (type) {
1192+
case protocol::SpatialJoinType::INNER:
1193+
return core::JoinType::kInner;
1194+
case protocol::SpatialJoinType::LEFT:
1195+
return core::JoinType::kLeft;
1196+
}
1197+
1198+
VELOX_UNSUPPORTED("Unknown spatial join type");
1199+
}
11891200
} // namespace
11901201

11911202
core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan(
@@ -1264,6 +1275,20 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan(
12641275
ROW(std::move(outputNames), std::move(outputTypes)));
12651276
}
12661277

1278+
core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan(
1279+
const std::shared_ptr<const protocol::SpatialJoinNode>& node,
1280+
const std::shared_ptr<protocol::TableWriteInfo>& tableWriteInfo,
1281+
const protocol::TaskId& taskId) {
1282+
auto joinType = toJoinType(node->type);
1283+
1284+
return std::make_shared<core::SpatialJoinNode>(
1285+
node->id,
1286+
joinType,
1287+
exprConverter_.toVeloxExpr(node->filter),
1288+
toVeloxQueryPlan(node->left, tableWriteInfo, taskId),
1289+
toVeloxQueryPlan(node->right, tableWriteInfo, taskId),
1290+
toRowType(node->outputVariables, typeParser_));}
1291+
12671292
std::shared_ptr<const core::IndexLookupJoinNode>
12681293
VeloxQueryPlanConverterBase::toVeloxQueryPlan(
12691294
const std::shared_ptr<const protocol::IndexJoinNode>& node,
@@ -1842,6 +1867,10 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan(
18421867
std::dynamic_pointer_cast<const protocol::MergeJoinNode>(node)) {
18431868
return toVeloxQueryPlan(join, tableWriteInfo, taskId);
18441869
}
1870+
if (auto spatialJoin =
1871+
std::dynamic_pointer_cast<const protocol::SpatialJoinNode>(node)) {
1872+
return toVeloxQueryPlan(spatialJoin, tableWriteInfo, taskId);
1873+
}
18451874
if (auto remoteSource =
18461875
std::dynamic_pointer_cast<const protocol::RemoteSourceNode>(node)) {
18471876
return toVeloxQueryPlan(remoteSource, tableWriteInfo, taskId);

presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ class VeloxQueryPlanConverterBase {
110110
const std::shared_ptr<protocol::TableWriteInfo>& tableWriteInfo,
111111
const protocol::TaskId& taskId);
112112

113+
velox::core::PlanNodePtr toVeloxQueryPlan(
114+
const std::shared_ptr<const protocol::SpatialJoinNode>& node,
115+
const std::shared_ptr<protocol::TableWriteInfo>& tableWriteInfo,
116+
const protocol::TaskId& taskId);
117+
113118
std::shared_ptr<const velox::core::IndexLookupJoinNode> toVeloxQueryPlan(
114119
const std::shared_ptr<const protocol::IndexJoinNode>& node,
115120
const std::shared_ptr<protocol::TableWriteInfo>& tableWriteInfo,

presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,10 @@ void to_json(json& j, const std::shared_ptr<PlanNode>& p) {
728728
j = *std::static_pointer_cast<SemiJoinNode>(p);
729729
return;
730730
}
731+
if (type == ".SpatialJoinNode") {
732+
j = *std::static_pointer_cast<SpatialJoinNode>(p);
733+
return;
734+
}
731735
if (type == ".TableScanNode") {
732736
j = *std::static_pointer_cast<TableScanNode>(p);
733737
return;
@@ -896,6 +900,12 @@ void from_json(const json& j, std::shared_ptr<PlanNode>& p) {
896900
p = std::static_pointer_cast<PlanNode>(k);
897901
return;
898902
}
903+
if (type == ".SpatialJoinNode") {
904+
std::shared_ptr<SpatialJoinNode> k = std::make_shared<SpatialJoinNode>();
905+
j.get_to(*k);
906+
p = std::static_pointer_cast<PlanNode>(k);
907+
return;
908+
}
899909
if (type == ".TableScanNode") {
900910
std::shared_ptr<TableScanNode> k = std::make_shared<TableScanNode>();
901911
j.get_to(*k);
@@ -9485,6 +9495,115 @@ void from_json(const json& j, SortedRangeSet& p) {
94859495
namespace facebook::presto::protocol {
94869496
// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM()
94879497

9498+
// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays
9499+
static const std::pair<SpatialJoinType, json> SpatialJoinType_enum_table[] =
9500+
{ // NOLINT: cert-err58-cpp
9501+
{SpatialJoinType::INNER, "INNER"},
9502+
{SpatialJoinType::LEFT, "LEFT"}};
9503+
void to_json(json& j, const SpatialJoinType& e) {
9504+
static_assert(
9505+
std::is_enum<SpatialJoinType>::value, "SpatialJoinType must be an enum!");
9506+
const auto* it = std::find_if(
9507+
std::begin(SpatialJoinType_enum_table),
9508+
std::end(SpatialJoinType_enum_table),
9509+
[e](const std::pair<SpatialJoinType, json>& ej_pair) -> bool {
9510+
return ej_pair.first == e;
9511+
});
9512+
j = ((it != std::end(SpatialJoinType_enum_table))
9513+
? it
9514+
: std::begin(SpatialJoinType_enum_table))
9515+
->second;
9516+
}
9517+
void from_json(const json& j, SpatialJoinType& e) {
9518+
static_assert(
9519+
std::is_enum<SpatialJoinType>::value, "SpatialJoinType must be an enum!");
9520+
const auto* it = std::find_if(
9521+
std::begin(SpatialJoinType_enum_table),
9522+
std::end(SpatialJoinType_enum_table),
9523+
[&j](const std::pair<SpatialJoinType, json>& ej_pair) -> bool {
9524+
return ej_pair.second == j;
9525+
});
9526+
e = ((it != std::end(SpatialJoinType_enum_table))
9527+
? it
9528+
: std::begin(SpatialJoinType_enum_table))
9529+
->first;
9530+
}
9531+
} // namespace facebook::presto::protocol
9532+
namespace facebook::presto::protocol {
9533+
SpatialJoinNode::SpatialJoinNode() noexcept {
9534+
_type = ".SpatialJoinNode";
9535+
}
9536+
9537+
void to_json(json& j, const SpatialJoinNode& p) {
9538+
j = json::object();
9539+
j["@type"] = ".SpatialJoinNode";
9540+
to_json_key(j, "id", p.id, "SpatialJoinNode", "PlanNodeId", "id");
9541+
to_json_key(j, "type", p.type, "SpatialJoinNode", "SpatialJoinType", "type");
9542+
to_json_key(j, "left", p.left, "SpatialJoinNode", "PlanNode", "left");
9543+
to_json_key(j, "right", p.right, "SpatialJoinNode", "PlanNode", "right");
9544+
to_json_key(
9545+
j,
9546+
"outputVariables",
9547+
p.outputVariables,
9548+
"SpatialJoinNode",
9549+
"List<VariableReferenceExpression>",
9550+
"outputVariables");
9551+
to_json_key(
9552+
j, "filter", p.filter, "SpatialJoinNode", "RowExpression", "filter");
9553+
to_json_key(
9554+
j,
9555+
"leftPartitionVariable",
9556+
p.leftPartitionVariable,
9557+
"SpatialJoinNode",
9558+
"VariableReferenceExpression",
9559+
"leftPartitionVariable");
9560+
to_json_key(
9561+
j,
9562+
"rightPartitionVariable",
9563+
p.rightPartitionVariable,
9564+
"SpatialJoinNode",
9565+
"VariableReferenceExpression",
9566+
"rightPartitionVariable");
9567+
to_json_key(j, "kdbTree", p.kdbTree, "SpatialJoinNode", "String", "kdbTree");
9568+
}
9569+
9570+
void from_json(const json& j, SpatialJoinNode& p) {
9571+
p._type = j["@type"];
9572+
from_json_key(j, "id", p.id, "SpatialJoinNode", "PlanNodeId", "id");
9573+
from_json_key(
9574+
j, "type", p.type, "SpatialJoinNode", "SpatialJoinType", "type");
9575+
from_json_key(j, "left", p.left, "SpatialJoinNode", "PlanNode", "left");
9576+
from_json_key(j, "right", p.right, "SpatialJoinNode", "PlanNode", "right");
9577+
from_json_key(
9578+
j,
9579+
"outputVariables",
9580+
p.outputVariables,
9581+
"SpatialJoinNode",
9582+
"List<VariableReferenceExpression>",
9583+
"outputVariables");
9584+
from_json_key(
9585+
j, "filter", p.filter, "SpatialJoinNode", "RowExpression", "filter");
9586+
from_json_key(
9587+
j,
9588+
"leftPartitionVariable",
9589+
p.leftPartitionVariable,
9590+
"SpatialJoinNode",
9591+
"VariableReferenceExpression",
9592+
"leftPartitionVariable");
9593+
from_json_key(
9594+
j,
9595+
"rightPartitionVariable",
9596+
p.rightPartitionVariable,
9597+
"SpatialJoinNode",
9598+
"VariableReferenceExpression",
9599+
"rightPartitionVariable");
9600+
from_json_key(
9601+
j, "kdbTree", p.kdbTree, "SpatialJoinNode", "String", "kdbTree");
9602+
}
9603+
} // namespace facebook::presto::protocol
9604+
namespace facebook::presto::protocol {
9605+
// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM()
9606+
94889607
// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays
94899608
static const std::pair<Form, json> Form_enum_table[] =
94909609
{ // NOLINT: cert-err58-cpp

presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,6 +2188,27 @@ void to_json(json& j, const SortedRangeSet& p);
21882188
void from_json(const json& j, SortedRangeSet& p);
21892189
} // namespace facebook::presto::protocol
21902190
namespace facebook::presto::protocol {
2191+
enum class SpatialJoinType { INNER, LEFT };
2192+
extern void to_json(json& j, const SpatialJoinType& e);
2193+
extern void from_json(const json& j, SpatialJoinType& e);
2194+
} // namespace facebook::presto::protocol
2195+
namespace facebook::presto::protocol {
2196+
struct SpatialJoinNode : public PlanNode {
2197+
SpatialJoinType type = {};
2198+
std::shared_ptr<PlanNode> left = {};
2199+
std::shared_ptr<PlanNode> right = {};
2200+
List<VariableReferenceExpression> outputVariables = {};
2201+
std::shared_ptr<RowExpression> filter = {};
2202+
std::shared_ptr<VariableReferenceExpression> leftPartitionVariable = {};
2203+
std::shared_ptr<VariableReferenceExpression> rightPartitionVariable = {};
2204+
std::shared_ptr<String> kdbTree = {};
2205+
2206+
SpatialJoinNode() noexcept;
2207+
};
2208+
void to_json(json& j, const SpatialJoinNode& p);
2209+
void from_json(const json& j, SpatialJoinNode& p);
2210+
} // namespace facebook::presto::protocol
2211+
namespace facebook::presto::protocol {
21912212
enum class Form {
21922213
IF,
21932214
NULL_IF,

presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ AbstractClasses:
157157
- { name: RemoteSourceNode, key: com.facebook.presto.sql.planner.plan.RemoteSourceNode }
158158
- { name: SampleNode, key: com.facebook.presto.sql.planner.plan.SampleNode }
159159
- { name: SemiJoinNode, key: .SemiJoinNode }
160+
- { name: SpatialJoinNode, key: .SpatialJoinNode }
160161
- { name: TableScanNode, key: .TableScanNode }
161162
- { name: TableWriterNode, key: .TableWriterNode }
162163
- { name: TableWriterMergeNode, key: com.facebook.presto.sql.planner.plan.TableWriterMergeNode }
@@ -317,6 +318,7 @@ JavaClasses:
317318
- presto-spi/src/main/java/com/facebook/presto/spi/plan/JoinNode.java
318319
- presto-spi/src/main/java/com/facebook/presto/spi/plan/SemiJoinNode.java
319320
- presto-spi/src/main/java/com/facebook/presto/spi/plan/MergeJoinNode.java
321+
- presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java
320322
- presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/IndexJoinNode.java
321323
- presto-spi/src/main/java/com/facebook/presto/spi/plan/IndexSourceNode.java
322324
- presto-spi/src/main/java/com/facebook/presto/spi/plan/TopNNode.java

presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ AbstractClasses:
155155
- { name: RemoteSourceNode, key: com.facebook.presto.sql.planner.plan.RemoteSourceNode }
156156
- { name: SampleNode, key: com.facebook.presto.sql.planner.plan.SampleNode }
157157
- { name: SemiJoinNode, key: .SemiJoinNode }
158+
- { name: SpatialJoinNode, key: .SpatialJoinNode }
158159
- { name: TableScanNode, key: .TableScanNode }
159160
- { name: TableWriterNode, key: .TableWriterNode }
160161
- { name: TableWriterMergeNode, key: com.facebook.presto.sql.planner.plan.TableWriterMergeNode }
@@ -360,6 +361,7 @@ JavaClasses:
360361
- presto-spi/src/main/java/com/facebook/presto/spi/plan/JoinNode.java
361362
- presto-spi/src/main/java/com/facebook/presto/spi/plan/SemiJoinNode.java
362363
- presto-spi/src/main/java/com/facebook/presto/spi/plan/MergeJoinNode.java
364+
- presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java
363365
- presto-spi/src/main/java/com/facebook/presto/spi/plan/TopNNode.java
364366
- presto-hive/src/main/java/com/facebook/presto/hive/HivePartitioningHandle.java
365367
- presto-main/src/main/java/com/facebook/presto/split/EmptySplit.java

presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,16 +412,15 @@ public void testGeometryQueries()
412412
assertQuery("SELECT " +
413413
"ST_DISTANCE(ST_POINT(a.nationkey, a.regionkey), ST_POINT(b.nationkey, b.regionkey)) " +
414414
"FROM nation a JOIN nation b ON a.nationkey < b.nationkey");
415-
assertQueryFails(
415+
assertQuery(
416416
"WITH regions(name, geom) AS (VALUES" +
417417
" ('A', ST_GeometryFromText('POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0))'))," +
418418
" ('B', ST_GeometryFromText('POLYGON ((5 0, 5 5, 10 5, 10 0, 5 0))')))," +
419419
"points(id, geom) AS (VALUES" +
420420
" ('P1', ST_Point(1, 1))," +
421421
" ('P2', ST_Point(6, 1))," +
422422
" ('P3', ST_Point(8, 4)))" +
423-
"SELECT p.id, r.name FROM points p LEFT JOIN regions r ON ST_Within(p.geom, r.geom)",
424-
"Error from native plan checker: .SpatialJoinNode no abstract type PlanNode ");
423+
"SELECT p.id, r.name FROM points p LEFT JOIN regions r ON ST_Within(p.geom, r.geom)");
425424
}
426425

427426
@Test

0 commit comments

Comments
 (0)