Skip to content

Commit 46b31a5

Browse files
committed
bug-fix
1 parent 3859d7d commit 46b31a5

File tree

4 files changed

+19
-8
lines changed

4 files changed

+19
-8
lines changed

cpp/src/cuw2v/cuw2v.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ std::pair<float, float> CuW2V::FeedData(const int* cols, const int* indptr,
265265
// accumulate loss nume / deno
266266
std::vector<float> loss_nume(block_cnt_), loss_deno(block_cnt_);
267267
thrust::copy(dev_loss_nume.begin(), dev_loss_nume.end(), loss_nume.begin());
268-
thrust::copy(dev_loss_deno.begin(), dev_loss_deno.end(), loss_nume.begin());
268+
thrust::copy(dev_loss_deno.begin(), dev_loss_deno.end(), loss_deno.begin());
269269
CHECK_CUDA(cudaDeviceSynchronize());
270270
float loss_nume_sum = std::accumulate(loss_nume.begin(), loss_nume.end(), 0.0f);
271271
float loss_deno_sum = std::accumulate(loss_deno.begin(), loss_deno.end(), 0.0f);

cusim/culda/pyculda.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,17 @@
2020
from cusim.config_pb2 import CuLDAConfigProto
2121

2222
EPS = 1e-10
23+
WARP_SIZE = 32
2324

2425
class CuLDA:
2526
def __init__(self, opt=None):
2627
self.opt = aux.get_opt_as_proto(opt or {}, CuLDAConfigProto)
2728
self.logger = aux.get_logger("culda", level=self.opt.py_log_level)
2829

30+
assert self.opt.block_dim <= WARP_SIZE ** 2 and \
31+
self.opt.block_dim % WARP_SIZE == 0, \
32+
f"invalid block dim ({self.opt.block_dim}, warp size: {WARP_SIZE})"
33+
2934
tmp = tempfile.NamedTemporaryFile(mode='w', delete=False)
3035
opt_content = json.dumps(aux.proto_to_dict(self.opt), indent=2)
3136
tmp.write(opt_content)

cusim/cuw2v/pycuw2v.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919
from cusim.config_pb2 import CuW2VConfigProto
2020

2121
EPS = 1e-10
22+
WARP_SIZE = 32
2223

2324
class CuW2V:
2425
def __init__(self, opt=None):
2526
self.opt = aux.get_opt_as_proto(opt or {}, CuW2VConfigProto)
2627
self.logger = aux.get_logger("culda", level=self.opt.py_log_level)
2728

29+
assert self.opt.block_dim <= WARP_SIZE ** 2 and \
30+
self.opt.block_dim % WARP_SIZE == 0, \
31+
f"invalid block dim ({self.opt.block_dim}, warp size: {WARP_SIZE})"
32+
2833
tmp = tempfile.NamedTemporaryFile(mode='w', delete=False)
2934
opt_content = json.dumps(aux.proto_to_dict(self.opt), indent=2)
3035
tmp.write(opt_content)
@@ -61,6 +66,7 @@ def init_model(self):
6166
dtype=np.float32)
6267
self.word_count = np.power(self.word_count, self.opt.count_power)
6368
self.num_words = len(self.words)
69+
assert len(self.words) == len(self.word_count)
6470

6571
# count number of docs
6672
h5f = h5py.File(pjoin(data_dir, "token.h5"), "r")
@@ -70,6 +76,12 @@ def init_model(self):
7076
self.logger.info("number of words: %d, docs: %d",
7177
self.num_words, self.num_docs)
7278

79+
if self.opt.neg:
80+
self.obj.build_random_table( \
81+
self.word_count, self.opt.random_size, self.opt.num_threads)
82+
else:
83+
self.obj.build_huffman_tree(self.word_count)
84+
7385
# random initialize alpha and beta
7486
np.random.seed(self.opt.seed)
7587
self.emb_in = np.random.normal( \
@@ -86,11 +98,6 @@ def init_model(self):
8698
def train_model(self):
8799
self.preprocess_data()
88100
self.init_model()
89-
if self.opt.neg:
90-
self.obj.build_random_table( \
91-
self.word_count, self.opt.random_size, self.opt.num_threads)
92-
else:
93-
self.obj.build_huffman_tree(self.word_count)
94101
h5f = h5py.File(pjoin(self.opt.processed_data_dir, "token.h5"), "r")
95102
for epoch in range(1, self.opt.epochs + 1):
96103
self.logger.info("Epoch %d / %d", epoch, self.opt.epochs)

cusim/proto/config.proto

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,5 @@ message CuW2VConfigProto {
6060
optional bool use_mean = 20 [default = true];
6161
optional double lr = 21 [default = 0.001];
6262
optional int32 window_size = 22 [default = 5];
63-
64-
63+
optional int32 num_threads = 23 [default = 4];
6564
}

0 commit comments

Comments
 (0)