Skip to content

Commit bfed461

Browse files
authored
chore: suffix search (#5327)
feat(search): Implement suffix/infix search for text fields
1 parent 42de8c3 commit bfed461

File tree

7 files changed

+235
-128
lines changed

7 files changed

+235
-128
lines changed

src/core/search/indices.cc

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
#include <algorithm>
2323
#include <cctype>
2424

25-
#include "base/logging.h"
26-
2725
namespace dfly::search {
2826

2927
using namespace std;
@@ -74,6 +72,16 @@ absl::flat_hash_set<string> NormalizeTags(string_view taglist, bool case_sensiti
7472
return tags;
7573
}
7674

75+
// Iterate over all suffixes of all words
76+
void IterateAllSuffixes(const absl::flat_hash_set<string>& words,
77+
absl::FunctionRef<void(std::string_view)> cb) {
78+
for (string_view word : words) {
79+
for (size_t offs = 0; offs < word.length(); offs++) {
80+
cb(word.substr(offs));
81+
}
82+
}
83+
}
84+
7785
}; // namespace
7886

7987
NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} {
@@ -147,8 +155,11 @@ vector<DocId> NumericIndex::GetAllDocsWithNonNullValues() const {
147155
}
148156

149157
template <typename C>
150-
BaseStringIndex<C>::BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive)
158+
BaseStringIndex<C>::BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive,
159+
bool with_suffix)
151160
: case_sensitive_{case_sensitive}, entries_{mr} {
161+
if (with_suffix)
162+
suffix_trie_.emplace(mr);
152163
}
153164

154165
template <typename C>
@@ -169,18 +180,49 @@ const typename BaseStringIndex<C>::Container* BaseStringIndex<C>::Matching(
169180
}
170181

171182
template <typename C>
172-
void BaseStringIndex<C>::MatchingPrefix(std::string_view prefix,
173-
absl::FunctionRef<void(const Container*)> cb) const {
183+
void BaseStringIndex<C>::MatchPrefix(std::string_view prefix,
184+
absl::FunctionRef<void(const Container*)> cb) const {
185+
// TODO(vlad): Use right iterator to avoid string comparison?
174186
for (auto it = entries_.lower_bound(prefix);
175187
it != entries_.end() && (*it).first.rfind(prefix, 0) == 0; ++it) {
176188
cb(&(*it).second);
177189
}
178190
}
179191

180192
template <typename C>
181-
typename BaseStringIndex<C>::Container* BaseStringIndex<C>::GetOrCreate(string_view word) {
182-
auto* mr = entries_.get_allocator().resource();
183-
return &entries_.try_emplace(PMR_NS::string{word, mr}, mr, 1000 /* block size */).first->second;
193+
void BaseStringIndex<C>::MatchSuffix(std::string_view suffix,
194+
absl::FunctionRef<void(const Container*)> cb) const {
195+
// If we have a suffix trie built, we just need to fetch the relevant suffix
196+
if (suffix_trie_) {
197+
auto it = suffix_trie_->find(suffix);
198+
cb((it != suffix_trie_->end()) ? &it->second : nullptr);
199+
return;
200+
}
201+
202+
// Otherwise, iterate over all entries and look for the suffix
203+
for (const auto& entry : entries_) {
204+
int32_t start = entry.first.size() - suffix.size();
205+
if (start >= 0 && entry.first.substr(start) == suffix)
206+
cb(&entry.second);
207+
}
208+
}
209+
210+
template <typename C>
211+
void BaseStringIndex<C>::MatchInfix(std::string_view infix,
212+
absl::FunctionRef<void(const Container*)> cb) const {
213+
// If we have a suffix trie built, we just need to match the prefix
214+
if (suffix_trie_) {
215+
for (auto it = suffix_trie_->lower_bound(infix);
216+
it != suffix_trie_->end() && (*it).first.rfind(infix, 0) == 0; ++it)
217+
cb(&(*it).second);
218+
return;
219+
}
220+
221+
// Otherwise, iterate over all entries and check if it contains the entry
222+
for (const auto& entry : entries_) {
223+
if (entry.first.find(infix) != string::npos)
224+
cb(&entry.second);
225+
}
184226
}
185227

186228
template <typename C>
@@ -197,7 +239,12 @@ bool BaseStringIndex<C>::Add(DocId id, const DocumentAccessor& doc, string_view
197239
if (tokens.size() > 1)
198240
unique_ids_ = false;
199241
for (string_view token : tokens)
200-
GetOrCreate(token)->Insert(id);
242+
GetOrCreate(&entries_, token)->Insert(id);
243+
244+
if (suffix_trie_)
245+
IterateAllSuffixes(tokens,
246+
[&](string_view str) { GetOrCreate(&*suffix_trie_, str)->Insert(id); });
247+
201248
return true;
202249
}
203250

@@ -209,15 +256,11 @@ void BaseStringIndex<C>::Remove(DocId id, const DocumentAccessor& doc, string_vi
209256
for (string_view str : strings_list)
210257
tokens.merge(Tokenize(str));
211258

212-
for (const auto& token : tokens) {
213-
auto it = entries_.find(token);
214-
if (it == entries_.end())
215-
continue;
259+
for (string_view token : tokens)
260+
Remove(&entries_, id, token);
216261

217-
it->second.Remove(id);
218-
if (it->second.Size() == 0)
219-
entries_.erase(it);
220-
}
262+
if (suffix_trie_)
263+
IterateAllSuffixes(tokens, [&](string_view str) { Remove(&*suffix_trie_, id, str); });
221264
}
222265

223266
template <typename C> vector<string> BaseStringIndex<C>::GetTerms() const {
@@ -259,9 +302,32 @@ template <typename C> vector<DocId> BaseStringIndex<C>::GetAllDocsWithNonNullVal
259302
return result;
260303
}
261304

305+
template <typename C>
306+
typename BaseStringIndex<C>::Container* BaseStringIndex<C>::GetOrCreate(
307+
search::RaxTreeMap<Container>* map, string_view word) {
308+
auto* mr = map->get_allocator().resource();
309+
return &map->try_emplace(PMR_NS::string{word, mr}, mr, 1000 /* block size */).first->second;
310+
}
311+
312+
template <typename C>
313+
void BaseStringIndex<C>::Remove(search::RaxTreeMap<Container>* map, DocId id, string_view word) {
314+
auto it = map->find(word);
315+
if (it == map->end())
316+
return;
317+
318+
it->second.Remove(id);
319+
if (it->second.Size() == 0)
320+
map->erase(it);
321+
}
322+
262323
template struct BaseStringIndex<CompressedSortedSet>;
263324
template struct BaseStringIndex<SortedVector>;
264325

326+
TextIndex::TextIndex(PMR_NS::memory_resource* mr, const StopWords* stopwords,
327+
const Synonyms* synonyms, bool with_suffixtrie)
328+
: BaseStringIndex(mr, false, with_suffixtrie), stopwords_{stopwords}, synonyms_{synonyms} {
329+
}
330+
265331
std::optional<DocumentAccessor::StringList> TextIndex::GetStrings(const DocumentAccessor& doc,
266332
std::string_view field) const {
267333
return doc.GetStrings(field);

src/core/search/indices.h

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,24 @@ struct NumericIndex : public BaseIndex {
4646
// Base index for string based indices.
4747
template <typename C> struct BaseStringIndex : public BaseIndex {
4848
using Container = BlockList<C>;
49+
using VecOrPtr = std::variant<std::vector<DocId>, const Container*>;
4950

50-
BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive);
51+
BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive, bool with_suffixtrie);
5152

5253
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
5354
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
5455

5556
// Pointer is valid as long as index is not mutated. Nullptr if not found
5657
const Container* Matching(std::string_view str, bool strip_whitespace = true) const;
5758

58-
// Iterate over all Matching on prefix.
59-
void MatchingPrefix(std::string_view prefix, absl::FunctionRef<void(const Container*)> cb) const;
59+
// Iterate over all nodes matching on prefix.
60+
void MatchPrefix(std::string_view prefix, absl::FunctionRef<void(const Container*)> cb) const;
61+
62+
// Iterate over all nodes matching suffix query. Faster if suffix trie is built.
63+
void MatchSuffix(std::string_view suffix, absl::FunctionRef<void(const Container*)> cb) const;
64+
65+
// Iterate over all nodes matching infix query. Faster if suffix trie is built.
66+
void MatchInfix(std::string_view prefix, absl::FunctionRef<void(const Container*)> cb) const;
6067

6168
// Returns all the terms that appear as keys in the reverse index.
6269
std::vector<std::string> GetTerms() const;
@@ -73,21 +80,22 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
7380
// Used by Add & Remove to tokenize text value
7481
virtual absl::flat_hash_set<std::string> Tokenize(std::string_view value) const = 0;
7582

76-
Container* GetOrCreate(std::string_view word);
83+
static Container* GetOrCreate(search::RaxTreeMap<Container>* map, std::string_view word);
84+
static void Remove(search::RaxTreeMap<Container>* map, DocId id, std::string_view word);
7785

7886
bool case_sensitive_ = false;
7987
bool unique_ids_ = true; // If true, docs ids are unique in the index, otherwise they can repeat.
8088
search::RaxTreeMap<Container> entries_;
89+
std::optional<search::RaxTreeMap<Container>> suffix_trie_;
8190
};
8291

8392
// Index for text fields.
8493
// Hashmap based lookup per word.
8594
struct TextIndex : public BaseStringIndex<CompressedSortedSet> {
8695
using StopWords = absl::flat_hash_set<std::string>;
8796

88-
TextIndex(PMR_NS::memory_resource* mr, const StopWords* stopwords, const Synonyms* synonyms)
89-
: BaseStringIndex(mr, false), stopwords_{stopwords}, synonyms_{synonyms} {
90-
}
97+
TextIndex(PMR_NS::memory_resource* mr, const StopWords* stopwords, const Synonyms* synonyms,
98+
bool with_suffixtrie);
9199

92100
protected:
93101
std::optional<StringList> GetStrings(const DocumentAccessor& doc,
@@ -103,7 +111,7 @@ struct TextIndex : public BaseStringIndex<CompressedSortedSet> {
103111
// Hashmap based lookup per word.
104112
struct TagIndex : public BaseStringIndex<SortedVector> {
105113
TagIndex(PMR_NS::memory_resource* mr, SchemaField::TagParams params)
106-
: BaseStringIndex(mr, params.case_sensitive), separator_{params.separator} {
114+
: BaseStringIndex(mr, params.case_sensitive, false), separator_{params.separator} {
107115
}
108116

109117
protected:

src/core/search/search.cc

Lines changed: 45 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "core/search/indices.h"
2222
#include "core/search/query_driver.h"
2323
#include "core/search/sort_indices.h"
24+
#include "core/search/tag_types.h"
2425
#include "core/search/vector_utils.h"
2526

2627
using namespace std;
@@ -268,28 +269,14 @@ struct BasicSearch {
268269
return out;
269270
}
270271

271-
template <typename C>
272-
IndexResult CollectPrefixMatches(BaseStringIndex<C>* index, std::string_view prefix) {
272+
template <typename C, typename F>
273+
IndexResult CollectMatches(BaseStringIndex<C>* index, std::string_view word, F&& f) {
273274
IndexResult result{};
274-
index->MatchingPrefix(
275-
prefix, [&result, this](const auto* c) { Merge(IndexResult{c}, &result, LogicOp::OR); });
275+
invoke(f, *index, word,
276+
[&result, this](const auto* c) { Merge(IndexResult{c}, &result, LogicOp::OR); });
276277
return result;
277278
}
278279

279-
template <typename C>
280-
IndexResult CollectSuffixMatches(BaseStringIndex<C>* index, std::string_view suffix) {
281-
// TODO: Implement full text search for suffix
282-
error_ = "Not implemented";
283-
return IndexResult{};
284-
}
285-
286-
template <typename C>
287-
IndexResult CollectInfixMatches(BaseStringIndex<C>* index, std::string_view infix) {
288-
// TODO: Implement full text search for infix
289-
error_ = "Not implemented";
290-
return IndexResult{};
291-
}
292-
293280
IndexResult Search(monostate, string_view) {
294281
return IndexResult{};
295282
}
@@ -299,32 +286,6 @@ struct BasicSearch {
299286
return {&indices_->GetAllDocs()};
300287
}
301288

302-
// "term": access field's text index or unify results from all text indices if no field is set
303-
IndexResult Search(const AstTermNode& node, string_view active_field) {
304-
std::string term = node.affix;
305-
bool strip_whitespace = true;
306-
307-
if (auto synonyms = indices_->GetSynonyms(); synonyms) {
308-
if (auto group_id = synonyms->GetGroupToken(term); group_id) {
309-
term = *group_id;
310-
strip_whitespace = false;
311-
}
312-
}
313-
314-
if (!active_field.empty()) {
315-
if (auto* index = GetIndex<TextIndex>(active_field); index)
316-
return index->Matching(term, strip_whitespace);
317-
return IndexResult{};
318-
}
319-
320-
vector<TextIndex*> selected_indices = indices_->GetAllTextIndices();
321-
auto mapping = [&term, strip_whitespace](TextIndex* index) {
322-
return index->Matching(term, strip_whitespace);
323-
};
324-
325-
return UnifyResults(GetSubResults(selected_indices, mapping), LogicOp::OR);
326-
}
327-
328289
IndexResult Search(const AstStarFieldNode& node, string_view active_field) {
329290
// Try to get a sort index first, as `@field:*` might imply wanting sortable behavior
330291
BaseSortIndex* sort_index = indices_->GetSortIndex(active_field);
@@ -337,7 +298,7 @@ struct BasicSearch {
337298
return base_index ? IndexResult{base_index->GetAllDocsWithNonNullValues()} : IndexResult{};
338299
}
339300

340-
IndexResult Search(const AstPrefixNode& node, string_view active_field) {
301+
template <TagType T> IndexResult Search(const AstAffixNode<T>& node, string_view active_field) {
341302
vector<TextIndex*> indices;
342303
if (!active_field.empty()) {
343304
if (auto* index = GetIndex<TextIndex>(active_field); index)
@@ -349,21 +310,42 @@ struct BasicSearch {
349310
}
350311

351312
auto mapping = [&node, this](TextIndex* index) {
352-
return CollectPrefixMatches(index, node.affix);
313+
if constexpr (T == TagType::PREFIX)
314+
return CollectMatches(index, node.affix, &TextIndex::MatchPrefix);
315+
else if constexpr (T == TagType::SUFFIX)
316+
return CollectMatches(index, node.affix, &TextIndex::MatchSuffix);
317+
else if constexpr (T == TagType::INFIX)
318+
return CollectMatches(index, node.affix, &TextIndex::MatchInfix);
319+
else
320+
return vector<DocId>{};
353321
};
354322
return UnifyResults(GetSubResults(indices, mapping), LogicOp::OR);
355323
}
356324

357-
IndexResult Search(const AstSuffixNode& node, string_view active_field) {
358-
// TODO: Implement full text search for suffix
359-
error_ = "Not implemented";
360-
return IndexResult{};
361-
}
325+
// "term": access field's text index or unify results from all text indices if no field is set
326+
IndexResult Search(const AstAffixNode<TagType::REGULAR> node, string_view active_field) {
327+
std::string term = node.affix;
328+
bool strip_whitespace = true;
362329

363-
IndexResult Search(const AstInfixNode& node, string_view active_field) {
364-
// TODO: Implement full text search for infix
365-
error_ = "Not implemented";
366-
return IndexResult{};
330+
if (auto synonyms = indices_->GetSynonyms(); synonyms) {
331+
if (auto group_id = synonyms->GetGroupToken(term); group_id) {
332+
term = *group_id;
333+
strip_whitespace = false;
334+
}
335+
}
336+
337+
if (!active_field.empty()) {
338+
if (auto* index = GetIndex<TextIndex>(active_field); index)
339+
return index->Matching(term, strip_whitespace);
340+
return IndexResult{};
341+
}
342+
343+
vector<TextIndex*> selected_indices = indices_->GetAllTextIndices();
344+
auto mapping = [&term, strip_whitespace](TextIndex* index) {
345+
return index->Matching(term, strip_whitespace);
346+
};
347+
348+
return UnifyResults(GetSubResults(selected_indices, mapping), LogicOp::OR);
367349
}
368350

369351
// [range]: access field's numeric index
@@ -411,13 +393,13 @@ struct BasicSearch {
411393
return tag_index->Matching(term.affix);
412394
},
413395
[tag_index, this](const AstPrefixNode& prefix) {
414-
return CollectPrefixMatches(tag_index, prefix.affix);
396+
return CollectMatches(tag_index, prefix.affix, &TagIndex::MatchPrefix);
415397
},
416398
[tag_index, this](const AstSuffixNode& suffix) {
417-
return CollectSuffixMatches(tag_index, suffix.affix);
399+
return CollectMatches(tag_index, suffix.affix, &TagIndex::MatchSuffix);
418400
},
419401
[tag_index, this](const AstInfixNode& infix) {
420-
return CollectInfixMatches(tag_index, infix.affix);
402+
return CollectMatches(tag_index, infix.affix, &TagIndex::MatchInfix);
421403
}};
422404
auto mapping = [ov](const auto& tag) { return visit(ov, tag); };
423405
return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR);
@@ -570,9 +552,12 @@ void FieldIndices::CreateIndices(PMR_NS::memory_resource* mr) {
570552
continue;
571553

572554
switch (field_info.type) {
573-
case SchemaField::TEXT:
574-
indices_[field_ident] = make_unique<TextIndex>(mr, &options_.stopwords, synonyms_);
555+
case SchemaField::TEXT: {
556+
const auto& tparams = std::get<SchemaField::TextParams>(field_info.special_params);
557+
indices_[field_ident] =
558+
make_unique<TextIndex>(mr, &options_.stopwords, synonyms_, tparams.with_suffixtrie);
575559
break;
560+
}
576561
case SchemaField::NUMERIC:
577562
indices_[field_ident] = make_unique<NumericIndex>(mr);
578563
break;

0 commit comments

Comments
 (0)