Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 85 additions & 55 deletions cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,25 @@ struct StringBinaryTransformExecBase {
using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
using ArrayType2 = typename TypeTraits<Type2>::ArrayType;

// Helper to get view from ArraySpan (works for both string and numeric types)
template <typename T = Type2>
static typename std::enable_if<is_base_binary_type<T>::value, ViewType2>::type
GetViewFromSpan(const ArraySpan& span, int64_t i) {
using offset_type_2 = typename T::offset_type;
const offset_type_2* offsets = span.GetValues<offset_type_2>(1);
const uint8_t* data = span.GetValues<uint8_t>(2, 0);
const auto start = offsets[i];
const auto end = offsets[i + 1];
return ViewType2(reinterpret_cast<const char*>(data + start), end - start);
}

template <typename T = Type2>
static typename std::enable_if<!is_base_binary_type<T>::value, ViewType2>::type
GetViewFromSpan(const ArraySpan& span, int64_t i) {
// For numeric types, just return the value directly
return span.GetValues<ViewType2>(1)[i];
}

static Status Execute(KernelContext* ctx, StringTransform* transform,
const ExecSpan& batch, ExecResult* out) {
if (batch[0].is_scalar()) {
Expand Down Expand Up @@ -325,14 +344,11 @@ struct StringBinaryTransformExecBase {
output_offsets[0] = 0;
offset_type output_ncodeunits = 0;

// TODO(wesm): rewrite to not require boxing
const ArrayType2 array2(data2.ToArrayData());

// Apply transform
RETURN_NOT_OK(arrow::internal::VisitBitBlocks(
data2.buffers[0].data, data2.offset, data2.length,
[&](int64_t i) {
ViewType2 value2 = array2.GetView(i);
ViewType2 value2 = GetViewFromSpan(data2, i);
ARROW_ASSIGN_OR_RAISE(
auto encoded_nbytes_,
transform->Transform(input_string, input_ncodeunits, value2,
Expand Down Expand Up @@ -377,17 +393,14 @@ struct StringBinaryTransformExecBase {
const offset_type* data1_offsets = data1.GetValues<offset_type>(1);
const uint8_t* data1_data = data1.GetValues<uint8_t>(2, /*offset=*/0);

// TODO(wesm): rewrite to not require boxing
const ArrayType2 array2(data2.ToArrayData());

// Apply transform
RETURN_NOT_OK(arrow::internal::VisitTwoBitBlocks(
data1.buffers[0].data, data1.offset, data2.buffers[0].data, data2.offset,
data1.length,
[&](int64_t i) {
const offset_type input_ncodeunits = data1_offsets[i + 1] - data1_offsets[i];
const uint8_t* input_string = data1_data + data1_offsets[i];
ViewType2 value2 = array2.GetView(i);
ViewType2 value2 = GetViewFromSpan(data2, i);
ARROW_ASSIGN_OR_RAISE(
auto encoded_nbytes_,
transform->Transform(input_string, input_ncodeunits, value2,
Expand Down Expand Up @@ -3051,6 +3064,8 @@ struct BinaryJoin {
return ExecArrayArray(ctx, batch[0].array, batch[1].array, out);
}

// Lookup helpers for accessing list offsets and separator strings

struct ListScalarOffsetLookup {
const ArrayType& values;

Expand All @@ -3059,16 +3074,21 @@ struct BinaryJoin {
bool IsNull(int64_t i) { return false; }
};

struct ListArrayOffsetLookup {
explicit ListArrayOffsetLookup(const ListArrayType& lists)
: lists_(lists), offsets_(lists.raw_value_offsets()) {}
struct ListArraySpanOffsetLookup {
explicit ListArraySpanOffsetLookup(const ArraySpan& lists)
: validity_(lists.buffers[0].data),
offset_(lists.offset),
offsets_(lists.GetValues<ListOffsetType>(1)) {}

int64_t GetStart(int64_t i) { return offsets_[i]; }
int64_t GetStop(int64_t i) { return offsets_[i + 1]; }
bool IsNull(int64_t i) { return lists_.IsNull(i); }
bool IsNull(int64_t i) {
return validity_ && !bit_util::GetBit(validity_, offset_ + i);
}

private:
const ListArrayType& lists_;
const uint8_t* validity_;
int64_t offset_;
const ListOffsetType* offsets_;
};

Expand All @@ -3079,11 +3099,27 @@ struct BinaryJoin {
std::string_view GetView(int64_t i) { return separator; }
};

struct SeparatorArrayLookup {
const ArrayType& separators;
struct SeparatorArraySpanLookup {
explicit SeparatorArraySpanLookup(const ArraySpan& separators)
: validity_(separators.buffers[0].data),
offset_(separators.offset),
offsets_(separators.GetValues<typename ArrayType::offset_type>(1)),
data_(separators.GetValues<uint8_t>(2, 0)) {}

bool IsNull(int64_t i) { return separators.IsNull(i); }
std::string_view GetView(int64_t i) { return separators.GetView(i); }
bool IsNull(int64_t i) {
return validity_ && !bit_util::GetBit(validity_, offset_ + i);
}
std::string_view GetView(int64_t i) {
auto start = offsets_[i];
auto end = offsets_[i + 1];
return std::string_view(reinterpret_cast<const char*>(data_ + start), end - start);
}

private:
const uint8_t* validity_;
int64_t offset_;
const typename ArrayType::offset_type* offsets_;
const uint8_t* data_;
};

// Scalar, array -> array
Expand All @@ -3105,104 +3141,98 @@ struct BinaryJoin {
out->value = std::move(nulls->data());
return Status::OK();
}
// TODO(wesm): rewrite to not use ArrayData
const ArrayType separators(right.ToArrayData());

using offset_type = typename ArrayType::offset_type;
const offset_type* sep_offsets = right.GetValues<offset_type>(1);

BuilderType builder(ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(separators.length()));
RETURN_NOT_OK(builder.Reserve(right.length));

// Presize data to avoid multiple reallocations when joining strings
int64_t total_data_length = 0;
const int64_t list_length = strings.length();
if (list_length) {
const int64_t string_length = strings.total_values_length();
total_data_length +=
string_length * (separators.length() - separators.null_count());
for (int64_t i = 0; i < separators.length(); ++i) {
if (separators.IsNull(i)) {
total_data_length += string_length * (right.length - right.GetNullCount());
for (int64_t i = 0; i < right.length; ++i) {
if (right.IsNull(i)) {
continue;
}
total_data_length += (list_length - 1) * separators.value_length(i);
offset_type sep_length = sep_offsets[i + 1] - sep_offsets[i];
total_data_length += (list_length - 1) * sep_length;
}
}
RETURN_NOT_OK(builder.ReserveData(total_data_length));

return JoinStrings(separators.length(), strings, ListScalarOffsetLookup{strings},
SeparatorArrayLookup{separators}, &builder, out);
return JoinStrings(right.length, strings, ListScalarOffsetLookup{strings},
SeparatorArraySpanLookup{right}, &builder, out);
}

// Array, scalar -> array
static Status ExecArrayScalar(KernelContext* ctx, const ArraySpan& left,
const Scalar& right, ExecResult* out) {
// TODO(wesm): rewrite to not use ArrayData
const ListArrayType lists(left.ToArrayData());
const auto& separator_scalar = checked_cast<const BaseBinaryScalar&>(right);

if (!separator_scalar.is_valid) {
ARROW_ASSIGN_OR_RAISE(
auto nulls,
MakeArrayOfNull(lists.value_type(), lists.length(), ctx->memory_pool()));
ARROW_ASSIGN_OR_RAISE(auto nulls, MakeArrayOfNull(left.type->field(0)->type(),
left.length, ctx->memory_pool()));
out->value = std::move(nulls->data());
return Status::OK();
}

std::string_view separator(*separator_scalar.value);
const auto& strings = checked_cast<const ArrayType&>(*lists.values());
const auto list_offsets = lists.raw_value_offsets();
const auto list_offsets = left.GetValues<ListOffsetType>(1);
const ArraySpan& strings_span = left.child_data[0];
const ArrayType strings(strings_span.ToArrayData());

BuilderType builder(ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(lists.length()));
RETURN_NOT_OK(builder.Reserve(left.length));

// Presize data to avoid multiple reallocations when joining strings
int64_t total_data_length = strings.total_values_length();
for (int64_t i = 0; i < lists.length(); ++i) {
for (int64_t i = 0; i < left.length; ++i) {
const auto start = list_offsets[i], end = list_offsets[i + 1];
if (end > start && !ValuesContainNull(strings, start, end)) {
total_data_length += (end - start - 1) * separator.length();
}
}
RETURN_NOT_OK(builder.ReserveData(total_data_length));

return JoinStrings(lists.length(), strings, ListArrayOffsetLookup{lists},
return JoinStrings(left.length, strings, ListArraySpanOffsetLookup{left},
SeparatorScalarLookup{separator}, &builder, out);
}

// Array, array -> array
static Status ExecArrayArray(KernelContext* ctx, const ArraySpan& left,
const ArraySpan& right, ExecResult* out) {
// TODO(wesm): rewrite to not use ArrayData
const ListArrayType lists(left.ToArrayData());
const ArrayType separators(right.ToArrayData());

const auto& strings = checked_cast<const ArrayType&>(*lists.values());
const auto list_offsets = lists.raw_value_offsets();
const auto list_offsets = left.GetValues<ListOffsetType>(1);
const ArraySpan& strings_span = left.child_data[0];
const ArrayType strings(strings_span.ToArrayData());
const auto string_offsets = strings.raw_value_offsets();

using offset_type = typename ArrayType::offset_type;
const offset_type* sep_offsets = right.GetValues<offset_type>(1);

BuilderType builder(ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(lists.length()));
RETURN_NOT_OK(builder.Reserve(left.length));

// Presize data to avoid multiple reallocations when joining strings
int64_t total_data_length = 0;
for (int64_t i = 0; i < lists.length(); ++i) {
if (separators.IsNull(i)) {
for (int64_t i = 0; i < left.length; ++i) {
if (right.IsNull(i)) {
continue;
}
const auto start = list_offsets[i], end = list_offsets[i + 1];
if (end > start && !ValuesContainNull(strings, start, end)) {
total_data_length += string_offsets[end] - string_offsets[start];
total_data_length += (end - start - 1) * separators.value_length(i);
offset_type sep_length = sep_offsets[i + 1] - sep_offsets[i];
total_data_length += (end - start - 1) * sep_length;
}
}
RETURN_NOT_OK(builder.ReserveData(total_data_length));

struct SeparatorLookup {
const ArrayType& separators;

bool IsNull(int64_t i) { return separators.IsNull(i); }
std::string_view GetView(int64_t i) { return separators.GetView(i); }
};
return JoinStrings(lists.length(), strings, ListArrayOffsetLookup{lists},
SeparatorArrayLookup{separators}, &builder, out);
return JoinStrings(left.length, strings, ListArraySpanOffsetLookup{left},
SeparatorArraySpanLookup{right}, &builder, out);
}

template <typename ListOffsetLookup, typename SeparatorLookup>
Expand Down
Loading