Skip to content

Commit 469d0ff

Browse files
authored
Merge pull request #6 from js1010/task/add-benchmark
Task/add benchmark
2 parents ec13e7c + 5898669 commit 469d0ff

28 files changed

+2362
-289
lines changed

MANIFEST.in

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
include cuda_setup.py
2+
include requirements.txt
3+
include pyproject.toml
4+
recursive-include cpp/src/cuw2v/ *.cu
5+
recursive-include cpp/src/culda/ *.cu
6+
recursive-include cpp/src/ioutils/ *.cc
7+
recursive-include cpp/include/cuw2v/ *.cuh
8+
recursive-include cpp/include/cuw2v/ *.hpp
9+
recursive-include cpp/include/culda/ *.cuh
10+
recursive-include cpp/include/culda/ *.hpp
11+
recursive-include cpp/include/ioutils/ *.cuh
12+
recursive-include cpp/include/ioutils/ *.hpp
13+
recursive-include 3rd/json11/ *
14+
recursive-include 3rd/spdlog/ *

README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
### Introduction
2+
3+
This project is to speed up various ML models (e.g. topic modeling, word embedding, etc) by CUDA. It would be nice to think of it as [gensim](https://github.com/RaRe-Technologies/gensim)'s GPU version project. As a starting step, I implemented the most widely used word embedding model, the [word2vec](https://arxiv.org/pdf/1301.3781.pdf) model, and the most representative topic model, the [LDA (Latent Dirichlet Allocation)](https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf) model.
4+
15
### How to install
26

7+
- install from source
38

49
```shell
510
# clone repo and submodules
@@ -14,3 +19,63 @@ python -m grpc_tools.protoc --python_out cusim/ --proto_path cusim/proto/ config
1419
# install
1520
python setup.py install
1621
```
22+
23+
- pip installation will be available soon
24+
25+
### How to use
26+
27+
- `examples/example_w2v.py`, `examples/example_lda.py` and `examples/README.md` will be very helpful to understand the usage.
28+
- paremeter description can be seen in `cusim/proto/config.proto`
29+
30+
### Performance
31+
32+
- [AWS g4dn 2xlarge instance](https://aws.amazon.com/ec2/instance-types/g4/) is used to the experiment. (One NVIDIA T4 GPU with 8 vcpus, Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz)
33+
- results can be reproduced by simply running `examples/example_w2v.py` and `examples/example_lda.py`
34+
- To evaluate w2v model, I used `evaluate_word_pairs` function ([ref link](https://radimrehurek.com/gensim/auto_examples/tutorials/run_word2vec.html#evaluating)) in gensim, note that better performance on WS-353 test set does not necessarily mean that the model will workbetter in application as desribed on the link. However, it is good to be measured quantitively and fast training time will be at least very objective measure of the performaance.
35+
- I trained W2V model on `quora-duplicat-questions` dataset from gensim downloader api on GPU with cusim and compare the performance (both speed and model quality) with gensim.
36+
- To evaluate LDA model, I found there is no good way to measure the quality of traing results quantitatively. But we can check the model by looking at the top words of each topic. Also, we can compare the training time quantitatively.
37+
- W2V (skip gram, hierarchical softmax)
38+
39+
| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
40+
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
41+
| training time (sec) | 892.596 | 544.212 | 310.727 | 226.472 | **16.162** |
42+
| pearson | 0.487832 | 0.487696 | 0.482821 | 0.487136 | **0.492101** |
43+
| spearman | 0.500846 | 0.506214 | 0.501048 | **0.506718** | 0.479468 |
44+
45+
- W2V (skip gram, negative sampling)
46+
47+
| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
48+
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
49+
| training time (sec) | 586.545 | 340.489 | 220.804 | 146.23 | **33.9173** |
50+
| pearson | 0.354448 | 0.353952 | 0.352398 | 0.352925 | **0.360436** |
51+
| spearman | 0.369146 | 0.369365 | **0.370565** | 0.365822 | 0.355204 |
52+
53+
- W2V (CBOW, hierarchical softmax)
54+
55+
| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
56+
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
57+
| training time (sec) | 250.135 | 155.121 | 103.57 | 73.8073 | **6.20787** |
58+
| pearson | 0.309651 | 0.321803 | 0.324854 | 0.314255 | **0.480298** |
59+
| spearman | 0.294047 | 0.308723 | 0.318293 | 0.300591 | **0.480971** |
60+
61+
- W2V (CBOW, negative sampling)
62+
63+
| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
64+
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
65+
| training time (sec) | 176.923 | 100.369 | 69.7829 | 49.9274 | **9.90391** |
66+
| pearson | 0.18772 | 0.193152 | 0.204509 | 0.187924 | **0.368202** |
67+
| spearman | 0.243975 | 0.24587 | 0.260531 | 0.237441 | **0.358042** |
68+
69+
- LDA (`nytimes` dataset from https://archive.ics.uci.edu/ml/datasets/bag+of+words)
70+
- I found that setting `workers` variable in gensim LdaMulticore does not work properly (it uses all cores in instance anyway), so I just compared the speed between cusim with single GPU and gensim with 8 vcpus.
71+
- One can compare the quality of modeling by looking at `examples/cusim.topics.txt` and `examples/gensim.topics.txt`.
72+
73+
| attr | gensim (8 vpus) | cusim (NVIDIA T4)|
74+
|:--------------------|------------------:|--------:|
75+
| training time (sec) | 447.376 | **76.6972** |
76+
77+
### Future tasks
78+
79+
- support half precision
80+
- support multi device (multi device implementation on LDA model will not be that hard, while multi device training on w2v may require some considerations)
81+
- implement other models such as FastText, BERT, etc

cpp/include/culda/cuda_lda_kernels.cuh

Lines changed: 73 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,36 @@ float Digamma(float x) {
2626
}
2727

2828
__global__ void EstepKernel(
29-
const int* cols, const int* indptr, const bool* vali,
30-
const int num_cols, const int num_indptr,
29+
const int* cols, const int* indptr,
30+
const bool* vali, const float* counts,
31+
const bool init_gamma, const int num_cols, const int num_indptr,
3132
const int num_topics, const int num_iters,
32-
float* gamma, float* new_gamma, float* phi,
3333
const float* alpha, const float* beta,
34-
float* grad_alpha, float* new_beta,
35-
float* train_losses, float* vali_losses, int* mutex) {
34+
float* gamma, float* grad_alpha, float* new_beta,
35+
float* train_losses, float* vali_losses, int* locks) {
3636

3737
// storage for block
38-
float* _gamma = gamma + num_topics * blockIdx.x;
39-
float* _new_gamma = new_gamma + num_topics * blockIdx.x;
40-
float* _phi = phi + num_topics * blockIdx.x;
38+
extern __shared__ float shared_memory[];
39+
float* _new_gamma = &shared_memory[0];
40+
float* _phi = &shared_memory[num_topics];
41+
float* _loss_vec = &shared_memory[num_topics * 2];
42+
float* _vali_phi_sum = &shared_memory[num_topics * 3];
43+
4144
float* _grad_alpha = grad_alpha + num_topics * blockIdx.x;
4245

4346
for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
4447
int beg = indptr[i], end = indptr[i + 1];
45-
// initialize gamma
46-
for (int j = threadIdx.x; j < num_topics; j += blockDim.x)
47-
_gamma[j] = alpha[j] + (end - beg) / num_topics;
48+
float* _gamma = gamma + num_topics * i;
49+
if (init_gamma) {
50+
for (int j = threadIdx.x; j < num_topics; j += blockDim.x) {
51+
_gamma[j] = alpha[j] + (end - beg) / num_topics;
52+
}
53+
}
4854
__syncthreads();
55+
56+
// initiate phi sum for validation data for computing vali loss
57+
for (int j = threadIdx.x; j < num_topics; j += blockDim.x)
58+
_vali_phi_sum[j] = 0.0f;
4959

5060
// iterate E step
5161
for (int j = 0; j < num_iters; ++j) {
@@ -58,7 +68,7 @@ __global__ void EstepKernel(
5868
for (int k = beg; k < end; ++k) {
5969
const int w = cols[k];
6070
const bool _vali = vali[k];
61-
71+
const float c = counts[k];
6272
// compute phi
6373
if (not _vali or j + 1 == num_iters) {
6474
for (int l = threadIdx.x; l < num_topics; l += blockDim.x)
@@ -70,37 +80,52 @@ __global__ void EstepKernel(
7080

7181
for (int l = threadIdx.x; l < num_topics; l += blockDim.x) {
7282
_phi[l] /= phi_sum;
73-
if (not _vali) _new_gamma[l] += _phi[l];
83+
84+
// update gamma for train data and phi_sum for computing loss
85+
if (_vali)
86+
_vali_phi_sum[l] += _phi[l] * c;
87+
else
88+
_new_gamma[l] += _phi[l] * c;
89+
7490
}
7591
__syncthreads();
7692
}
7793

7894
if (j + 1 == num_iters) {
79-
// write access of w th vector of new_beta
80-
if (threadIdx.x == 0) {
81-
while (atomicCAS(&mutex[w], 0, 1)) {}
82-
}
95+
// update beta for train data
96+
if (not _vali) {
97+
// write access of w th vector of new_beta
98+
if (threadIdx.x == 0) {
99+
while (atomicCAS(&locks[w], 0, 1)) {}
100+
}
83101

84-
__syncthreads();
102+
__syncthreads();
103+
for (int l = threadIdx.x; l < num_topics; l += blockDim.x)
104+
new_beta[w * num_topics + l] += _phi[l] * c;
105+
__syncthreads();
106+
107+
// release lock
108+
if (threadIdx.x == 0) locks[w] = 0;
109+
__syncthreads();
110+
}
111+
112+
// comput loss and reset shared mem
113+
// see Eq (15) in https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf
85114
for (int l = threadIdx.x; l < num_topics; l += blockDim.x) {
86-
if (j + 1 == num_iters) {
87-
if (not _vali) new_beta[w * num_topics + l] += _phi[l];
88-
_phi[l] *= beta[w * num_topics + l];
89-
}
115+
_loss_vec[l] = logf(fmaxf(beta[w * num_topics + l], EPS));
116+
_loss_vec[l] -= logf(fmaxf(_phi[l], EPS));
117+
_loss_vec[l] *= _phi[l];
90118
}
91119
__syncthreads();
92-
93-
// release lock
94-
if (threadIdx.x == 0) mutex[w] = 0;
95-
__syncthreads();
96-
97-
float p = fmaxf(EPS, ReduceSum(_phi, num_topics));
120+
float _loss = ReduceSum(_loss_vec, num_topics) * c;
98121
if (threadIdx.x == 0) {
99-
if (_vali)
100-
vali_losses[blockIdx.x] += logf(p);
122+
if (_vali)
123+
vali_losses[blockIdx.x] += _loss;
101124
else
102-
train_losses[blockIdx.x] += logf(p);
103-
}
125+
train_losses[blockIdx.x] += _loss;
126+
}
127+
__syncthreads();
128+
104129
}
105130
__syncthreads();
106131
}
@@ -110,9 +135,23 @@ __global__ void EstepKernel(
110135
_gamma[k] = _new_gamma[k] + alpha[k];
111136
__syncthreads();
112137
}
138+
139+
// update gradient of alpha and loss from E[log(theta)]
113140
float gamma_sum = ReduceSum(_gamma, num_topics);
114-
for (int j = threadIdx.x; j < num_topics; j += blockDim.x)
115-
_grad_alpha[j] += (Digamma(_gamma[j]) - Digamma(gamma_sum));
141+
for (int j = threadIdx.x; j < num_topics; j += blockDim.x) {
142+
float Elogthetad = Digamma(_gamma[j]) - Digamma(gamma_sum);
143+
_grad_alpha[j] += Elogthetad;
144+
_new_gamma[j] *= Elogthetad;
145+
_vali_phi_sum[j] *= Elogthetad;
146+
}
147+
148+
// see Eq (15) in https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf
149+
float train_loss = ReduceSum(_new_gamma, num_topics);
150+
float vali_loss = ReduceSum(_vali_phi_sum, num_topics);
151+
if (threadIdx.x == 0) {
152+
train_losses[blockIdx.x] += train_loss;
153+
vali_losses[blockIdx.x] += vali_loss;
154+
}
116155

117156
__syncthreads();
118157
}

cpp/include/culda/culda.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,11 @@ class CuLDA {
6565
void LoadModel(float* alpha, float* beta,
6666
float* grad_alpha, float* new_beta, const int num_words);
6767
std::pair<float, float> FeedData(
68-
const int* indices, const int* indptr, const bool* vali,
69-
const int num_indices, const int num_indptr, const int num_iters);
68+
const int* indices, const int* indptr,
69+
const bool* vali, const float* counts,
70+
float* gamma, const bool init_gamma,
71+
const int num_indices, const int num_indptr,
72+
const int num_iters);
7073
void Pull();
7174
void Push();
7275
int GetBlockCnt();
@@ -78,8 +81,7 @@ class CuLDA {
7881
std::unique_ptr<CuSimLogger> logger_container_;
7982
thrust::device_vector<float> dev_alpha_, dev_beta_;
8083
thrust::device_vector<float> dev_grad_alpha_, dev_new_beta_;
81-
thrust::device_vector<float> dev_gamma_, dev_new_gamma_, dev_phi_;
82-
thrust::device_vector<int> dev_mutex_;
84+
thrust::device_vector<int> dev_locks_;
8385

8486
float *alpha_, *beta_, *grad_alpha_, *new_beta_;
8587
int block_cnt_, block_dim_;

cpp/include/cuw2v/cuda_w2v_base_kernels.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
#pragma once
77
#include "utils/cuda_utils_kernels.cuh"
88

9+
#define MAX_EXP 20
10+
911
namespace cusim {
1012

1113

1214
__inline__ __device__
1315
void PositiveFeedback(const float* vec1, float* vec2, float* grad,
1416
float& loss_nume, float& loss_deno, const int num_dims, const float lr) {
1517
static __shared__ float g;
16-
float dot = Dot(vec1, vec2, num_dims);
18+
float dot = fmaxf(-MAX_EXP, fminf(MAX_EXP, Dot(vec1, vec2, num_dims)));
1719
if (threadIdx.x == 0) {
1820
float exp_dot = expf(-dot);
1921
g = exp_dot / (1 + exp_dot) * lr;
@@ -32,7 +34,7 @@ __inline__ __device__
3234
void NegativeFeedback(const float* vec1, float* vec2, float* grad,
3335
float& loss_nume, float& loss_deno, const int num_dims, const float lr) {
3436
static __shared__ float g;
35-
float dot = Dot(vec1, vec2, num_dims);
37+
float dot = fmaxf(-MAX_EXP, fminf(MAX_EXP, Dot(vec1, vec2, num_dims)));
3638
if (threadIdx.x == 0) {
3739
float exp_dot = expf(dot);
3840
g = exp_dot / (1 + exp_dot) * lr;

cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ __global__ void W2VHsSgKernel(
3939
__syncthreads();
4040
int beg2 = max(beg, j - window_size + reduced_windows);
4141
int end2 = min(end, j + window_size - reduced_windows + 1);
42-
float* _emb_in = emb_in + num_dims * cols[j];
4342
for (int k = beg2; k < end2; ++k) {
4443
if (k == j) continue;
45-
int beg3 = hs_indptr[cols[k]];
46-
int end3 = hs_indptr[cols[k] + 1];
44+
float* _emb_in = emb_in + num_dims * cols[k];
45+
int beg3 = hs_indptr[cols[j]];
46+
int end3 = hs_indptr[cols[j] + 1];
4747
for (int l = beg3; l < end3; ++l) {
4848
if (codes[l]) {
4949
PositiveFeedback(_emb_in, emb_out + num_dims * points[l],
@@ -55,7 +55,7 @@ __global__ void W2VHsSgKernel(
5555
__syncthreads();
5656
}
5757
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
58-
emb_in[num_dims * cols[j] + l] += grad[l];
58+
_emb_in[l] += grad[l];
5959
grad[l] = 0.0f;
6060
}
6161
__syncthreads();
@@ -70,7 +70,7 @@ __global__ void W2VHsCbowKernel(
7070
const int num_indptr, const int num_dims, const int window_size, default_random_engine* rngs,
7171
float* emb_in, float* emb_out,
7272
float* loss_nume, float* loss_deno,
73-
const bool use_mean, const float lr) {
73+
const bool cbow_mean, const float lr) {
7474

7575
default_random_engine& rng = rngs[blockIdx.x];
7676
float& _loss_nume = loss_nume[blockIdx.x];
@@ -98,15 +98,15 @@ __global__ void W2VHsCbowKernel(
9898
grad[k] = 0.0f;
9999
cbow[k] = 0.0f;
100100
}
101-
101+
102102
// compute cbow
103103
for (int k = beg2; k < end2; ++k) {
104104
if (k == j) continue;
105105
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
106106
cbow[l] += emb_in[num_dims * cols[k] + l];
107107
}
108108
}
109-
if (use_mean) {
109+
if (cbow_mean) {
110110
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
111111
cbow[k] /= (end2 - beg2 - 1);
112112
}
@@ -126,8 +126,8 @@ __global__ void W2VHsCbowKernel(
126126
__syncthreads();
127127
}
128128

129-
// normalize grad if use_mean = true
130-
if (use_mean) {
129+
// normalize grad if cbow_mean = true
130+
if (cbow_mean) {
131131
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
132132
grad[k] /= (end2 - beg2 - 1);
133133
}

0 commit comments

Comments
 (0)