diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc index dd5abed16c3..f1bcbca96bb 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -224,6 +224,25 @@ struct StringBinaryTransformExecBase { using ArrayType1 = typename TypeTraits::ArrayType; using ArrayType2 = typename TypeTraits::ArrayType; + // Helper to get view from ArraySpan (works for both string and numeric types) + template + static typename std::enable_if::value, ViewType2>::type + GetViewFromSpan(const ArraySpan& span, int64_t i) { + using offset_type_2 = typename T::offset_type; + const offset_type_2* offsets = span.GetValues(1); + const uint8_t* data = span.GetValues(2, 0); + const auto start = offsets[i]; + const auto end = offsets[i + 1]; + return ViewType2(reinterpret_cast(data + start), end - start); + } + + template + static typename std::enable_if::value, ViewType2>::type + GetViewFromSpan(const ArraySpan& span, int64_t i) { + // For numeric types, just return the value directly + return span.GetValues(1)[i]; + } + static Status Execute(KernelContext* ctx, StringTransform* transform, const ExecSpan& batch, ExecResult* out) { if (batch[0].is_scalar()) { @@ -325,14 +344,11 @@ struct StringBinaryTransformExecBase { output_offsets[0] = 0; offset_type output_ncodeunits = 0; - // TODO(wesm): rewrite to not require boxing - const ArrayType2 array2(data2.ToArrayData()); - // Apply transform RETURN_NOT_OK(arrow::internal::VisitBitBlocks( data2.buffers[0].data, data2.offset, data2.length, [&](int64_t i) { - ViewType2 value2 = array2.GetView(i); + ViewType2 value2 = GetViewFromSpan(data2, i); ARROW_ASSIGN_OR_RAISE( auto encoded_nbytes_, transform->Transform(input_string, input_ncodeunits, value2, @@ -377,9 +393,6 @@ struct StringBinaryTransformExecBase { const offset_type* data1_offsets = data1.GetValues(1); const uint8_t* data1_data = data1.GetValues(2, /*offset=*/0); - // TODO(wesm): rewrite to not require boxing - const ArrayType2 array2(data2.ToArrayData()); - // Apply transform RETURN_NOT_OK(arrow::internal::VisitTwoBitBlocks( data1.buffers[0].data, data1.offset, data2.buffers[0].data, data2.offset, @@ -387,7 +400,7 @@ struct StringBinaryTransformExecBase { [&](int64_t i) { const offset_type input_ncodeunits = data1_offsets[i + 1] - data1_offsets[i]; const uint8_t* input_string = data1_data + data1_offsets[i]; - ViewType2 value2 = array2.GetView(i); + ViewType2 value2 = GetViewFromSpan(data2, i); ARROW_ASSIGN_OR_RAISE( auto encoded_nbytes_, transform->Transform(input_string, input_ncodeunits, value2, @@ -3051,6 +3064,8 @@ struct BinaryJoin { return ExecArrayArray(ctx, batch[0].array, batch[1].array, out); } + // Lookup helpers for accessing list offsets and separator strings + struct ListScalarOffsetLookup { const ArrayType& values; @@ -3059,16 +3074,21 @@ struct BinaryJoin { bool IsNull(int64_t i) { return false; } }; - struct ListArrayOffsetLookup { - explicit ListArrayOffsetLookup(const ListArrayType& lists) - : lists_(lists), offsets_(lists.raw_value_offsets()) {} + struct ListArraySpanOffsetLookup { + explicit ListArraySpanOffsetLookup(const ArraySpan& lists) + : validity_(lists.buffers[0].data), + offset_(lists.offset), + offsets_(lists.GetValues(1)) {} int64_t GetStart(int64_t i) { return offsets_[i]; } int64_t GetStop(int64_t i) { return offsets_[i + 1]; } - bool IsNull(int64_t i) { return lists_.IsNull(i); } + bool IsNull(int64_t i) { + return validity_ && !bit_util::GetBit(validity_, offset_ + i); + } private: - const ListArrayType& lists_; + const uint8_t* validity_; + int64_t offset_; const ListOffsetType* offsets_; }; @@ -3079,11 +3099,27 @@ struct BinaryJoin { std::string_view GetView(int64_t i) { return separator; } }; - struct SeparatorArrayLookup { - const ArrayType& separators; + struct SeparatorArraySpanLookup { + explicit SeparatorArraySpanLookup(const ArraySpan& separators) + : validity_(separators.buffers[0].data), + offset_(separators.offset), + offsets_(separators.GetValues(1)), + data_(separators.GetValues(2, 0)) {} - bool IsNull(int64_t i) { return separators.IsNull(i); } - std::string_view GetView(int64_t i) { return separators.GetView(i); } + bool IsNull(int64_t i) { + return validity_ && !bit_util::GetBit(validity_, offset_ + i); + } + std::string_view GetView(int64_t i) { + auto start = offsets_[i]; + auto end = offsets_[i + 1]; + return std::string_view(reinterpret_cast(data_ + start), end - start); + } + + private: + const uint8_t* validity_; + int64_t offset_; + const typename ArrayType::offset_type* offsets_; + const uint8_t* data_; }; // Scalar, array -> array @@ -3105,57 +3141,56 @@ struct BinaryJoin { out->value = std::move(nulls->data()); return Status::OK(); } - // TODO(wesm): rewrite to not use ArrayData - const ArrayType separators(right.ToArrayData()); + + using offset_type = typename ArrayType::offset_type; + const offset_type* sep_offsets = right.GetValues(1); BuilderType builder(ctx->memory_pool()); - RETURN_NOT_OK(builder.Reserve(separators.length())); + RETURN_NOT_OK(builder.Reserve(right.length)); // Presize data to avoid multiple reallocations when joining strings int64_t total_data_length = 0; const int64_t list_length = strings.length(); if (list_length) { const int64_t string_length = strings.total_values_length(); - total_data_length += - string_length * (separators.length() - separators.null_count()); - for (int64_t i = 0; i < separators.length(); ++i) { - if (separators.IsNull(i)) { + total_data_length += string_length * (right.length - right.GetNullCount()); + for (int64_t i = 0; i < right.length; ++i) { + if (right.IsNull(i)) { continue; } - total_data_length += (list_length - 1) * separators.value_length(i); + offset_type sep_length = sep_offsets[i + 1] - sep_offsets[i]; + total_data_length += (list_length - 1) * sep_length; } } RETURN_NOT_OK(builder.ReserveData(total_data_length)); - return JoinStrings(separators.length(), strings, ListScalarOffsetLookup{strings}, - SeparatorArrayLookup{separators}, &builder, out); + return JoinStrings(right.length, strings, ListScalarOffsetLookup{strings}, + SeparatorArraySpanLookup{right}, &builder, out); } // Array, scalar -> array static Status ExecArrayScalar(KernelContext* ctx, const ArraySpan& left, const Scalar& right, ExecResult* out) { - // TODO(wesm): rewrite to not use ArrayData - const ListArrayType lists(left.ToArrayData()); const auto& separator_scalar = checked_cast(right); if (!separator_scalar.is_valid) { - ARROW_ASSIGN_OR_RAISE( - auto nulls, - MakeArrayOfNull(lists.value_type(), lists.length(), ctx->memory_pool())); + ARROW_ASSIGN_OR_RAISE(auto nulls, MakeArrayOfNull(left.type->field(0)->type(), + left.length, ctx->memory_pool())); out->value = std::move(nulls->data()); return Status::OK(); } std::string_view separator(*separator_scalar.value); - const auto& strings = checked_cast(*lists.values()); - const auto list_offsets = lists.raw_value_offsets(); + const auto list_offsets = left.GetValues(1); + const ArraySpan& strings_span = left.child_data[0]; + const ArrayType strings(strings_span.ToArrayData()); BuilderType builder(ctx->memory_pool()); - RETURN_NOT_OK(builder.Reserve(lists.length())); + RETURN_NOT_OK(builder.Reserve(left.length)); // Presize data to avoid multiple reallocations when joining strings int64_t total_data_length = strings.total_values_length(); - for (int64_t i = 0; i < lists.length(); ++i) { + for (int64_t i = 0; i < left.length; ++i) { const auto start = list_offsets[i], end = list_offsets[i + 1]; if (end > start && !ValuesContainNull(strings, start, end)) { total_data_length += (end - start - 1) * separator.length(); @@ -3163,46 +3198,41 @@ struct BinaryJoin { } RETURN_NOT_OK(builder.ReserveData(total_data_length)); - return JoinStrings(lists.length(), strings, ListArrayOffsetLookup{lists}, + return JoinStrings(left.length, strings, ListArraySpanOffsetLookup{left}, SeparatorScalarLookup{separator}, &builder, out); } // Array, array -> array static Status ExecArrayArray(KernelContext* ctx, const ArraySpan& left, const ArraySpan& right, ExecResult* out) { - // TODO(wesm): rewrite to not use ArrayData - const ListArrayType lists(left.ToArrayData()); - const ArrayType separators(right.ToArrayData()); - - const auto& strings = checked_cast(*lists.values()); - const auto list_offsets = lists.raw_value_offsets(); + const auto list_offsets = left.GetValues(1); + const ArraySpan& strings_span = left.child_data[0]; + const ArrayType strings(strings_span.ToArrayData()); const auto string_offsets = strings.raw_value_offsets(); + using offset_type = typename ArrayType::offset_type; + const offset_type* sep_offsets = right.GetValues(1); + BuilderType builder(ctx->memory_pool()); - RETURN_NOT_OK(builder.Reserve(lists.length())); + RETURN_NOT_OK(builder.Reserve(left.length)); // Presize data to avoid multiple reallocations when joining strings int64_t total_data_length = 0; - for (int64_t i = 0; i < lists.length(); ++i) { - if (separators.IsNull(i)) { + for (int64_t i = 0; i < left.length; ++i) { + if (right.IsNull(i)) { continue; } const auto start = list_offsets[i], end = list_offsets[i + 1]; if (end > start && !ValuesContainNull(strings, start, end)) { total_data_length += string_offsets[end] - string_offsets[start]; - total_data_length += (end - start - 1) * separators.value_length(i); + offset_type sep_length = sep_offsets[i + 1] - sep_offsets[i]; + total_data_length += (end - start - 1) * sep_length; } } RETURN_NOT_OK(builder.ReserveData(total_data_length)); - struct SeparatorLookup { - const ArrayType& separators; - - bool IsNull(int64_t i) { return separators.IsNull(i); } - std::string_view GetView(int64_t i) { return separators.GetView(i); } - }; - return JoinStrings(lists.length(), strings, ListArrayOffsetLookup{lists}, - SeparatorArrayLookup{separators}, &builder, out); + return JoinStrings(left.length, strings, ListArraySpanOffsetLookup{left}, + SeparatorArraySpanLookup{right}, &builder, out); } template