Skip to content

Commit 6f62c2a

Browse files
authored
GH-48268: [C++][Acero] Enhance the type checking for hash join residual filter (#48272)
### Rationale for this change Type checking for hash join filter isn't enforced for some corner cases (literal filter expression). Some invalid tests are introduced. ### What changes are included in this PR? Enforce the type checking for all cases. Also fix the problematic test cases. Also refined the trivial residual filter handling in swiss join. ### Are these changes tested? Test included. ### Are there any user-facing changes? None. * GitHub Issue: #48268 Authored-by: Rossi Sun <[email protected]> Signed-off-by: Rossi Sun <[email protected]>
1 parent 5aa7dd1 commit 6f62c2a

File tree

4 files changed

+67
-20
lines changed

4 files changed

+67
-20
lines changed

cpp/src/arrow/acero/hash_join_node.cc

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,19 @@ Result<Expression> HashJoinSchema::BindFilter(Expression filter,
370370
const Schema& left_schema,
371371
const Schema& right_schema,
372372
ExecContext* exec_context) {
373-
if (filter.IsBound() || filter == literal(true)) {
373+
auto ValidateFilterTypeAndReturn = [](Expression filter) -> Result<Expression> {
374+
if (filter.type()->id() != Type::BOOL) {
375+
return Status::TypeError("Filter expression must evaluate to bool, but ",
376+
filter.ToString(), " evaluates to ",
377+
filter.type()->ToString());
378+
}
374379
return filter;
380+
};
381+
382+
if (filter.IsBound()) {
383+
return ValidateFilterTypeAndReturn(std::move(filter));
375384
}
385+
376386
// Step 1: Construct filter schema
377387
FieldVector fields;
378388
auto left_f_to_i =
@@ -401,12 +411,8 @@ Result<Expression> HashJoinSchema::BindFilter(Expression filter,
401411

402412
// Step 3: Bind
403413
ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(filter_schema, exec_context));
404-
if (filter.type()->id() != Type::BOOL) {
405-
return Status::TypeError("Filter expression must evaluate to bool, but ",
406-
filter.ToString(), " evaluates to ",
407-
filter.type()->ToString());
408-
}
409-
return filter;
414+
415+
return ValidateFilterTypeAndReturn(std::move(filter));
410416
}
411417

412418
Expression HashJoinSchema::RewriteFilterToUseFilterSchema(

cpp/src/arrow/acero/hash_join_node_test.cc

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1902,6 +1902,41 @@ TEST(HashJoin, CheckHashJoinNodeOptionsValidation) {
19021902
}
19031903
}
19041904

1905+
TEST(HashJoin, CheckResidualFilterType) {
1906+
BatchesWithSchema input_left;
1907+
input_left.schema = schema({field("lkey", int32()), field("lpayload", int32())});
1908+
1909+
BatchesWithSchema input_right;
1910+
input_right.schema = schema({field("rkey", int32()), field("rpayload", int32())});
1911+
1912+
Declaration left{"source",
1913+
SourceNodeOptions{input_left.schema, input_left.gen(/*parallel=*/false,
1914+
/*slow=*/false)}};
1915+
Declaration right{
1916+
"source", SourceNodeOptions{input_right.schema, input_right.gen(/*parallel=*/false,
1917+
/*slow=*/false)}};
1918+
1919+
for (const auto& filter :
1920+
{literal(MakeNullScalar(boolean())), literal(true), literal(false),
1921+
equal(field_ref("lpayload"), field_ref("rpayload"))}) {
1922+
HashJoinNodeOptions options{
1923+
JoinType::INNER, {FieldRef("lkey")}, {FieldRef("rkey")}, filter};
1924+
Declaration join{"hashjoin", {left, right}, options};
1925+
ASSERT_OK(DeclarationToStatus(std::move(join)));
1926+
}
1927+
1928+
for (const auto& filter :
1929+
{literal(NullScalar()), literal(42),
1930+
call("add", {field_ref("lpayload"), field_ref("rpayload")})}) {
1931+
HashJoinNodeOptions options{
1932+
JoinType::INNER, {FieldRef("lkey")}, {FieldRef("rkey")}, filter};
1933+
Declaration join{"hashjoin", {left, right}, options};
1934+
EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError,
1935+
::testing::HasSubstr("must evaluate to bool"),
1936+
DeclarationToStatus(std::move(join)));
1937+
}
1938+
}
1939+
19051940
class ResidualFilterCaseRunner {
19061941
public:
19071942
ResidualFilterCaseRunner(BatchesWithSchema left_input, BatchesWithSchema right_input)
@@ -2369,8 +2404,8 @@ TEST(HashJoin, FineGrainedResidualFilter) {
23692404
{
23702405
// Literal false, null, and scalar false, null.
23712406
for (Expression filter :
2372-
{literal(false), literal(NullScalar()), equal(literal(0), literal(1)),
2373-
equal(literal(1), literal(NullScalar()))}) {
2407+
{literal(false), literal(MakeNullScalar(boolean())),
2408+
equal(literal(0), literal(1)), equal(literal(1), literal(NullScalar()))}) {
23742409
std::vector<FieldRef> left_keys{"l_key", "l_filter"},
23752410
right_keys{"r_key", "r_filter"};
23762411
{

cpp/src/arrow/acero/swiss_join.cc

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,6 +1845,11 @@ void JoinResidualFilter::Init(Expression filter, QueryContext* ctx, MemoryPool*
18451845
const HashJoinProjectionMaps* build_schemas,
18461846
SwissTableForJoin* hash_table) {
18471847
filter_ = std::move(filter);
1848+
if (auto lit = filter_.literal(); lit) {
1849+
const auto& scalar = lit->scalar_as<BooleanScalar>();
1850+
is_trivial_ = true;
1851+
is_literal_true_ = scalar.is_valid && scalar.value;
1852+
}
18481853
ctx_ = ctx;
18491854
pool_ = pool;
18501855
hardware_flags_ = hardware_flags;
@@ -1918,14 +1923,14 @@ Status JoinResidualFilter::FilterLeftSemi(const ExecBatch& keypayload_batch,
19181923
arrow::util::TempVectorStack* temp_stack,
19191924
int* num_passing_ids,
19201925
uint16_t* passing_batch_row_ids) const {
1921-
if (filter_ == literal(true)) {
1926+
if (is_literal_true_) {
19221927
CollectPassingBatchIds(1, hardware_flags_, batch_start_row, num_batch_rows,
19231928
match_bitvector, num_passing_ids, passing_batch_row_ids);
19241929
return Status::OK();
19251930
}
19261931

19271932
*num_passing_ids = 0;
1928-
if (filter_.IsNullLiteral() || filter_ == literal(false)) {
1933+
if (is_trivial_ && !is_literal_true_) {
19291934
return Status::OK();
19301935
}
19311936

@@ -1993,7 +1998,7 @@ Status JoinResidualFilter::FilterLeftAnti(const ExecBatch& keypayload_batch,
19931998
arrow::util::TempVectorStack* temp_stack,
19941999
int* num_passing_ids,
19952000
uint16_t* passing_batch_row_ids) const {
1996-
if (filter_ == literal(true)) {
2001+
if (is_literal_true_) {
19972002
CollectPassingBatchIds(0, hardware_flags_, batch_start_row, num_batch_rows,
19982003
match_bitvector, num_passing_ids, passing_batch_row_ids);
19992004
return Status::OK();
@@ -2032,12 +2037,12 @@ Status JoinResidualFilter::FilterRightSemiAnti(
20322037
int64_t thread_id, const ExecBatch& keypayload_batch, int batch_start_row,
20332038
int num_batch_rows, const uint8_t* match_bitvector, const uint32_t* key_ids,
20342039
bool no_duplicate_keys, arrow::util::TempVectorStack* temp_stack) const {
2035-
if (filter_.IsNullLiteral() || filter_ == literal(false)) {
2040+
if (is_trivial_ && !is_literal_true_) {
20362041
return Status::OK();
20372042
}
20382043

20392044
int num_matching_ids = 0;
2040-
if (filter_ == literal(true)) {
2045+
if (is_literal_true_) {
20412046
auto match_relative_batch_ids_buf =
20422047
arrow::util::TempVectorHolder<uint16_t>(temp_stack, num_batch_rows);
20432048
auto match_key_ids_buf =
@@ -2091,13 +2096,13 @@ Status JoinResidualFilter::FilterInner(
20912096
const ExecBatch& keypayload_batch, int num_batch_rows, uint16_t* batch_row_ids,
20922097
uint32_t* key_ids, uint32_t* payload_ids_maybe_null, bool output_payload_ids,
20932098
arrow::util::TempVectorStack* temp_stack, int* num_passing_rows) const {
2094-
if (filter_ == literal(true)) {
2099+
if (is_literal_true_) {
20952100
*num_passing_rows = num_batch_rows;
20962101
return Status::OK();
20972102
}
20982103

20992104
*num_passing_rows = 0;
2100-
if (filter_.IsNullLiteral() || filter_ == literal(false)) {
2105+
if (is_trivial_ && !is_literal_true_) {
21012106
return Status::OK();
21022107
}
21032108

@@ -2114,8 +2119,7 @@ Status JoinResidualFilter::FilterOneBatch(const ExecBatch& keypayload_batch,
21142119
arrow::util::TempVectorStack* temp_stack,
21152120
int* num_passing_rows) const {
21162121
// Caller must do shortcuts for trivial filter.
2117-
ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) &&
2118-
filter_ != literal(false));
2122+
ARROW_DCHECK(!is_trivial_);
21192123
ARROW_DCHECK(!output_key_ids || key_ids_maybe_null);
21202124
ARROW_DCHECK(!output_payload_ids || payload_ids_maybe_null);
21212125

@@ -2128,6 +2132,7 @@ Status JoinResidualFilter::FilterOneBatch(const ExecBatch& keypayload_batch,
21282132
ARROW_ASSIGN_OR_RAISE(Datum mask,
21292133
EvalFilter(keypayload_batch, num_batch_rows, batch_row_ids,
21302134
key_ids_maybe_null, payload_ids_maybe_null));
2135+
DCHECK_EQ(mask.type()->id(), Type::BOOL);
21312136
if (mask.is_scalar()) {
21322137
const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
21332138
if (mask_scalar.is_valid && mask_scalar.value) {
@@ -2162,8 +2167,7 @@ Status JoinResidualFilter::FilterOneBatch(const ExecBatch& keypayload_batch,
21622167
Result<Datum> JoinResidualFilter::EvalFilter(
21632168
const ExecBatch& keypayload_batch, int num_batch_rows, const uint16_t* batch_row_ids,
21642169
const uint32_t* key_ids_maybe_null, const uint32_t* payload_ids_maybe_null) const {
2165-
ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) &&
2166-
filter_ != literal(false));
2170+
ARROW_DCHECK(!is_trivial_);
21672171

21682172
ARROW_ASSIGN_OR_RAISE(
21692173
ExecBatch input,

cpp/src/arrow/acero/swiss_join_internal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,8 @@ class JoinResidualFilter {
980980

981981
private:
982982
Expression filter_;
983+
bool is_trivial_ = false;
984+
bool is_literal_true_ = false;
983985

984986
QueryContext* ctx_;
985987
MemoryPool* pool_;

0 commit comments

Comments
 (0)