Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions include/merlin/core_kernels/lookup_ptr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ namespace nv {
namespace merlin {

// Use 1 thread to deal with a KV-pair, including copying value.
template <typename K, typename V, typename S>
template <typename K, typename V, typename S, int Strategy>
__global__ void tlp_lookup_ptr_kernel_with_filter(
Bucket<K, V, S>* __restrict__ buckets, const uint64_t buckets_num,
uint32_t bucket_capacity, const uint32_t dim, const K* __restrict__ keys,
V** __restrict values, S* __restrict scores, bool* __restrict founds,
uint64_t n) {
uint64_t n, bool update_score, const S global_epoch) {
using BUCKET = Bucket<K, V, S>;
using ScoreFunctor = ScoreFunctor<K, V, S, Strategy>;
// Load `STRIDE` digests every time.
constexpr uint32_t STRIDE = sizeof(VecD_Load) / sizeof(D);

Expand All @@ -43,6 +44,9 @@ __global__ void tlp_lookup_ptr_kernel_with_filter(
uint32_t key_pos = {0};
if (kv_idx < n) {
key = keys[kv_idx];
if (update_score) {
score = ScoreFunctor::desired_when_missed(scores, kv_idx, global_epoch);
}
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
Expand Down Expand Up @@ -86,12 +90,32 @@ __global__ void tlp_lookup_ptr_kernel_with_filter(
uint32_t index = (__ffs(cmp_result) - 1) >> 3;
cmp_result &= (cmp_result - 1);
possible_pos = pos_cur + i * 4 + index;
auto current_key = bucket_keys_ptr[possible_pos];
score = *BUCKET::scores(bucket_keys_ptr, bucket_capacity, possible_pos);
if (current_key == key) {
key_pos = possible_pos;
occupy_result = OccupyResult::DUPLICATE;
goto WRITE_BACK;
if (update_score) {
auto current_key = BUCKET::keys(bucket_keys_ptr, possible_pos);
K expected_key = key;
// Modifications to the bucket will not before this instruction.
bool result = current_key->compare_exchange_strong(
expected_key, static_cast<K>(LOCKED_KEY),
cuda::std::memory_order_acquire, cuda::std::memory_order_relaxed);
if (result) {
occupy_result = OccupyResult::DUPLICATE;
key_pos = possible_pos;
ScoreFunctor::update_with_digest(bucket_keys_ptr, key_pos, scores,
kv_idx, score, bucket_capacity,
get_digest<K>(key), false);
current_key->store(key, cuda::std::memory_order_release);
score = *BUCKET::scores(bucket_keys_ptr, bucket_capacity, key_pos);
goto WRITE_BACK;
}
} else {
auto current_key = bucket_keys_ptr[possible_pos];
score =
*BUCKET::scores(bucket_keys_ptr, bucket_capacity, possible_pos);
if (current_key == key) {
key_pos = possible_pos;
occupy_result = OccupyResult::DUPLICATE;
goto WRITE_BACK;
}
}
} while (true);
VecD_Comp empty_digests_ = empty_digests<K>();
Expand Down
20 changes: 16 additions & 4 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -645,13 +645,16 @@ class HashTableBase {
* @endparblock
* @param stream The CUDA stream that is used to execute the operation.
* @param unique_key If all keys in the same batch are unique.
* @param update_score If true then update the found keys in the table, and
* will use scores as input.
*
*/
virtual void find(const size_type n, const key_type* keys, // (n)
value_type** values, // (n)
bool* founds, // (n)
score_type* scores = nullptr, // (n)
cudaStream_t stream = 0, bool unique_key = true) const = 0;
cudaStream_t stream = 0, bool unique_key = true,
bool update_score = false) = 0;

/**
* @brief Checks if there are elements with key equivalent to `keys` in the
Expand Down Expand Up @@ -2556,13 +2559,16 @@ class HashTable : public HashTableBase<K, V, S> {
* @endparblock
* @param stream The CUDA stream that is used to execute the operation.
* @param unique_key If all keys in the same batch are unique.
* @param update_score If true then update the found keys in the table, and
* will use scores as input.
*
*/
void find(const size_type n, const key_type* keys, // (n)
value_type** values, // (n)
bool* founds, // (n)
score_type* scores = nullptr, // (n)
cudaStream_t stream = 0, bool unique_key = true) const {
cudaStream_t stream = 0, bool unique_key = true,
bool update_score = false) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend declaring a new find_and_update instead of reuse this find**

if (n == 0) {
return;
}
Expand All @@ -2572,13 +2578,19 @@ class HashTable : public HashTableBase<K, V, S> {
lock_ptr = std::make_unique<read_shared_lock>(mutex_, stream);
}

if (update_score) {
check_evict_strategy(scores);
}

constexpr uint32_t MinBucketCapacityFilter = sizeof(VecD_Load) / sizeof(D);
if (unique_key && options_.max_bucket_size >= MinBucketCapacityFilter) {
constexpr uint32_t BLOCK_SIZE = 128U;
tlp_lookup_ptr_kernel_with_filter<key_type, value_type, score_type>
tlp_lookup_ptr_kernel_with_filter<key_type, value_type, score_type,
evict_strategy>
<<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
table_->buckets, table_->buckets_num, options_.max_bucket_size,
options_.dim, keys, values, scores, founds, n);
options_.dim, keys, values, scores, founds, n, update_score,
global_epoch_);
} else {
using Selector = SelectLookupPtrKernel<key_type, value_type, score_type>;
static thread_local int step_counter = 0;
Expand Down