Skip to content

Commit 1009614

Browse files
authored
Merge pull request #5 from js1010/dev
Dev
2 parents a324cb2 + ec13e7c commit 1009614

23 files changed

+1115
-28
lines changed

cpp/include/culda/culda.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class CuLDA {
7575
DeviceInfo dev_info_;
7676
json11::Json opt_;
7777
std::shared_ptr<spdlog::logger> logger_;
78+
std::unique_ptr<CuSimLogger> logger_container_;
7879
thrust::device_vector<float> dev_alpha_, dev_beta_;
7980
thrust::device_vector<float> dev_grad_alpha_, dev_new_beta_;
8081
thrust::device_vector<float> dev_gamma_, dev_new_gamma_, dev_phi_;
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (c) 2021 Jisang Yoon
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the Apache 2.0 license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
#pragma once
7+
#include "utils/cuda_utils_kernels.cuh"
8+
9+
namespace cusim {
10+
11+
12+
__inline__ __device__
13+
void PositiveFeedback(const float* vec1, float* vec2, float* grad,
14+
float& loss_nume, float& loss_deno, const int num_dims, const float lr) {
15+
static __shared__ float g;
16+
float dot = Dot(vec1, vec2, num_dims);
17+
if (threadIdx.x == 0) {
18+
float exp_dot = expf(-dot);
19+
g = exp_dot / (1 + exp_dot) * lr;
20+
loss_nume += logf(1 + exp_dot);
21+
loss_deno++;
22+
}
23+
__syncthreads();
24+
for (int i = threadIdx.x; i < num_dims; i += blockDim.x) {
25+
grad[i] += vec2[i] * g;
26+
vec2[i] += vec1[i] * g;
27+
}
28+
__syncthreads();
29+
}
30+
31+
__inline__ __device__
32+
void NegativeFeedback(const float* vec1, float* vec2, float* grad,
33+
float& loss_nume, float& loss_deno, const int num_dims, const float lr) {
34+
static __shared__ float g;
35+
float dot = Dot(vec1, vec2, num_dims);
36+
if (threadIdx.x == 0) {
37+
float exp_dot = expf(dot);
38+
g = exp_dot / (1 + exp_dot) * lr;
39+
loss_nume += logf(1 + exp_dot);
40+
loss_deno++;
41+
}
42+
__syncthreads();
43+
for (int i = threadIdx.x; i < num_dims; i += blockDim.x) {
44+
grad[i] -= vec2[i] * g;
45+
vec2[i] -= vec1[i] * g;
46+
}
47+
__syncthreads();
48+
}
49+
50+
} // cusim
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// Copyright (c) 2021 Jisang Yoon
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the Apache 2.0 license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
#pragma once
7+
#include "utils/cuda_utils_kernels.cuh"
8+
#include "cuw2v/cuda_w2v_base_kernels.cuh"
9+
10+
11+
namespace cusim {
12+
13+
__global__ void W2VHsSgKernel(
14+
const int* cols, const int* indptr,
15+
const bool* codes, const int* points, const int* hs_indptr,
16+
const int num_indptr, const int num_dims, const int window_size,
17+
default_random_engine* rngs,
18+
float* emb_in, float* emb_out,
19+
float* loss_nume, float* loss_deno, const float lr) {
20+
21+
default_random_engine& rng = rngs[blockIdx.x];
22+
float& _loss_nume = loss_nume[blockIdx.x];
23+
float& _loss_deno = loss_deno[blockIdx.x];
24+
25+
uniform_int_distribution<int> dist_window(0, window_size - 1);
26+
static __shared__ int reduced_windows;
27+
extern __shared__ float shared_memory[];
28+
float* grad = &shared_memory[0];
29+
30+
// zero-initialize shared mem
31+
for (int i = threadIdx.x; i < num_dims; i += blockDim.x)
32+
grad[i] = 0.0f;
33+
__syncthreads();
34+
35+
for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
36+
int beg = indptr[i], end = indptr[i + 1];
37+
for (int j = beg; j < end; ++j) {
38+
if (threadIdx.x == 0) reduced_windows = dist_window(rng);
39+
__syncthreads();
40+
int beg2 = max(beg, j - window_size + reduced_windows);
41+
int end2 = min(end, j + window_size - reduced_windows + 1);
42+
float* _emb_in = emb_in + num_dims * cols[j];
43+
for (int k = beg2; k < end2; ++k) {
44+
if (k == j) continue;
45+
int beg3 = hs_indptr[cols[k]];
46+
int end3 = hs_indptr[cols[k] + 1];
47+
for (int l = beg3; l < end3; ++l) {
48+
if (codes[l]) {
49+
PositiveFeedback(_emb_in, emb_out + num_dims * points[l],
50+
grad, _loss_nume, _loss_deno, num_dims, lr);
51+
} else {
52+
NegativeFeedback(_emb_in, emb_out + num_dims * points[l],
53+
grad, _loss_nume, _loss_deno, num_dims, lr);
54+
}
55+
__syncthreads();
56+
}
57+
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
58+
emb_in[num_dims * cols[j] + l] += grad[l];
59+
grad[l] = 0.0f;
60+
}
61+
__syncthreads();
62+
}
63+
}
64+
}
65+
}
66+
67+
__global__ void W2VHsCbowKernel(
68+
const int* cols, const int* indptr,
69+
const bool* codes, const int* points, const int* hs_indptr,
70+
const int num_indptr, const int num_dims, const int window_size, default_random_engine* rngs,
71+
float* emb_in, float* emb_out,
72+
float* loss_nume, float* loss_deno,
73+
const bool use_mean, const float lr) {
74+
75+
default_random_engine& rng = rngs[blockIdx.x];
76+
float& _loss_nume = loss_nume[blockIdx.x];
77+
float& _loss_deno = loss_deno[blockIdx.x];
78+
79+
uniform_int_distribution<int> dist_window(0, window_size - 1);
80+
static __shared__ int reduced_windows;
81+
extern __shared__ float shared_memory[];
82+
float* grad = &shared_memory[0];
83+
float* cbow = &shared_memory[num_dims];
84+
85+
__syncthreads();
86+
87+
for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
88+
int beg = indptr[i], end = indptr[i + 1];
89+
for (int j = beg; j < end; ++j) {
90+
if (threadIdx.x == 0) reduced_windows = dist_window(rng);
91+
__syncthreads();
92+
int beg2 = max(beg, j - window_size + reduced_windows);
93+
int end2 = min(end, j + window_size - reduced_windows + 1);
94+
if (end2 - beg2 <= 1) continue;
95+
96+
// zero-initialize shared mem
97+
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
98+
grad[k] = 0.0f;
99+
cbow[k] = 0.0f;
100+
}
101+
102+
// compute cbow
103+
for (int k = beg2; k < end2; ++k) {
104+
if (k == j) continue;
105+
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
106+
cbow[l] += emb_in[num_dims * cols[k] + l];
107+
}
108+
}
109+
if (use_mean) {
110+
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
111+
cbow[k] /= (end2 - beg2 - 1);
112+
}
113+
}
114+
__syncthreads();
115+
116+
int beg3 = hs_indptr[cols[j]];
117+
int end3 = hs_indptr[cols[j] + 1];
118+
for (int k = beg3; k < end3; ++k) {
119+
if (codes[k]) {
120+
PositiveFeedback(cbow, emb_out + num_dims * points[k],
121+
grad, _loss_nume, _loss_deno, num_dims, lr);
122+
} else {
123+
NegativeFeedback(cbow, emb_out + num_dims * points[k],
124+
grad, _loss_nume, _loss_deno, num_dims, lr);
125+
}
126+
__syncthreads();
127+
}
128+
129+
// normalize grad if use_mean = true
130+
if (use_mean) {
131+
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
132+
grad[k] /= (end2 - beg2 - 1);
133+
}
134+
}
135+
__syncthreads();
136+
137+
// update emb_in
138+
for (int k = beg2; k < end2; ++k) {
139+
if (k == j) continue;
140+
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
141+
emb_in[num_dims * cols[k] + l] += grad[l];
142+
}
143+
__syncthreads();
144+
}
145+
}
146+
}
147+
}
148+
149+
} // cusim
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// Copyright (c) 2021 Jisang Yoon
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the Apache 2.0 license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
#pragma once
7+
#include "utils/cuda_utils_kernels.cuh"
8+
#include "cuw2v/cuda_w2v_base_kernels.cuh"
9+
10+
11+
namespace cusim {
12+
13+
__global__ void W2VNegSgKernel(
14+
const int* cols, const int* indptr,
15+
const int* random_table, default_random_engine* rngs, const int random_size,
16+
const int num_indptr, const int num_dims, const int neg, const int window_size,
17+
float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const float lr) {
18+
19+
default_random_engine& rng = rngs[blockIdx.x];
20+
float& _loss_nume = loss_nume[blockIdx.x];
21+
float& _loss_deno = loss_deno[blockIdx.x];
22+
23+
uniform_int_distribution<int> dist_neg(0, random_size - 1);
24+
uniform_int_distribution<int> dist_window(0, window_size - 1);
25+
__shared__ int reduced_windows;
26+
__shared__ int neg_word;
27+
extern __shared__ float shared_memory[];
28+
float* grad = &shared_memory[0];
29+
30+
// zero-initialize shared mem
31+
for (int i = threadIdx.x; i < num_dims; i += blockDim.x)
32+
grad[i] = 0.0f;
33+
__syncthreads();
34+
35+
for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
36+
int beg = indptr[i], end = indptr[i + 1];
37+
for (int j = beg; j < end; ++j) {
38+
if (threadIdx.x == 0) reduced_windows = dist_window(rng);
39+
__syncthreads();
40+
int beg2 = max(beg, j - window_size + reduced_windows);
41+
int end2 = min(end, j + window_size - reduced_windows + 1);
42+
float* _emb_in = emb_in + num_dims * cols[j];
43+
for (int k = beg2; k < end2; ++k) {
44+
if (k == j) continue;
45+
PositiveFeedback(_emb_in, emb_out + num_dims * cols[k],
46+
grad, _loss_nume, _loss_deno, num_dims, lr);
47+
for (int l = 0; l < neg; ++l) {
48+
if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)];
49+
__syncthreads();
50+
NegativeFeedback(_emb_in, emb_out + num_dims * neg_word,
51+
grad, _loss_nume, _loss_deno, num_dims, lr);
52+
}
53+
__syncthreads();
54+
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
55+
emb_in[num_dims * cols[j] + l] += grad[l];
56+
grad[l] = 0.0f;
57+
}
58+
__syncthreads();
59+
}
60+
}
61+
}
62+
}
63+
64+
__global__ void W2VNegCbowKernel(
65+
const int* cols, const int* indptr,
66+
const int* random_table, default_random_engine* rngs, const int random_size,
67+
const int num_indptr, const int num_dims, const int neg, const int window_size,
68+
float* emb_in, float* emb_out,
69+
float* loss_nume, float* loss_deno, const bool use_mean, const float lr) {
70+
71+
default_random_engine& rng = rngs[blockIdx.x];
72+
float& _loss_nume = loss_nume[blockIdx.x];
73+
float& _loss_deno = loss_deno[blockIdx.x];
74+
75+
uniform_int_distribution<int> dist_neg(0, random_size - 1);
76+
uniform_int_distribution<int> dist_window(0, window_size - 1);
77+
static __shared__ int reduced_windows;
78+
static __shared__ int neg_word;
79+
extern __shared__ float shared_memory[];
80+
float* grad = &shared_memory[0];
81+
float* cbow = &shared_memory[num_dims];
82+
83+
__syncthreads();
84+
85+
for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
86+
int beg = indptr[i], end = indptr[i + 1];
87+
for (int j = beg; j < end; ++j) {
88+
if (threadIdx.x == 0) reduced_windows = dist_window(rng);
89+
__syncthreads();
90+
int beg2 = max(beg, j - window_size + reduced_windows);
91+
int end2 = min(end, j + window_size - reduced_windows + 1);
92+
if (end2 - beg2 <= 1) continue;
93+
94+
// zero-initialize shared mem
95+
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
96+
grad[k] = 0.0f;
97+
cbow[k] = 0.0f;
98+
}
99+
100+
// compute cbow
101+
for (int k = beg2; k < end2; ++k) {
102+
if (k == j) continue;
103+
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
104+
cbow[l] += emb_in[num_dims * cols[k] + l];
105+
}
106+
}
107+
if (use_mean) {
108+
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
109+
cbow[k] /= (end2 - beg2 - 1);
110+
}
111+
}
112+
__syncthreads();
113+
114+
PositiveFeedback(cbow, emb_out + num_dims * cols[j], grad,
115+
_loss_nume, _loss_deno, num_dims, lr);
116+
__syncthreads();
117+
118+
// update negative feedback
119+
for (int k = 0; k < neg; ++k){
120+
if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)];
121+
__syncthreads();
122+
NegativeFeedback(cbow, emb_out + num_dims * neg_word,
123+
grad, _loss_nume, _loss_deno, num_dims, lr);
124+
}
125+
__syncthreads();
126+
127+
// normalize grad if use_mean = true
128+
if (use_mean) {
129+
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
130+
grad[k] /= (end2 - beg2 - 1);
131+
}
132+
}
133+
__syncthreads();
134+
135+
// update emb_in
136+
for (int k = beg2; k < end2; ++k) {
137+
if (k == j) continue;
138+
for (int l = threadIdx.x; l < num_dims; l += blockDim.x)
139+
emb_in[num_dims * cols[k] + l] += grad[l];
140+
}
141+
__syncthreads();
142+
143+
}
144+
}
145+
}
146+
147+
} // cusim

0 commit comments

Comments
 (0)