diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 833dbc7c6c5c..fb945eacdbea 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -24,7 +24,7 @@ add_library(dfly_core allocation_tracker.cc bloom.cc compact_object.cc dense_set dragonfly_core.cc extent_tree.cc huff_coder.cc interpreter.cc glob_matcher.cc mi_memory_resource.cc qlist.cc sds_utils.cc segment_allocator.cc score_map.cc small_string.cc sorted_map.cc task_queue.cc - tx_queue.cc string_set.cc string_map.cc top_keys.cc detail/bitpacking.cc) + tx_queue.cc string_set.cc string_map.cc top_keys.cc detail/bitpacking.cc prob/cuckoo_filter.cc) cxx_link(dfly_core base absl::flat_hash_map absl::str_format redis_lib TRDP::lua lua_modules fibers2 ${SEARCH_LIB} jsonpath OpenSSL::Crypto TRDP::dconv TRDP::lz4) diff --git a/src/core/compact_object.cc b/src/core/compact_object.cc index 0c1148a68123..8fae2434aa30 100644 --- a/src/core/compact_object.cc +++ b/src/core/compact_object.cc @@ -13,6 +13,7 @@ extern "C" { #include "redis/quicklist.h" #include "redis/redis_aux.h" #include "redis/stream.h" +#include "redis/tdigest.h" #include "redis/util.h" #include "redis/zmalloc.h" // for non-string objects. #include "redis/zset.h" @@ -26,6 +27,7 @@ extern "C" { #include "core/bloom.h" #include "core/detail/bitpacking.h" #include "core/huff_coder.h" +#include "core/prob/cuckoo_filter.h" #include "core/qlist.h" #include "core/sorted_map.h" #include "core/string_map.h" @@ -398,6 +400,12 @@ static_assert(sizeof(CompactObj) == 18); namespace detail { +size_t MallocUsedTDigest(const td_histogram_t* tdigest) { + size_t size = sizeof(tdigest); + size += (2 * (tdigest->cap * sizeof(double))); + return size; +} + size_t RobjWrapper::MallocUsed(bool slow) const { if (!inner_obj_) return 0; @@ -418,6 +426,8 @@ size_t RobjWrapper::MallocUsed(bool slow) const { return MallocUsedZSet(encoding_, inner_obj_); case OBJ_STREAM: return slow ? MallocUsedStream((stream*)inner_obj_) : sz_; + case OBJ_TDIGEST: + return MallocUsedTDigest((td_histogram_t*)inner_obj_); default: LOG(FATAL) << "Not supported " << type_; @@ -477,11 +487,17 @@ size_t RobjWrapper::Size() const { case OBJ_STREAM: // Size mean malloc bytes for streams return sz_; + case OBJ_TDIGEST: + return 0; default:; } return 0; } +inline void FreeObjTDigest(void* ptr) { + td_free((td_histogram*)ptr); +} + void RobjWrapper::Free(MemoryResource* mr) { if (!inner_obj_) return; @@ -511,6 +527,9 @@ void RobjWrapper::Free(MemoryResource* mr) { case OBJ_STREAM: FreeObjStream(inner_obj_); break; + case OBJ_TDIGEST: + FreeObjTDigest(inner_obj_); + break; default: LOG(FATAL) << "Unknown object type"; break; @@ -603,6 +622,9 @@ bool RobjWrapper::DefragIfNeeded(float ratio) { return do_defrag(DefragSet); } else if (type() == OBJ_ZSET) { return do_defrag(DefragZSet); + } else if (type() == OBJ_TDIGEST) { + // TODO implement this + return false; } return false; } @@ -826,6 +848,14 @@ size_t CompactObj::Size() const { DCHECK_EQ(mask_bits_.encoding, NONE_ENC); raw_size = u_.sbf->current_size(); break; + case TOPK_TAG: + DCHECK_EQ(mask_bits_.encoding, NONE_ENC); + raw_size = 0; + break; + case CUCKOO_FILTER_TAG: + DCHECK_EQ(mask_bits_.encoding, NONE_ENC); + raw_size = GetCuckooFilter()->NumItems(); + break; default: LOG(DFATAL) << "Should not reach " << int(taglen_); } @@ -892,6 +922,14 @@ CompactObjType CompactObj::ObjType() const { return OBJ_SBF; } + if (taglen_ == TOPK_TAG) { + return OBJ_TOPK; + } + + if (taglen_ == CUCKOO_FILTER_TAG) { + return OBJ_CUCKOO_FILTER; + } + LOG(FATAL) << "TBD " << int(taglen_); return kInvalidCompactObjType; } @@ -995,11 +1033,49 @@ void CompactObj::SetSBF(uint64_t initial_capacity, double fp_prob, double grow_f } } +void CompactObj::SetTopK(size_t topk, size_t width, size_t depth, double decay) { + TopKeys::Options opts; + size_t total_buckets = 4; + // Heuristic + if (topk > 4) { + total_buckets = topk / 4; + } + opts.buckets = total_buckets; + opts.depth = 4; + // fingerprints = buckets * depth = topk + opts.decay_base = decay; + // We need this so we can set the key. The problem with this is upon cell reset, + // we don't set the key and a query for TopK won't return that key because we never set it. + opts.min_key_count_to_record = 0; + SetMeta(TOPK_TAG); + u_.topk = AllocateMR(opts); +} + +void CompactObj::SetCuckooFilter(prob::CuckooFilter filter) { + SetMeta(CUCKOO_FILTER_TAG); + u_.cuckoo_filter = AllocateMR(std::move(filter)); +} + +prob::CuckooFilter* CompactObj::GetCuckooFilter() { + DCHECK(taglen_ == CUCKOO_FILTER_TAG); + return u_.cuckoo_filter; +} + +const prob::CuckooFilter* CompactObj::GetCuckooFilter() const { + DCHECK(taglen_ == CUCKOO_FILTER_TAG); + return u_.cuckoo_filter; +} + SBF* CompactObj::GetSBF() const { DCHECK_EQ(SBF_TAG, taglen_); return u_.sbf; } +TopKeys* CompactObj::GetTopK() const { + DCHECK_EQ(TOPK_TAG, taglen_); + return u_.topk; +} + void CompactObj::SetString(std::string_view str) { CHECK(!IsExternal()); mask_bits_.encoding = NONE_ENC; @@ -1090,6 +1166,9 @@ bool CompactObj::DefragIfNeeded(float ratio) { return false; case EXTERNAL_TAG: return false; + case CUCKOO_FILTER_TAG: + // TODO: implement this + return false; default: // This is the case when the object is at inline_str return false; @@ -1101,7 +1180,8 @@ bool CompactObj::HasAllocated() const { (taglen_ == ROBJ_TAG && u_.r_obj.inner_obj() == nullptr)) return false; - DCHECK(taglen_ == ROBJ_TAG || taglen_ == SMALL_TAG || taglen_ == JSON_TAG || taglen_ == SBF_TAG); + DCHECK(taglen_ == ROBJ_TAG || taglen_ == SMALL_TAG || taglen_ == JSON_TAG || taglen_ == SBF_TAG || + taglen_ == TOPK_TAG || taglen_ == CUCKOO_FILTER_TAG); return true; } @@ -1295,6 +1375,10 @@ void CompactObj::Free() { } } else if (taglen_ == SBF_TAG) { DeleteMR(u_.sbf); + } else if (taglen_ == TOPK_TAG) { + DeleteMR(u_.topk); + } else if (taglen_ == CUCKOO_FILTER_TAG) { + DeleteMR(u_.cuckoo_filter); } else { LOG(FATAL) << "Unsupported tag " << int(taglen_); } @@ -1327,6 +1411,12 @@ size_t CompactObj::MallocUsed(bool slow) const { if (taglen_ == SBF_TAG) { return u_.sbf->MallocUsed(); } + if (taglen_ == TOPK_TAG) { + return 0; + } + if (taglen_ == CUCKOO_FILTER_TAG) { + return GetCuckooFilter()->UsedBytes(); + } LOG(DFATAL) << "should not reach"; return 0; } diff --git a/src/core/compact_object.h b/src/core/compact_object.h index e136fa5a724d..4e518c9add20 100644 --- a/src/core/compact_object.h +++ b/src/core/compact_object.h @@ -15,6 +15,7 @@ #include "core/mi_memory_resource.h" #include "core/small_string.h" #include "core/string_or_view.h" +#include "core/top_keys.h" namespace dfly { @@ -27,6 +28,11 @@ constexpr unsigned kEncodingJsonFlat = 1; class SBF; +namespace prob { +class CuckooFilter; +class CuckooReserveParams; +} // namespace prob + namespace detail { // redis objects or blobs of upto 4GB size. @@ -123,6 +129,8 @@ class CompactObj { EXTERNAL_TAG = 20, JSON_TAG = 21, SBF_TAG = 22, + TOPK_TAG = 23, + CUCKOO_FILTER_TAG = 24, }; // String encoding types. @@ -311,6 +319,13 @@ class CompactObj { void SetSBF(uint64_t initial_capacity, double fp_prob, double grow_factor); SBF* GetSBF() const; + void SetTopK(size_t topk, size_t width, size_t depth, double decay); + TopKeys* GetTopK() const; + + void SetCuckooFilter(prob::CuckooFilter filter); + prob::CuckooFilter* GetCuckooFilter(); + const prob::CuckooFilter* GetCuckooFilter() const; + // dest must have at least Size() bytes available void GetString(char* dest) const; @@ -479,6 +494,8 @@ class CompactObj { // using 'packed' to reduce alignement of U to 1. JsonWrapper json_obj __attribute__((packed)); SBF* sbf __attribute__((packed)); + TopKeys* topk __attribute__((packed)); + prob::CuckooFilter* cuckoo_filter __attribute__((packed)); int64_t ival __attribute__((packed)); ExternalPtr ext_ptr; diff --git a/src/core/compact_object_test.cc b/src/core/compact_object_test.cc index 633b9c4d53c3..63973a46e9b2 100644 --- a/src/core/compact_object_test.cc +++ b/src/core/compact_object_test.cc @@ -22,6 +22,8 @@ extern "C" { #include "redis/intset.h" #include "redis/redis_aux.h" #include "redis/stream.h" +#include "redis/td_malloc.h" +#include "redis/tdigest.h" #include "redis/zmalloc.h" } @@ -682,6 +684,24 @@ TEST_F(CompactObjectTest, HuffMan) { } } +TEST_F(CompactObjectTest, TDigst) { + // Allocators + ASSERT_EQ(zmalloc, __td_malloc); + ASSERT_EQ(zcalloc, __td_calloc); + ASSERT_EQ(zrealloc, __td_realloc); + ASSERT_EQ(zfree, __td_free); + + // Basic usage + td_histogram_t* hist = td_new(10); + cobj_.InitRobj(OBJ_TDIGEST, 0, hist); + ASSERT_EQ(cobj_.GetRobjWrapper()->type(), OBJ_TDIGEST); + ASSERT_EQ(cobj_.RObjPtr(), hist); + ASSERT_EQ(0, hist->unmerged_weight); + ASSERT_EQ(0, hist->merged_weight); + ASSERT_EQ(td_add(hist, 0.0, 1), 0); + cobj_.Reset(); +} + static void ascii_pack_naive(const char* ascii, size_t len, uint8_t* bin) { const char* end = ascii + len; diff --git a/src/core/detail/fixed_array.h b/src/core/detail/fixed_array.h new file mode 100644 index 000000000000..a0d557f64500 --- /dev/null +++ b/src/core/detail/fixed_array.h @@ -0,0 +1,115 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include +#include +#include + +#include "base/logging.h" + +namespace dfly::detail { + +/* Analogous to absl::FixedArray but uses a memory resource for allocation. */ +template class PmrFixedArray { + private: + static_assert(std::is_default_constructible_v, + "PmrFixedArray requires default-constructible T"); + + static_assert(std::is_nothrow_destructible_v, + "PmrFixedArray requires nothrow-destructible T"); + + using Allocator = std::pmr::polymorphic_allocator; + + public: + PmrFixedArray(size_t size, std::pmr::memory_resource* mr); + + PmrFixedArray(const PmrFixedArray&) = delete; + PmrFixedArray& operator=(const PmrFixedArray&) = delete; + + PmrFixedArray(PmrFixedArray&& other) noexcept; + PmrFixedArray& operator=(PmrFixedArray&&) = delete; + + ~PmrFixedArray() noexcept; + + T& operator[](size_t i); + const T& operator[](size_t i) const; + + size_t size() const; + + private: + void Reset(); + + private: + size_t size_ = 0; + T* data_ = nullptr; + Allocator alloc_; +}; + +// Implementation +/******************************************************************/ +template +PmrFixedArray::PmrFixedArray(size_t size, std::pmr::memory_resource* mr) + : size_(size), alloc_(mr) { + DCHECK(mr); + if (!size_) { + return; + } + + data_ = alloc_.allocate(size_); + + // Construct elements one by one, with rollback on exception + size_t constructed = 0; + try { + for (; constructed < size_; ++constructed) { + ::new (static_cast(data_ + constructed)) T(); + } + } catch (...) { + // Destroy all already-constructed objects + std::destroy_n(data_, constructed); + alloc_.deallocate(data_, size_); + Reset(); + throw; + } +} + +template +PmrFixedArray::PmrFixedArray(PmrFixedArray&& other) noexcept + : size_(other.size_), data_(other.data_), alloc_(std::move(other.alloc_)) { + other.Reset(); +} + +template PmrFixedArray::~PmrFixedArray() noexcept { + DCHECK((size_ > 0 && data_) || (size_ == 0 && !data_)); + if (!data_) { + return; + } + + // Deallocate memory (should not throw) + std::destroy_n(data_, size_); + alloc_.deallocate(data_, size_); + Reset(); +} + +template T& PmrFixedArray::operator[](size_t i) { + CHECK(i < size_); + return data_[i]; +} + +template const T& PmrFixedArray::operator[](size_t i) const { + CHECK(i < size_); + return data_[i]; +} + +template size_t PmrFixedArray::size() const { + return size_; +} + +template void PmrFixedArray::Reset() { + size_ = 0; + data_ = nullptr; +} + +} // namespace dfly::detail diff --git a/src/core/prob/cuckoo_filter.cc b/src/core/prob/cuckoo_filter.cc new file mode 100644 index 000000000000..df962935ced2 --- /dev/null +++ b/src/core/prob/cuckoo_filter.cc @@ -0,0 +1,228 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "core/prob/cuckoo_filter.h" + +#include + +#include +#include + +#include "absl/numeric/bits.h" +#include "glog/logging.h" + +namespace dfly::prob { + +namespace { + +bool IsPowerOfTwo(uint64_t n) { + return absl::has_single_bit(n); +} + +uint64_t GetNextPowerOfTwo(uint64_t n) { + return absl::bit_ceil(n); +} + +uint8_t GetFingerprint(uint64_t hash) { + return static_cast(hash % 255 + 1); +} + +uint64_t AltIndex(uint8_t fp, uint64_t index) { + return index ^ (static_cast(fp) * 0x5bd1e995); +} + +} // anonymous namespace + +std::optional CuckooFilter::Init(const CuckooReserveParams& params, + std::pmr::memory_resource* mr) { + CuckooFilter filter{params, mr}; + if (!filter.AddNewSubFilter()) { + return std::nullopt; + } + return filter; +} + +CuckooFilter::Hash CuckooFilter::GetHash(std::string_view item) { + return XXH3_64bits_withSeed(item.data(), item.size(), 0xc6a4a7935bd1e995ULL); +} + +CuckooFilter::CuckooFilter(const CuckooReserveParams& params, std::pmr::memory_resource* mr) + : bucket_size_(params.bucket_size), + max_iterations_(params.max_iterations), + expansion_(GetNextPowerOfTwo(params.expansion)), + filters_(mr), + mr_(mr) { + if (bucket_size_) { + num_buckets_ = GetNextPowerOfTwo(params.capacity / bucket_size_); + if (!num_buckets_) { + num_buckets_ = 1; + } + } else { + num_buckets_ = 1; + } + + DCHECK(IsPowerOfTwo(num_buckets_)); +} + +bool CuckooFilter::AddNewSubFilter() { + static constexpr uint64_t kCfMaxNumBuckets = (1ULL << 56) - 1; + DCHECK(filters_.size() < std::numeric_limits::max()); + + const uint64_t growth = std::pow(expansion_, filters_.size()); + if (growth > kCfMaxNumBuckets / num_buckets_) { + return false; + } + + const uint64_t new_buckets_count = num_buckets_ * growth; + if (new_buckets_count > std::numeric_limits::max() / bucket_size_) { + return false; + } + + filters_.emplace_back(new_buckets_count * bucket_size_, mr_); + return true; +} + +CuckooFilter::LookupParams CuckooFilter::MakeLookupParams(uint64_t hash, uint64_t num_buckets) { + const uint8_t fp = GetFingerprint(hash); + const uint64_t h1 = hash % num_buckets; + return {fp, h1, AltIndex(fp, h1) % num_buckets}; +} + +bool CuckooFilter::Insert(Hash hash) { + LookupParams p = MakeLookupParams(hash, num_buckets_); + + for (int i = filters_.size() - 1; i >= 0; --i) { + SubFilter& f = filters_[i]; + for (uint64_t idx : GetIndexesInSubFilter(f, p)) { + for (uint8_t b = 0; b < bucket_size_; ++b) { + size_t offset = idx * bucket_size_ + b; + if (f[offset] == 0) { + f[offset] = p.fp; + ++num_items_; + return true; + } + } + } + } + + if (KOInsert(p, &filters_.back())) { + ++num_items_; + return true; + } + + if (expansion_ == 0) { + LOG(WARNING) << "Cuckoo filter is full, unable to allocate new subfilter due to expansion = 0"; + return false; + } + + if (!AddNewSubFilter()) { + return false; + } + + return Insert(hash); +} + +bool CuckooFilter::KOInsert(const LookupParams& p, SubFilter* sub_filter) { + const uint64_t num_buckets = GetNumBucketsInSubFilter(*sub_filter); + uint64_t idx = p.h1 % num_buckets; + Entry fp = p.fp; + + uint16_t victim_idx = 0; + for (uint16_t i = 0; i < max_iterations_; ++i) { + const size_t base = idx * bucket_size_; + const size_t victim_offset = base + victim_idx; + + std::swap((*sub_filter)[victim_offset], fp); + idx = AltIndex(fp, idx) % num_buckets; + + const uint16_t new_base = idx * bucket_size_; + for (uint8_t b = 0; b < bucket_size_; ++b) { + const size_t offset = new_base + b; + if ((*sub_filter)[offset] == 0) { + (*sub_filter)[offset] = fp; + return true; + } + } + + victim_idx = (victim_idx + 1) % bucket_size_; + } + + // Roll back + for (uint16_t i = 0; i < max_iterations_; ++i) { + victim_idx = (victim_idx + bucket_size_ - 1) % bucket_size_; + idx = AltIndex(fp, idx) % num_buckets; + + const size_t base = idx * bucket_size_; + const size_t victim_offset = base + victim_idx; + + std::swap((*sub_filter)[victim_offset], fp); + } + + return false; +} + +bool CuckooFilter::Exists(std::string_view item) const { + return Exists(GetHash(item)); +} + +bool CuckooFilter::Exists(Hash hash) const { + LookupParams p = MakeLookupParams(hash, num_buckets_); + + for (const auto& f : filters_) { + for (uint64_t idx : GetIndexesInSubFilter(f, p)) { + for (uint8_t b = 0; b < bucket_size_; ++b) { + size_t offset = idx * bucket_size_ + b; + if (f[offset] == p.fp) { + return true; + } + } + } + } + + return false; +} + +bool CuckooFilter::Delete(std::string_view item) { + const auto hash = GetHash(item); + LookupParams p = MakeLookupParams(hash, num_buckets_); + + for (int i = filters_.size() - 1; i >= 0; --i) { + SubFilter& f = filters_[i]; + for (uint64_t idx : GetIndexesInSubFilter(f, p)) { + const size_t base = idx * bucket_size_; + for (uint8_t b = 0; b < bucket_size_; ++b) { + const size_t offset = base + b; + if (f[offset] == p.fp) { + f[offset] = 0; + --num_items_; + ++num_deletes_; + return true; + } + } + } + } + + return false; +} + +uint64_t CuckooFilter::Count(std::string_view item) const { + const auto hash = GetHash(item); + LookupParams p = MakeLookupParams(hash, num_buckets_); + uint64_t count = 0; + + for (const SubFilter& f : filters_) { + for (uint64_t idx : GetIndexesInSubFilter(f, p)) { + const size_t base = idx * bucket_size_; + for (uint8_t b = 0; b < bucket_size_; ++b) { + if (f[base + b] == p.fp) { + ++count; + } + } + } + } + + return count; +} + +} // namespace dfly::prob diff --git a/src/core/prob/cuckoo_filter.h b/src/core/prob/cuckoo_filter.h new file mode 100644 index 000000000000..5679979f87f3 --- /dev/null +++ b/src/core/prob/cuckoo_filter.h @@ -0,0 +1,118 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include +#include +#include +#include + +#include "core/detail/fixed_array.h" + +namespace dfly::prob { + +struct CuckooReserveParams { + static constexpr uint8_t kDefaultBucketSize = 2; + static constexpr uint16_t kDefaultMaxIterations = 20; + static constexpr uint16_t kDefaultExpansion = 1; + + uint8_t bucket_size = kDefaultBucketSize; + uint16_t max_iterations = kDefaultMaxIterations; + uint16_t expansion = kDefaultExpansion; + uint64_t capacity; +}; + +class CuckooFilter { + private: + using Entry = uint8_t; + // SubFilter stores num_buckets * bucket_size entries. + using SubFilter = detail::PmrFixedArray; + + struct LookupParams { + Entry fp; // fingerprint + uint64_t h1; // first hash + uint64_t h2; // second hash + }; + + static LookupParams MakeLookupParams(uint64_t hash, uint64_t num_buckets); + + public: + static std::optional Init(const CuckooReserveParams& params, + std::pmr::memory_resource* mr); + + using Hash = uint64_t; + static Hash GetHash(std::string_view item); + + CuckooFilter(const CuckooFilter&) = delete; + CuckooFilter& operator=(const CuckooFilter&) = delete; + + CuckooFilter(CuckooFilter&&) = default; + CuckooFilter& operator=(CuckooFilter&&) = delete; + + ~CuckooFilter() = default; + + bool Insert(Hash hash); + + bool Exists(std::string_view item) const; + bool Exists(Hash hash) const; + + bool Delete(std::string_view item); + + uint64_t Count(std::string_view item) const; + + size_t UsedBytes() const; + size_t NumItems() const; + + private: + explicit CuckooFilter(const CuckooReserveParams& params, std::pmr::memory_resource* mr); + + // Inserts new subfilter with expansion ^ filters_.size() buckets. + bool AddNewSubFilter(); + + /* Attempts to insert the fingerprint by randomly evicting existing entries ("kick-out"). + The evicted fingerprint is recursively reinserted into its alternate bucket up to + max_iterations. */ + bool KOInsert(const LookupParams& p, SubFilter* sub_filter); + + std::array GetIndexesInSubFilter(const SubFilter& sub_filter, + const LookupParams& p) const; + + uint64_t GetNumBucketsInSubFilter(const SubFilter& sub_filter) const; + + private: + const uint8_t bucket_size_; + const uint16_t max_iterations_; + const uint16_t expansion_; + + uint64_t num_buckets_ = 0; + uint64_t num_items_ = 0; + uint64_t num_deletes_ = 0; + + std::pmr::vector filters_; + std::pmr::memory_resource* mr_; +}; + +// Implementation +/******************************************************************/ +inline std::array CuckooFilter::GetIndexesInSubFilter(const SubFilter& sub_filter, + const LookupParams& p) const { + const uint64_t num_buckets = GetNumBucketsInSubFilter(sub_filter); + return {p.h1 % num_buckets, p.h2 % num_buckets}; +} + +inline uint64_t CuckooFilter::GetNumBucketsInSubFilter(const SubFilter& sub_filter) const { + return sub_filter.size() / bucket_size_; +} + +inline size_t CuckooFilter::UsedBytes() const { + // TODO: there is a bug in the code + return filters_.capacity() * sizeof(SubFilter) + sizeof(CuckooFilter); +} + +inline size_t CuckooFilter::NumItems() const { + return num_items_; +} + +} // namespace dfly::prob diff --git a/src/core/top_keys.cc b/src/core/top_keys.cc index 0282d184c52b..f70a746c4f8b 100644 --- a/src/core/top_keys.cc +++ b/src/core/top_keys.cc @@ -21,10 +21,10 @@ TopKeys::TopKeys(Options options) } } -void TopKeys::Touch(std::string_view key) { - auto ResetCell = [&](Cell& cell, uint64_t fingerprint) { +void TopKeys::Touch(std::string_view key, size_t incr) { + auto ResetCell = [&](Cell& cell, uint64_t fingerprint, size_t size = 1) { cell.fingerprint = fingerprint; - cell.count = 1; + cell.count = size; cell.key.clear(); }; @@ -36,13 +36,16 @@ void TopKeys::Touch(std::string_view key) { Cell& cell = GetCell(id, bucket); if (cell.count == 0) { // No fingerprint in cell. - ResetCell(cell, fingerprint); + ResetCell(cell, fingerprint, incr); + if (incr > 1) { + cell.key = key; + } } else if (cell.fingerprint == fingerprint) { // Same fingerprint, simply increment count. // We could make sure that, if !cell.key.empty(), then key == cell.key.empty() here. However, // what do we do in case they are different? - ++cell.count; + cell.count += incr; if (cell.count >= options_.min_key_count_to_record && cell.key.empty()) { cell.key = key; @@ -51,7 +54,12 @@ void TopKeys::Touch(std::string_view key) { // Different fingerprint, apply exponential decay. const double rand = absl::Uniform(bitgen_, 0, 1.0); if (rand < std::pow(options_.decay_base, -static_cast(cell.count))) { - --cell.count; + if (incr != 1 && cell.count < incr) { + incr -= cell.count; + cell.count = 0; + } else { + cell.count -= incr; + } if (cell.count == 0) { ResetCell(cell, fingerprint); } @@ -88,4 +96,15 @@ const TopKeys::Cell& TopKeys::GetCell(uint32_t d, uint32_t bucket) const { return fingerprints_[d * options_.buckets + bucket]; } +void TopKeys::Query(absl::flat_hash_map* keys) { + for (unsigned array = 0; array < options_.depth; ++array) { + for (unsigned bucket = 0; bucket < options_.buckets; ++bucket) { + const Cell& cell = GetCell(array, bucket); + if (!cell.key.empty() && keys->contains(cell.key)) { + keys->find(cell.key)->second = true; + } + } + } +} + } // end of namespace dfly diff --git a/src/core/top_keys.h b/src/core/top_keys.h index 5382c80f173e..3b8eb3beed25 100644 --- a/src/core/top_keys.h +++ b/src/core/top_keys.h @@ -9,6 +9,7 @@ #include #include #include + #include "base/random.h" namespace dfly { @@ -32,6 +33,7 @@ namespace dfly { class TopKeys { TopKeys(const TopKeys&) = delete; TopKeys& operator=(const TopKeys&) = delete; + public: struct Options { // HeavyKeeper options @@ -49,9 +51,17 @@ class TopKeys { explicit TopKeys(Options options); - void Touch(std::string_view key); + void Touch(std::string_view key, size_t incr = 1); absl::flat_hash_map GetTopKeys() const; + // Checks whether each item in the list exists in the current set of TopKeys + // If a key in keys exists in TopKeys, we set its bool to True + void Query(absl::flat_hash_map* keys); + + Options GetOptions() const { + return options_; + }; + private: // Each cell consists of a key-fingerprint, a count, and potentially the key itself, when it's // above options_.min_key_count_to_record. @@ -66,7 +76,7 @@ class TopKeys { Options options_; base::Xoroshiro128p bitgen_; - // fingerprints_'s size is options_.buckets * options_.arrays. Always access fields via GetCell(). + // fingerprints_'s size is options_.buckets * options_.depth. Always access fields via GetCell(). std::vector fingerprints_; }; diff --git a/src/facade/resp_expr.h b/src/facade/resp_expr.h index 52725e993acf..c799c67e7168 100644 --- a/src/facade/resp_expr.h +++ b/src/facade/resp_expr.h @@ -57,6 +57,10 @@ class RespExpr { : std::nullopt; } + double GetDouble() const { + return std::get(u); + } + size_t UsedMemory() const { return 0; } diff --git a/src/redis/CMakeLists.txt b/src/redis/CMakeLists.txt index 7dc19a719930..b9798dd2d510 100644 --- a/src/redis/CMakeLists.txt +++ b/src/redis/CMakeLists.txt @@ -8,11 +8,11 @@ else() set(ZMALLOC_DEPS "") endif() -add_library(redis_lib crc16.c crc64.c crcspeed.c debug.c intset.c geo.c +add_library(redis_lib crc16.c crc64.c crcspeed.c debug.c intset.c geo.c geohash.c geohash_helper.c t_zset.c listpack.c lzf_c.c lzf_d.c sds.c quicklist.c rax.c redis_aux.c t_stream.c - util.c ziplist.c hyperloglog.c ${ZMALLOC_SRC}) + util.c ziplist.c hyperloglog.c tdigest.c ${ZMALLOC_SRC}) cxx_link(redis_lib ${ZMALLOC_DEPS}) diff --git a/src/redis/dict.c b/src/redis/dict.c index 44deeeeb3bae..2bc524c11602 100644 --- a/src/redis/dict.c +++ b/src/redis/dict.c @@ -169,7 +169,7 @@ int _dictExpand(dict *d, unsigned long size, int* malloc_failed) if (*malloc_failed) return DICT_ERR; } else - new_ht_table = zcalloc(newsize*sizeof(dictEntry*)); + new_ht_table = zcalloc(newsize, sizeof(dictEntry*)); new_ht_used = 0; diff --git a/src/redis/redis_aux.h b/src/redis/redis_aux.h index 8ec55263e840..7618ede03108 100644 --- a/src/redis/redis_aux.h +++ b/src/redis/redis_aux.h @@ -6,6 +6,9 @@ /* redis.h auxiliary definitions */ /* the last one in object.h is OBJ_STREAM and it is 6, * this will add enough place for Redis types to grow */ +#define OBJ_CUCKOO_FILTER 12u +#define OBJ_TOPK 13U +#define OBJ_TDIGEST 14U #define OBJ_JSON 15U #define OBJ_SBF 16U diff --git a/src/redis/t-digest.LICENSE.md b/src/redis/t-digest.LICENSE.md new file mode 100644 index 000000000000..c36552ca387f --- /dev/null +++ b/src/redis/t-digest.LICENSE.md @@ -0,0 +1,21 @@ +# MIT License + +Copyright (c) 2019 Bob Rudis + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/redis/td_malloc.h b/src/redis/td_malloc.h new file mode 100644 index 000000000000..031c8f05dcfd --- /dev/null +++ b/src/redis/td_malloc.h @@ -0,0 +1,24 @@ +/** + * Adaptive histogram based on something like streaming k-means crossed with Q-digest. + * The implementation is a direct descendent of MergingDigest + * https://github.com/tdunning/t-digest/ + * + * Copyright (c) 2021 Redis, All rights reserved. + * + * Allocator selection. + * + * This file is used in order to change the t-digest allocator at compile time. + * Just define the following defines to what you want to use. Also add + * the include of your alternate allocator if needed (not needed in order + * to use the default libc allocator). */ +#ifndef TD_ALLOC_H +#define TD_ALLOC_H + +#include "zmalloc.h" + +#define __td_malloc zmalloc +#define __td_calloc zcalloc +#define __td_realloc zrealloc +#define __td_free zfree + +#endif diff --git a/src/redis/tdigest.c b/src/redis/tdigest.c new file mode 100644 index 000000000000..08879bce9328 --- /dev/null +++ b/src/redis/tdigest.c @@ -0,0 +1,674 @@ +#include +#include +#include +#include +#include "tdigest.h" +#include +#include + +#ifndef TD_MALLOC_INCLUDE +#define TD_MALLOC_INCLUDE "td_malloc.h" +#endif + +#include TD_MALLOC_INCLUDE + +#define __td_max(x, y) (((x) > (y)) ? (x) : (y)) +#define __td_min(x, y) (((x) < (y)) ? (x) : (y)) + +static inline double weighted_average_sorted(double x1, double w1, double x2, double w2) { + const double x = (x1 * w1 + x2 * w2) / (w1 + w2); + return __td_max(x1, __td_min(x, x2)); +} + +static inline bool _tdigest_long_long_add_safe(long long a, long long b) { + if (b < 0) { + return (a >= __LONG_LONG_MAX__ - b); + } else { + return (a <= __LONG_LONG_MAX__ - b); + } +} + +static inline double weighted_average(double x1, double w1, double x2, double w2) { + if (x1 <= x2) { + return weighted_average_sorted(x1, w1, x2, w2); + } else { + return weighted_average_sorted(x2, w2, x1, w1); + } +} + +inline static void swap(double *arr, int i, int j) { + const double temp = arr[i]; + arr[i] = arr[j]; + arr[j] = temp; +} + +inline static void swap_l(long long *arr, int i, int j) { + const long long temp = arr[i]; + arr[i] = arr[j]; + arr[j] = temp; +} + +static unsigned int partition(double *means, long long *weights, unsigned int start, + unsigned int end, unsigned int pivot_idx) { + const double pivotMean = means[pivot_idx]; + swap(means, pivot_idx, end); + swap_l(weights, pivot_idx, end); + + int i = start - 1; + + for (unsigned int j = start; j < end; j++) { + // If current element is smaller than the pivot + if (means[j] < pivotMean) { + // increment index of smaller element + i++; + swap(means, i, j); + swap_l(weights, i, j); + } + } + swap(means, i + 1, end); + swap_l(weights, i + 1, end); + return i + 1; +} + +/** + * Standard quick sort except that sorting rearranges parallel arrays + * + * @param means Values to sort on + * @param weights The auxillary values to sort. + * @param start The beginning of the values to sort + * @param end The value after the last value to sort + */ +static void td_qsort(double *means, long long *weights, unsigned int start, unsigned int end) { + if (start < end) { + // two elements can be directly compared + if ((end - start) == 1) { + if (means[start] > means[end]) { + swap(means, start, end); + swap_l(weights, start, end); + } + return; + } + // generating a random number as a pivot was very expensive vs the array size + // const unsigned int pivot_idx = start + rand()%(end - start + 1); + const unsigned int pivot_idx = (end + start) / 2; // central pivot + const unsigned int new_pivot_idx = partition(means, weights, start, end, pivot_idx); + if (new_pivot_idx > start) { + td_qsort(means, weights, start, new_pivot_idx - 1); + } + td_qsort(means, weights, new_pivot_idx + 1, end); + } +} + +static inline size_t cap_from_compression(double compression) { + if ((size_t)compression > ((SIZE_MAX / sizeof(double) / 6) - 10)) { + return 0; + } + + return (6 * (size_t)(compression)) + 10; +} + +static inline bool should_td_compress(td_histogram_t *h) { + return ((h->merged_nodes + h->unmerged_nodes) >= (h->cap - 1)); +} + +static inline int next_node(td_histogram_t *h) { return h->merged_nodes + h->unmerged_nodes; } + +int td_compress(td_histogram_t *h); + +static inline int _check_overflow(const double v) { + // double-precision overflow detected on h->unmerged_weight + if (v == INFINITY) { + return EDOM; + } + return 0; +} + +static inline int _check_td_overflow(const double new_unmerged_weight, + const double new_total_weight) { + // double-precision overflow detected on h->unmerged_weight + if (new_unmerged_weight == INFINITY) { + return EDOM; + } + if (new_total_weight == INFINITY) { + return EDOM; + } + const double denom = 2 * MM_PI * new_total_weight * log(new_total_weight); + if (denom == INFINITY) { + return EDOM; + } + + return 0; +} + +int td_centroid_count(td_histogram_t *h) { return next_node(h); } + +void td_reset(td_histogram_t *h) { + if (!h) { + return; + } + h->min = __DBL_MAX__; + h->max = -h->min; + h->merged_nodes = 0; + h->merged_weight = 0; + h->unmerged_nodes = 0; + h->unmerged_weight = 0; + h->total_compressions = 0; +} + +int td_init(double compression, td_histogram_t **result) { + + const size_t capacity = cap_from_compression(compression); + if (capacity < 1) { + return 1; + } + td_histogram_t *histogram; + histogram = (td_histogram_t *)__td_malloc(sizeof(td_histogram_t)); + if (!histogram) { + return 1; + } + histogram->cap = capacity; + histogram->compression = (double)compression; + td_reset(histogram); + histogram->nodes_mean = (double *)__td_calloc(capacity, sizeof(double)); + if (!histogram->nodes_mean) { + td_free(histogram); + return 1; + } + histogram->nodes_weight = (long long *)__td_calloc(capacity, sizeof(long long)); + if (!histogram->nodes_weight) { + td_free(histogram); + return 1; + } + *result = histogram; + + return 0; +} + +td_histogram_t *td_new(double compression) { + td_histogram_t *mdigest = NULL; + td_init(compression, &mdigest); + return mdigest; +} + +void td_free(td_histogram_t *histogram) { + if (histogram->nodes_mean) { + __td_free((void *)(histogram->nodes_mean)); + } + if (histogram->nodes_weight) { + __td_free((void *)(histogram->nodes_weight)); + } + __td_free((void *)(histogram)); +} + +int td_merge(td_histogram_t *into, td_histogram_t *from) { + if (td_compress(into) != 0) + return EDOM; + if (td_compress(from) != 0) + return EDOM; + const int pos = from->merged_nodes + from->unmerged_nodes; + for (int i = 0; i < pos; i++) { + const double mean = from->nodes_mean[i]; + const long long weight = from->nodes_weight[i]; + if (td_add(into, mean, weight) != 0) { + return EDOM; + } + } + return 0; +} + +long long td_size(td_histogram_t *h) { return h->merged_weight + h->unmerged_weight; } + +double td_cdf(td_histogram_t *h, double val) { + td_compress(h); + // no data to examine + if (h->merged_nodes == 0) { + return NAN; + } + // bellow lower bound + if (val < h->min) { + return 0; + } + // above upper bound + if (val > h->max) { + return 1; + } + if (h->merged_nodes == 1) { + // exactly one centroid, should have max==min + const double width = h->max - h->min; + if (val - h->min <= width) { + // min and max are too close together to do any viable interpolation + return 0.5; + } else { + // interpolate if somehow we have weight > 0 and max != min + return (val - h->min) / width; + } + } + const int n = h->merged_nodes; + // check for the left tail + const double left_centroid_mean = h->nodes_mean[0]; + const double left_centroid_weight = (double)h->nodes_weight[0]; + const double merged_weight_d = (double)h->merged_weight; + if (val < left_centroid_mean) { + // note that this is different than h->nodes_mean[0] > min + // ... this guarantees we divide by non-zero number and interpolation works + const double width = left_centroid_mean - h->min; + if (width > 0) { + // must be a sample exactly at min + if (val == h->min) { + return 0.5 / merged_weight_d; + } else { + return (1 + (val - h->min) / width * (left_centroid_weight / 2 - 1)) / + merged_weight_d; + } + } else { + // this should be redundant with the check val < h->min + return 0; + } + } + // and the right tail + const double right_centroid_mean = h->nodes_mean[n - 1]; + const double right_centroid_weight = (double)h->nodes_weight[n - 1]; + if (val > right_centroid_mean) { + const double width = h->max - right_centroid_mean; + if (width > 0) { + if (val == h->max) { + return 1 - 0.5 / merged_weight_d; + } else { + // there has to be a single sample exactly at max + const double dq = (1 + (h->max - val) / width * (right_centroid_weight / 2 - 1)) / + merged_weight_d; + return 1 - dq; + } + } else { + return 1; + } + } + // we know that there are at least two centroids and mean[0] < x < mean[n-1] + // that means that there are either one or more consecutive centroids all at exactly x + // or there are consecutive centroids, c0 < x < c1 + double weightSoFar = 0; + for (int it = 0; it < n - 1; it++) { + // weightSoFar does not include weight[it] yet + if (h->nodes_mean[it] == val) { + // we have one or more centroids == x, treat them as one + // dw will accumulate the weight of all of the centroids at x + double dw = 0; + while (it < n && h->nodes_mean[it] == val) { + dw += (double)h->nodes_weight[it]; + it++; + } + return (weightSoFar + dw / 2) / (double)h->merged_weight; + } else if (h->nodes_mean[it] <= val && val < h->nodes_mean[it + 1]) { + const double node_weight = (double)h->nodes_weight[it]; + const double node_weight_next = (double)h->nodes_weight[it + 1]; + const double node_mean = h->nodes_mean[it]; + const double node_mean_next = h->nodes_mean[it + 1]; + // landed between centroids ... check for floating point madness + if (node_mean_next - node_mean > 0) { + // note how we handle singleton centroids here + // the point is that for singleton centroids, we know that their entire + // weight is exactly at the centroid and thus shouldn't be involved in + // interpolation + double leftExcludedW = 0; + double rightExcludedW = 0; + if (node_weight == 1) { + if (node_weight_next == 1) { + // two singletons means no interpolation + // left singleton is in, right is out + return (weightSoFar + 1) / merged_weight_d; + } else { + leftExcludedW = 0.5; + } + } else if (node_weight_next == 1) { + rightExcludedW = 0.5; + } + double dw = (node_weight + node_weight_next) / 2; + + // adjust endpoints for any singleton + double dwNoSingleton = dw - leftExcludedW - rightExcludedW; + + double base = weightSoFar + node_weight / 2 + leftExcludedW; + return (base + dwNoSingleton * (val - node_mean) / (node_mean_next - node_mean)) / + merged_weight_d; + } else { + // this is simply caution against floating point madness + // it is conceivable that the centroids will be different + // but too near to allow safe interpolation + double dw = (node_weight + node_weight_next) / 2; + return (weightSoFar + dw) / merged_weight_d; + } + } else { + weightSoFar += (double)h->nodes_weight[it]; + } + } + return 1 - 0.5 / merged_weight_d; +} + +static double td_internal_iterate_centroids_to_index(const td_histogram_t *h, const double index, + const double left_centroid_weight, + const int total_centroids, double *weightSoFar, + int *node_pos) { + if (left_centroid_weight > 1 && index < left_centroid_weight / 2) { + // there is a single sample at min so we interpolate with less weight + return h->min + (index - 1) / (left_centroid_weight / 2 - 1) * (h->nodes_mean[0] - h->min); + } + + // usually the last centroid will have unit weight so this test will make it moot + if (index > h->merged_weight - 1) { + return h->max; + } + + // if the right-most centroid has more than one sample, we still know + // that one sample occurred at max so we can do some interpolation + const double right_centroid_weight = (double)h->nodes_weight[total_centroids - 1]; + const double right_centroid_mean = h->nodes_mean[total_centroids - 1]; + if (right_centroid_weight > 1 && + (double)h->merged_weight - index <= right_centroid_weight / 2) { + return h->max - ((double)h->merged_weight - index - 1) / (right_centroid_weight / 2 - 1) * + (h->max - right_centroid_mean); + } + + for (; *node_pos < total_centroids - 1; (*node_pos)++) { + const int i = *node_pos; + const double node_weight = (double)h->nodes_weight[i]; + const double node_weight_next = (double)h->nodes_weight[i + 1]; + const double node_mean = h->nodes_mean[i]; + const double node_mean_next = h->nodes_mean[i + 1]; + const double dw = (node_weight + node_weight_next) / 2; + if (*weightSoFar + dw > index) { + // centroids i and i+1 bracket our current point + // check for unit weight + double leftUnit = 0; + if (node_weight == 1) { + if (index - *weightSoFar < 0.5) { + // within the singleton's sphere + return node_mean; + } else { + leftUnit = 0.5; + } + } + double rightUnit = 0; + if (node_weight_next == 1) { + if (*weightSoFar + dw - index <= 0.5) { + // no interpolation needed near singleton + return node_mean_next; + } + rightUnit = 0.5; + } + const double z1 = index - *weightSoFar - leftUnit; + const double z2 = *weightSoFar + dw - index - rightUnit; + return weighted_average(node_mean, z2, node_mean_next, z1); + } + *weightSoFar += dw; + } + + // weightSoFar = totalWeight - weight[total_centroids-1]/2 (very nearly) + // so we interpolate out to max value ever seen + const double z1 = index - h->merged_weight - right_centroid_weight / 2.0; + const double z2 = right_centroid_weight / 2 - z1; + return weighted_average(right_centroid_mean, z1, h->max, z2); +} + +double td_quantile(td_histogram_t *h, double q) { + td_compress(h); + // q should be in [0,1] + if (q < 0.0 || q > 1.0 || h->merged_nodes == 0) { + return NAN; + } + // with one data point, all quantiles lead to Rome + if (h->merged_nodes == 1) { + return h->nodes_mean[0]; + } + + // if values were stored in a sorted array, index would be the offset we are interested in + const double index = q * (double)h->merged_weight; + + // beyond the boundaries, we return min or max + // usually, the first centroid will have unit weight so this will make it moot + if (index < 1) { + return h->min; + } + + // we know that there are at least two centroids now + const int n = h->merged_nodes; + + // if the left centroid has more than one sample, we still know + // that one sample occurred at min so we can do some interpolation + const double left_centroid_weight = (double)h->nodes_weight[0]; + + // in between extremes we interpolate between centroids + double weightSoFar = left_centroid_weight / 2; + int i = 0; + return td_internal_iterate_centroids_to_index(h, index, left_centroid_weight, n, &weightSoFar, + &i); +} + +int td_quantiles(td_histogram_t *h, const double *quantiles, double *values, size_t length) { + td_compress(h); + + if (NULL == quantiles || NULL == values) { + return EINVAL; + } + + const int n = h->merged_nodes; + if (n == 0) { + for (size_t i = 0; i < length; i++) { + values[i] = NAN; + } + return 0; + } + if (n == 1) { + for (size_t i = 0; i < length; i++) { + const double requested_quantile = quantiles[i]; + + // q should be in [0,1] + if (requested_quantile < 0.0 || requested_quantile > 1.0) { + values[i] = NAN; + } else { + // with one data point, all quantiles lead to Rome + values[i] = h->nodes_mean[0]; + } + } + return 0; + } + + // we know that there are at least two centroids now + // if the left centroid has more than one sample, we still know + // that one sample occurred at min so we can do some interpolation + const double left_centroid_weight = (double)h->nodes_weight[0]; + + // in between extremes we interpolate between centroids + double weightSoFar = left_centroid_weight / 2; + int node_pos = 0; + + // to avoid allocations we use the values array for intermediate computation + // i.e. to store the expected cumulative count at each percentile + for (size_t qpos = 0; qpos < length; qpos++) { + const double index = quantiles[qpos] * (double)h->merged_weight; + values[qpos] = td_internal_iterate_centroids_to_index(h, index, left_centroid_weight, n, + &weightSoFar, &node_pos); + } + return 0; +} + +static double td_internal_trimmed_mean(const td_histogram_t *h, const double leftmost_weight, + const double rightmost_weight) { + double count_done = 0; + double trimmed_sum = 0; + double trimmed_count = 0; + for (int i = 0; i < h->merged_nodes; i++) { + + const double n_weight = (double)h->nodes_weight[i]; + // Assume the whole centroid falls into the range + double count_add = n_weight; + + // If we haven't reached the low threshold yet, skip appropriate part of the centroid. + count_add -= __td_min(__td_max(0, leftmost_weight - count_done), count_add); + + // If we have reached the upper threshold, ignore the overflowing part of the centroid. + + count_add = __td_min(__td_max(0, rightmost_weight - count_done), count_add); + + // consider the whole centroid processed + count_done += n_weight; + + // increment the sum / count + trimmed_sum += h->nodes_mean[i] * count_add; + trimmed_count += count_add; + + // break once we cross the high threshold + if (count_done >= rightmost_weight) + break; + } + + return trimmed_sum / trimmed_count; +} + +double td_trimmed_mean_symmetric(td_histogram_t *h, double proportion_to_cut) { + td_compress(h); + // proportion_to_cut should be in [0,1] + if (h->merged_nodes == 0 || proportion_to_cut < 0.0 || proportion_to_cut > 1.0) { + return NAN; + } + // with one data point, all values lead to Rome + if (h->merged_nodes == 1) { + return h->nodes_mean[0]; + } + + /* translate the percentiles to counts */ + const double leftmost_weight = floor((double)h->merged_weight * proportion_to_cut); + const double rightmost_weight = ceil((double)h->merged_weight * (1.0 - proportion_to_cut)); + + return td_internal_trimmed_mean(h, leftmost_weight, rightmost_weight); +} + +double td_trimmed_mean(td_histogram_t *h, double leftmost_cut, double rightmost_cut) { + td_compress(h); + // leftmost_cut and rightmost_cut should be in [0,1] + if (h->merged_nodes == 0 || leftmost_cut < 0.0 || leftmost_cut > 1.0 || rightmost_cut < 0.0 || + rightmost_cut > 1.0) { + return NAN; + } + // with one data point, all values lead to Rome + if (h->merged_nodes == 1) { + return h->nodes_mean[0]; + } + + /* translate the percentiles to counts */ + const double leftmost_weight = floor((double)h->merged_weight * leftmost_cut); + const double rightmost_weight = ceil((double)h->merged_weight * rightmost_cut); + + return td_internal_trimmed_mean(h, leftmost_weight, rightmost_weight); +} + +int td_add(td_histogram_t *h, double mean, long long weight) { + if (should_td_compress(h)) { + const int overflow_res = td_compress(h); + if (overflow_res != 0) + return overflow_res; + } + const int pos = next_node(h); + if (pos >= h->cap) + return EDOM; + if (_tdigest_long_long_add_safe(h->unmerged_weight, weight) == false) + return EDOM; + const long long new_unmerged_weight = h->unmerged_weight + weight; + if (_tdigest_long_long_add_safe(new_unmerged_weight, h->merged_weight) == false) + return EDOM; + const long long new_total_weight = new_unmerged_weight + h->merged_weight; + // double-precision overflow detected + const int overflow_res = + _check_td_overflow((double)new_unmerged_weight, (double)new_total_weight); + if (overflow_res != 0) + return overflow_res; + + if (mean < h->min) { + h->min = mean; + } + if (mean > h->max) { + h->max = mean; + } + h->nodes_mean[pos] = mean; + h->nodes_weight[pos] = weight; + h->unmerged_nodes++; + h->unmerged_weight = new_unmerged_weight; + return 0; +} + +int td_compress(td_histogram_t *h) { + if (h->unmerged_nodes == 0) { + return 0; + } + int N = h->merged_nodes + h->unmerged_nodes; + td_qsort(h->nodes_mean, h->nodes_weight, 0, N - 1); + const double total_weight = (double)h->merged_weight + (double)h->unmerged_weight; + // double-precision overflow detected + const int overflow_res = _check_td_overflow((double)h->unmerged_weight, (double)total_weight); + if (overflow_res != 0) + return overflow_res; + if (total_weight <= 1) + return 0; + const double denom = 2 * MM_PI * total_weight * log(total_weight); + if (_check_overflow(denom) != 0) + return EDOM; + + // Compute the normalizer given compression and number of points. + const double normalizer = h->compression / denom; + if (_check_overflow(normalizer) != 0) + return EDOM; + int cur = 0; + double weight_so_far = 0; + + for (int i = 1; i < N; i++) { + const double proposed_weight = (double)h->nodes_weight[cur] + (double)h->nodes_weight[i]; + const double z = proposed_weight * normalizer; + // quantile up to cur + const double q0 = weight_so_far / total_weight; + // quantile up to cur + i + const double q2 = (weight_so_far + proposed_weight) / total_weight; + // Convert a quantile to the k-scale + const bool should_add = (z <= (q0 * (1 - q0))) && (z <= (q2 * (1 - q2))); + // next point will fit + // so merge into existing centroid + if (should_add) { + h->nodes_weight[cur] += h->nodes_weight[i]; + const double delta = h->nodes_mean[i] - h->nodes_mean[cur]; + const double weighted_delta = (delta * h->nodes_weight[i]) / h->nodes_weight[cur]; + h->nodes_mean[cur] += weighted_delta; + } else { + weight_so_far += h->nodes_weight[cur]; + cur++; + h->nodes_weight[cur] = h->nodes_weight[i]; + h->nodes_mean[cur] = h->nodes_mean[i]; + } + if (cur != i) { + h->nodes_weight[i] = 0; + h->nodes_mean[i] = 0.0; + } + } + h->merged_nodes = cur + 1; + h->merged_weight = total_weight; + h->unmerged_nodes = 0; + h->unmerged_weight = 0; + h->total_compressions++; + return 0; +} + +double td_min(td_histogram_t *h) { return h->min; } + +double td_max(td_histogram_t *h) { return h->max; } + +int td_compression(td_histogram_t *h) { return h->compression; } + +const long long *td_centroids_weight(td_histogram_t *h) { return h->nodes_weight; } + +const double *td_centroids_mean(td_histogram_t *h) { return h->nodes_mean; } + +long long td_centroids_weight_at(td_histogram_t *h, int pos) { return h->nodes_weight[pos]; } + +double td_centroids_mean_at(td_histogram_t *h, int pos) { + if (pos < 0 || pos > h->merged_nodes) { + return NAN; + } + return h->nodes_mean[pos]; +} diff --git a/src/redis/tdigest.h b/src/redis/tdigest.h new file mode 100644 index 000000000000..c07436c54064 --- /dev/null +++ b/src/redis/tdigest.h @@ -0,0 +1,258 @@ +#pragma once +#include + +/** + * Adaptive histogram based on something like streaming k-means crossed with Q-digest. + * The implementation is a direct descendent of MergingDigest + * https://github.com/tdunning/t-digest/ + * + * Copyright (c) 2021 Redis, All rights reserved. + * Copyright (c) 2018 Andrew Werner, All rights reserved. + * + * The special characteristics of this algorithm are: + * + * - smaller summaries than Q-digest + * + * - provides part per million accuracy for extreme quantiles and typically <1000 ppm accuracy + * for middle quantiles + * + * - fast + * + * - simple + * + * - easy to adapt for use with map-reduce + */ + +#define MM_PI 3.14159265358979323846 + +struct td_histogram { + // compression is a setting used to configure the size of centroids when merged. + double compression; + + double min; + double max; + + // cap is the total size of nodes + int cap; + // merged_nodes is the number of merged nodes at the front of nodes. + int merged_nodes; + // unmerged_nodes is the number of buffered nodes. + int unmerged_nodes; + + // we run the merge in reverse every other merge to avoid left-to-right bias in merging + long long total_compressions; + + long long merged_weight; + long long unmerged_weight; + + double *nodes_mean; + long long *nodes_weight; +}; + +typedef struct td_histogram td_histogram_t; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Allocate the memory, initialise the t-digest, and return the histogram as output parameter. + * @param compression The compression parameter. + * 100 is a common value for normal uses. + * 1000 is extremely large. + * The number of centroids retained will be a smallish (usually less than 10) multiple of this + * number. + * @return the histogram on success, NULL if allocation failed. + */ +td_histogram_t *td_new(double compression); + +/** + * Allocate the memory and initialise the t-digest. + * + * @param compression The compression parameter. + * 100 is a common value for normal uses. + * 1000 is extremely large. + * The number of centroids retained will be a smallish (usually less than 10) multiple of this + * number. + * @param result Output parameter to capture allocated histogram. + * @return 0 on success, 1 if allocation failed. + */ +int td_init(double compression, td_histogram_t **result); + +/** + * Frees the memory associated with the t-digest. + * + * @param h The histogram you want to free. + */ +void td_free(td_histogram_t *h); + +/** + * Reset a histogram to zero - empty out a histogram and re-initialise it + * + * If you want to re-use an existing histogram, but reset everything back to zero, this + * is the routine to use. + * + * @param h The histogram you want to reset to empty. + * + */ +void td_reset(td_histogram_t *h); + +/** + * Adds a sample to a histogram. + * + * @param val The value to add. + * @param weight The weight of this point. + * @return 0 on success, EDOM if overflow was detected as a consequence of adding the provided + * weight. + * + */ +int td_add(td_histogram_t *h, double val, long long weight); + +/** + * Re-examines a t-digest to determine whether some centroids are redundant. If your data are + * perversely ordered, this may be a good idea. Even if not, this may save 20% or so in space. + * + * The cost is roughly the same as adding as many data points as there are centroids. This + * is typically < 10 * compression, but could be as high as 100 * compression. + * This is a destructive operation that is not thread-safe. + * + * @param h The histogram you want to compress. + * @return 0 on success, EDOM if overflow was detected as a consequence of adding the provided + * weight. If overflow is detected the histogram is not changed. + * + */ +int td_compress(td_histogram_t *h); + +/** + * Merges all of the values from 'from' to 'this' histogram. + * + * @param h "This" pointer + * @param from Histogram to copy values from. + * * @return 0 on success, EDOM if overflow was detected as a consequence of merging the the + * provided histogram. If overflow is detected the original histogram is not detected. + */ +int td_merge(td_histogram_t *h, td_histogram_t *from); + +/** + * Returns the fraction of all points added which are ≤ x. + * + * @param x The cutoff for the cdf. + * @return The fraction of all data which is less or equal to x. + */ +double td_cdf(td_histogram_t *h, double x); + +/** + * Returns an estimate of the cutoff such that a specified fraction of the data + * added to this TDigest would be less than or equal to the cutoff. + * + * @param q The desired fraction + * @return The value x such that cdf(x) == q; + */ +double td_quantile(td_histogram_t *h, double q); + +/** + * Returns an estimate of the cutoff such that a specified fraction of the data + * added to this TDigest would be less than or equal to the cutoffs. + * + * @param quantiles The ordered percentiles array to get the values for. + * @param values Destination array containing the values at the given quantiles. + * The values array should be allocated by the caller. + * @return 0 on success, ENOMEM if the provided destination array is null. + */ +int td_quantiles(td_histogram_t *h, const double *quantiles, double *values, size_t length); + +/** + * Returns the trimmed mean ignoring values outside given cutoff upper and lower limits. + * + * @param leftmost_cut Fraction to cut off of the left tail of the distribution. + * @param rightmost_cut Fraction to cut off of the right tail of the distribution. + * @return The trimmed mean ignoring values outside given cutoff upper and lower limits; + */ +double td_trimmed_mean(td_histogram_t *h, double leftmost_cut, double rightmost_cut); + +/** + * Returns the trimmed mean ignoring values outside given a symmetric cutoff limits. + * + * @param proportion_to_cut Fraction to cut off of the left and right tails of the distribution. + * @return The trimmed mean ignoring values outside given cutoff upper and lower limits; + */ +double td_trimmed_mean_symmetric(td_histogram_t *h, double proportion_to_cut); + +/** + * Returns the current compression factor. + * + * @return The compression factor originally used to set up the TDigest. + */ +int td_compression(td_histogram_t *h); + +/** + * Returns the number of points that have been added to this TDigest. + * + * @return The sum of the weights on all centroids. + */ +long long td_size(td_histogram_t *h); + +/** + * Returns the number of centroids being used by this TDigest. + * + * @return The number of centroids being used. + */ +int td_centroid_count(td_histogram_t *h); + +/** + * Get minimum value from the histogram. Will return __DBL_MAX__ if the histogram + * is empty. + * + * @param h "This" pointer + */ +double td_min(td_histogram_t *h); + +/** + * Get maximum value from the histogram. Will return - __DBL_MAX__ if the histogram + * is empty. + * + * @param h "This" pointer + */ +double td_max(td_histogram_t *h); + +/** + * Get the full centroids weight array for 'this' histogram. + * + * @param h "This" pointer + * + * @return The full centroids weight array. + */ +const long long *td_centroids_weight(td_histogram_t *h); + +/** + * Get the full centroids mean array for 'this' histogram. + * + * @param h "This" pointer + * + * @return The full centroids mean array. + */ +const double *td_centroids_mean(td_histogram_t *h); + +/** + * Get the centroid weight for 'this' histogram and 'pos'. + * + * @param h "This" pointer + * @param pos centroid position. + * + * @return The centroid weight. + */ +long long td_centroids_weight_at(td_histogram_t *h, int pos); + +/** + * Get the centroid mean for 'this' histogram and 'pos'. + * + * @param h "This" pointer + * @param pos centroid position. + * + * @return The centroid mean. + */ +double td_centroids_mean_at(td_histogram_t *h, int pos); + +#ifdef __cplusplus +} +#endif diff --git a/src/redis/zmalloc.c b/src/redis/zmalloc.c index 4ed616916294..fb02341c1a53 100644 --- a/src/redis/zmalloc.c +++ b/src/redis/zmalloc.c @@ -197,10 +197,11 @@ void *ztrycalloc_usable(size_t size, size_t *usable) { } /* Allocate memory and zero it or panic */ -void *zcalloc(size_t size) { - void *ptr = ztrycalloc_usable(size, NULL); +void *zcalloc(size_t num, size_t size) { + size_t bytes = num * size; + void *ptr = ztrycalloc_usable(bytes, NULL); - if (!ptr) zmalloc_oom_handler(size); + if (!ptr) zmalloc_oom_handler(bytes); return ptr; } diff --git a/src/redis/zmalloc.h b/src/redis/zmalloc.h index 91012b79f703..c74bcb906d61 100644 --- a/src/redis/zmalloc.h +++ b/src/redis/zmalloc.h @@ -32,6 +32,8 @@ #define __ZMALLOC_H #include +#include +#include /* Double expansion needed for stringification of macro values. */ #define __xstr(s) __zm_str(s) @@ -88,7 +90,7 @@ #endif void *zmalloc(size_t size); -void *zcalloc(size_t size); +void *zcalloc(size_t num, size_t size); void *zrealloc(void *ptr, size_t size); void *ztrymalloc(size_t size); void *ztrycalloc(size_t size); diff --git a/src/redis/zmalloc_mi.c b/src/redis/zmalloc_mi.c index 82915a9444b0..4b60a193ffe6 100644 --- a/src/redis/zmalloc_mi.c +++ b/src/redis/zmalloc_mi.c @@ -28,6 +28,20 @@ void* zmalloc(size_t size) { return res; } +void *zcalloc(size_t num, size_t size) { + assert(zmalloc_heap); + void* res = mi_heap_calloc(zmalloc_heap, num, size); + size_t usable = mi_usable_size(res); + + // assertion does not hold. Basically mi_good_size is not a good function for + // doing accounting. + // assert(usable == mi_good_size(size)); + zmalloc_used_memory_tl += usable; + + return res; + +} + void* ztrymalloc_usable(size_t size, size_t* usable) { return zmalloc_usable(size, usable); } @@ -50,16 +64,6 @@ void* zrealloc(void* ptr, size_t size) { return zrealloc_usable(ptr, size, &usable); } -void* zcalloc(size_t size) { - // mi_good_size(size) is not working. try for example, size=690557. - - void* res = mi_heap_calloc(zmalloc_heap, 1, size); - size_t usable = mi_usable_size(res); - zmalloc_used_memory_tl += usable; - - return res; -} - void* zmalloc_usable(size_t size, size_t* usable) { assert(zmalloc_heap); void* res = mi_heap_malloc(zmalloc_heap, size); diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index 27236153d44c..a97ae46108a5 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -62,7 +62,7 @@ add_library(dragonfly_lib bloom_family.cc cluster/cluster_config.cc cluster/cluster_family.cc cluster/incoming_slot_migration.cc cluster/outgoing_slot_migration.cc cluster/cluster_defs.cc cluster/cluster_utility.cc acl/user.cc acl/user_registry.cc acl/acl_family.cc - acl/validator.cc) + acl/validator.cc tdigest_family.cc topk_family.cc prob/cuckoo_filter_family.cc) if (DF_ENABLE_MEMORY_TRACKING) target_compile_definitions(dragonfly_lib PRIVATE DFLY_ENABLE_MEMORY_TRACKING) @@ -127,6 +127,9 @@ cxx_test(cluster/cluster_family_test dfly_test_lib LABELS DFLY) cxx_test(acl/acl_family_test dfly_test_lib LABELS DFLY) cxx_test(engine_shard_set_test dfly_test_lib LABELS DFLY) cxx_test(search/search_family_test dfly_test_lib LABELS DFLY) +cxx_test(tdigest_family_test dfly_test_lib LABELS DFLY) +cxx_test(topk_family_test dfly_test_lib LABELS DFLY) +cxx_test(prob/cuckoo_filter_family_test dfly_test_lib LABELS DFLY) if (WITH_ASAN OR WITH_USAN) target_compile_definitions(stream_family_test PRIVATE SANITIZERS) target_compile_definitions(multi_test PRIVATE SANITIZERS) @@ -145,4 +148,4 @@ add_dependencies(check_dfly dragonfly_test json_family_test list_family_test redis_parser_test stream_family_test string_family_test bitops_family_test set_family_test zset_family_test geo_family_test hll_family_test cluster_config_test cluster_family_test acl_family_test - json_family_memory_test) + json_family_memory_test tdigest_family_test topk_family_test cuckoo_filter_family_test) diff --git a/src/server/acl/acl_commands_def.h b/src/server/acl/acl_commands_def.h index 99b8257c7591..b806cde1de4b 100644 --- a/src/server/acl/acl_commands_def.h +++ b/src/server/acl/acl_commands_def.h @@ -41,6 +41,9 @@ enum AclCat { SCRIPTING = 1ULL << 20, // Extensions + CUCKOO_FILTER = 1ULL << 25, + TOPK = 1ULL << 27, + TDIGEST = 1ULL << 27, BLOOM = 1ULL << 28, FT_SEARCH = 1ULL << 29, THROTTLE = 1ULL << 30, diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 7583a0ff7dbb..bfd811d006ce 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -51,12 +51,15 @@ extern "C" { #include "server/list_family.h" #include "server/multi_command_squasher.h" #include "server/namespaces.h" +#include "server/prob/cuckoo_filter_family.h" #include "server/script_mgr.h" #include "server/search/search_family.h" #include "server/server_state.h" #include "server/set_family.h" #include "server/stream_family.h" #include "server/string_family.h" +#include "server/tdigest_family.h" +#include "server/topk_family.h" #include "server/transaction.h" #include "server/version.h" #include "server/zset_family.h" @@ -2719,6 +2722,9 @@ void Service::RegisterCommands() { BloomFamily::Register(®istry_); server_family_.Register(®istry_); cluster_family_.Register(®istry_); + TDigestFamily::Register(®istry_); + TopKeysFamily::Register(®istry_); + CuckooFilterFamily::Register(®istry_); // AclFamily should always be registered last // If we add a new familly, register that first above and *not* below diff --git a/src/server/prob/cuckoo_filter_family.cc b/src/server/prob/cuckoo_filter_family.cc new file mode 100644 index 000000000000..a7d3b236f484 --- /dev/null +++ b/src/server/prob/cuckoo_filter_family.cc @@ -0,0 +1,352 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/prob/cuckoo_filter_family.h" + +#include "absl/functional/function_ref.h" +#include "core/prob/cuckoo_filter.h" +#include "facade/cmd_arg_parser.h" +#include "facade/op_status.h" +#include "server/acl/acl_commands_def.h" +#include "server/command_registry.h" +#include "server/db_slice.h" +#include "server/error.h" +#include "server/transaction.h" + +namespace dfly { +using namespace facade; +using CI = CommandId; + +namespace { + +OpStatus InitCuckooFilter(CompactObj* obj, const prob::CuckooReserveParams& params) { + auto cuckoo_filter = prob::CuckooFilter::Init(params, obj->memory_resource()); + if (!cuckoo_filter) { + return OpStatus::OUT_OF_MEMORY; + } + + obj->SetCuckooFilter(std::move(cuckoo_filter).value()); + return OpStatus::OK; +} + +OpResult OpReserve(const OpArgs& op_args, std::string_view key, + const prob::CuckooReserveParams& params) { + auto res_it = op_args.GetDbSlice().AddOrFind(op_args.db_cntx, key); + RETURN_ON_BAD_STATUS(res_it); + + if (!res_it->is_new) { + return OpStatus::KEY_EXISTS; + } + + auto status = InitCuckooFilter(&res_it->it->second, params); + if (status != OpStatus::OK) { + return status; + } + return true; +} + +OpResult OpAddCommon(const OpArgs& op_args, std::string_view key, std::string_view item, + bool is_nx) { + auto res_it = op_args.GetDbSlice().AddOrFind(op_args.db_cntx, key); + RETURN_ON_BAD_STATUS(res_it); + + auto& obj = res_it->it->second; + if (res_it->is_new) { + auto status = InitCuckooFilter(&obj, {}); + if (status != OpStatus::OK) { + return status; + } + } + + auto* filter = obj.GetCuckooFilter(); + const uint64_t hash = prob::CuckooFilter::GetHash(item); + if (is_nx && filter->Exists(hash)) { + // TODO: improve this to not to call Exists and Insert + // We should add InsertUnique method to the CuckooFilter + return false; + } + + return filter->Insert(hash); +} + +OpResult> OpInsertCommon(const OpArgs& op_args, std::string_view key, bool is_nx, + const prob::CuckooReserveParams& params, + bool create_cuckoo_if_not_exists, + absl::Span items) { + OpResult res_it; + if (create_cuckoo_if_not_exists) { + res_it = op_args.GetDbSlice().AddOrFind(op_args.db_cntx, key); + } else { + res_it = op_args.GetDbSlice().FindMutable(op_args.db_cntx, key, OBJ_CUCKOO_FILTER); + } + + RETURN_ON_BAD_STATUS(res_it); + + auto& obj = res_it->it->second; + if (res_it->is_new) { + DCHECK(create_cuckoo_if_not_exists); + auto status = InitCuckooFilter(&obj, params); + if (status != OpStatus::OK) { + return status; + } + } + + auto* filter = obj.GetCuckooFilter(); + + auto insert = [&](std::string_view item) { + const auto hash = prob::CuckooFilter::GetHash(item); + return filter->Insert(hash); + }; + + auto insert_with_exists = [&](std::string_view item) { + const auto hash = prob::CuckooFilter::GetHash(item); + if (filter->Exists(hash)) { + return false; + } + return filter->Insert(hash); + }; + + auto cb = + is_nx ? static_cast>(insert_with_exists) : insert; + + std::vector result(items.size()); + for (size_t i = 0; i < items.size(); i++) { + result[i] = cb(items[i]); + } + + return result; +} + +OpResult> OpExistsCommon(const OpArgs& op_args, std::string_view key, + absl::Span items) { + auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_CUCKOO_FILTER); + RETURN_ON_BAD_STATUS(res_it); + + const auto* filter = res_it->GetInnerIt()->second.GetCuckooFilter(); + std::vector results(items.size()); + for (size_t i = 0; i < items.size(); i++) { + results[i] = filter->Exists(items[i]); + } + + return results; +} + +OpResult OpDel(const OpArgs& op_args, std::string_view key, std::string_view item) { + auto res_it = op_args.GetDbSlice().FindMutable(op_args.db_cntx, key, OBJ_CUCKOO_FILTER); + RETURN_ON_BAD_STATUS(res_it); + + auto* filter = res_it->it->second.GetCuckooFilter(); + return filter->Delete(item); +} + +OpResult OpCount(const OpArgs& op_args, std::string_view key, std::string_view item) { + auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_CUCKOO_FILTER); + RETURN_ON_BAD_STATUS(res_it); + + const auto* filter = res_it->GetInnerIt()->second.GetCuckooFilter(); + return filter->Count(item); +} + +void AddImplCommon(CmdArgList args, const CommandContext& cmd_cntx, bool is_nx) { + CmdArgParser parser{args}; + std::string_view key = parser.Next(); + std::string_view item = parser.Next(); + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpAddCommon(t->GetOpArgs(shard), key, item, is_nx); + }; + + auto result = cmd_cntx.tx->ScheduleSingleHopT(std::move(cb)); + auto* rb = static_cast(cmd_cntx.rb); + if (result) { + rb->SendLong(result.value() ? 1 : 0); + } else { + rb->SendError(result.status()); + } +} + +void InsertImplCommon(CmdArgList args, const CommandContext& cmd_cntx, bool is_nx) { + CmdArgParser parser{args}; + std::string_view key = parser.Next(); + + /* TODO: improve agruments parsing. + If CAPACITY and NOCREATE are specified at the same time -> error. */ + prob::CuckooReserveParams params; + bool create_cuckoo_if_not_exists = true; + while (parser.HasNext()) { + if (parser.Check("CAPACITY")) { + params.capacity = parser.Next(); + } else if (parser.Check("NOCREATE")) { + create_cuckoo_if_not_exists = false; + } else if (parser.Check("ITEMS")) { + break; + } + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpInsertCommon(t->GetOpArgs(shard), key, is_nx, params, create_cuckoo_if_not_exists, + parser.Tail()); + }; + + auto result = cmd_cntx.tx->ScheduleSingleHopT(std::move(cb)); + auto* rb = static_cast(cmd_cntx.rb); + if (result) { + const auto& result_array = result.value(); + rb->StartArray(result_array.size()); + for (const auto& was_added : result_array) { + rb->SendLong(was_added ? 1 : 0); + } + } else { + rb->SendError(result.status()); + } +} + +void ExistsImplCommon(CmdArgList args, const CommandContext& cmd_cntx, bool is_multi) { + CmdArgParser parser{args}; + std::string_view key = parser.Next(); + + std::vector items; + if (is_multi) { + while (parser.HasNext()) { + items.push_back(parser.Next()); + } + } else { + items.push_back(parser.Next()); + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpExistsCommon(t->GetOpArgs(shard), key, items); + }; + + auto result = cmd_cntx.tx->ScheduleSingleHopT(std::move(cb)); + auto* rb = static_cast(cmd_cntx.rb); + if (result) { + if (is_multi) { + rb->StartArray(result->size()); + for (auto res : *result) { + rb->SendLong(res); + } + } else { + DCHECK(result->size() == 1); + rb->SendLong((*result)[0]); + } + } else { + rb->SendError(result.status()); + } +} + +} // anonymous namespace + +void CuckooFilterFamily::Reserve(CmdArgList args, const CommandContext& cmd_cntx) { + CmdArgParser parser{args}; + std::string_view key = parser.Next(); + + prob::CuckooReserveParams params{.capacity = parser.Next()}; + while (parser.HasNext()) { + if (parser.Check("BUCKETSIZE")) { + params.bucket_size = parser.Next(); + } else if (parser.Check("MAXITERATIONS")) { + params.max_iterations = parser.Next(); + } else if (parser.Check("EXPANSION")) { + params.expansion = parser.Next(); + } else { + break; + } + } + + if (!parser.Finalize()) { + return cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpReserve(t->GetOpArgs(shard), key, params); + }; + + auto result = cmd_cntx.tx->ScheduleSingleHopT(std::move(cb)); + auto* rb = static_cast(cmd_cntx.rb); + if (result) { + rb->SendOk(); + } else { + rb->SendError(result.status()); + } +} + +void CuckooFilterFamily::Add(CmdArgList args, const CommandContext& cmd_cntx) { + return AddImplCommon(args, cmd_cntx, false); +} + +void CuckooFilterFamily::AddNx(CmdArgList args, const CommandContext& cmd_cntx) { + return AddImplCommon(args, cmd_cntx, true); +} + +void CuckooFilterFamily::Insert(CmdArgList args, const CommandContext& cmd_cntx) { + return InsertImplCommon(args, cmd_cntx, false); +} + +void CuckooFilterFamily::InsertNx(CmdArgList args, const CommandContext& cmd_cntx) { + return InsertImplCommon(args, cmd_cntx, true); +} + +void CuckooFilterFamily::Exists(CmdArgList args, const CommandContext& cmd_cntx) { + return ExistsImplCommon(args, cmd_cntx, false); +} + +void CuckooFilterFamily::MExists(CmdArgList args, const CommandContext& cmd_cntx) { + return ExistsImplCommon(args, cmd_cntx, true); +} + +void CuckooFilterFamily::Del(CmdArgList args, const CommandContext& cmd_cntx) { + CmdArgParser parser{args}; + std::string_view key = parser.Next(); + std::string_view item = parser.Next(); + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpDel(t->GetOpArgs(shard), key, item); + }; + + auto result = cmd_cntx.tx->ScheduleSingleHopT(std::move(cb)); + auto* rb = static_cast(cmd_cntx.rb); + if (result) { + rb->SendLong(result.value() ? 1 : 0); + } else { + rb->SendError(result.status()); + } +} + +void CuckooFilterFamily::Count(CmdArgList args, const CommandContext& cmd_cntx) { + CmdArgParser parser{args}; + std::string_view key = parser.Next(); + std::string_view item = parser.Next(); + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpCount(t->GetOpArgs(shard), key, item); + }; + + auto result = cmd_cntx.tx->ScheduleSingleHopT(std::move(cb)); + auto* rb = static_cast(cmd_cntx.rb); + if (result) { + rb->SendLong(result.value()); + } else { + rb->SendError(result.status()); + } +} + +#define HFUNC(x) SetHandler(&CuckooFilterFamily::x) + +void CuckooFilterFamily::Register(CommandRegistry* registry) { + registry->StartFamily(); + + *registry + << CI{"CF.RESERVE", CO::WRITE | CO::DENYOOM, -3, 1, 1, acl::CUCKOO_FILTER}.HFUNC(Reserve) + << CI{"CF.ADD", CO::WRITE | CO::DENYOOM, 3, 1, 1, acl::CUCKOO_FILTER}.HFUNC(Add) + << CI{"CF.ADDNX", CO::WRITE | CO::DENYOOM, 3, 1, 1, acl::CUCKOO_FILTER}.HFUNC(AddNx) + << CI{"CF.INSERT", CO::WRITE | CO::DENYOOM, -4, 1, 1, acl::CUCKOO_FILTER}.HFUNC(Insert) + << CI{"CF.INSERTNX", CO::WRITE | CO::DENYOOM, -4, 1, 1, acl::CUCKOO_FILTER}.HFUNC(InsertNx) + << CI{"CF.EXISTS", CO::READONLY | CO::FAST, 3, 1, 1, acl::CUCKOO_FILTER}.HFUNC(Exists) + << CI{"CF.MEXISTS", CO::READONLY | CO::FAST, -3, 1, 1, acl::CUCKOO_FILTER}.HFUNC(MExists) + << CI{"CF.DEL", CO::WRITE, 3, 1, 1, acl::CUCKOO_FILTER}.HFUNC(Del) + << CI{"CF.COUNT", CO::READONLY | CO::FAST, 3, 1, 1, acl::CUCKOO_FILTER}.HFUNC(Count); +}; + +} // namespace dfly diff --git a/src/server/prob/cuckoo_filter_family.h b/src/server/prob/cuckoo_filter_family.h new file mode 100644 index 000000000000..13c276f31970 --- /dev/null +++ b/src/server/prob/cuckoo_filter_family.h @@ -0,0 +1,36 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include "server/common.h" + +namespace facade { +class SinkReplyBuilder; +} // namespace facade + +namespace dfly { + +class CommandRegistry; +struct CommandContext; + +class CuckooFilterFamily { + public: + static void Register(CommandRegistry* registry); + + private: + using SinkReplyBuilder = facade::SinkReplyBuilder; + + static void Reserve(CmdArgList args, const CommandContext& cmd_cntx); + static void Add(CmdArgList args, const CommandContext& cmd_cntx); + static void AddNx(CmdArgList args, const CommandContext& cmd_cntx); + static void Insert(CmdArgList args, const CommandContext& cmd_cntx); + static void InsertNx(CmdArgList args, const CommandContext& cmd_cntx); + static void Exists(CmdArgList args, const CommandContext& cmd_cntx); + static void MExists(CmdArgList args, const CommandContext& cmd_cntx); + static void Del(CmdArgList args, const CommandContext& cmd_cntx); + static void Count(CmdArgList args, const CommandContext& cmd_cntx); +}; + +} // namespace dfly diff --git a/src/server/prob/cuckoo_filter_family_test.cc b/src/server/prob/cuckoo_filter_family_test.cc new file mode 100644 index 000000000000..32d253b51b2f --- /dev/null +++ b/src/server/prob/cuckoo_filter_family_test.cc @@ -0,0 +1,63 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include + +#include "base/gtest.h" +#include "base/logging.h" +#include "facade/facade_test.h" +#include "server/tdigest_family.h" +#include "server/test_utils.h" + +using namespace testing; +using namespace util; + +namespace dfly { + +class CuckooFilterFamilyTest : public BaseFamilyTest { + protected: +}; + +TEST_F(CuckooFilterFamilyTest, Simple) { + auto resp = Run({"CF.RESERVE", "my_filter", "100"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"CF.ADD", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(1)); + + resp = Run({"CF.EXISTS", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(1)); + + resp = Run({"CF.ADD", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(1)); + + resp = Run({"CF.EXISTS", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(1)); + + resp = Run({"CF.COUNT", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(2)); + + resp = Run({"CF.DEL", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(1)); + + resp = Run({"CF.EXISTS", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(1)); + + resp = Run({"CF.COUNT", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(1)); + + resp = Run({"CF.DEL", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(1)); + + resp = Run({"CF.EXISTS", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(0)); + + resp = Run({"CF.COUNT", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(0)); + + resp = Run({"CF.DEL", "my_filter", "foo"}); + EXPECT_THAT(resp, IntArg(0)); +} + +} // namespace dfly diff --git a/src/server/rdb_save.cc b/src/server/rdb_save.cc index 8a17dc8972c9..351804bb45eb 100644 --- a/src/server/rdb_save.cc +++ b/src/server/rdb_save.cc @@ -272,6 +272,15 @@ error_code RdbSerializer::SelectDb(uint32_t dbid) { // Called by snapshot io::Result RdbSerializer::SaveEntry(const PrimeKey& pk, const PrimeValue& pv, uint64_t expire_ms, uint32_t mc_flags, DbIndex dbid) { + if (pv.ObjType() == OBJ_TDIGEST) { + return 0; + } + if (pv.ObjType() == OBJ_TOPK) { + return 0; + } + if (pv.ObjType() == OBJ_CUCKOO_FILTER) { + return 0; + } if (!pv.TagAllowsEmptyValue() && pv.Size() == 0) { string_view key = pk.GetSlice(&tmp_str_); LOG(DFATAL) << "SaveEntry skipped empty PrimeValue with key: " << key << " with tag " diff --git a/src/server/tdigest_family.cc b/src/server/tdigest_family.cc new file mode 100644 index 000000000000..ea9195e47adb --- /dev/null +++ b/src/server/tdigest_family.cc @@ -0,0 +1,656 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/tdigest_family.h" + +#include "facade/cmd_arg_parser.h" +#include "facade/error.h" +#include "server/acl/acl_commands_def.h" +#include "server/command_registry.h" +#include "server/db_slice.h" +#include "server/engine_shard.h" +#include "server/engine_shard_set.h" +#include "server/transaction.h" + +extern "C" { +#include "redis/tdigest.h" +} + +namespace dfly { + +void TDigestFamily::Create(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + size_t compression = 50; + if (parser.HasNext() && !parser.Check("COMPRESSION", &compression)) { + return cmd_cntx.rb->SendError(facade::kSyntaxErr); + } + + auto cb = [key, compression](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto res = db_slice.AddOrFind(db_cntx, key); + if (!res) { + return res.status(); + } + + if (!res->is_new) { + return OpStatus::KEY_EXISTS; + } + + td_histogram_t* td = td_new(compression); + // DENYOOM should cover this + if (!td) { + db_slice.Del(db_cntx, res->it); + return OpStatus::OUT_OF_MEMORY; + } + res->it->second.InitRobj(OBJ_TDIGEST, 0, td); + return OpStatus::OK; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + // SendError covers ok + if (res.status() == OpStatus::KEY_EXISTS) { + return cmd_cntx.rb->SendError("key already exists"); + } + return cmd_cntx.rb->SendError(res.status()); +} + +void TDigestFamily::Add(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + std::vector values; + while (parser.HasNext()) { + double val = parser.Next(); + values.push_back(val); + } + + if (parser.HasError()) { + return cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + } + + auto cb = [key, &values](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TDIGEST); + if (!it) { + return it.status(); + } + auto* wrapper = it->it->second.GetRobjWrapper(); + auto* td = (td_histogram_t*)wrapper->inner_obj(); + for (auto value : values) { + if (td_add(td, value, 1) != 0) { + return OpStatus::OUT_OF_RANGE; + } + } + return OpStatus::OK; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + // SendError covers ok + return cmd_cntx.rb->SendError(res.status()); +} + +double TdGetByRank(td_histogram_t* td, double total_obs, double rnk) { + const double input_p = rnk / total_obs; + return td_quantile(td, input_p); +} + +void ByRankImpl(CmdArgList args, const CommandContext& cmd_cntx, bool reverse) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + std::vector ranks; + while (parser.HasNext()) { + double val = parser.Next(); + ranks.push_back(val); + } + + if (parser.HasError()) { + cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + } + + using ByRankResult = std::vector; + auto cb = [key, reverse, &ranks](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TDIGEST); + if (!it) { + return it.status(); + } + auto* wrapper = it->it->second.GetRobjWrapper(); + auto* td = (td_histogram_t*)wrapper->inner_obj(); + const size_t size = td_size(td); + const double min = td_min(td); + const double max = td_max(td); + + ByRankResult result; + for (auto rnk : ranks) { + if (size == 0) { + result.push_back(NAN); + } else if (rnk == 0) { + result.push_back(reverse ? max : min); + } else if (rnk >= size) { + result.push_back(reverse ? -INFINITY : INFINITY); + } else { + result.push_back(TdGetByRank(td, size, reverse ? (size - rnk - 1) : rnk)); + } + } + return result; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + + auto* rb = static_cast(cmd_cntx.rb); + rb->StartArray(res->size()); + for (auto res : *res) { + rb->SendDouble(res); + } +} + +void TDigestFamily::ByRank(CmdArgList args, const CommandContext& cmd_cntx) { + return ByRankImpl(args, cmd_cntx, false); +} + +void TDigestFamily::ByRevRank(CmdArgList args, const CommandContext& cmd_cntx) { + return ByRankImpl(args, cmd_cntx, true); +} + +struct InfoResult { + double compression; + int cap; + int merged_nodes; + int unmerged_nodes; + int64_t merged_weight; + int64_t unmerged_weight; + int64_t observations; + int64_t total_compressions; + size_t mem_usage; +}; + +void TDigestFamily::Info(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + + auto cb = [key](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindReadOnly(db_cntx, key, OBJ_TDIGEST); + if (!it) { + return it.status(); + } + auto* wrapper = it->GetInnerIt()->second.GetRobjWrapper(); + auto* td = (td_histogram_t*)wrapper->inner_obj(); + InfoResult res; + res.compression = td->compression; + res.cap = td->cap; + res.merged_nodes = td->merged_nodes; + res.unmerged_nodes = td->unmerged_nodes; + res.merged_weight = td->merged_weight; + res.unmerged_weight = td->unmerged_weight; + res.observations = res.unmerged_weight + res.merged_weight; + res.total_compressions = td->total_compressions; + res.mem_usage = wrapper->MallocUsed(false); + return res; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + auto* rb = static_cast(cmd_cntx.rb); + rb->StartArray(9 * 2); + rb->SendSimpleString("Compression"); + rb->SendLong(res->compression); + rb->SendSimpleString("Capacity"); + rb->SendLong(res->cap); + rb->SendSimpleString("Merged nodes"); + rb->SendLong(res->merged_nodes); + rb->SendSimpleString("Unmerged nodes"); + rb->SendLong(res->unmerged_nodes); + rb->SendSimpleString("Merged weight"); + rb->SendLong(res->merged_weight); + rb->SendSimpleString("Unmerged weight"); + rb->SendLong(res->unmerged_weight); + rb->SendSimpleString("Observations"); + rb->SendLong(res->observations); + rb->SendSimpleString("Total compressions"); + rb->SendLong(res->total_compressions); + rb->SendSimpleString("Memory usage"); + rb->SendLong(res->mem_usage); +} + +struct MinMax { + double min = 0; + double max = 0; +}; + +OpResult MinMaxImpl(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + + auto cb = [key](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindReadOnly(db_cntx, key, OBJ_TDIGEST); + if (!it) { + return it.status(); + } + auto* wrapper = it->GetInnerIt()->second.GetRobjWrapper(); + auto* td = (td_histogram_t*)wrapper->inner_obj(); + const double min = (td_size(td) > 0) ? td_min(td) : NAN; + const double max = (td_size(td) > 0) ? td_max(td) : NAN; + return MinMax{min, max}; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + return res; +} + +void TDigestFamily::Max(CmdArgList args, const CommandContext& cmd_cntx) { + auto res = MinMaxImpl(args, cmd_cntx); + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + auto* rb = static_cast(cmd_cntx.rb); + rb->SendDouble(res->max); +} + +void TDigestFamily::Min(CmdArgList args, const CommandContext& cmd_cntx) { + auto res = MinMaxImpl(args, cmd_cntx); + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + auto* rb = static_cast(cmd_cntx.rb); + rb->SendDouble(res->min); +} + +void TDigestFamily::Reset(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + auto key = parser.Next(); + + auto cb = [key](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TDIGEST); + if (!it) { + return it.status(); + } + auto* wrapper = it->it->second.GetRobjWrapper(); + auto* td = (td_histogram_t*)wrapper->inner_obj(); + td_reset(td); + return OpStatus::OK; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + // SendError covers ok + return cmd_cntx.rb->SendError(res.status()); +} + +double HalfRoundDown(double f) { + double round; + double frac = modf(f, &round); + + if (fabs(frac) <= 0.5) + return round; + + if (round >= 0.0) { + return round + 1.0; + } + + return round - 1.0; +} + +void RankImpl(CmdArgList args, const CommandContext& cmd_cntx, bool reverse) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + std::vector ranks; + while (parser.HasNext()) { + double val = parser.Next(); + ranks.push_back(val); + } + + if (parser.HasError()) { + cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + } + + using RankResult = std::vector; + auto cb = [key, reverse, &ranks](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TDIGEST); + if (!it) { + return it.status(); + } + auto* wrapper = it->it->second.GetRobjWrapper(); + auto* td = (td_histogram_t*)wrapper->inner_obj(); + const size_t size = td_size(td); + const double min = td_min(td); + const double max = td_max(td); + + RankResult result; + + for (auto rnk : ranks) { + if (size == 0) { + result.push_back(-2); + } else if (rnk < min) { + result.push_back(reverse ? size : -1); + } else if (rnk > max) { + result.push_back(reverse ? -1 : size); + } else { + const double cdf_val = td_cdf(td, rnk); + const double cdf_val_prior_round = cdf_val * size; + const size_t cdf_to_absolute = + reverse ? round(cdf_val_prior_round) : HalfRoundDown(cdf_val_prior_round); + const size_t res = reverse ? round(size - cdf_to_absolute) : cdf_to_absolute; + result.push_back(res); + } + } + return result; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + // SendError covers ok + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + auto* rb = static_cast(cmd_cntx.rb); + rb->StartArray(res->size()); + for (auto res : *res) { + rb->SendLong(res); + } +} + +void TDigestFamily::Rank(CmdArgList args, const CommandContext& cmd_cntx) { + return RankImpl(args, cmd_cntx, false); +} + +void TDigestFamily::RevRank(CmdArgList args, const CommandContext& cmd_cntx) { + return RankImpl(args, cmd_cntx, true); +} + +void TDigestFamily::Cdf(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + std::vector values; + while (parser.HasNext()) { + double val = parser.Next(); + values.push_back(val); + } + + if (parser.HasError()) { + cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + } + + using Result = std::vector; + auto cb = [key, &values](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TDIGEST); + if (!it) { + return it.status(); + } + auto* wrapper = it->it->second.GetRobjWrapper(); + auto* td = (td_histogram_t*)wrapper->inner_obj(); + Result result; + for (auto val : values) { + result.push_back(td_cdf(td, val)); + } + return result; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + + auto* rb = static_cast(cmd_cntx.rb); + rb->StartArray(res->size()); + for (auto res : *res) { + rb->SendDouble(res); + } +} + +void TDigestFamily::Quantile(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + std::vector quantiles; + while (parser.HasNext()) { + double val = parser.Next(); + if (val < 0 || val > 1.0) { + cmd_cntx.rb->SendError("quantile should be in [0,1]"); + } + quantiles.push_back(val); + } + + if (parser.HasError()) { + cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + } + + using Result = std::vector; + auto cb = [key, &quantiles](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TDIGEST); + if (!it) { + return it.status(); + } + auto* wrapper = it->it->second.GetRobjWrapper(); + auto* td = (td_histogram_t*)wrapper->inner_obj(); + Result result; + result.resize(quantiles.size()); + auto total = quantiles.size(); + for (size_t i = 0; i < total; ++i) { + int start = i; + while (i < total - 1 && quantiles[i] <= quantiles[i + 1]) { + ++i; + } + td_quantiles(td, quantiles.data() + start, result.data() + start, i - start + 1); + } + return result; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + + auto* rb = static_cast(cmd_cntx.rb); + rb->StartArray(res->size()); + for (auto res : *res) { + rb->SendDouble(res); + } +} + +void TDigestFamily::TrimmedMean(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + auto low_cut = parser.Next(); + auto high_cut = parser.Next(); + auto out_of_range = [](auto e) { return e < 0 || e > 1.0; }; + + if (out_of_range(low_cut) || out_of_range(high_cut)) { + cmd_cntx.rb->SendError("cut value should be in [0,1]"); + } + + if (parser.Error()) { + cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + } + + auto cb = [key, high_cut, low_cut](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TDIGEST); + if (!it) { + return it.status(); + } + auto* wrapper = it->it->second.GetRobjWrapper(); + auto* td = (td_histogram_t*)wrapper->inner_obj(); + const double value = td_trimmed_mean(td, low_cut, high_cut); + return value; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + + auto* rb = static_cast(cmd_cntx.rb); + rb->SendDouble(*res); +} + +struct MergeInput { + bool has_compression_defined = false; + size_t compression = 0; + bool override = false; +}; + +void TDigestFamily::Merge(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto dest_key = parser.Next(); + auto total_keys = parser.Next(); + MergeInput input; + std::vector source_keys; + for (size_t i = 0; i < total_keys; ++i) { + source_keys.push_back(parser.Next()); + } + + if (parser.HasNext()) { + if (!parser.Check("COMPRESSION", &input.compression)) { + return cmd_cntx.rb->SendError(facade::kSyntaxErr); + } + input.has_compression_defined = true; + } + + if (parser.HasNext()) { + if (!parser.Check("OVERRIDE")) { + return cmd_cntx.rb->SendError(facade::kSyntaxErr); + } + input.override = true; + } + + if (parser.Error()) { + cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + return; + } + + using PerThread = std::vector; + std::vector errors(shard_set->size(), false); + std::vector sources(shard_set->size()); + + auto cb = [&errors, &sources, input, dest_key](Transaction* t, EngineShard* es) { + auto id = es->shard_id(); + ShardArgs keys = t->GetShardArgs(id); + auto& db_slice = t->GetDbSlice(id); + auto db_cntx = t->GetDbContext(); + for (auto key : keys) { + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TDIGEST); + if (key == dest_key && it.status() != OpStatus::WRONG_TYPE) { + continue; + } + if (!it) { + errors[id] = true; + return OpStatus::OK; + } + auto* wrapper = it->it->second.GetRobjWrapper(); + auto* td = (td_histogram_t*)wrapper->inner_obj(); + sources[id].push_back(td); + } + return OpStatus::OK; + }; + + cmd_cntx.tx->Execute(std::move(cb), false); + + for (auto error : errors) { + if (error) { + cmd_cntx.rb->SendError(facade::kKeyNotFoundErr); + cmd_cntx.tx->Conclude(); + return; + } + } + + auto dest_shard = Shard(dest_key, sources.size()); + + auto hop_cb = [&sources, dest_key, input, dest_shard](Transaction* tx, EngineShard* es) { + if (es->shard_id() != dest_shard) { + return OpStatus::OK; + } + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto res = db_slice.AddOrFind(db_cntx, dest_key); + + double compression = 100; + if (input.has_compression_defined) { + compression = input.compression; + } else if (input.override) { + for (auto& t : sources) { + for (auto* hist : t) { + compression = std::max(compression, hist->compression); + } + } + } + + td_histogram_t* td_dst = nullptr; + td_histogram_t* td_result = nullptr; + + td_init(compression, &td_result); + if (!res->is_new && !input.override) { + auto* wrapper = res->it->second.GetRobjWrapper(); + td_dst = (td_histogram_t*)wrapper->inner_obj(); + td_merge(td_result, td_dst); + } + + for (auto& t : sources) { + for (auto* hist : t) { + td_merge(td_result, hist); + } + } + + res->it->second.Reset(); + res->it->second.InitRobj(OBJ_TDIGEST, 0, td_result); + + return OpStatus::OK; + }; + + cmd_cntx.tx->Execute(std::move(hop_cb), true); + cmd_cntx.rb->SendOk(); +} + +using CI = CommandId; + +#define HFUNC(x) SetHandler(&TDigestFamily::x) + +void TDigestFamily::Register(CommandRegistry* registry) { + registry->StartFamily(); + + *registry << CI{"TDIGEST.CREATE", CO::WRITE | CO::DENYOOM, -1, 1, 1, acl::TDIGEST}.HFUNC(Create) + << CI{"TDIGEST.ADD", CO::WRITE | CO::DENYOOM, -2, 1, 1, acl::TDIGEST}.HFUNC(Add) + << CI{"TDIGEST.RESET", CO::WRITE, 2, 1, 1, acl::TDIGEST}.HFUNC(Reset) + << CI{"TDIGEST.CDF", CO::READONLY, -3, 1, 1, acl::TDIGEST}.HFUNC(Cdf) + << CI{"TDIGEST.RANK", CO::READONLY, -3, 1, 1, acl::TDIGEST}.HFUNC(Rank) + << CI{"TDIGEST.REVRANK", CO::READONLY, -3, 1, 1, acl::TDIGEST}.HFUNC(RevRank) + << CI{"TDIGEST.BYRANK", CO::READONLY, -3, 1, 1, acl::TDIGEST}.HFUNC(ByRank) + << CI{"TDIGEST.BYREVRANK", CO::READONLY, -3, 1, 1, acl::TDIGEST}.HFUNC(ByRevRank) + << CI{"TDIGEST.INFO", CO::READONLY, 2, 1, 1, acl::TDIGEST}.HFUNC(Info) + << CI{"TDIGEST.MAX", CO::READONLY, 2, 1, 1, acl::TDIGEST}.HFUNC(Max) + << CI{"TDIGEST.MIN", CO::READONLY, 2, 1, 1, acl::TDIGEST}.HFUNC(Min) + << CI{"TDIGEST.TRIMMED_MEAN", CO::READONLY, 4, 1, 1, acl::TDIGEST}.HFUNC(TrimmedMean) + << CI{"TDIGEST.MERGE", CO::WRITE | CO::VARIADIC_KEYS, -3, 3, 3, acl::TDIGEST}.HFUNC( + Merge) + << CI{"TDIGEST.QUANTILE", CO::READONLY, -3, 1, 1, acl::TDIGEST}.HFUNC(Quantile); +}; + +} // namespace dfly diff --git a/src/server/tdigest_family.h b/src/server/tdigest_family.h new file mode 100644 index 000000000000..28f0b23fcc96 --- /dev/null +++ b/src/server/tdigest_family.h @@ -0,0 +1,35 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include "server/common.h" + +namespace dfly { + +class CommandRegistry; +struct CommandContext; + +class TDigestFamily { + public: + static void Register(CommandRegistry* registry); + + private: + static void Create(CmdArgList args, const CommandContext& cmd_cntx); + static void Add(CmdArgList args, const CommandContext& cmd_cntx); + static void Rank(CmdArgList args, const CommandContext& cmd_cntx); + static void RevRank(CmdArgList args, const CommandContext& cmd_cntx); + static void ByRank(CmdArgList args, const CommandContext& cmd_cntx); + static void ByRevRank(CmdArgList args, const CommandContext& cmd_cntx); + static void Reset(CmdArgList args, const CommandContext& cmd_cntx); + static void Info(CmdArgList args, const CommandContext& cmd_cntx); + static void Max(CmdArgList args, const CommandContext& cmd_cntx); + static void Min(CmdArgList args, const CommandContext& cmd_cntx); + static void Cdf(CmdArgList args, const CommandContext& cmd_cntx); + static void Quantile(CmdArgList args, const CommandContext& cmd_cntx); + static void TrimmedMean(CmdArgList args, const CommandContext& cmd_cntx); + static void Merge(CmdArgList args, const CommandContext& cmd_cntx); +}; + +} // namespace dfly diff --git a/src/server/tdigest_family_test.cc b/src/server/tdigest_family_test.cc new file mode 100644 index 000000000000..c54e0f9537c1 --- /dev/null +++ b/src/server/tdigest_family_test.cc @@ -0,0 +1,209 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/tdigest_family.h" + +#include + +#include "base/gtest.h" +#include "base/logging.h" +#include "facade/facade_test.h" +#include "server/test_utils.h" + +using namespace testing; +using namespace util; + +namespace dfly { + +class TDigestFamilyTest : public BaseFamilyTest { + protected: +}; + +TEST_F(TDigestFamilyTest, Basic) { + // errors + std::string err = "ERR wrong number of arguments for 'tdigest.create' command"; + ASSERT_THAT(Run({"TDIGEST.CREATE", "k1", "k2"}), ErrArg("ERR syntax error")); + // Triggers a check in InitByArgs -- a logical error + // TODO fix this + // ASSERT_THAT(Run({"TDIGEST.CREATE"}), ErrArg(err)); + + auto resp = Run({"TDIGEST.CREATE", "k1"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"TDIGEST.CREATE", "k1", "COMPRESSION", "200"}); + ASSERT_THAT(resp, ErrArg("ERR key already exists")); + + resp = Run({"TDIGEST.CREATE", "k2", "COMPRESSION", "200"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"TDIGEST.ADD", "k1", "10.0", "20.0"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"TDIGEST.ADD", "k2", "30.0", "40.0"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"TDIGEST.RESET", "k1"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"TDIGEST.RESET", "k2"}); + EXPECT_EQ(resp, "OK"); +} + +TEST_F(TDigestFamilyTest, Merge) { + auto resp = Run({"TDIGEST.CREATE", "k1"}); + resp = Run({"TDIGEST.CREATE", "k2"}); + + Run({"TDIGEST.ADD", "k1", "10.0", "20.0"}); + Run({"TDIGEST.ADD", "k2", "30.0", "40.0"}); + + resp = Run({"TDIGEST.MERGE", "res", "2", "k1", "k2"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"TDIGEST.BYRANK", "res", "0", "1", "2", "3", "4"}); + auto results = resp.GetVec(); + ASSERT_THAT(results, ElementsAre(DoubleArg(10), DoubleArg(20), DoubleArg(30), DoubleArg(40), + DoubleArg(INFINITY))); + + resp = Run({"TDIGEST.INFO", "res"}); + results = resp.GetVec(); + ASSERT_THAT(results, + ElementsAre("Compression", 100, "Capacity", 610, "Merged nodes", 4, "Unmerged nodes", + 0, "Merged weight", 4, "Unmerged weight", 0, "Observations", 4, + "Total compressions", 2, "Memory usage", 9768)); + + Run({"TDIGEST.CREATE", "k3"}); + Run({"TDIGEST.CREATE", "k4"}); + Run({"TDIGEST.CREATE", "k5"}); + Run({"TDIGEST.CREATE", "k6"}); + + Run({"TDIGEST.ADD", "k3", "11.0", "21.0"}); + Run({"TDIGEST.ADD", "k4", "31.1", "40.1"}); + Run({"TDIGEST.ADD", "k5", "10.0", "20.0"}); + Run({"TDIGEST.ADD", "k6", "32.2", "42.1"}); + + // OVERIDE overides the key + // compression sets the compression level + resp = Run({"TDIGEST.MERGE", "res", "6", "k1", "k2", "k3", "k4", "k5", "k6", "COMPRESSION", "50", + "OVERRIDE"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"TDIGEST.INFO", "res"}); + results = resp.GetVec(); + ASSERT_THAT(results, + ElementsAre("Compression", IntArg(50), "Capacity", IntArg(310), "Merged nodes", + IntArg(10), "Unmerged nodes", IntArg(2), "Merged weight", IntArg(10), + "Unmerged weight", IntArg(2), "Observations", IntArg(12), + "Total compressions", IntArg(5), "Memory usage", IntArg(4968))); + + Run({"SET", "foo", "bar"}); + resp = Run({"TDIGEST.MERGE", "foo", "2", "k1", "k2"}); + ASSERT_THAT(resp, ErrArg("ERR no such key")); + resp = Run({"TDIGEST.MERGE", "k1", "2", "foo", "k2"}); + ASSERT_THAT(resp, ErrArg("ERR no such key")); +} + +TEST_F(TDigestFamilyTest, MinMax) { + // errors + std::string min_err = "ERR wrong number of arguments for 'tdigest.min' command"; + std::string max_err = "ERR wrong number of arguments for 'tdigest.max' command"; + ASSERT_THAT(Run({"TDIGEST.MAX", "k1", "k2"}), ErrArg(max_err)); + ASSERT_THAT(Run({"TDIGEST.MAX"}), ErrArg(max_err)); + ASSERT_THAT(Run({"TDIGEST.MIN", "k1", "k2"}), ErrArg(min_err)); + ASSERT_THAT(Run({"TDIGEST.MIN"}), ErrArg(min_err)); + + Run({"TDIGEST.CREATE", "k1"}); + Run({"TDIGEST.ADD", "k1", "10.0", "22.0", "33.0", "44.4", "55.5"}); + + ASSERT_THAT(Run({"TDIGEST.MIN", "k1"}), DoubleArg(10)); + ASSERT_THAT(Run({"TDIGEST.MAX", "k1"}), DoubleArg(55.5)); +} + +TEST_F(TDigestFamilyTest, Rank) { + // errors + auto error = [](std::string_view msg) { + std::string err = "ERR wrong number of arguments for "; + return absl::StrCat(err, "'", msg, "'", " command"); + }; + ASSERT_THAT(Run({"TDIGEST.RANK", "k1"}), ErrArg(error("tdigest.rank"))); + ASSERT_THAT(Run({"TDIGEST.REVRANK", "k1"}), ErrArg(error("tdigest.revrank"))); + ASSERT_THAT(Run({"TDIGEST.BYRANK", "k1"}), ErrArg(error("tdigest.byrank"))); + ASSERT_THAT(Run({"TDIGEST.BYREVRANK", "k1"}), ErrArg(error("tdigest.byrevrank"))); + + Run({"TDIGEST.CREATE", "k1"}); + Run({"TDIGEST.ADD", "k1", "10.0", "22.0", "33.0", "44.4", "55.5"}); + + auto resp = Run({"TDIGEST.BYRANK", "k1", "0", "1", "2", "3", "4", "5"}); + auto results = resp.GetVec(); + ASSERT_THAT(results, ElementsAre(DoubleArg(10), DoubleArg(22), DoubleArg(33), DoubleArg(44.4), + DoubleArg(55.5), DoubleArg(INFINITY))); + + resp = Run({"TDIGEST.BYREVRANK", "k1", "0", "1", "2", "3", "4", "5"}); + results = resp.GetVec(); + ASSERT_THAT(results, ElementsAre(DoubleArg(55.5), DoubleArg(44.4), DoubleArg(33), DoubleArg(22), + DoubleArg(10), DoubleArg(-INFINITY))); + + ASSERT_THAT(Run({"TDIGEST.RANK", "k1", "1"}), IntArg(-1)); + ASSERT_THAT(Run({"TDIGEST.REVRANK", "k1", "1"}), IntArg(5)); + + ASSERT_THAT(Run({"TDIGEST.RANK", "k1", "50"}), IntArg(4)); + ASSERT_THAT(Run({"TDIGEST.REVRANK", "k1", "50"}), IntArg(1)); +} + +TEST_F(TDigestFamilyTest, Cdf) { + Run({"TDIGEST.CREATE", "k1"}); + // errors + std::string err = "ERR wrong number of arguments for 'tdigest.cdf' command"; + ASSERT_THAT(Run({"TDIGEST.CDF", "k1"}), ErrArg(err)); + + Run({"TDIGEST.ADD", "k1", "1", "2", "2", "3", "3", "3", "4", "4", "4", "4", "5", "5", "5", "5", + "5"}); + + auto resp = Run({"TDIGEST.CDF", "k1", "0", "1", "2", "3", "4", "5", "6"}); + + const auto& results = resp.GetVec(); + ASSERT_THAT(results, ElementsAre(DoubleArg(0), DoubleArg(0.033333333333333333), + DoubleArg(0.13333333333333333), DoubleArg(0.29999999999999999), + DoubleArg(0.53333333333333333), DoubleArg(0.83333333333333337), + DoubleArg(1))); +} + +TEST_F(TDigestFamilyTest, Quantile) { + Run({"TDIGEST.CREATE", "k1"}); + // errors + std::string err = "ERR wrong number of arguments for 'tdigest.quantile' command"; + + ASSERT_THAT(Run({"TDIGEST.QUANTILE", "k1"}), ErrArg(err)); + + Run({"TDIGEST.ADD", "k1", "1", "2", "2", "3", "3", "3", "4", "4", "4", "4", "5", "5", "5", "5", + "5"}); + + auto resp = Run({"TDIGEST.QUANTILE", "k1", "0", "0.1", "0.2", "0.3", "0.4", "0.5", "0.6", "0.7", + "0.8", "0.9", "1"}); + + const auto& results = resp.GetVec(); + ASSERT_THAT(results, ElementsAre(DoubleArg(1), DoubleArg(2), DoubleArg(3), DoubleArg(3), + DoubleArg(4), DoubleArg(4), DoubleArg(4), DoubleArg(5), + DoubleArg(5), DoubleArg(5), DoubleArg(5))); +} + +TEST_F(TDigestFamilyTest, TrimmedMean) { + Run({"TDIGEST.CREATE", "k1", "compression", "1000"}); + // errors + std::string err = "ERR wrong number of arguments for 'tdigest.trimmed_mean' command"; + + ASSERT_THAT(Run({"TDIGEST.TRIMMED_MEAN", "k1"}), ErrArg(err)); + ASSERT_THAT(Run({"TDIGEST.TRIMMED_MEAN", "k1", "0.1"}), ErrArg(err)); + + Run({"TDIGEST.ADD", "k1", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}); + + auto resp = Run({"TDIGEST.TRIMMED_MEAN", "k1", "0.1", "0.6"}); + ASSERT_THAT(resp, DoubleArg(4)); + + resp = Run({"TDIGEST.TRIMMED_MEAN", "k1", "0.3", "0.9"}); + ASSERT_THAT(resp, DoubleArg(6.5)); + + resp = Run({"TDIGEST.TRIMMED_MEAN", "k1", "0", "1"}); + ASSERT_THAT(resp, DoubleArg(5.5)); +} +} // namespace dfly diff --git a/src/server/topk_family.cc b/src/server/topk_family.cc new file mode 100644 index 000000000000..2219834aacf2 --- /dev/null +++ b/src/server/topk_family.cc @@ -0,0 +1,269 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/topk_family.h" + +#include + +#include "facade/cmd_arg_parser.h" +#include "facade/error.h" +#include "server/acl/acl_commands_def.h" +#include "server/command_registry.h" +#include "server/db_slice.h" +#include "server/engine_shard.h" +#include "server/engine_shard_set.h" +#include "server/transaction.h" + +namespace dfly { + +void TopKeysFamily::Reserve(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + auto total_elements = parser.Next(); + if (parser.HasError()) { + return cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + } + + auto cb = [key, total_elements](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto res = db_slice.AddOrFind(db_cntx, key); + if (!res) { + return res.status(); + } + + if (!res->is_new) { + return OpStatus::KEY_EXISTS; + } + + res->it->second.SetTopK(total_elements, (1 << 16), 4, 1.08); + return OpStatus::OK; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + // SendError covers ok + if (res.status() == OpStatus::KEY_EXISTS) { + return cmd_cntx.rb->SendError("key already exists"); + } + return cmd_cntx.rb->SendOk(); +} + +void TopKeysFamily::Add(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + std::vector values; + while (parser.HasNext()) { + auto val = parser.Next(); + values.push_back(val); + } + + if (parser.HasError()) { + return cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + } + + using Result = std::vector; + auto cb = [key, &values](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TOPK); + if (!it) { + return it.status(); + } + auto* topk = it->it->second.GetTopK(); + Result results; + results.reserve(values.size()); + for (auto val : values) { + // TODO return key if removed because of an exponential delay + topk->Touch(val); + results.push_back("nil"); + } + return results; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + auto* rb = static_cast(cmd_cntx.rb); + // TODO fix return reply once Touch signature changes. See comment in cb above + rb->StartArray(res->size()); + for ([[maybe_unused]] const auto& reply : *res) { + rb->SendNull(); + } +} + +void TopKeysFamily::Query(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + absl::flat_hash_map items; + std::vector results; + while (parser.HasNext()) { + auto val = parser.Next(); + items[val] = false; + results.push_back(val); + } + + if (parser.HasError()) { + return cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + } + + auto cb = [key, &items](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TOPK); + if (!it) { + return it.status(); + } + auto* topk = it->it->second.GetTopK(); + topk->Query(&items); + return OpStatus::OK; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + auto* rb = static_cast(cmd_cntx.rb); + rb->StartArray(results.size()); + for (auto res : results) { + DCHECK(items.contains(res)); + if (items.find(res)->second) { + rb->SendLong(1); + continue; + } + rb->SendLong(0); + } +} + +void TopKeysFamily::List(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + // if (parser.HasError()) { + // return cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + // } + + // even declval() is too verbose here :scream: + // using Result = std::invoke_result_t; + using Result = absl::flat_hash_map; + auto cb = [key](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TOPK); + if (!it) { + return it.status(); + } + auto* topk = it->it->second.GetTopK(); + return topk->GetTopKeys(); + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + auto* rb = static_cast(cmd_cntx.rb); + rb->StartArray(res->size()); + for (const auto& reply : *res) { + rb->SendSimpleString(reply.first); + } +} + +void TopKeysFamily::Info(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + + auto cb = [key](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TOPK); + if (!it) { + return it.status(); + } + auto* topk = it->it->second.GetTopK(); + return topk->GetOptions(); + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + if (!res) { + return cmd_cntx.rb->SendError(res.status()); + } + auto* rb = static_cast(cmd_cntx.rb); + rb->StartArray(4 * 2); + rb->SendSimpleString("k"); + rb->SendLong(res->buckets * res->depth); + rb->SendSimpleString("width"); + rb->SendLong(res->buckets); + rb->SendSimpleString("depth"); + rb->SendLong(res->depth); + rb->SendSimpleString("decay"); + rb->SendDouble(res->decay_base); +} + +struct IncrByT { + std::string_view name; + size_t incr; +}; + +void TopKeysFamily::IncrBy(CmdArgList args, const CommandContext& cmd_cntx) { + facade::CmdArgParser parser{args}; + + auto key = parser.Next(); + std::vector incrs; + while (parser.HasNext()) { + auto name = parser.Next(); + auto val = parser.Next(); + if (val < 1 || val > 100000) { + return cmd_cntx.rb->SendError(facade::kInvalidIntErr); + } + incrs.push_back({name, val}); + } + + if (parser.HasError()) { + return cmd_cntx.rb->SendError(parser.Error()->MakeReply()); + } + + using Result = std::vector; + auto cb = [key, &incrs](Transaction* tx, EngineShard* es) -> OpResult { + auto& db_slice = tx->GetDbSlice(es->shard_id()); + auto db_cntx = tx->GetDbContext(); + auto it = db_slice.FindMutable(db_cntx, key, OBJ_TOPK); + if (!it) { + return it.status(); + } + auto* topk = it->it->second.GetTopK(); + Result results; + for (auto [name, incr] : incrs) { + // TODO return key if removed because of an exponential delay + topk->Touch(name, incr); + results.push_back("nil"); + } + return results; + }; + + auto res = cmd_cntx.tx->ScheduleSingleHopT(cb); + // Todo we should sent an array reply instead. If a key got evicted from TopKeys while it was + // touched/incremented we should return that. Otherwise `nill` + return cmd_cntx.rb->SendError(res.status()); +} + +using CI = CommandId; + +#define HFUNC(x) SetHandler(&TopKeysFamily::x) + +void TopKeysFamily::Register(CommandRegistry* registry) { + registry->StartFamily(); + + *registry << CI{"TOPK.RESERVE", CO::WRITE | CO::DENYOOM, -3, 1, 1, acl::TOPK}.HFUNC(Reserve) + << CI{"TOPK.LIST", CO::READONLY, -2, 1, 1, acl::TOPK}.HFUNC(List) + << CI{"TOPK.QUERY", CO::READONLY, -2, 1, 1, acl::TOPK}.HFUNC(Query) + << CI{"TOPK.INFO", CO::READONLY, 2, 1, 1, acl::TOPK}.HFUNC(Info) + << CI{"TOPK.INCRBY", CO::WRITE | CO::DENYOOM, -4, 1, 1, acl::TOPK}.HFUNC(IncrBy) + << CI{"TOPK.ADD", CO::WRITE | CO::DENYOOM, -3, 1, 1, acl::TOPK}.HFUNC(Add); +}; + +} // namespace dfly diff --git a/src/server/topk_family.h b/src/server/topk_family.h new file mode 100644 index 000000000000..7881aa394f71 --- /dev/null +++ b/src/server/topk_family.h @@ -0,0 +1,29 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include "server/common.h" + +namespace dfly { + +class CommandRegistry; +struct CommandContext; + +class TopKeysFamily { + public: + static void Register(CommandRegistry* registry); + + private: + static void Reserve(CmdArgList args, const CommandContext& cmd_cntx); + static void Add(CmdArgList args, const CommandContext& cmd_cntx); + static void List(CmdArgList args, const CommandContext& cmd_cntx); + static void Query(CmdArgList args, const CommandContext& cmd_cntx); + static void Info(CmdArgList args, const CommandContext& cmd_cntx); + static void IncrBy(CmdArgList args, const CommandContext& cmd_cntx); + // The following are deprecated + // static void Count(CmdArgList args, const CommandContext& cmd_cntx); +}; + +} // namespace dfly diff --git a/src/server/topk_family_test.cc b/src/server/topk_family_test.cc new file mode 100644 index 000000000000..5acc55c1944d --- /dev/null +++ b/src/server/topk_family_test.cc @@ -0,0 +1,71 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/topk_family.h" + +#include + +#include "base/gtest.h" +#include "base/logging.h" +#include "facade/facade_test.h" +#include "server/test_utils.h" + +using namespace testing; +using namespace util; + +namespace dfly { + +class TopKFamilyTest : public BaseFamilyTest { + protected: +}; + +TEST_F(TopKFamilyTest, Basic) { + // errors + std::string err = "ERR wrong number of arguments for 'topk.reserve' command"; + auto resp = Run({"TOPK.RESERVE", "k1"}); + ASSERT_THAT(resp, ErrArg(err)); + + resp = Run({"TOPK.RESERVE", "k1", "12"}); + EXPECT_EQ(resp, "OK"); + + err = "ERR wrong number of arguments for 'topk.info' command"; + resp = Run({"TOPK.INFO"}); + ASSERT_THAT(resp, ErrArg(err)); + + resp = Run({"TOPK.INFO", "k1"}); + auto v = resp.GetVec(); + ASSERT_THAT(v, ElementsAre("k", 12, "width", 3, "depth", 4, "decay", DoubleArg(1.08))); + + err = "ERR wrong number of arguments for 'topk.add' command"; + resp = Run({"TOPK.ADD", "k1"}); + ASSERT_THAT(resp, ErrArg(err)); + + resp = Run({"TOPK.ADD", "k1", "foo", "bar", "fooz"}); + v = resp.GetVec(); + // TODO fix reply of command when an element of a cell is replaced with the one added here + ASSERT_THAT(v, + ElementsAre(ArgType(RespExpr::NIL), ArgType(RespExpr::NIL), ArgType(RespExpr::NIL))); + // First time nothing is added + resp = Run({"TOPK.QUERY", "k1", "foo", "bar", "fooz"}); + v = resp.GetVec(); + ASSERT_THAT(v, ElementsAre(0, 0, 0)); + + // Second time elements are added + resp = Run({"TOPK.ADD", "k1", "foo", "bar", "fooz"}); + v = resp.GetVec(); + ASSERT_THAT(v, + ElementsAre(ArgType(RespExpr::NIL), ArgType(RespExpr::NIL), ArgType(RespExpr::NIL))); + + resp = Run({"TOPK.QUERY", "k1", "foo", "bar", "fooz"}); + v = resp.GetVec(); + ASSERT_THAT(v, ElementsAre(1, 1, 1)); + + resp = Run({"TOPK.LIST", "k1"}); + v = resp.GetVec(); + ASSERT_THAT(v, UnorderedElementsAre("foo", "bar", "fooz")); + + // TODO add TOPK.INCRBY +} + +} // namespace dfly diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 578d9426d2c7..e92ad7827e29 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -1582,7 +1582,7 @@ OpResult DetermineKeys(const CommandId* cid, CmdArgList args) { return OpStatus::SYNTAX_ERR; } - if (absl::EndsWith(name, "STORE")) + if (absl::EndsWith(name, "STORE") || absl::EqualsIgnoreCase(name, "TDIGEST.MERGE")) bonus = 0; // ZSTORE commands unsigned num_keys_index; @@ -1597,7 +1597,7 @@ OpResult DetermineKeys(const CommandId* cid, CmdArgList args) { if (num_custom_keys == 0 && (absl::StartsWith(name, "ZDIFF") || absl::StartsWith(name, "ZUNION") || - absl::StartsWith(name, "ZINTER"))) { + absl::StartsWith(name, "ZINTER") || absl::EqualsIgnoreCase(name, "TDIGEST.MERGE"))) { return OpStatus::AT_LEAST_ONE_KEY; }