Skip to content

Commit c064f69

Browse files
committed
use mutex
1 parent 38c83db commit c064f69

File tree

4 files changed

+33
-5
lines changed

4 files changed

+33
-5
lines changed

cpp/include/culda/cuda_lda_kernels.cuh

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ __global__ void EstepKernel(
3131
const int num_topics, const int num_iters,
3232
float* gamma, float* new_gamma, float* phi,
3333
const float* alpha, const float* beta,
34-
float* grad_alpha, float* new_beta, float* train_losses, float* vali_losses) {
34+
float* grad_alpha, float* new_beta,
35+
float* train_losses, float* vali_losses, int* mutex) {
3536

3637
// storage for block
3738
float* _gamma = gamma + num_topics * blockIdx.x;
@@ -57,6 +58,7 @@ __global__ void EstepKernel(
5758
for (int k = beg; k < end; ++k) {
5859
const int w = cols[k];
5960
const bool _vali = vali[k];
61+
6062
// compute phi
6163
if (not _vali or j + 1 == num_iters) {
6264
for (int l = threadIdx.x; l < num_topics; l += blockDim.x)
@@ -65,17 +67,33 @@ __global__ void EstepKernel(
6567

6668
// normalize phi and add it to new gamma and new beta
6769
float phi_sum = ReduceSum(_phi, num_topics);
70+
6871
for (int l = threadIdx.x; l < num_topics; l += blockDim.x) {
6972
_phi[l] /= phi_sum;
7073
if (not _vali) _new_gamma[l] += _phi[l];
74+
}
75+
__syncthreads();
76+
}
77+
78+
if (j + 1 == num_iters) {
79+
// write access of w th vector of new_beta
80+
if (threadIdx.x == 0) {
81+
while (atomicCAS(&mutex[w], 0, 1)) {}
82+
}
83+
84+
__syncthreads();
85+
for (int l = threadIdx.x; l < num_topics; l += blockDim.x) {
7186
if (j + 1 == num_iters) {
7287
if (not _vali) new_beta[w * num_topics + l] += _phi[l];
7388
_phi[l] *= beta[w * num_topics + l];
7489
}
7590
}
7691
__syncthreads();
77-
}
78-
if (j + 1 == num_iters) {
92+
93+
// release lock
94+
if (threadIdx.x == 0) mutex[w] = 0;
95+
__syncthreads();
96+
7997
float p = fmaxf(EPS, ReduceSum(_phi, num_topics));
8098
if (threadIdx.x == 0) {
8199
if (_vali)

cpp/include/culda/culda.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,16 @@ class CuLDA {
7070
void Pull();
7171
void Push();
7272
int GetBlockCnt();
73+
7374
private:
7475
DeviceInfo dev_info_;
7576
json11::Json opt_;
7677
std::shared_ptr<spdlog::logger> logger_;
7778
thrust::device_vector<float> dev_alpha_, dev_beta_;
7879
thrust::device_vector<float> dev_grad_alpha_, dev_new_beta_;
7980
thrust::device_vector<float> dev_gamma_, dev_new_gamma_, dev_phi_;
81+
thrust::device_vector<int> dev_mutex_;
82+
8083
float *alpha_, *beta_, *grad_alpha_, *new_beta_;
8184
int block_cnt_, block_dim_;
8285
int num_topics_, num_words_;

cpp/src/culda/culda.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,18 @@ void CuLDA::LoadModel(float* alpha, float* beta,
5151
new_beta_ = new_beta;
5252
dev_grad_alpha_.resize(num_topics_ * block_cnt_);
5353
dev_new_beta_.resize(num_topics_ * num_words_);
54-
5554
// copy to device
5655
thrust::copy(grad_alpha_, grad_alpha_ + block_cnt_ * num_topics_, dev_grad_alpha_.begin());
5756
thrust::copy(new_beta_, new_beta_ + num_words_ * num_topics_, dev_new_beta_.begin());
5857
dev_gamma_.resize(num_topics_ * block_cnt_);
5958
dev_new_gamma_.resize(num_topics_ * block_cnt_);
6059
dev_phi_.resize(num_topics_ * block_cnt_);
60+
61+
// set mutex
62+
dev_mutex_.resize(num_words_);
63+
std::vector<int> host_mutex(num_words_, 0);
64+
thrust::copy(host_mutex.begin(), host_mutex.end(), dev_mutex_.begin());
65+
6166
CHECK_CUDA(cudaDeviceSynchronize());
6267
}
6368

@@ -91,7 +96,8 @@ std::pair<float, float> CuLDA::FeedData(
9196
thrust::raw_pointer_cast(dev_grad_alpha_.data()),
9297
thrust::raw_pointer_cast(dev_new_beta_.data()),
9398
thrust::raw_pointer_cast(dev_train_losses.data()),
94-
thrust::raw_pointer_cast(dev_vali_losses.data()));
99+
thrust::raw_pointer_cast(dev_vali_losses.data()),
100+
thrust::raw_pointer_cast(dev_mutex_.data()));
95101
CHECK_CUDA(cudaDeviceSynchronize());
96102
DEBUG0("run E step in GPU");
97103

examples/example1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def run_lda():
4646
opt = {
4747
"data_path": DATA_PATH,
4848
"processed_data_dir": PROCESSED_DATA_DIR,
49+
"skip_preprocess":True,
4950
}
5051
lda = CuLDA(opt)
5152
lda.train_model()

0 commit comments

Comments
 (0)