Skip to content

Commit cc01c04

Browse files
branch-4.0: [fix](ann range search) range search prepare failed on NULL literal #60564 (#60821)
cherry pick from #60564
1 parent ce72b33 commit cc01c04

File tree

3 files changed

+352
-44
lines changed

3 files changed

+352
-44
lines changed

be/src/vec/exprs/vectorized_fn_call.cpp

Lines changed: 61 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -391,31 +391,7 @@ void VectorizedFnCall::prepare_ann_range_search(
391391
auto left_child = get_child(0);
392392
auto right_child = get_child(1);
393393

394-
auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child);
395-
if (right_literal == nullptr) {
396-
suitable_for_ann_index = false;
397-
return;
398-
}
399-
400-
auto right_col = right_literal->get_column_ptr()->convert_to_full_column_if_const();
401-
auto right_type = right_literal->get_data_type();
402-
403-
PrimitiveType right_primitive = right_type->get_primitive_type();
404-
const bool float32_literal = right_primitive == PrimitiveType::TYPE_FLOAT;
405-
const bool float64_literal = right_primitive == PrimitiveType::TYPE_DOUBLE;
406-
if (!float32_literal && !float64_literal) {
407-
mark_unsuitable("Right child is not a Float32Literal or Float64Literal.");
408-
return;
409-
}
410-
411-
if (float32_literal) {
412-
const ColumnFloat32* cf32_right = assert_cast<const ColumnFloat32*>(right_col.get());
413-
range_search_runtime.radius = cf32_right->get_data()[0];
414-
} else if (float64_literal) {
415-
const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get());
416-
range_search_runtime.radius = static_cast<float>(cf64_right->get_data()[0]);
417-
}
418-
394+
// ========== Step 1: Check left child - must be a distance function ==========
419395
auto get_virtual_expr = [&](const VExprSPtr& expr,
420396
std::shared_ptr<VirtualSlotRef>& slot_ref) -> VExprSPtr {
421397
auto virtual_ref = std::dynamic_pointer_cast<VirtualSlotRef>(expr);
@@ -430,40 +406,40 @@ void VectorizedFnCall::prepare_ann_range_search(
430406
std::shared_ptr<VirtualSlotRef> vir_slot_ref;
431407
auto normalized_left = get_virtual_expr(left_child, vir_slot_ref);
432408

433-
std::shared_ptr<VectorizedFnCall> function_call;
434-
if (float32_literal) {
435-
function_call = std::dynamic_pointer_cast<VectorizedFnCall>(normalized_left);
436-
if (function_call == nullptr) {
437-
mark_unsuitable("Left child is not a function call.");
438-
return;
439-
}
440-
} else {
441-
auto cast_float_to_double = std::dynamic_pointer_cast<VCastExpr>(normalized_left);
442-
if (cast_float_to_double == nullptr) {
443-
mark_unsuitable("Left child is not a cast expression.");
409+
// Try to find the distance function call, it may be wrapped in a Cast(Float->Double)
410+
std::shared_ptr<VectorizedFnCall> function_call =
411+
std::dynamic_pointer_cast<VectorizedFnCall>(normalized_left);
412+
bool has_float_to_double_cast = false;
413+
414+
if (function_call == nullptr) {
415+
// Check if it's a Cast expression wrapping a function call
416+
auto cast_expr = std::dynamic_pointer_cast<VCastExpr>(normalized_left);
417+
if (cast_expr == nullptr) {
418+
mark_unsuitable("Left child is neither a function call nor a cast expression.");
444419
return;
445420
}
446-
447-
auto normalized_cast_child =
448-
get_virtual_expr(cast_float_to_double->get_child(0), vir_slot_ref);
421+
has_float_to_double_cast = true;
422+
auto normalized_cast_child = get_virtual_expr(cast_expr->get_child(0), vir_slot_ref);
449423
function_call = std::dynamic_pointer_cast<VectorizedFnCall>(normalized_cast_child);
450424
if (function_call == nullptr) {
451425
mark_unsuitable("Left child of cast is not a function call.");
452426
return;
453427
}
454428
}
455429

430+
// Check if it's a supported distance function
456431
if (DISTANCE_FUNCS.find(function_call->_function_name) == DISTANCE_FUNCS.end()) {
457432
mark_unsuitable(fmt::format("Left child is not a supported distance function: {}",
458433
function_call->_function_name));
459434
return;
460-
} else {
461-
// Strip the _approximate suffix.
462-
std::string metric_name = function_call->_function_name;
463-
metric_name = metric_name.substr(0, metric_name.size() - 12);
464-
range_search_runtime.metric_type = segment_v2::string_to_metric(metric_name);
465435
}
466436

437+
// Strip the _approximate suffix to get metric type
438+
std::string metric_name = function_call->_function_name;
439+
metric_name = metric_name.substr(0, metric_name.size() - 12);
440+
range_search_runtime.metric_type = segment_v2::string_to_metric(metric_name);
441+
442+
// ========== Step 2: Validate distance function arguments ==========
467443
// Identify the slot ref child and the constant query array child (ArrayLiteral or CAST to array)
468444
Int32 idx_of_slot_ref = -1;
469445
Int32 idx_of_array_expr = -1;
@@ -502,6 +478,47 @@ void VectorizedFnCall::prepare_ann_range_search(
502478
}
503479
range_search_runtime.query_value = extract_result.value();
504480
range_search_runtime.dim = range_search_runtime.query_value->size();
481+
482+
// ========== Step 3: Check right child - must be a float/double literal ==========
483+
auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child);
484+
if (right_literal == nullptr) {
485+
mark_unsuitable("Right child is not a literal.");
486+
return;
487+
}
488+
489+
// Handle nullable literal gracefully - just mark as unsuitable instead of crash
490+
if (right_literal->is_nullable()) {
491+
mark_unsuitable("Right literal is nullable, not supported for ANN range search.");
492+
return;
493+
}
494+
495+
auto right_type = right_literal->get_data_type();
496+
PrimitiveType right_primitive = right_type->get_primitive_type();
497+
const bool float32_literal = right_primitive == PrimitiveType::TYPE_FLOAT;
498+
const bool float64_literal = right_primitive == PrimitiveType::TYPE_DOUBLE;
499+
500+
if (!float32_literal && !float64_literal) {
501+
mark_unsuitable("Right child is not a Float32Literal or Float64Literal.");
502+
return;
503+
}
504+
505+
// Validate consistency: if we have Cast(Float->Double), right must be double literal
506+
if (has_float_to_double_cast && !float64_literal) {
507+
mark_unsuitable("Cast expression expects double literal on right side.");
508+
return;
509+
}
510+
511+
// Extract radius value
512+
auto right_col = right_literal->get_column_ptr()->convert_to_full_column_if_const();
513+
if (float32_literal) {
514+
const ColumnFloat32* cf32_right = assert_cast<const ColumnFloat32*>(right_col.get());
515+
range_search_runtime.radius = cf32_right->get_data()[0];
516+
} else {
517+
const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get());
518+
range_search_runtime.radius = static_cast<float>(cf64_right->get_data()[0]);
519+
}
520+
521+
// ========== Done: Mark as suitable for ANN range search ==========
505522
range_search_runtime.is_ann_range_search = true;
506523
range_search_runtime.user_params = user_params;
507524
VLOG_DEBUG << fmt::format("Ann range search params: {}", range_search_runtime.to_string());
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
-- This file is automatically generated. You should know what you did if you want to edit this
2+
-- !nullable_subquery_empty --
3+
4+
-- !nullable_subquery_empty_ge --
5+
6+
-- !nullable_subquery_all_null --
7+
8+
-- !nullable_subquery_normal --
9+
0 [1, 2, 3, 4]
10+
1 [2, 3, 4, 5]
11+
2 [3, 4, 5, 6]
12+
13+
-- !nullable_subquery_normal_max --
14+
0 [1, 2, 3, 4]
15+
1 [2, 3, 4, 5]
16+
2 [3, 4, 5, 6]
17+
3 [4, 5, 6, 7]
18+
4 [5, 6, 7, 8]
19+
20+
-- !coalesce_with_null --
21+
0 [1, 2, 3, 4]
22+
1 [2, 3, 4, 5]
23+
2 [3, 4, 5, 6]
24+
25+
-- !case_nullable --
26+
0 [1, 2, 3, 4]
27+
1 [2, 3, 4, 5]
28+
2 [3, 4, 5, 6]
29+
30+
-- !normal_literal --
31+
0 [1, 2, 3, 4]
32+
1 [2, 3, 4, 5]
33+
2 [3, 4, 5, 6]
34+
35+
-- !ip_nullable_subquery --
36+
37+
-- !non_dist_nullable_empty --
38+
39+
-- !non_dist_nullable_all_null --
40+
41+
-- !non_dist_nullable_normal --
42+
0 [1, 2, 3, 4]
43+
1 [2, 3, 4, 5]
44+
45+
-- !non_dist_func_nullable --
46+
47+
-- !arithmetic_nullable --
48+
49+
-- !mixed_dist_and_regular_nullable --
50+
51+
-- !dist_normal_regular_nullable --
52+
53+
-- !or_condition_nullable --
54+

0 commit comments

Comments
 (0)