Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 122 additions & 47 deletions cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,62 @@ void StringBoolTransform(KernelContext* ctx, const ExecSpan& batch,
}
}

using MatchSubstringState = OptionsWrapper<MatchSubstringOptions>;
// Similar to OptionsWrapper, but caches a compiled object to avoid recompiling on each
// invocation (e.g., regex matchers). Follows the same pattern as OptionsWrapper.
template <typename OptionsType>
struct CachedOptionsWrapper : public KernelState {
explicit CachedOptionsWrapper(OptionsType options) : options(std::move(options)) {}

static Result<std::unique_ptr<KernelState>> Init(KernelContext*,
const KernelInitArgs& args) {
if (auto options_ptr = static_cast<const OptionsType*>(args.options)) {
return std::make_unique<CachedOptionsWrapper>(*options_ptr);
}
return Status::Invalid(
"Attempted to initialize KernelState from null FunctionOptions");
}

static const OptionsType& Get(const KernelState& state) {
return checked_cast<const CachedOptionsWrapper&>(state).options;
}

static const OptionsType& Get(KernelContext* ctx) { return Get(*ctx->state()); }

static CachedOptionsWrapper& GetMutable(KernelContext* ctx) {
return checked_cast<CachedOptionsWrapper&>(*ctx->state());
}

// Get or create cached object of a specific type
template <typename ObjectType, typename... Args>
Result<const ObjectType*> GetOrCreate(Args&&... args) {
if (!cached_object) {
ARROW_ASSIGN_OR_RAISE(auto object, ObjectType::Make(std::forward<Args>(args)...));
// Convert to shared_ptr (handles both unique_ptr and value returns)
cached_object = MakeCachedObject<ObjectType>(std::move(object));
}
return static_cast<const ObjectType*>(cached_object.get());
}

OptionsType options;
// Type-erased cache for compiled objects (can store any object type)
std::shared_ptr<void> cached_object;

private:
// Convert unique_ptr to shared_ptr<void>
template <typename T>
static std::shared_ptr<void> MakeCachedObject(std::unique_ptr<T>&& ptr) {
return std::shared_ptr<void>(std::move(ptr));
}

// Convert value to shared_ptr<void>
template <typename T>
static std::shared_ptr<void> MakeCachedObject(T&& value) {
return std::shared_ptr<void>(std::make_shared<T>(std::move(value)));
}
};

// State for match/count/find substring operations with cached matcher compilation
using MatchSubstringState = CachedOptionsWrapper<MatchSubstringOptions>;

// This is an implementation of the Knuth-Morris-Pratt algorithm
struct PlainSubstringMatcher {
Expand Down Expand Up @@ -1368,90 +1423,103 @@ struct MatchSubstringImpl {
template <typename Type, typename Matcher>
struct MatchSubstring {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
// TODO Cache matcher across invocations (for regex compilation)
ARROW_ASSIGN_OR_RAISE(auto matcher, Matcher::Make(MatchSubstringState::Get(ctx)));
return MatchSubstringImpl<Type, Matcher>::Exec(ctx, batch, out, matcher.get());
// Get or create cached matcher to avoid recompiling on each invocation
auto& state = MatchSubstringState::GetMutable(ctx);
ARROW_ASSIGN_OR_RAISE(auto matcher, state.GetOrCreate<Matcher>(state.options));
return MatchSubstringImpl<Type, Matcher>::Exec(ctx, batch, out, matcher);
}
};

#ifdef ARROW_WITH_RE2
template <typename Type>
struct MatchSubstring<Type, RegexSubstringMatcher> {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
// TODO Cache matcher across invocations (for regex compilation)
ARROW_ASSIGN_OR_RAISE(auto matcher,
RegexSubstringMatcher::Make(MatchSubstringState::Get(ctx),
/*is_utf8=*/Type::is_utf8));
// Get or create cached matcher to avoid recompiling regex on each invocation
auto& state = MatchSubstringState::GetMutable(ctx);
constexpr bool is_utf8 = Type::is_utf8;
ARROW_ASSIGN_OR_RAISE(auto matcher, state.GetOrCreate<RegexSubstringMatcher>(
state.options, /*is_utf8=*/is_utf8));
return MatchSubstringImpl<Type, RegexSubstringMatcher>::Exec(ctx, batch, out,
matcher.get());
matcher);
}
};
#endif

template <typename Type>
struct MatchSubstring<Type, PlainSubstringMatcher> {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
auto options = MatchSubstringState::Get(ctx);
if (options.ignore_case) {
auto& state = MatchSubstringState::GetMutable(ctx);
if (state.options.ignore_case) {
#ifdef ARROW_WITH_RE2
ARROW_ASSIGN_OR_RAISE(
auto matcher, RegexSubstringMatcher::Make(options, /*is_utf8=*/Type::is_utf8,
/*literal=*/true));
// Get or create cached regex matcher for case-insensitive plain substring
constexpr bool is_utf8 = Type::is_utf8;
ARROW_ASSIGN_OR_RAISE(auto matcher,
state.GetOrCreate<RegexSubstringMatcher>(
state.options, /*is_utf8=*/is_utf8, /*literal=*/true));
return MatchSubstringImpl<Type, RegexSubstringMatcher>::Exec(ctx, batch, out,
matcher.get());
matcher);
#else
return Status::NotImplemented("ignore_case requires RE2");
#endif
}
ARROW_ASSIGN_OR_RAISE(auto matcher, PlainSubstringMatcher::Make(options));
// Get or create cached plain matcher (caches KMP prefix table)
ARROW_ASSIGN_OR_RAISE(auto matcher,
state.GetOrCreate<PlainSubstringMatcher>(state.options));
return MatchSubstringImpl<Type, PlainSubstringMatcher>::Exec(ctx, batch, out,
matcher.get());
matcher);
}
};

template <typename Type>
struct MatchSubstring<Type, PlainStartsWithMatcher> {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
auto options = MatchSubstringState::Get(ctx);
if (options.ignore_case) {
auto& state = MatchSubstringState::GetMutable(ctx);
if (state.options.ignore_case) {
#ifdef ARROW_WITH_RE2
MatchSubstringOptions converted_options = options;
converted_options.pattern = "^" + RE2::QuoteMeta(options.pattern);
// Get or create cached regex matcher for case-insensitive starts_with
MatchSubstringOptions converted_options = state.options;
converted_options.pattern = "^" + RE2::QuoteMeta(state.options.pattern);
constexpr bool is_utf8 = Type::is_utf8;
ARROW_ASSIGN_OR_RAISE(
auto matcher,
RegexSubstringMatcher::Make(converted_options, /*is_utf8=*/Type::is_utf8));
auto matcher, state.GetOrCreate<RegexSubstringMatcher>(converted_options,
/*is_utf8=*/is_utf8));
return MatchSubstringImpl<Type, RegexSubstringMatcher>::Exec(ctx, batch, out,
matcher.get());
matcher);
#else
return Status::NotImplemented("ignore_case requires RE2");
#endif
}
ARROW_ASSIGN_OR_RAISE(auto matcher, PlainStartsWithMatcher::Make(options));
// Get or create cached plain starts_with matcher
ARROW_ASSIGN_OR_RAISE(auto matcher,
state.GetOrCreate<PlainStartsWithMatcher>(state.options));
return MatchSubstringImpl<Type, PlainStartsWithMatcher>::Exec(ctx, batch, out,
matcher.get());
matcher);
}
};

template <typename Type>
struct MatchSubstring<Type, PlainEndsWithMatcher> {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
auto options = MatchSubstringState::Get(ctx);
if (options.ignore_case) {
auto& state = MatchSubstringState::GetMutable(ctx);
if (state.options.ignore_case) {
#ifdef ARROW_WITH_RE2
MatchSubstringOptions converted_options = options;
converted_options.pattern = RE2::QuoteMeta(options.pattern) + "$";
// Get or create cached regex matcher for case-insensitive ends_with
MatchSubstringOptions converted_options = state.options;
converted_options.pattern = RE2::QuoteMeta(state.options.pattern) + "$";
constexpr bool is_utf8 = Type::is_utf8;
ARROW_ASSIGN_OR_RAISE(
auto matcher,
RegexSubstringMatcher::Make(converted_options, /*is_utf8=*/Type::is_utf8));
auto matcher, state.GetOrCreate<RegexSubstringMatcher>(converted_options,
/*is_utf8=*/is_utf8));
return MatchSubstringImpl<Type, RegexSubstringMatcher>::Exec(ctx, batch, out,
matcher.get());
matcher);
#else
return Status::NotImplemented("ignore_case requires RE2");
#endif
}
ARROW_ASSIGN_OR_RAISE(auto matcher, PlainEndsWithMatcher::Make(options));
return MatchSubstringImpl<Type, PlainEndsWithMatcher>::Exec(ctx, batch, out,
matcher.get());
// Get or create cached plain ends_with matcher
ARROW_ASSIGN_OR_RAISE(auto matcher,
state.GetOrCreate<PlainEndsWithMatcher>(state.options));
return MatchSubstringImpl<Type, PlainEndsWithMatcher>::Exec(ctx, batch, out, matcher);
}
};

Expand Down Expand Up @@ -1952,7 +2020,8 @@ void AddAsciiStringCountSubstring(FunctionRegistry* registry) {
// ----------------------------------------------------------------------
// Replace substring (plain, regex)

using ReplaceState = OptionsWrapper<ReplaceSubstringOptions>;
// State for replace substring operations with cached replacer compilation
using ReplaceState = CachedOptionsWrapper<ReplaceSubstringOptions>;

template <typename Type, typename Replacer>
struct ReplaceSubstring {
Expand All @@ -1962,8 +2031,9 @@ struct ReplaceSubstring {
using OffsetBuilder = TypedBufferBuilder<offset_type>;

static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
// TODO Cache replacer across invocations (for regex compilation)
ARROW_ASSIGN_OR_RAISE(auto replacer, Replacer::Make(ReplaceState::Get(ctx)));
// Get or create cached replacer to avoid recompiling on each invocation
auto& state = ReplaceState::GetMutable(ctx);
ARROW_ASSIGN_OR_RAISE(auto replacer, state.GetOrCreate<Replacer>(state.options));
return Replace(ctx, batch, *replacer, out);
}

Expand Down Expand Up @@ -2185,7 +2255,8 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) {

#ifdef ARROW_WITH_RE2

using ExtractRegexState = OptionsWrapper<ExtractRegexOptions>;
// State for extract_regex operations with cached ExtractRegexData compilation
using ExtractRegexState = CachedOptionsWrapper<ExtractRegexOptions>;

struct BaseExtractRegexData {
Status Init() {
Expand Down Expand Up @@ -2215,7 +2286,6 @@ struct BaseExtractRegexData {
: regex(new RE2(pattern, MakeRE2Options(is_utf8))) {}
};

// TODO cache this once per ExtractRegexOptions
struct ExtractRegexData : public BaseExtractRegexData {
static Result<ExtractRegexData> Make(const ExtractRegexOptions& options, bool is_utf8) {
ExtractRegexData data(options.pattern, is_utf8);
Expand Down Expand Up @@ -2247,9 +2317,11 @@ Result<TypeHolder> ResolveExtractRegexOutput(KernelContext* ctx,
}
DCHECK(is_base_binary_like(input_type->id()));
auto is_utf8 = is_string(input_type->id());
ExtractRegexOptions options = ExtractRegexState::Get(ctx);
ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, is_utf8));
return data.ResolveOutputType(types);
// Get or create cached data to avoid recompiling regex
auto& state = ExtractRegexState::GetMutable(ctx);
ARROW_ASSIGN_OR_RAISE(auto data,
state.GetOrCreate<ExtractRegexData>(state.options, is_utf8));
return data->ResolveOutputType(types);
}

struct ExtractRegexBase {
Expand Down Expand Up @@ -2291,9 +2363,12 @@ struct ExtractRegex : public ExtractRegexBase {
using ExtractRegexBase::ExtractRegexBase;

static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
ExtractRegexOptions options = ExtractRegexState::Get(ctx);
ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, Type::is_utf8));
return ExtractRegex(data).Extract(ctx, batch, out);
// Get or create cached data to avoid recompiling regex
auto& state = ExtractRegexState::GetMutable(ctx);
constexpr bool is_utf8 = Type::is_utf8;
ARROW_ASSIGN_OR_RAISE(auto data,
state.GetOrCreate<ExtractRegexData>(state.options, is_utf8));
return ExtractRegex(*data).Extract(ctx, batch, out);
}

Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
Expand Down
Loading