diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 61a16f5f5eb..e36a7acabdb 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -325,6 +325,9 @@ static auto kElementWiseAggregateOptionsType = DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls)); static auto kExtractRegexOptionsType = GetFunctionOptionsType( DataMember("pattern", &ExtractRegexOptions::pattern)); +static auto kExtractRegexSpanOptionsType = + GetFunctionOptionsType( + DataMember("pattern", &ExtractRegexSpanOptions::pattern)); static auto kJoinOptionsType = GetFunctionOptionsType( DataMember("null_handling", &JoinOptions::null_handling), DataMember("null_replacement", &JoinOptions::null_replacement)); @@ -438,6 +441,12 @@ ExtractRegexOptions::ExtractRegexOptions(std::string pattern) ExtractRegexOptions::ExtractRegexOptions() : ExtractRegexOptions("") {} constexpr char ExtractRegexOptions::kTypeName[]; +ExtractRegexSpanOptions::ExtractRegexSpanOptions(std::string pattern) + : FunctionOptions(internal::kExtractRegexSpanOptionsType), + pattern(std::move(pattern)) {} +ExtractRegexSpanOptions::ExtractRegexSpanOptions() : ExtractRegexSpanOptions("") {} +constexpr char ExtractRegexSpanOptions::kTypeName[]; + JoinOptions::JoinOptions(NullHandlingBehavior null_handling, std::string null_replacement) : FunctionOptions(internal::kJoinOptionsType), null_handling(null_handling), @@ -684,6 +693,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kDayOfWeekOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexSpanOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kListSliceOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kMakeStructOptionsType)); diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 0e5a388b107..492ea05f6d5 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -265,6 +265,16 @@ class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { std::string pattern; }; +class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions { + public: + explicit ExtractRegexSpanOptions(std::string pattern); + ExtractRegexSpanOptions(); + static constexpr char const kTypeName[] = "ExtractRegexSpanOptions"; + + /// Regular expression with named capture fields + std::string pattern; +}; + /// Options for IsIn and IndexIn functions class ARROW_EXPORT SetLookupOptions : public FunctionOptions { public: diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc index e58f7b065a8..6f02432d3d4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -22,6 +22,7 @@ #include #include "arrow/array/builder_nested.h" +#include "arrow/array/builder_primitive.h" #include "arrow/compute/kernels/scalar_string_internal.h" #include "arrow/result.h" #include "arrow/util/config.h" @@ -2184,20 +2185,12 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) { using ExtractRegexState = OptionsWrapper; -// TODO cache this once per ExtractRegexOptions -struct ExtractRegexData { - // Use unique_ptr<> because RE2 is non-movable (for ARROW_ASSIGN_OR_RAISE) - std::unique_ptr regex; - std::vector group_names; - - static Result Make(const ExtractRegexOptions& options, - bool is_utf8 = true) { - ExtractRegexData data(options.pattern, is_utf8); - RETURN_NOT_OK(RegexStatus(*data.regex)); - - const int group_count = data.regex->NumberOfCapturingGroups(); - const auto& name_map = data.regex->CapturingGroupNames(); - data.group_names.reserve(group_count); +struct BaseExtractRegexData { + Status Init() { + RETURN_NOT_OK(RegexStatus(*regex)); + const int group_count = regex->NumberOfCapturingGroups(); + const auto& name_map = regex->CapturingGroupNames(); + group_names.reserve(group_count); for (int i = 0; i < group_count; i++) { auto item = name_map.find(i + 1); // re2 starts counting from 1 @@ -2205,8 +2198,27 @@ struct ExtractRegexData { // XXX should we instead just create fields with an empty name? return Status::Invalid("Regular expression contains unnamed groups"); } - data.group_names.emplace_back(item->second); + group_names.emplace_back(item->second); } + return Status::OK(); + } + + int64_t num_groups() const { return static_cast(group_names.size()); } + + std::unique_ptr regex; + std::vector group_names; + + protected: + explicit BaseExtractRegexData(const std::string& pattern, bool is_utf8 = true) + : regex(new RE2(pattern, MakeRE2Options(is_utf8))) {} +}; + +// TODO cache this once per ExtractRegexOptions +struct ExtractRegexData : public BaseExtractRegexData { + static Result Make(const ExtractRegexOptions& options, + bool is_utf8 = true) { + ExtractRegexData data(options.pattern, is_utf8); + ARROW_RETURN_NOT_OK(data.Init()); return data; } @@ -2220,7 +2232,7 @@ struct ExtractRegexData { // of each field in the output struct type. DCHECK(is_base_binary_like(input_type->id())); FieldVector fields; - fields.reserve(group_names.size()); + fields.reserve(num_groups()); std::shared_ptr owned_type = input_type->GetSharedPtr(); std::transform(group_names.begin(), group_names.end(), std::back_inserter(fields), [&](const std::string& name) { return field(name, owned_type); }); @@ -2229,7 +2241,7 @@ struct ExtractRegexData { private: explicit ExtractRegexData(const std::string& pattern, bool is_utf8 = true) - : regex(new RE2(pattern, MakeRE2Options(is_utf8))) {} + : BaseExtractRegexData(pattern, is_utf8) {} }; Result ResolveExtractRegexOutput(KernelContext* ctx, @@ -2240,7 +2252,7 @@ Result ResolveExtractRegexOutput(KernelContext* ctx, } struct ExtractRegexBase { - const ExtractRegexData& data; + const BaseExtractRegexData& data; const int group_count; std::vector found_values; std::vector args; @@ -2248,9 +2260,9 @@ struct ExtractRegexBase { const RE2::Arg** args_pointers_start; const RE2::Arg* null_arg = nullptr; - explicit ExtractRegexBase(const ExtractRegexData& data) + explicit ExtractRegexBase(const BaseExtractRegexData& data) : data(data), - group_count(static_cast(data.group_names.size())), + group_count(static_cast(data.num_groups())), found_values(group_count) { args.reserve(group_count); args_pointers.reserve(group_count); @@ -2280,25 +2292,23 @@ struct ExtractRegex : public ExtractRegexBase { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { ExtractRegexOptions options = ExtractRegexState::Get(ctx); ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, Type::is_utf8)); - return ExtractRegex{data}.Extract(ctx, batch, out); + return ExtractRegex(data).Extract(ctx, batch, out); } Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - // TODO: why is this needed? Type resolution should already be - // done and the output type set in the output variable - ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, data.ResolveOutputType(batch.GetTypes())); - DCHECK_NE(out_type.type, nullptr); - std::shared_ptr type = out_type.GetSharedPtr(); - - std::unique_ptr array_builder; - RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type, &array_builder)); + DCHECK_NE(out->array_data(), nullptr); + std::shared_ptr type = out->array_data()->type; + ARROW_ASSIGN_OR_RAISE(std::unique_ptr array_builder, + MakeBuilder(type, ctx->memory_pool())); StructBuilder* struct_builder = checked_cast(array_builder.get()); + ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].length())); std::vector field_builders; field_builders.reserve(group_count); for (int i = 0; i < group_count; i++) { field_builders.push_back( checked_cast(struct_builder->field_builder(i))); + RETURN_NOT_OK(field_builders.back()->Reserve(batch[0].length())); } auto visit_null = [&]() { return struct_builder->AppendNull(); }; @@ -2347,6 +2357,142 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) { } DCHECK_OK(registry->AddFunction(std::move(func))); } + +struct ExtractRegexSpanData : public BaseExtractRegexData { + static Result Make(const std::string& pattern, + bool is_utf8 = true) { + auto data = ExtractRegexSpanData(pattern, is_utf8); + ARROW_RETURN_NOT_OK(data.Init()); + return data; + } + + Result ResolveOutputType(const std::vector& types) const { + const DataType* input_type = types[0].type; + if (input_type == nullptr) { + return nullptr; + } + DCHECK(is_base_binary_like(input_type->id())); + FieldVector fields; + fields.reserve(num_groups()); + auto index_type = is_binary_like(input_type->id()) ? int32() : int64(); + for (const auto& group_name : group_names) { + // list size is 2 as every span contains position and length + fields.push_back(field(group_name, fixed_size_list(index_type, 2))); + } + return struct_(std::move(fields)); + } + + private: + ExtractRegexSpanData(const std::string& pattern, const bool is_utf8) + : BaseExtractRegexData(pattern, is_utf8) {} +}; + +template +struct ExtractRegexSpan : ExtractRegexBase { + using ArrayType = typename TypeTraits::ArrayType; + using BuilderType = typename TypeTraits::BuilderType; + using offset_type = typename Type::offset_type; + using OffsetBuilderType = + typename TypeTraits::ArrowType>::BuilderType; + using OffsetCType = + typename TypeTraits::ArrowType>::CType; + + using ExtractRegexBase::ExtractRegexBase; + + static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + auto options = OptionsWrapper::Get(ctx); + ARROW_ASSIGN_OR_RAISE(auto data, + ExtractRegexSpanData::Make(options.pattern, Type::is_utf8)); + return ExtractRegexSpan{data}.Extract(ctx, batch, out); + } + + Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + DCHECK_NE(out->array_data(), nullptr); + std::shared_ptr out_type = out->array_data()->type; + ARROW_ASSIGN_OR_RAISE(auto out_builder, MakeBuilder(out_type, ctx->memory_pool())); + StructBuilder* struct_builder = checked_cast(out_builder.get()); + ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length)); + + std::vector span_builders; + std::vector array_builders; + span_builders.reserve(group_count); + array_builders.reserve(group_count); + for (int i = 0; i < group_count; i++) { + span_builders.push_back( + checked_cast(struct_builder->field_builder(i))); + array_builders.push_back( + checked_cast(span_builders.back()->value_builder())); + RETURN_NOT_OK(span_builders.back()->Reserve(batch[0].length())); + RETURN_NOT_OK(array_builders.back()->Reserve(2 * batch[0].length())); + } + + auto visit_null = [&]() { return struct_builder->AppendNull(); }; + auto visit_value = [&](std::string_view element) -> Status { + if (Match(element)) { + for (int i = 0; i < group_count; i++) { + // https://github.com/google/re2/issues/24#issuecomment-97653183 + if (found_values[i].data() != nullptr) { + int64_t begin = found_values[i].data() - element.data(); + int64_t size = found_values[i].size(); + array_builders[i]->UnsafeAppend(static_cast(begin)); + array_builders[i]->UnsafeAppend(static_cast(size)); + ARROW_RETURN_NOT_OK(span_builders[i]->Append()); + } else { + ARROW_RETURN_NOT_OK(span_builders[i]->AppendNull()); + } + } + ARROW_RETURN_NOT_OK(struct_builder->Append()); + } else { + ARROW_RETURN_NOT_OK(struct_builder->AppendNull()); + } + return Status::OK(); + }; + ARROW_RETURN_NOT_OK( + VisitArraySpanInline(batch[0].array, visit_value, visit_null)); + + ARROW_ASSIGN_OR_RAISE(auto out_array, struct_builder->Finish()); + out->value = std::move(out_array->data()); + return Status::OK(); + } +}; + +const FunctionDoc extract_regex_span_doc( + "Extract string spans captured by a regex pattern", + ("For each string in strings, match the regular expression and, if\n" + "successful, emit a struct with field names and values coming from the\n" + "regular expression's named capture groups. Each struct field value\n" + "will be a fixed_size_list(offset_type, 2) where offset_type is int32\n" + "or int64, depending on the input string type. The two elements in\n" + "each fixed-size list are the index and the length of the substring\n" + "matched by the corresponding named capture group.\n" + "\n" + "If the input is null or the regular expression fails matching,\n" + "a null output value is emitted.\n" + "\n" + "Regular expression matching is done using the Google RE2 library."), + {"strings"}, "ExtractRegexSpanOptions", /*options_required=*/true); + +Result ResolveExtractRegexSpanOutputType( + KernelContext* ctx, const std::vector& types) { + auto options = OptionsWrapper::Get(*ctx->state()); + ARROW_ASSIGN_OR_RAISE(auto span, ExtractRegexSpanData::Make(options.pattern)); + return span.ResolveOutputType(types); +} + +void AddAsciiStringExtractRegexSpan(FunctionRegistry* registry) { + auto func = std::make_shared("extract_regex_span", Arity::Unary(), + extract_regex_span_doc); + OutputType output_type(ResolveExtractRegexSpanOutputType); + for (const auto& type : BaseBinaryTypes()) { + ScalarKernel kernel({type}, output_type, + GenerateVarBinaryToVarBinary(type), + OptionsWrapper::Init); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(std::move(kernel))); + } + DCHECK_OK(registry->AddFunction(func)); +} #endif // ARROW_WITH_RE2 // ---------------------------------------------------------------------- @@ -3457,6 +3603,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddAsciiStringSplitWhitespace(registry); #ifdef ARROW_WITH_RE2 AddAsciiStringSplitRegex(registry); + AddAsciiStringExtractRegexSpan(registry); #endif AddAsciiStringJoin(registry); AddAsciiStringRepeat(registry); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 38455dc1467..672839f3cca 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -314,6 +314,7 @@ TYPED_TEST(TestBinaryKernels, NonUtf8Regex) { this->MakeArray({"\xfc\x40", "this \xfc\x40 that \xfc\x40"}), this->MakeArray({"bazz", "this bazz that \xfc\x40"}), &options); } + // TODO the following test is broken (GH-45735) { ExtractRegexOptions options("(?P[\\xfc])(?P\\d)"); auto null_bitmap = std::make_shared("0"); @@ -370,6 +371,7 @@ TYPED_TEST(TestBinaryKernels, NonUtf8WithNullRegex) { this->template MakeArray({{"\x00\x40", 2}}), this->type(), R"(["bazz"])", &options); } + // TODO the following test is broken (GH-45735) { ExtractRegexOptions options("(?P[\\x00])(?P\\d)"); auto null_bitmap = std::make_shared("0"); @@ -1959,6 +1961,62 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegex) { &options); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpan) { + ExtractRegexSpanOptions options{"(?P[ab]+)(?P\\d+)"}; + auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64(); + auto out_type = struct_({field("letter", fixed_size_list(type_fixe_size_list, 2)), + field("digit", fixed_size_list(type_fixe_size_list, 2))}); + this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options); + this->CheckUnary("extract_regex_span", R"([ null,"123ab","cd123ab","cd123abef"])", + out_type, R"([null,null,null,null])", &options); + this->CheckUnary( + "extract_regex_span", + R"(["a1", "b2", "c3", null,"123ab","abb12","abc13","cedbb15","cedaabb125efg"])", + out_type, + R"([{"letter":[0,1], "digit":[1,1]}, + {"letter":[0,1], "digit":[1,1]}, + null, + null, + null, + {"letter":[0,3], "digit":[3,2]}, + null, + {"letter":[3,2], "digit":[5,2]}, + {"letter":[3,4], "digit":[7,3]}])", + &options); + this->CheckUnary("extract_regex_span", R"([ "a3","b2","cdaa123","cdab123ef"])", + out_type, + R"([{"letter":[0,1], "digit":[1,1]}, + {"letter":[0,1], "digit":[1,1]}, + {"letter":[2,2], "digit":[4,3]}, + {"letter":[2,2], "digit":[4,3]}])", + &options); +} + +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanCaptureOption) { + ExtractRegexSpanOptions options{"(?Pfoo)?(?P\\d+)?"}; + auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64(); + auto out_type = struct_({field("foo", fixed_size_list(type_fixe_size_list, 2)), + field("digit", fixed_size_list(type_fixe_size_list, 2))}); + this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options); + this->CheckUnary("extract_regex_span", R"(["foo","foo123","abcfoo123","abc",null])", + out_type, + R"([{"foo":[0,3],"digit":null}, + {"foo":[0,3],"digit":[3,3]}, + {"foo":null,"digit":null}, + {"foo":null,"digit":null}, + null])", + &options); + options = ExtractRegexSpanOptions{"(?Pfoo)(?P\\d+)?"}; + this->CheckUnary("extract_regex_span", R"(["foo123","foo","123","abc","abcfoo"])", + out_type, + R"([{"foo":[0,3],"digit":[3,3]}, + {"foo":[0,3],"digit":null}, + null, + null, + {"foo":[3,3],"digit":null}])", + &options); +} + TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoCapture) { // XXX Should we accept this or is it a user error? ExtractRegexOptions options{"foo"}; @@ -1967,11 +2025,24 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoCapture) { R"([{}, null, null])", &options); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanNoCapture) { + // XXX Should we accept this or is it a user error? + ExtractRegexSpanOptions options{"foo"}; + auto type = struct_({}); + this->CheckUnary("extract_regex_span", R"(["oofoo", "bar", null])", type, + R"([{}, null, null])", &options); +} + TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoOptions) { Datum input = ArrayFromJSON(this->type(), "[]"); ASSERT_RAISES(Invalid, CallFunction("extract_regex", {input})); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanNoOptions) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ASSERT_RAISES(Invalid, CallFunction("extract_regex_span", {input})); +} + TYPED_TEST(TestBaseBinaryKernels, ExtractRegexInvalid) { Datum input = ArrayFromJSON(this->type(), "[]"); ExtractRegexOptions options{"invalid["}; @@ -1985,6 +2056,18 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegexInvalid) { CallFunction("extract_regex", {input}, &options)); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanInvalid) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ExtractRegexSpanOptions options{"invalid["}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"), + CallFunction("extract_regex_span", {input}, &options)); + options = ExtractRegexSpanOptions{"(.)"}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Regular expression contains unnamed groups"), + CallFunction("extract_regex_span", {input}, &options)); +} + #endif TYPED_TEST(TestStringKernels, Strptime) { diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 8825ffebf2a..57673dfe1fc 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1128,17 +1128,26 @@ when a positive ``max_splits`` is given. String component extraction ~~~~~~~~~~~~~~~~~~~~~~~~~~~ -+---------------+-------+------------------------+-------------+-------------------------------+-------+ -| Function name | Arity | Input types | Output type | Options class | Notes | -+===============+=======+========================+=============+===============================+=======+ -| extract_regex | Unary | Binary- or String-like | Struct | :struct:`ExtractRegexOptions` | \(1) | -+---------------+-------+------------------------+-------------+-------------------------------+-------+ ++--------------------+-------+------------------------+-------------+-----------------------------------+-------+ +| Function name | Arity | Input types | Output type | Options class | Notes | ++====================+=======+========================+=============+===================================+=======+ +| extract_regex | Unary | Binary- or String-like | Struct | :struct:`ExtractRegexOptions` | \(1) | ++--------------------+-------+------------------------+-------------+-----------------------------------+-------+ +| extract_regex_span | Unary | Binary- or String-like | Struct | :struct:`ExtractRegexSpanOptions` | \(2) | ++--------------------+-------+------------------------+-------------+-----------------------------------+-------+ * \(1) Extract substrings defined by a regular expression using the Google RE2 library. The output struct field names refer to the named capture groups, e.g. 'letter' and 'digit' for the regular expression ``(?P[ab])(?P\\d)``. +* \(2) Extract the offset and length of substrings defined by a regular expression + using the Google RE2 library. The output struct field names refer to the named + capture groups, e.g. 'letter' and 'digit' for the regular expression + ``(?P[ab])(?P\\d)``. Each output struct field is a fixed size list + of two integers: the index to the start of the captured group and the length + of the captured group, respectively. + String joining ~~~~~~~~~~~~~~ diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 63370c938b9..db6cf5b45d4 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -1222,6 +1222,25 @@ class ExtractRegexOptions(_ExtractRegexOptions): self._set_options(pattern) +cdef class _ExtractRegexSpanOptions(FunctionOptions): + def _set_options(self, pattern): + self.wrapped.reset(new CExtractRegexSpanOptions(tobytes(pattern))) + + +class ExtractRegexSpanOptions(_ExtractRegexSpanOptions): + """ + Options for the `extract_regex_span` function. + + Parameters + ---------- + pattern : str + Regular expression with named capture fields. + """ + + def __init__(self, pattern): + self._set_options(pattern) + + cdef class _SliceOptions(FunctionOptions): def _set_options(self, start, stop, step): self.wrapped.reset(new CSliceOptions(start, stop, step)) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 8040cf9ff03..1809c74afc5 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -40,6 +40,7 @@ RunEndEncodeOptions, ElementWiseAggregateOptions, ExtractRegexOptions, + ExtractRegexSpanOptions, FilterOptions, IndexOptions, JoinOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index c3ddaba88fd..f9fa091171d 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2500,6 +2500,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CExtractRegexOptions(c_string pattern) c_string pattern + cdef cppclass CExtractRegexSpanOptions \ + "arrow::compute::ExtractRegexSpanOptions"(CFunctionOptions): + CExtractRegexSpanOptions(c_string pattern) + c_string pattern + cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): CCastOptions() CCastOptions(c_bool safe) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 8a756a262b6..73506fedfc8 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -152,6 +152,7 @@ def test_option_class_equality(request): pc.RunEndEncodeOptions(), pc.ElementWiseAggregateOptions(skip_nulls=True), pc.ExtractRegexOptions("pattern"), + pc.ExtractRegexSpanOptions("pattern"), pc.FilterOptions(), pc.IndexOptions(pa.scalar(1)), pc.JoinOptions(), @@ -1092,6 +1093,16 @@ def test_extract_regex(): assert struct.tolist() == expected +def test_extract_regex_span(): + ar = pa.array(['a1', 'zb234z']) + expected = [{'letter': [0, 1], 'digit': [1, 1]}, + {'letter': [1, 1], 'digit': [2, 3]}] + struct = pc.extract_regex_span(ar, pattern=r'(?P[ab])(?P\d+)') + assert struct.tolist() == expected + struct = pc.extract_regex_span(ar, r'(?P[ab])(?P\d+)') + assert struct.tolist() == expected + + def test_binary_join(): ar_list = pa.array([['foo', 'bar'], None, []]) expected = pa.array(['foo-bar', None, ''])