diff --git a/modin/core/storage_formats/pandas/query_compiler_caster.py b/modin/core/storage_formats/pandas/query_compiler_caster.py index 9de0debe3e2..a5ca35f764b 100644 --- a/modin/core/storage_formats/pandas/query_compiler_caster.py +++ b/modin/core/storage_formats/pandas/query_compiler_caster.py @@ -128,12 +128,16 @@ def _normalize_class_name(class_of_wrapped_fn: Optional[str]) -> str: _AUTO_SWITCH_CLASS = defaultdict[BackendAndClassName, set[str]] +# For pre-op switch methods, we store method_name -> is_arg_based mapping +# where is_arg_based=True means switch only if parameters are unsupported +_AUTO_SWITCH_PRE_OP_CLASS = defaultdict[BackendAndClassName, dict[str, bool]] + _CLASS_AND_BACKEND_TO_POST_OP_SWITCH_METHODS: _AUTO_SWITCH_CLASS = _AUTO_SWITCH_CLASS( set ) -_CLASS_AND_BACKEND_TO_PRE_OP_SWITCH_METHODS: _AUTO_SWITCH_CLASS = _AUTO_SWITCH_CLASS( - set +_CLASS_AND_BACKEND_TO_PRE_OP_SWITCH_METHODS: _AUTO_SWITCH_PRE_OP_CLASS = _AUTO_SWITCH_PRE_OP_CLASS( + dict ) @@ -621,21 +625,45 @@ def _maybe_switch_backend_pre_op( to the new query compiler type. """ input_backend = input_qc.get_backend() - if ( - function_name - in _CLASS_AND_BACKEND_TO_PRE_OP_SWITCH_METHODS[ - BackendAndClassName( - backend=input_qc.get_backend(), class_name=class_of_wrapped_fn + backend_class_key = BackendAndClassName( + backend=input_qc.get_backend(), class_name=class_of_wrapped_fn + ) + + # Check if this function is registered for pre-op switch + registered_methods = _CLASS_AND_BACKEND_TO_PRE_OP_SWITCH_METHODS[backend_class_key] + + if function_name in registered_methods: + is_arg_based = registered_methods[function_name] + + if is_arg_based: + # Arg-based switch: only switch if parameters are unsupported + stay_cost = input_qc.stay_cost( + api_cls_name=class_of_wrapped_fn, + operation=function_name, + arguments=arguments, + ) + + # Only trigger switch if parameters are unsupported (COST_IMPOSSIBLE) + if stay_cost is not None and stay_cost >= QCCoercionCost.COST_IMPOSSIBLE: + result_backend = _get_backend_for_auto_switch( + input_qc=input_qc, + class_of_wrapped_fn=class_of_wrapped_fn, + function_name=function_name, + arguments=arguments, + ) + else: + # Parameters are supported, no need to switch + result_backend = input_backend + else: + # Non-arg-based switch: always consider switching + result_backend = _get_backend_for_auto_switch( + input_qc=input_qc, + class_of_wrapped_fn=class_of_wrapped_fn, + function_name=function_name, + arguments=arguments, ) - ] - ): - result_backend = _get_backend_for_auto_switch( - input_qc=input_qc, - class_of_wrapped_fn=class_of_wrapped_fn, - function_name=function_name, - arguments=arguments, - ) else: + # No registration found, stay on current backend result_backend = input_backend def cast_to_qc(arg: Any) -> Any: @@ -773,12 +801,18 @@ def _get_backend_for_auto_switch( min_move_stay_delta = None best_backend = starting_backend - + all_backends_impossible = True + stay_cost = input_qc.stay_cost( api_cls_name=class_of_wrapped_fn, operation=function_name, arguments=arguments, ) + + # Check if the current backend can handle the workload + if stay_cost is not None and stay_cost < QCCoercionCost.COST_IMPOSSIBLE: + all_backends_impossible = False + data_max_shape = input_qc._max_shape() emit_metric( f"hybrid.auto.api.{class_of_wrapped_fn}.{function_name}.group.{metrics_group}", @@ -835,6 +869,12 @@ def _get_backend_for_auto_switch( # We can execute this workload if we need to, consider # move_to_cost/transfer time in our decision move_stay_delta = (move_to_cost + other_execute_cost) - stay_cost + + # Check if this backend can handle the workload (both execution and transfer must be possible) + if (other_execute_cost < QCCoercionCost.COST_IMPOSSIBLE and + move_to_cost < QCCoercionCost.COST_IMPOSSIBLE): + all_backends_impossible = False + if move_stay_delta < 0 and ( min_move_stay_delta is None or move_stay_delta < min_move_stay_delta ): @@ -861,6 +901,20 @@ def _get_backend_for_auto_switch( + f"{move_stay_delta}" ) + # Check if all backends are impossible and raise exception + if all_backends_impossible: + emit_metric(f"hybrid.auto.decision.impossible.group.{metrics_group}", 1) + get_logger().error( + f"All backends impossible for {class_of_wrapped_fn}.{function_name}: " + f"starting_backend={starting_backend}, stay_cost={stay_cost}" + ) + ErrorMessage.not_implemented( + f"No available backend can handle the workload for operation " + f"{class_of_wrapped_fn}.{function_name}. All backends returned COST_IMPOSSIBLE. " + f"Current backend: {starting_backend}, stay_cost: {stay_cost}. " + f"This operation cannot be executed due to memory or capability constraints across all backends." + ) + if best_backend == starting_backend: emit_metric(f"hybrid.auto.decision.{best_backend}.group.{metrics_group}", 0) get_logger().info( @@ -1228,7 +1282,7 @@ def register_function_for_post_op_switch( def register_function_for_pre_op_switch( - class_name: Optional[str], backend: str, method: str + class_name: Optional[str], backend: str, method: str, *, arg_based: bool = False ) -> None: """ Register a function for pre-operation backend switch. @@ -1242,7 +1296,12 @@ def register_function_for_pre_op_switch( Only consider switching when the starting backend is this one. method : str The name of the method to register. + arg_based : bool, default: False + If True, the switch will only be triggered if unsupported parameters are detected + for the operation, avoiding unnecessary backend switching when parameters + are supported. If False, the switch will always be considered (existing behavior). """ _CLASS_AND_BACKEND_TO_PRE_OP_SWITCH_METHODS[ BackendAndClassName(backend=backend, class_name=class_name) - ].add(method) + ][method] = arg_based +