Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 96 additions & 92 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ Result<std::unique_ptr<substrait::Expression>> MakeListElementReference(
return MakeDirectReference(std::move(expr), std::move(ref_segment));
}

Result<std::unique_ptr<substrait::Expression::ScalarFunction>> EncodeSubstraitCall(
Result<std::unique_ptr<substrait::Expression::ScalarFunction>> EncodeSubstraitCall(
const SubstraitCall& call, ExtensionSet* ext_set,
const ConversionOptions& conversion_options) {
ARROW_ASSIGN_OR_RAISE(uint32_t anchor, ext_set->EncodeFunction(call.id()));
Expand All @@ -1272,19 +1272,18 @@ Result<std::unique_ptr<substrait::Expression::ScalarFunction>> EncodeSubstraitCa
" arguments but no argument could be found at index ", i);
}
}

for (const auto& option : call.options()) {
substrait::FunctionOption* fn_option = scalar_fn->add_options();
fn_option->set_name(option.first);
for (const auto& opt_val : option.second) {
std::string* pref = fn_option->add_preference();
*pref = opt_val;
*fn_option->add_preference() = opt_val;
}
}

return scalar_fn;
}


Result<std::vector<std::unique_ptr<substrait::Expression>>> DatumToLiterals(
const Datum& datum, ExtensionSet* ext_set,
const ConversionOptions& conversion_options) {
Expand Down Expand Up @@ -1365,83 +1364,87 @@ Result<std::unique_ptr<substrait::Expression>> ToProto(
auto call = CallNotNull(expr);

if (call->function_name == "case_when") {
auto conditions = call->arguments[0].call();
if (conditions && conditions->function_name == "make_struct") {
// catch the special case of calls convertible to IfThen
auto if_then_ = std::make_unique<substrait::Expression::IfThen>();

// don't try to convert argument 0 of the case_when; we have to convert the elements
// of make_struct individually
std::vector<std::unique_ptr<substrait::Expression>> arguments(
call->arguments.size() - 1);
for (size_t i = 1; i < call->arguments.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(arguments[i - 1],
ToProto(call->arguments[i], ext_set, conversion_options));
}
auto conditions = call->arguments[0].call();
if (conditions && conditions->function_name == "make_struct") {
auto if_then_ = std::make_unique<substrait::Expression::IfThen>();

for (size_t i = 0; i < conditions->arguments.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(auto cond_substrait, ToProto(conditions->arguments[i],
ext_set, conversion_options));
auto clause = std::make_unique<substrait::Expression::IfThen::IfClause>();
clause->set_allocated_if_(cond_substrait.release());
clause->set_allocated_then(arguments[i].release());
if_then_->mutable_ifs()->AddAllocated(clause.release());
}
std::vector<std::unique_ptr<substrait::Expression>> converted_args;
converted_args.reserve(call->arguments.size());

if_then_->set_allocated_else_(arguments.back().release());
for (size_t i = 1; i < call->arguments.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(
auto arg, ToProto(call->arguments[i], ext_set, conversion_options));
converted_args.push_back(std::move(arg));
}

out->set_allocated_if_then(if_then_.release());
return out;
for (size_t i = 0; i < conditions->arguments.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(
auto cond, ToProto(conditions->arguments[i], ext_set, conversion_options));
auto clause = std::make_unique<substrait::Expression::IfThen::IfClause>();
clause->set_allocated_if_(cond.release());
clause->set_allocated_then(converted_args[i].release());
if_then_->mutable_ifs()->AddAllocated(clause.release());
}

if_then_->set_allocated_else_(converted_args.back().release());
out->set_allocated_if_then(if_then_.release());
return out;
}
}

// the remaining function pattern matchers only convert the function itself, so we
// should be able to convert all its arguments first here
std::vector<std::unique_ptr<substrait::Expression>> arguments(call->arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(arguments[i],
ToProto(call->arguments[i], ext_set, conversion_options));
}

if (call->function_name == "struct_field") {
// catch the special case of calls convertible to a StructField
const auto& field_options =
checked_cast<const compute::StructFieldOptions&>(*call->options);
const DataType& struct_type = *call->arguments[0].type();
DCHECK_EQ(struct_type.id(), Type::STRUCT);

ARROW_ASSIGN_OR_RAISE(auto field_path, field_options.field_ref.FindOne(struct_type));
out = std::move(arguments[0]);
for (int index : field_path.indices()) {
ARROW_ASSIGN_OR_RAISE(out, MakeStructFieldReference(std::move(out), index));
}
return out;
ARROW_ASSIGN_OR_RAISE(
auto base, ToProto(call->arguments[0], ext_set, conversion_options));

const auto& field_options =
checked_cast<const compute::StructFieldOptions&>(*call->options);
const DataType& struct_type = *call->arguments[0].type();

ARROW_ASSIGN_OR_RAISE(auto field_path, field_options.field_ref.FindOne(struct_type));
out = std::move(base);

for (int index : field_path.indices()) {
ARROW_ASSIGN_OR_RAISE(out, MakeStructFieldReference(std::move(out), index));
}
return out;
}


if (call->function_name == "list_element") {
// catch the special case of calls convertible to a ListElement
if (arguments[0]->has_selection() &&
arguments[0]->selection().has_direct_reference()) {
if (arguments[1]->has_literal() && arguments[1]->literal().literal_type_case() ==
substrait::Expression::Literal::kI32) {
return MakeListElementReference(std::move(arguments[0]),
arguments[1]->literal().i32());
}
}
ARROW_ASSIGN_OR_RAISE(
auto base, ToProto(call->arguments[0], ext_set, conversion_options));
ARROW_ASSIGN_OR_RAISE(
auto offset, ToProto(call->arguments[1], ext_set, conversion_options));

if (base->has_selection() &&
offset->has_literal() &&
offset->literal().literal_type_case() ==
substrait::Expression::Literal::kI32) {
return MakeListElementReference(std::move(base), offset->literal().i32());
}
}


if (call->function_name == "if_else") {
// catch the special case of calls convertible to IfThen
auto if_clause = std::make_unique<substrait::Expression::IfThen::IfClause>();
if_clause->set_allocated_if_(arguments[0].release());
if_clause->set_allocated_then(arguments[1].release());
ARROW_ASSIGN_OR_RAISE(
auto if_, ToProto(call->arguments[0], ext_set, conversion_options));
ARROW_ASSIGN_OR_RAISE(
auto then_, ToProto(call->arguments[1], ext_set, conversion_options));
ARROW_ASSIGN_OR_RAISE(
auto else_, ToProto(call->arguments[2], ext_set, conversion_options));

auto if_then = std::make_unique<substrait::Expression::IfThen>();
if_then->mutable_ifs()->AddAllocated(if_clause.release());
if_then->set_allocated_else_(arguments[2].release());
auto clause = std::make_unique<substrait::Expression::IfThen::IfClause>();
clause->set_allocated_if_(if_.release());
clause->set_allocated_then(then_.release());

out->set_allocated_if_then(if_then.release());
return out;
auto if_then = std::make_unique<substrait::Expression::IfThen>();
if_then->mutable_ifs()->AddAllocated(clause.release());
if_then->set_allocated_else_(else_.release());

out->set_allocated_if_then(if_then.release());
return out;
} else if (call->function_name == "cast") {
auto cast = std::make_unique<substrait::Expression::Cast>();

Expand All @@ -1456,44 +1459,45 @@ Result<std::unique_ptr<substrait::Expression>> ToProto(
return Status::Invalid("Substrait is only capable of representing unsafe casts");
}

if (arguments.size() != 1) {
return Status::Invalid(
"A call to the cast function must have exactly one argument");
}

cast->set_allocated_input(arguments[0].release());

ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Type> to_type,
ToProto(*cast_options->to_type.type, /*nullable=*/true, ext_set,
conversion_options));
if (call->arguments.size() != 1) {
return Status::Invalid(
"A call to the cast function must have exactly one argument");
}

cast->set_allocated_type(to_type.release());
ARROW_ASSIGN_OR_RAISE(
auto input, ToProto(call->arguments[0], ext_set, conversion_options));
cast->set_allocated_input(input.release());

out->set_allocated_cast(cast.release());
return out;
} else if (call->function_name == "is_in") {
auto or_list = std::make_unique<substrait::Expression::SingularOrList>();
auto or_list = std::make_unique<substrait::Expression::SingularOrList>();

if (arguments.size() != 1) {
return Status::Invalid(
"A call to the is_in function must have exactly one argument");
}
if (call->arguments.size() != 1) {
return Status::Invalid(
"A call to the is_in function must have exactly one argument");
}

or_list->set_allocated_value(arguments[0].release());
std::shared_ptr<compute::SetLookupOptions> is_in_options =
internal::checked_pointer_cast<compute::SetLookupOptions>(call->options);
ARROW_ASSIGN_OR_RAISE(
auto value, ToProto(call->arguments[0], ext_set, conversion_options));
or_list->set_allocated_value(value.release());

// TODO(GH-36420) Acero does not currently handle nulls correctly
ARROW_ASSIGN_OR_RAISE(
std::vector<std::unique_ptr<substrait::Expression>> options,
DatumToLiterals(is_in_options->value_set, ext_set, conversion_options));
for (auto& option : options) {
or_list->mutable_options()->AddAllocated(option.release());
}
out->set_allocated_singular_or_list(or_list.release());
return out;
std::shared_ptr<compute::SetLookupOptions> is_in_options =
internal::checked_pointer_cast<compute::SetLookupOptions>(call->options);

ARROW_ASSIGN_OR_RAISE(
std::vector<std::unique_ptr<substrait::Expression>> options,
DatumToLiterals(is_in_options->value_set, ext_set, conversion_options));

for (auto& option : options) {
or_list->mutable_options()->AddAllocated(option.release());
}

out->set_allocated_singular_or_list(or_list.release());
return out;
}


// other expression types dive into extensions immediately
Result<ExtensionIdRegistry::ArrowToSubstraitCall> maybe_converter =
ext_set->registry()->GetArrowToSubstraitCall(call->function_name);
Expand Down