diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc index dd5abed16c3..edc3d44cd90 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -1223,7 +1223,62 @@ void StringBoolTransform(KernelContext* ctx, const ExecSpan& batch, } } -using MatchSubstringState = OptionsWrapper; +// Similar to OptionsWrapper, but caches a compiled object to avoid recompiling on each +// invocation (e.g., regex matchers). Follows the same pattern as OptionsWrapper. +template +struct CachedOptionsWrapper : public KernelState { + explicit CachedOptionsWrapper(OptionsType options) : options(std::move(options)) {} + + static Result> Init(KernelContext*, + const KernelInitArgs& args) { + if (auto options_ptr = static_cast(args.options)) { + return std::make_unique(*options_ptr); + } + return Status::Invalid( + "Attempted to initialize KernelState from null FunctionOptions"); + } + + static const OptionsType& Get(const KernelState& state) { + return checked_cast(state).options; + } + + static const OptionsType& Get(KernelContext* ctx) { return Get(*ctx->state()); } + + static CachedOptionsWrapper& GetMutable(KernelContext* ctx) { + return checked_cast(*ctx->state()); + } + + // Get or create cached object of a specific type + template + Result GetOrCreate(Args&&... args) { + if (!cached_object) { + ARROW_ASSIGN_OR_RAISE(auto object, ObjectType::Make(std::forward(args)...)); + // Convert to shared_ptr (handles both unique_ptr and value returns) + cached_object = MakeCachedObject(std::move(object)); + } + return static_cast(cached_object.get()); + } + + OptionsType options; + // Type-erased cache for compiled objects (can store any object type) + std::shared_ptr cached_object; + + private: + // Convert unique_ptr to shared_ptr + template + static std::shared_ptr MakeCachedObject(std::unique_ptr&& ptr) { + return std::shared_ptr(std::move(ptr)); + } + + // Convert value to shared_ptr + template + static std::shared_ptr MakeCachedObject(T&& value) { + return std::shared_ptr(std::make_shared(std::move(value))); + } +}; + +// State for match/count/find substring operations with cached matcher compilation +using MatchSubstringState = CachedOptionsWrapper; // This is an implementation of the Knuth-Morris-Pratt algorithm struct PlainSubstringMatcher { @@ -1368,9 +1423,10 @@ struct MatchSubstringImpl { template struct MatchSubstring { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - // TODO Cache matcher across invocations (for regex compilation) - ARROW_ASSIGN_OR_RAISE(auto matcher, Matcher::Make(MatchSubstringState::Get(ctx))); - return MatchSubstringImpl::Exec(ctx, batch, out, matcher.get()); + // Get or create cached matcher to avoid recompiling on each invocation + auto& state = MatchSubstringState::GetMutable(ctx); + ARROW_ASSIGN_OR_RAISE(auto matcher, state.GetOrCreate(state.options)); + return MatchSubstringImpl::Exec(ctx, batch, out, matcher); } }; @@ -1378,12 +1434,13 @@ struct MatchSubstring { template struct MatchSubstring { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - // TODO Cache matcher across invocations (for regex compilation) - ARROW_ASSIGN_OR_RAISE(auto matcher, - RegexSubstringMatcher::Make(MatchSubstringState::Get(ctx), - /*is_utf8=*/Type::is_utf8)); + // Get or create cached matcher to avoid recompiling regex on each invocation + auto& state = MatchSubstringState::GetMutable(ctx); + constexpr bool is_utf8 = Type::is_utf8; + ARROW_ASSIGN_OR_RAISE(auto matcher, state.GetOrCreate( + state.options, /*is_utf8=*/is_utf8)); return MatchSubstringImpl::Exec(ctx, batch, out, - matcher.get()); + matcher); } }; #endif @@ -1391,67 +1448,78 @@ struct MatchSubstring { template struct MatchSubstring { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - auto options = MatchSubstringState::Get(ctx); - if (options.ignore_case) { + auto& state = MatchSubstringState::GetMutable(ctx); + if (state.options.ignore_case) { #ifdef ARROW_WITH_RE2 - ARROW_ASSIGN_OR_RAISE( - auto matcher, RegexSubstringMatcher::Make(options, /*is_utf8=*/Type::is_utf8, - /*literal=*/true)); + // Get or create cached regex matcher for case-insensitive plain substring + constexpr bool is_utf8 = Type::is_utf8; + ARROW_ASSIGN_OR_RAISE(auto matcher, + state.GetOrCreate( + state.options, /*is_utf8=*/is_utf8, /*literal=*/true)); return MatchSubstringImpl::Exec(ctx, batch, out, - matcher.get()); + matcher); #else return Status::NotImplemented("ignore_case requires RE2"); #endif } - ARROW_ASSIGN_OR_RAISE(auto matcher, PlainSubstringMatcher::Make(options)); + // Get or create cached plain matcher (caches KMP prefix table) + ARROW_ASSIGN_OR_RAISE(auto matcher, + state.GetOrCreate(state.options)); return MatchSubstringImpl::Exec(ctx, batch, out, - matcher.get()); + matcher); } }; template struct MatchSubstring { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - auto options = MatchSubstringState::Get(ctx); - if (options.ignore_case) { + auto& state = MatchSubstringState::GetMutable(ctx); + if (state.options.ignore_case) { #ifdef ARROW_WITH_RE2 - MatchSubstringOptions converted_options = options; - converted_options.pattern = "^" + RE2::QuoteMeta(options.pattern); + // Get or create cached regex matcher for case-insensitive starts_with + MatchSubstringOptions converted_options = state.options; + converted_options.pattern = "^" + RE2::QuoteMeta(state.options.pattern); + constexpr bool is_utf8 = Type::is_utf8; ARROW_ASSIGN_OR_RAISE( - auto matcher, - RegexSubstringMatcher::Make(converted_options, /*is_utf8=*/Type::is_utf8)); + auto matcher, state.GetOrCreate(converted_options, + /*is_utf8=*/is_utf8)); return MatchSubstringImpl::Exec(ctx, batch, out, - matcher.get()); + matcher); #else return Status::NotImplemented("ignore_case requires RE2"); #endif } - ARROW_ASSIGN_OR_RAISE(auto matcher, PlainStartsWithMatcher::Make(options)); + // Get or create cached plain starts_with matcher + ARROW_ASSIGN_OR_RAISE(auto matcher, + state.GetOrCreate(state.options)); return MatchSubstringImpl::Exec(ctx, batch, out, - matcher.get()); + matcher); } }; template struct MatchSubstring { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - auto options = MatchSubstringState::Get(ctx); - if (options.ignore_case) { + auto& state = MatchSubstringState::GetMutable(ctx); + if (state.options.ignore_case) { #ifdef ARROW_WITH_RE2 - MatchSubstringOptions converted_options = options; - converted_options.pattern = RE2::QuoteMeta(options.pattern) + "$"; + // Get or create cached regex matcher for case-insensitive ends_with + MatchSubstringOptions converted_options = state.options; + converted_options.pattern = RE2::QuoteMeta(state.options.pattern) + "$"; + constexpr bool is_utf8 = Type::is_utf8; ARROW_ASSIGN_OR_RAISE( - auto matcher, - RegexSubstringMatcher::Make(converted_options, /*is_utf8=*/Type::is_utf8)); + auto matcher, state.GetOrCreate(converted_options, + /*is_utf8=*/is_utf8)); return MatchSubstringImpl::Exec(ctx, batch, out, - matcher.get()); + matcher); #else return Status::NotImplemented("ignore_case requires RE2"); #endif } - ARROW_ASSIGN_OR_RAISE(auto matcher, PlainEndsWithMatcher::Make(options)); - return MatchSubstringImpl::Exec(ctx, batch, out, - matcher.get()); + // Get or create cached plain ends_with matcher + ARROW_ASSIGN_OR_RAISE(auto matcher, + state.GetOrCreate(state.options)); + return MatchSubstringImpl::Exec(ctx, batch, out, matcher); } }; @@ -1952,7 +2020,8 @@ void AddAsciiStringCountSubstring(FunctionRegistry* registry) { // ---------------------------------------------------------------------- // Replace substring (plain, regex) -using ReplaceState = OptionsWrapper; +// State for replace substring operations with cached replacer compilation +using ReplaceState = CachedOptionsWrapper; template struct ReplaceSubstring { @@ -1962,8 +2031,9 @@ struct ReplaceSubstring { using OffsetBuilder = TypedBufferBuilder; static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - // TODO Cache replacer across invocations (for regex compilation) - ARROW_ASSIGN_OR_RAISE(auto replacer, Replacer::Make(ReplaceState::Get(ctx))); + // Get or create cached replacer to avoid recompiling on each invocation + auto& state = ReplaceState::GetMutable(ctx); + ARROW_ASSIGN_OR_RAISE(auto replacer, state.GetOrCreate(state.options)); return Replace(ctx, batch, *replacer, out); } @@ -2185,7 +2255,8 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) { #ifdef ARROW_WITH_RE2 -using ExtractRegexState = OptionsWrapper; +// State for extract_regex operations with cached ExtractRegexData compilation +using ExtractRegexState = CachedOptionsWrapper; struct BaseExtractRegexData { Status Init() { @@ -2215,7 +2286,6 @@ struct BaseExtractRegexData { : regex(new RE2(pattern, MakeRE2Options(is_utf8))) {} }; -// TODO cache this once per ExtractRegexOptions struct ExtractRegexData : public BaseExtractRegexData { static Result Make(const ExtractRegexOptions& options, bool is_utf8) { ExtractRegexData data(options.pattern, is_utf8); @@ -2247,9 +2317,11 @@ Result ResolveExtractRegexOutput(KernelContext* ctx, } DCHECK(is_base_binary_like(input_type->id())); auto is_utf8 = is_string(input_type->id()); - ExtractRegexOptions options = ExtractRegexState::Get(ctx); - ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, is_utf8)); - return data.ResolveOutputType(types); + // Get or create cached data to avoid recompiling regex + auto& state = ExtractRegexState::GetMutable(ctx); + ARROW_ASSIGN_OR_RAISE(auto data, + state.GetOrCreate(state.options, is_utf8)); + return data->ResolveOutputType(types); } struct ExtractRegexBase { @@ -2291,9 +2363,12 @@ struct ExtractRegex : public ExtractRegexBase { using ExtractRegexBase::ExtractRegexBase; static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - ExtractRegexOptions options = ExtractRegexState::Get(ctx); - ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, Type::is_utf8)); - return ExtractRegex(data).Extract(ctx, batch, out); + // Get or create cached data to avoid recompiling regex + auto& state = ExtractRegexState::GetMutable(ctx); + constexpr bool is_utf8 = Type::is_utf8; + ARROW_ASSIGN_OR_RAISE(auto data, + state.GetOrCreate(state.options, is_utf8)); + return ExtractRegex(*data).Extract(ctx, batch, out); } Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {