Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/core/search/ast_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) {
this->filter = make_unique<AstNode>(std::move(filter));
}

bool AstKnnNode::Filter() const {
return filter == nullptr;
}

} // namespace dfly::search

namespace std {
Expand Down
2 changes: 2 additions & 0 deletions src/core/search/ast_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ struct AstKnnNode {
OwnedFtVector vec;
std::string score_alias;
std::optional<float> ef_runtime;

bool Filter() const;
};

using NodeVariants =
Expand Down
29 changes: 21 additions & 8 deletions src/core/search/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,24 @@
#include "absl/container/flat_hash_set.h"
#include "base/pmr/memory_resource.h"
#include "core/string_map.h"
#include "server/tx_base.h"

namespace dfly::search {

using DocId = uint32_t;
using GlobalDocId = uint64_t;

inline GlobalDocId CreateGlobalDocId(ShardId shard_id, DocId local_doc_id) {
return ((uint64_t)shard_id << 32) | local_doc_id;
}

inline ShardId GlobalDocIdShardId(GlobalDocId id) {
return (id >> 32);
}

inline search::DocId GlobalDocIdLocalId(GlobalDocId id) {
return (id)&0xFFFF;
}

enum class VectorSimilarity { L2, IP, COSINE };

Expand Down Expand Up @@ -79,16 +93,16 @@ struct DocumentAccessor {
//
// Queries should be done directly on subclasses with their distinc
// query functions. All results for all index types should be sorted.
struct BaseIndex {
template <typename T> struct BaseIndex {
virtual ~BaseIndex() = default;

// Returns true if the document was added / indexed
virtual bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
virtual void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
virtual bool Add(T id, const DocumentAccessor& doc, std::string_view field) = 0;
virtual void Remove(T id, const DocumentAccessor& doc, std::string_view field) = 0;

// Returns documents that have non-null values for this field (used for @field:* queries)
// Result must be sorted
virtual std::vector<DocId> GetAllDocsWithNonNullValues() const = 0;
virtual std::vector<T> GetAllDocsWithNonNullValues() const = 0;

/* Called at the end of indexes rebuilding after all initial Add calls are done.
Some indices may need to finalize internal structures. See RangeTree for example. */
Expand All @@ -97,10 +111,9 @@ struct BaseIndex {
};

// Base class for type-specific sorting indices.
struct BaseSortIndex : BaseIndex {
virtual SortableValue Lookup(DocId doc) const = 0;
virtual std::vector<SortableValue> Sort(std::vector<DocId>* ids, size_t limit,
bool desc) const = 0;
template <typename T> struct BaseSortIndex : BaseIndex<T> {
virtual SortableValue Lookup(T doc) const = 0;
virtual std::vector<SortableValue> Sort(std::vector<T>* ids, size_t limit, bool desc) const = 0;
};

/* Used in iterators of inverse indices.
Expand Down
189 changes: 119 additions & 70 deletions src/core/search/indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
#include <absl/strings/str_split.h>

#include <boost/iterator/function_output_iterator.hpp>
#include <shared_mutex>

#include "core/search/base.h"
#include "core/search/vector_utils.h"
#include "util/fibers/synchronization.h"

#define UNI_ALGO_DISABLE_NFKC_NFKD

Expand Down Expand Up @@ -128,6 +133,16 @@ std::optional<GeoIndex::point> GetGeoPoint(const DocumentAccessor& doc, string_v
return GeoIndex::point{lon, lat};
}

template <typename Q, typename T = GlobalDocId> vector<pair<float, T>> QueueToVec(Q queue) {
vector<pair<float, T>> out(queue.size());
size_t idx = out.size();
while (!queue.empty()) {
out[--idx] = queue.top();
queue.pop();
}
return out;
}

}; // namespace

class RangeTreeAdapter : public NumericIndex::RangeTreeBase {
Expand Down Expand Up @@ -473,14 +488,12 @@ absl::flat_hash_set<std::string> TagIndex::Tokenize(std::string_view value) cons
return NormalizeTags(value, case_sensitive_, separator_);
}

BaseVectorIndex::BaseVectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim} {
template <typename T>
BaseVectorIndex<T>::BaseVectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim} {
}

std::pair<size_t /*dim*/, VectorSimilarity> BaseVectorIndex::Info() const {
return {dim_, sim_};
}

bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
template <typename T>
bool BaseVectorIndex<T>::Add(T id, const DocumentAccessor& doc, std::string_view field) {
auto vector = doc.GetVector(field);
if (!vector)
return false;
Expand All @@ -494,63 +507,73 @@ bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_vie
return true;
}

FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params,
ShardNoOpVectorIndex::ShardNoOpVectorIndex(const SchemaField::VectorParams& params)
: BaseVectorIndex<DocId>{params.dim, params.sim} {
}

FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params, ShardId shard_set_size,
PMR_NS::memory_resource* mr)
: BaseVectorIndex{params.dim, params.sim}, entries_{mr} {
: BaseVectorIndex<GlobalDocId>{params.dim, params.sim},
entries_{mr},
shard_vector_locks_(shard_set_size) {
DCHECK(!params.use_hnsw);
entries_.reserve(params.capacity * params.dim);
entries_.resize(shard_set_size);
for (size_t i = 0; i < shard_set_size; i++) {
entries_[i].reserve(params.capacity * params.dim);
}
}

void FlatVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
DCHECK_LE(id * dim_, entries_.size());
if (id * dim_ == entries_.size())
entries_.resize((id + 1) * dim_);

// TODO: Let get vector write to buf itself
void FlatVectorIndex::AddVector(GlobalDocId id,
const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) {
auto shard_id = search::GlobalDocIdShardId(id);
auto shard_doc_id = search::GlobalDocIdLocalId(id);
DCHECK_LE(shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_, entries_[shard_id].size());
if (shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_ == entries_[shard_id].size()) {
unique_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
entries_[shard_id].resize((shard_doc_id + 1) * BaseVectorIndex<GlobalDocId>::dim_);
}
if (vector) {
memcpy(&entries_[id * dim_], vector.get(), dim_ * sizeof(float));
memcpy(&entries_[shard_id][shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_], vector.get(),
BaseVectorIndex<GlobalDocId>::dim_ * sizeof(float));
}
}

void FlatVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
void FlatVectorIndex::Remove(GlobalDocId id, const DocumentAccessor& doc, string_view field) {
// noop
}

const float* FlatVectorIndex::Get(DocId doc) const {
return &entries_[doc * dim_];
}

std::vector<DocId> FlatVectorIndex::GetAllDocsWithNonNullValues() const {
std::vector<DocId> result;
std::vector<std::pair<float, GlobalDocId>> FlatVectorIndex::Knn(float* target, size_t k) const {
std::priority_queue<std::pair<float, search::GlobalDocId>> queue;

size_t num_vectors = entries_.size() / dim_;
result.reserve(num_vectors);
for (size_t shard_id = 0; shard_id < entries_.size(); shard_id++) {
shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
size_t num_vectors = entries_[shard_id].size() / BaseVectorIndex<GlobalDocId>::dim_;
for (GlobalDocId id = 0; id < num_vectors; ++id) {
const float* vec = &entries_[shard_id][id * dim_];
float dist = VectorDistance(target, vec, dim_, sim_);
queue.emplace(dist, CreateGlobalDocId(shard_id, id));
}
}

for (DocId id = 0; id < num_vectors; ++id) {
// Check if the vector is not zero (all elements are 0)
// TODO: Valid vector can contain 0s, we should use a better approach
const float* vec = Get(id);
bool is_zero_vector = true;
return QueueToVec(queue);
}

// TODO: Consider don't use check for zero vector
for (size_t i = 0; i < dim_; ++i) {
if (vec[i] != 0.0f) { // TODO: Consider using a threshold for float comparison
is_zero_vector = false;
break;
}
}
std::vector<std::pair<float, GlobalDocId>> FlatVectorIndex::Knn(
float* target, size_t k, const std::vector<FilterShardDocs>& allowed_docs) const {
std::priority_queue<std::pair<float, search::GlobalDocId>> queue;

if (!is_zero_vector) {
result.push_back(id);
for (size_t shard_id = 0; shard_id < allowed_docs.size(); shard_id++) {
shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
for (auto& shard_doc_id : allowed_docs[shard_id]) {
const float* vec = &entries_[shard_id][shard_doc_id * dim_];
float dist = VectorDistance(target, vec, dim_, sim_);
queue.emplace(dist, CreateGlobalDocId(shard_id, shard_doc_id));
}
}

// Result is already sorted by id, no need to sort again
// Also it has no duplicates
return result;
return QueueToVec(queue);
}

struct HnswlibAdapter {
template <typename T> struct HnswlibAdapter {
// Default setting of hnswlib/hnswalg
constexpr static size_t kDefaultEfRuntime = 10;

Expand All @@ -560,34 +583,45 @@ struct HnswlibAdapter {
100 /* seed*/} {
}

void Add(const float* data, DocId id) {
if (world_.cur_element_count + 1 >= world_.max_elements_)
world_.resizeIndex(world_.cur_element_count * 2);
world_.addPoint(data, id);
void Add(const float* data, T id) {
while (true) {
try {
absl::ReaderMutexLock lock(&resize_mutex_);
world_.addPoint(data, id);
return;
} catch (const std::exception& e) {
std::string error_msg = e.what();
if (absl::StrContains(error_msg, "The number of elements exceeds the specified limit")) {
ResizeIfFull();
continue;
}
throw e;
}
}
}

void Remove(DocId id) {
void Remove(T id) {
try {
world_.markDelete(id);
} catch (const std::exception& e) {
}
}

vector<pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef) {
vector<pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef) {
world_.setEf(ef.value_or(kDefaultEfRuntime));
return QueueToVec(world_.searchKnn(target, k));
}

vector<pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
const vector<DocId>& allowed) {
vector<pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef,
const vector<T>& allowed) {
struct BinsearchFilter : hnswlib::BaseFilterFunctor {
virtual bool operator()(hnswlib::labeltype id) {
return binary_search(allowed->begin(), allowed->end(), id);
}

BinsearchFilter(const vector<DocId>* allowed) : allowed{allowed} {
BinsearchFilter(const vector<T>* allowed) : allowed{allowed} {
}
const vector<DocId>* allowed;
const vector<T>* allowed;
};

world_.setEf(ef.value_or(kDefaultEfRuntime));
Expand All @@ -609,46 +643,61 @@ struct HnswlibAdapter {
return visit([](auto& space) -> hnswlib::SpaceInterface<float>* { return &space; }, space_);
}

template <typename Q> static vector<pair<float, DocId>> QueueToVec(Q queue) {
vector<pair<float, DocId>> out(queue.size());
size_t idx = out.size();
while (!queue.empty()) {
out[--idx] = queue.top();
queue.pop();
void ResizeIfFull() {
{
absl::ReaderMutexLock lock(&resize_mutex_);
if (world_.getCurrentElementCount() < world_.getMaxElements() ||
(world_.allow_replace_deleted_ && world_.getDeletedCount() > 0)) {
return;
}
}
try {
absl::WriterMutexLock lock(&resize_mutex_);
if (world_.getCurrentElementCount() == world_.getMaxElements() &&
(!world_.allow_replace_deleted_ || world_.getDeletedCount() == 0)) {
auto max_elements = world_.getMaxElements();
world_.resizeIndex(max_elements * 2);
LOG(INFO) << "Resizing HNSW Index, current size: " << max_elements
<< ", expand by: " << max_elements * 2;
}
} catch (const std::exception& e) {
throw e;
}
return out;
}

SpaceUnion space_;
hnswlib::HierarchicalNSW<float> world_;
absl::Mutex resize_mutex_;
};

HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
: BaseVectorIndex{params.dim, params.sim}, adapter_{make_unique<HnswlibAdapter>(params)} {
: BaseVectorIndex<GlobalDocId>{params.dim, params.sim},
adapter_{make_unique<HnswlibAdapter<GlobalDocId>>(params)} {
DCHECK(params.use_hnsw);
// TODO: Patch hnsw to use MR
}

HnswVectorIndex::~HnswVectorIndex() {
}

void HnswVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
void HnswVectorIndex::AddVector(GlobalDocId id,
const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) {
if (vector) {
adapter_->Add(vector.get(), id);
}
}

std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k,
std::optional<size_t> ef) const {
std::vector<std::pair<float, GlobalDocId>> HnswVectorIndex::Knn(float* target, size_t k,
std::optional<size_t> ef) const {
return adapter_->Knn(target, k, ef);
}
std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k,
std::optional<size_t> ef,
const std::vector<DocId>& allowed) const {

std::vector<std::pair<float, GlobalDocId>> HnswVectorIndex::Knn(
float* target, size_t k, std::optional<size_t> ef,
const std::vector<GlobalDocId>& allowed) const {
return adapter_->Knn(target, k, ef, allowed);
}

void HnswVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
void HnswVectorIndex::Remove(GlobalDocId id, const DocumentAccessor& doc, string_view field) {
adapter_->Remove(id);
}

Expand Down
Loading
Loading