Skip to content

Commit 4d90dd7

Browse files
authored
GH-47268: [C++][Compute] Fix discarded bad status for call binding (#47284)
### Rationale for this change Faithfully propagate the bad status thrown in call expression binding. ### What changes are included in this PR? Early return when status is bad. ### Are these changes tested? UT included. ### Are there any user-facing changes? None. * GitHub Issue: #47268 Authored-by: Rossi Sun <[email protected]> Signed-off-by: Rossi Sun <[email protected]>
1 parent 0c280fd commit 4d90dd7

File tree

2 files changed

+79
-46
lines changed

2 files changed

+79
-46
lines changed

cpp/src/arrow/compute/expression.cc

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -545,67 +545,61 @@ Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_
545545
std::vector<TypeHolder> types = GetTypes(call.arguments);
546546
ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context));
547547

548-
auto FinishBind = [&] {
549-
compute::KernelContext kernel_context(exec_context, call.kernel);
550-
if (call.kernel->init) {
551-
const FunctionOptions* options =
552-
call.options ? call.options.get() : call.function->default_options();
553-
ARROW_ASSIGN_OR_RAISE(
554-
call.kernel_state,
555-
call.kernel->init(&kernel_context, {call.kernel, types, options}));
556-
557-
kernel_context.SetState(call.kernel_state.get());
558-
}
559-
560-
ARROW_ASSIGN_OR_RAISE(
561-
call.type, call.kernel->signature->out_type().Resolve(&kernel_context, types));
562-
return Status::OK();
563-
};
564-
565548
// First try and bind exactly
566549
Result<const Kernel*> maybe_exact_match = call.function->DispatchExact(types);
567550
if (maybe_exact_match.ok()) {
568551
call.kernel = *maybe_exact_match;
569-
if (FinishBind().ok()) {
570-
return Expression(std::move(call));
552+
} else {
553+
if (!insert_implicit_casts) {
554+
return maybe_exact_match.status();
571555
}
572-
}
573556

574-
if (!insert_implicit_casts) {
575-
return maybe_exact_match.status();
576-
}
557+
// If exact binding fails, and we are allowed to cast, then prefer casting literals
558+
// first. Since DispatchBest generally prefers up-casting the best way to do this is
559+
// first down-cast the literals as much as possible
560+
types = GetTypesWithSmallestLiteralRepresentation(call.arguments);
561+
ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&types));
577562

578-
// If exact binding fails, and we are allowed to cast, then prefer casting literals
579-
// first. Since DispatchBest generally prefers up-casting the best way to do this is
580-
// first down-cast the literals as much as possible
581-
types = GetTypesWithSmallestLiteralRepresentation(call.arguments);
582-
ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&types));
563+
for (size_t i = 0; i < types.size(); ++i) {
564+
if (types[i] == call.arguments[i].type()) continue;
583565

584-
for (size_t i = 0; i < types.size(); ++i) {
585-
if (types[i] == call.arguments[i].type()) continue;
566+
if (const Datum* lit = call.arguments[i].literal()) {
567+
ARROW_ASSIGN_OR_RAISE(Datum new_lit,
568+
compute::Cast(*lit, types[i].GetSharedPtr()));
569+
call.arguments[i] = literal(std::move(new_lit));
570+
continue;
571+
}
586572

587-
if (const Datum* lit = call.arguments[i].literal()) {
588-
ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, types[i].GetSharedPtr()));
589-
call.arguments[i] = literal(std::move(new_lit));
590-
continue;
591-
}
573+
// construct an implicit cast Expression with which to replace this argument
574+
Expression::Call implicit_cast;
575+
implicit_cast.function_name = "cast";
576+
implicit_cast.arguments = {std::move(call.arguments[i])};
592577

593-
// construct an implicit cast Expression with which to replace this argument
594-
Expression::Call implicit_cast;
595-
implicit_cast.function_name = "cast";
596-
implicit_cast.arguments = {std::move(call.arguments[i])};
578+
// TODO(wesm): Use TypeHolder in options
579+
implicit_cast.options = std::make_shared<compute::CastOptions>(
580+
compute::CastOptions::Safe(types[i].GetSharedPtr()));
597581

598-
// TODO(wesm): Use TypeHolder in options
599-
implicit_cast.options = std::make_shared<compute::CastOptions>(
600-
compute::CastOptions::Safe(types[i].GetSharedPtr()));
582+
ARROW_ASSIGN_OR_RAISE(
583+
call.arguments[i],
584+
BindNonRecursive(std::move(implicit_cast),
585+
/*insert_implicit_casts=*/false, exec_context));
586+
}
587+
}
601588

589+
compute::KernelContext kernel_context(exec_context, call.kernel);
590+
if (call.kernel->init) {
591+
const FunctionOptions* options =
592+
call.options ? call.options.get() : call.function->default_options();
602593
ARROW_ASSIGN_OR_RAISE(
603-
call.arguments[i],
604-
BindNonRecursive(std::move(implicit_cast),
605-
/*insert_implicit_casts=*/false, exec_context));
594+
call.kernel_state,
595+
call.kernel->init(&kernel_context, {call.kernel, types, options}));
596+
597+
kernel_context.SetState(call.kernel_state.get());
606598
}
607599

608-
RETURN_NOT_OK(FinishBind());
600+
ARROW_ASSIGN_OR_RAISE(
601+
call.type, call.kernel->signature->out_type().Resolve(&kernel_context, types));
602+
609603
return Expression(std::move(call));
610604
}
611605

cpp/src/arrow/compute/expression_test.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,45 @@ TEST(Expression, BindCall) {
622622
add(cast(field_ref("i32"), float32()), literal(3.5F)));
623623
}
624624

625+
static Status RegisterInvalidInit() {
626+
const std::string name = "invalid_init";
627+
struct CastableFunction : public ScalarFunction {
628+
using ScalarFunction::ScalarFunction;
629+
630+
Result<const Kernel*> DispatchBest(std::vector<TypeHolder>* types) const override {
631+
return Status::Invalid("Shouldn't call DispatchBest on this function");
632+
}
633+
};
634+
auto func =
635+
std::make_shared<CastableFunction>(name, Arity::Unary(), FunctionDoc::Empty());
636+
637+
auto func_exec = [](KernelContext*, const ExecSpan&, ExecResult*) -> Status {
638+
return Status::OK();
639+
};
640+
auto func_init = [](KernelContext*,
641+
const KernelInitArgs&) -> Result<std::unique_ptr<KernelState>> {
642+
return Status::Invalid("Invalid Init");
643+
};
644+
645+
ScalarKernel kernel({int64()}, int64(), func_exec, func_init);
646+
ARROW_RETURN_NOT_OK(func->AddKernel(kernel));
647+
648+
auto registry = GetFunctionRegistry();
649+
ARROW_RETURN_NOT_OK(registry->AddFunction(std::move(func)));
650+
651+
return Status::OK();
652+
}
653+
654+
// GH-47268: The bad status in call binding is discarded.
655+
TEST(Expression, BindCallError) {
656+
ASSERT_OK(RegisterInvalidInit());
657+
auto expr = call("invalid_init", {field_ref("i64")});
658+
EXPECT_FALSE(expr.IsBound());
659+
660+
ASSERT_RAISES_WITH_MESSAGE(Invalid, "Invalid: Invalid Init",
661+
expr.Bind(*kBoringSchema).status());
662+
}
663+
625664
TEST(Expression, BindWithAliasCasts) {
626665
auto fm = GetFunctionRegistry();
627666
EXPECT_OK(fm->AddAlias("alias_cast", "cast"));

0 commit comments

Comments
 (0)