@@ -391,31 +391,7 @@ void VectorizedFnCall::prepare_ann_range_search(
391391 auto left_child = get_child (0 );
392392 auto right_child = get_child (1 );
393393
394- auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child);
395- if (right_literal == nullptr ) {
396- suitable_for_ann_index = false ;
397- return ;
398- }
399-
400- auto right_col = right_literal->get_column_ptr ()->convert_to_full_column_if_const ();
401- auto right_type = right_literal->get_data_type ();
402-
403- PrimitiveType right_primitive = right_type->get_primitive_type ();
404- const bool float32_literal = right_primitive == PrimitiveType::TYPE_FLOAT;
405- const bool float64_literal = right_primitive == PrimitiveType::TYPE_DOUBLE;
406- if (!float32_literal && !float64_literal) {
407- mark_unsuitable (" Right child is not a Float32Literal or Float64Literal." );
408- return ;
409- }
410-
411- if (float32_literal) {
412- const ColumnFloat32* cf32_right = assert_cast<const ColumnFloat32*>(right_col.get ());
413- range_search_runtime.radius = cf32_right->get_data ()[0 ];
414- } else if (float64_literal) {
415- const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get ());
416- range_search_runtime.radius = static_cast <float >(cf64_right->get_data ()[0 ]);
417- }
418-
394+ // ========== Step 1: Check left child - must be a distance function ==========
419395 auto get_virtual_expr = [&](const VExprSPtr& expr,
420396 std::shared_ptr<VirtualSlotRef>& slot_ref) -> VExprSPtr {
421397 auto virtual_ref = std::dynamic_pointer_cast<VirtualSlotRef>(expr);
@@ -430,40 +406,40 @@ void VectorizedFnCall::prepare_ann_range_search(
430406 std::shared_ptr<VirtualSlotRef> vir_slot_ref;
431407 auto normalized_left = get_virtual_expr (left_child, vir_slot_ref);
432408
433- std::shared_ptr<VectorizedFnCall> function_call;
434- if (float32_literal) {
435- function_call = std::dynamic_pointer_cast<VectorizedFnCall>(normalized_left);
436- if (function_call == nullptr ) {
437- mark_unsuitable (" Left child is not a function call." );
438- return ;
439- }
440- } else {
441- auto cast_float_to_double = std::dynamic_pointer_cast<VCastExpr>(normalized_left);
442- if (cast_float_to_double == nullptr ) {
443- mark_unsuitable (" Left child is not a cast expression." );
409+ // Try to find the distance function call, it may be wrapped in a Cast(Float->Double)
410+ std::shared_ptr<VectorizedFnCall> function_call =
411+ std::dynamic_pointer_cast<VectorizedFnCall>(normalized_left);
412+ bool has_float_to_double_cast = false ;
413+
414+ if (function_call == nullptr ) {
415+ // Check if it's a Cast expression wrapping a function call
416+ auto cast_expr = std::dynamic_pointer_cast<VCastExpr>(normalized_left);
417+ if (cast_expr == nullptr ) {
418+ mark_unsuitable (" Left child is neither a function call nor a cast expression." );
444419 return ;
445420 }
446-
447- auto normalized_cast_child =
448- get_virtual_expr (cast_float_to_double->get_child (0 ), vir_slot_ref);
421+ has_float_to_double_cast = true ;
422+ auto normalized_cast_child = get_virtual_expr (cast_expr->get_child (0 ), vir_slot_ref);
449423 function_call = std::dynamic_pointer_cast<VectorizedFnCall>(normalized_cast_child);
450424 if (function_call == nullptr ) {
451425 mark_unsuitable (" Left child of cast is not a function call." );
452426 return ;
453427 }
454428 }
455429
430+ // Check if it's a supported distance function
456431 if (DISTANCE_FUNCS.find (function_call->_function_name ) == DISTANCE_FUNCS.end ()) {
457432 mark_unsuitable (fmt::format (" Left child is not a supported distance function: {}" ,
458433 function_call->_function_name ));
459434 return ;
460- } else {
461- // Strip the _approximate suffix.
462- std::string metric_name = function_call->_function_name ;
463- metric_name = metric_name.substr (0 , metric_name.size () - 12 );
464- range_search_runtime.metric_type = segment_v2::string_to_metric (metric_name);
465435 }
466436
437+ // Strip the _approximate suffix to get metric type
438+ std::string metric_name = function_call->_function_name ;
439+ metric_name = metric_name.substr (0 , metric_name.size () - 12 );
440+ range_search_runtime.metric_type = segment_v2::string_to_metric (metric_name);
441+
442+ // ========== Step 2: Validate distance function arguments ==========
467443 // Identify the slot ref child and the constant query array child (ArrayLiteral or CAST to array)
468444 Int32 idx_of_slot_ref = -1 ;
469445 Int32 idx_of_array_expr = -1 ;
@@ -502,6 +478,47 @@ void VectorizedFnCall::prepare_ann_range_search(
502478 }
503479 range_search_runtime.query_value = extract_result.value ();
504480 range_search_runtime.dim = range_search_runtime.query_value ->size ();
481+
482+ // ========== Step 3: Check right child - must be a float/double literal ==========
483+ auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child);
484+ if (right_literal == nullptr ) {
485+ mark_unsuitable (" Right child is not a literal." );
486+ return ;
487+ }
488+
489+ // Handle nullable literal gracefully - just mark as unsuitable instead of crash
490+ if (right_literal->is_nullable ()) {
491+ mark_unsuitable (" Right literal is nullable, not supported for ANN range search." );
492+ return ;
493+ }
494+
495+ auto right_type = right_literal->get_data_type ();
496+ PrimitiveType right_primitive = right_type->get_primitive_type ();
497+ const bool float32_literal = right_primitive == PrimitiveType::TYPE_FLOAT;
498+ const bool float64_literal = right_primitive == PrimitiveType::TYPE_DOUBLE;
499+
500+ if (!float32_literal && !float64_literal) {
501+ mark_unsuitable (" Right child is not a Float32Literal or Float64Literal." );
502+ return ;
503+ }
504+
505+ // Validate consistency: if we have Cast(Float->Double), right must be double literal
506+ if (has_float_to_double_cast && !float64_literal) {
507+ mark_unsuitable (" Cast expression expects double literal on right side." );
508+ return ;
509+ }
510+
511+ // Extract radius value
512+ auto right_col = right_literal->get_column_ptr ()->convert_to_full_column_if_const ();
513+ if (float32_literal) {
514+ const ColumnFloat32* cf32_right = assert_cast<const ColumnFloat32*>(right_col.get ());
515+ range_search_runtime.radius = cf32_right->get_data ()[0 ];
516+ } else {
517+ const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get ());
518+ range_search_runtime.radius = static_cast <float >(cf64_right->get_data ()[0 ]);
519+ }
520+
521+ // ========== Done: Mark as suitable for ANN range search ==========
505522 range_search_runtime.is_ann_range_search = true ;
506523 range_search_runtime.user_params = user_params;
507524 VLOG_DEBUG << fmt::format (" Ann range search params: {}" , range_search_runtime.to_string ());
0 commit comments