Skip to content

Commit f93eadf

Browse files
committed
Initial design of the external storage interface.
1 parent 985df72 commit f93eadf

File tree

2 files changed

+156
-5
lines changed

2 files changed

+156
-5
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright (c) 2022, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <cstdint>
19+
#include <type_traits>
20+
#include "merlin/memory_pool.cuh"
21+
22+
namespace nv {
23+
namespace merlin {
24+
25+
template <class Key, class Value>
26+
class ExternalStorage {
27+
public:
28+
using size_type = size_t;
29+
using key_type = Key;
30+
using value_type = Value;
31+
32+
using dev_mem_pool_type = MemoryPool<DeviceAllocator<char>>;
33+
using host_mem_pool_type = MemoryPool<HostAllocator<char>>;
34+
35+
const size_type value_dim;
36+
37+
ExternalStorage() = delete;
38+
39+
/**
40+
* Constructs external storage object.
41+
*
42+
* @param value_dim The dimensionality of the values. In other words, each
43+
* value stored is exactly `value_dim * sizeof(value_type)` bytes large.
44+
*/
45+
ExternalStorage(const size_type value_dim) : value_dim{value_dim} {}
46+
47+
/**
48+
* @brief Inserts key/value pairs into the external storage that are about to
49+
* be evicted from the Merlin hashtable. If a key/value pair already exists,
50+
* overwrites the current value.
51+
*
52+
* @param dev_mem_pool Memory pool for temporarily allocating device memory.
53+
* @param host_mem_pool Memory pool for temporarily allocating host memory.
54+
* @param hkvs_is_pure_hbm True if the Merlin hashtable store is currently
55+
* operating in pure HBM mode, false otherwise. In pure HBM mode, all `values`
56+
* pointers are GUARANTEED to point to device memory.
57+
* @param n Number of key/value slots provided in other arguments.
58+
* @param d_masked_keys Device pointer to an (n)-sized array of keys.
59+
* Key-Value slots that should be ignored have the key set tO `EMPTY_KEY`.
60+
* @param d_values Device pointer to an (n)-sized array containing pointers to
61+
* respectively a memory location where the current values for a key are
62+
* stored. Each pointer points to a vector of length `value_dim`. Pointers
63+
* *can* be set to `nullptr` for slots where the corresponding key equated to
64+
* the `EMPTY_KEY`. The memory locations can be device or host memory (see
65+
* also `hkvs_is_pure_hbm`).
66+
* @param stream Stream that MUST be used for queuing asynchronous CUDA
67+
* operations. If only the input arguments or resources obtained from
68+
* respectively `dev_mem_pool` and `host_mem_pool` are used for such
69+
* operations, it is not necessary to synchronize the stream prior to
70+
* returning from the function.
71+
*/
72+
virtual void insert_or_assign(dev_mem_pool_type& dev_mem_pool,
73+
host_mem_pool_type& host_mem_pool,
74+
bool hkvs_is_pure_hbm, size_type n,
75+
const key_type* d_masked_keys, // (n)
76+
const value_type* const* d_values, // (n)
77+
cudaStream_t stream) = 0;
78+
79+
/**
80+
* @brief Attempts to find the supplied `d_keys` if the corresponding
81+
* `d_founds`-flag is `false` and fills the stored into the supplied memory
82+
* locations (i.e. in `d_values`).
83+
*
84+
* @param dev_mem_pool Memory pool for temporarily allocating device memory.
85+
* @param host_mem_pool Memory pool for temporarily allocating host memory.
86+
* @param n Number of key/value slots provided in other arguments.
87+
* @param d_keys Device pointer to an (n)-sized array of keys.
88+
* @param d_values Device pointer to an (n * value_dim)-sized array to store
89+
* the retrieved `d_values`. For slots where the corresponding `d_founds`-flag
90+
* is not `false`, the value may already have been assigned and, thus, MUST
91+
* not be altered.
92+
* @param d_founds Device pointer to an (n)-sized array which indicates
93+
* whether the corresponding `d_values` slot is already filled or not. So, if
94+
* and only if `d_founds` is still false, the implementation shall attempt to
95+
* retrieve and fill in the value for the corresponding key. If a key/value
96+
* was retrieved successfully from external storage, the implementation MUST
97+
* also set `d_founds` to `true`.
98+
* @param stream Stream that MUST be used for queuing asynchronous CUDA
99+
* operations. If only the input arguments or resources obtained from
100+
* respectively `dev_mem_pool` and `host_mem_pool` are used for such
101+
* operations, it is not necessary to synchronize the stream prior to
102+
* returning from the function.
103+
*/
104+
virtual void find(dev_mem_pool_type& dev_mem_pool,
105+
host_mem_pool_type& host_mem_pool, size_type n,
106+
const key_type* d_keys, // (n)
107+
value_type* d_values, // (n * value_dim)
108+
bool* d_founds, // (n)
109+
cudaStream_t stream) = 0;
110+
};
111+
112+
} // namespace merlin
113+
} // namespace nv

include/merlin_hashtable.cuh

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <shared_mutex>
2626
#include <type_traits>
2727
#include "merlin/core_kernels.cuh"
28+
#include "merlin/external_storage.cuh"
2829
#include "merlin/flexible_buffer.cuh"
2930
#include "merlin/memory_pool.cuh"
3031
#include "merlin/types.cuh"
@@ -160,6 +161,8 @@ class HashTable {
160161
using DeviceMemoryPool = MemoryPool<DeviceAllocator<char>>;
161162
using HostMemoryPool = MemoryPool<HostAllocator<char>>;
162163

164+
using external_storage_type = ExternalStorage<K, V>;
165+
163166
#if THRUST_VERSION >= 101600
164167
static constexpr auto thrust_par = thrust::cuda::par_nosync;
165168
#else
@@ -179,6 +182,8 @@ class HashTable {
179182
~HashTable() {
180183
CUDA_CHECK(cudaDeviceSynchronize());
181184

185+
unlink_external_storage();
186+
182187
// Erase table.
183188
if (initialized_) {
184189
destroy_table<key_type, vector_type, meta_type, DIM>(&table_);
@@ -308,9 +313,12 @@ class HashTable {
308313
}
309314
} else {
310315
const size_t dev_ws_size = n * (sizeof(vector_type*) + sizeof(int));
311-
auto dev_ws = dev_mem_pool_->get_workspace<1>(dev_ws_size, stream);
316+
auto dev_ws = dev_mem_pool_->get_workspace<1>(
317+
dev_ws_size + (ext_store_ ? n * sizeof(key_type) : 0), stream);
312318
auto d_dst = dev_ws.get<vector_type**>(0);
313319
auto d_src_offset = reinterpret_cast<int*>(d_dst + n);
320+
auto d_evicted_keys =
321+
ext_store_ ? reinterpret_cast<key_type*>(d_src_offset + n) : nullptr;
314322

315323
CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream));
316324

@@ -322,18 +330,26 @@ class HashTable {
322330
if (metas == nullptr) {
323331
upsert_kernel<key_type, vector_type, meta_type, DIM, TILE_SIZE>
324332
<<<grid_size, block_size, 0, stream>>>(
325-
table_, keys, d_dst, table_->buckets, table_->buckets_size,
326-
table_->bucket_max_size, table_->buckets_num, d_src_offset,
327-
N);
333+
table_, keys,
334+
/* d_evicted_keys, */ d_dst, table_->buckets,
335+
table_->buckets_size, table_->bucket_max_size,
336+
table_->buckets_num, d_src_offset, N);
328337
} else {
329338
upsert_kernel<key_type, vector_type, meta_type, DIM, TILE_SIZE>
330339
<<<grid_size, block_size, 0, stream>>>(
331-
table_, keys, d_dst, metas, table_->buckets,
340+
table_, keys,
341+
/* d_evicted_keys, */ d_dst, metas, table_->buckets,
332342
table_->buckets_size, table_->bucket_max_size,
333343
table_->buckets_num, d_src_offset, N);
334344
}
335345
}
336346

347+
if (ext_store_) {
348+
ext_store_->insert_or_assign(
349+
*dev_mem_pool_, *host_mem_pool_, table_->is_pure_hbm, n,
350+
d_evicted_keys, reinterpret_cast<value_type**>(d_dst), stream);
351+
}
352+
337353
{
338354
thrust::device_ptr<uintptr_t> d_dst_ptr(
339355
reinterpret_cast<uintptr_t*>(d_dst));
@@ -575,6 +591,11 @@ class HashTable {
575591
}
576592
}
577593

594+
if (ext_store_) {
595+
ext_store_->find(*dev_mem_pool_, *host_mem_pool_, n, keys, values, founds,
596+
stream);
597+
}
598+
578599
CudaCheckError();
579600
}
580601

@@ -1113,6 +1134,21 @@ class HashTable {
11131134
return total_count;
11141135
}
11151136

1137+
void link_external_storage(
1138+
std::shared_ptr<external_storage_type>& ext_store) {
1139+
MERLIN_CHECK(
1140+
ext_store->value_dim == DIM,
1141+
"Provided external storage value dimension is not incompatible!");
1142+
1143+
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
1144+
ext_store_ = ext_store;
1145+
}
1146+
1147+
void unlink_external_storage() {
1148+
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
1149+
ext_store_.reset();
1150+
}
1151+
11161152
private:
11171153
inline bool is_fast_mode() const noexcept { return table_->is_pure_hbm; }
11181154

@@ -1171,6 +1207,8 @@ class HashTable {
11711207

11721208
std::unique_ptr<DeviceMemoryPool> dev_mem_pool_;
11731209
std::unique_ptr<HostMemoryPool> host_mem_pool_;
1210+
1211+
std::shared_ptr<external_storage_type> ext_store_;
11741212
};
11751213

11761214
} // namespace merlin

0 commit comments

Comments
 (0)