Skip to content

Commit 084c4cb

Browse files
committed
implement InitRngsKernel
1 parent add8461 commit 084c4cb

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

cpp/include/cuw2v/cuw2v.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ class CuW2V {
8080

8181
// variables to construct random table
8282
thrust::device_vector<int> dev_random_table_;
83-
int random_size_, table_seed_;
84-
std::mt19937 table_rng_;
83+
int random_size_, table_seed_, cuda_seed_;
8584
thrust::device_vector<default_random_engine> dev_rngs_;
8685
};
8786

cpp/include/utils/cuda_utils_kernels.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,8 @@ float ReduceSum(const float* vec, const int length) {
208208
return shared[0];
209209
}
210210

211+
__global__ void InitRngsKernel(default_random_engine* rngs, int rand_seed) {
212+
rngs[blockIdx.x].seed(blockIdx.x + rand_seed);
213+
}
214+
211215
} // namespace cusim

cpp/src/cuw2v/cuw2v.cu

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,13 @@ bool CuW2V::Init(std::string opt_path) {
4646
// if zero, we will use hierarchical softmax
4747
neg_ = opt_["negative_sampling"].int_value();
4848

49-
// set seed for constructing random table of negative sampling
49+
// random seed
5050
table_seed_ = opt_["table_seed"].int_value();
51-
const unsigned int table_seed = table_seed_;
52-
table_rng_.seed(table_seed);
53-
51+
cuda_seed_ = opt_["cuda_seed"].int_value();
52+
dev_rngs_.resize(block_cnt_);
53+
InitRngsKernel<<<block_cnt_, 1>>>(
54+
thrust::raw_pointer_cast(dev_rngs_.data()), cuda_seed_);
55+
5456
INFO("num_dims: {}, block_dim: {}, block_cnt: {}, objective type: {}, neg: {}",
5557
num_dims_, block_dim_, block_cnt_, sg_? "skip gram": "cbow", neg_);
5658
return true;
@@ -63,22 +65,25 @@ void CuW2V::BuildRandomTable(const float* word_count, const int num_words,
6365
std::vector<float> acc;
6466
float cumsum = 0;
6567
for (int i = 0; i < num_words; ++i) {
66-
cumsum += word_count[i];
6768
acc.push_back(cumsum);
69+
cumsum += word_count[i];
6870
}
6971

70-
std::uniform_real_distribution<float> dist(0.0f, cumsum);
7172
dev_random_table_.resize(random_size_);
7273
std::vector<int> host_random_table(table_size);
7374
#pragma omp parallel num_threads(num_threads)
7475
{
76+
const unsigned int table_seed = table_seed_ + omp_get_thread_num();
77+
std::mt19937 rng(table_seed);
78+
std::uniform_real_distribution<float> dist(0.0f, cumsum);
7579
#pragma omp for schedule(static)
7680
for (int i = 0; i < random_size_; ++i) {
77-
float r = dist(table_rng_);
81+
float r = dist(rng);
7882
int pos = std::lower_bound(acc.begin(), acc.end(), r) - acc.begin();
7983
host_random_table[i] = pos;
8084
}
8185
}
86+
table_seed_ += num_threads;
8287

8388
thrust::copy(host_random_table.begin(), host_random_table.end(), dev_random_table_.begin());
8489
CHECK_CUDA(cudaDeviceSynchronize());
@@ -148,6 +153,8 @@ void CuW2V::BuildHuffmanTree(const float* word_count, const int num_words) {
148153
thrust::copy(host_points.begin(), host_points.end(), dev_points_.begin());
149154
thrust::copy(host_hs_indptr.begin(), host_hs_indptr.end(), dev_hs_indptr_.begin());
150155
CHECK_CUDA(cudaDeviceSynchronize());
156+
157+
huffman_nodes.clear();
151158
}
152159

153160
void CuW2V::LoadModel(float* emb_in, float* emb_out) {

0 commit comments

Comments
 (0)