19
19
from cusim .config_pb2 import CuW2VConfigProto
20
20
21
21
EPS = 1e-10
22
+ WARP_SIZE = 32
22
23
23
24
class CuW2V :
24
25
def __init__ (self , opt = None ):
25
26
self .opt = aux .get_opt_as_proto (opt or {}, CuW2VConfigProto )
26
27
self .logger = aux .get_logger ("culda" , level = self .opt .py_log_level )
27
28
29
+ assert self .opt .block_dim <= WARP_SIZE ** 2 and \
30
+ self .opt .block_dim % WARP_SIZE == 0 , \
31
+ f"invalid block dim ({ self .opt .block_dim } , warp size: { WARP_SIZE } )"
32
+
28
33
tmp = tempfile .NamedTemporaryFile (mode = 'w' , delete = False )
29
34
opt_content = json .dumps (aux .proto_to_dict (self .opt ), indent = 2 )
30
35
tmp .write (opt_content )
@@ -61,6 +66,7 @@ def init_model(self):
61
66
dtype = np .float32 )
62
67
self .word_count = np .power (self .word_count , self .opt .count_power )
63
68
self .num_words = len (self .words )
69
+ assert len (self .words ) == len (self .word_count )
64
70
65
71
# count number of docs
66
72
h5f = h5py .File (pjoin (data_dir , "token.h5" ), "r" )
@@ -70,6 +76,12 @@ def init_model(self):
70
76
self .logger .info ("number of words: %d, docs: %d" ,
71
77
self .num_words , self .num_docs )
72
78
79
+ if self .opt .neg :
80
+ self .obj .build_random_table ( \
81
+ self .word_count , self .opt .random_size , self .opt .num_threads )
82
+ else :
83
+ self .obj .build_huffman_tree (self .word_count )
84
+
73
85
# random initialize alpha and beta
74
86
np .random .seed (self .opt .seed )
75
87
self .emb_in = np .random .normal ( \
@@ -86,11 +98,6 @@ def init_model(self):
86
98
def train_model (self ):
87
99
self .preprocess_data ()
88
100
self .init_model ()
89
- if self .opt .neg :
90
- self .obj .build_random_table ( \
91
- self .word_count , self .opt .random_size , self .opt .num_threads )
92
- else :
93
- self .obj .build_huffman_tree (self .word_count )
94
101
h5f = h5py .File (pjoin (self .opt .processed_data_dir , "token.h5" ), "r" )
95
102
for epoch in range (1 , self .opt .epochs + 1 ):
96
103
self .logger .info ("Epoch %d / %d" , epoch , self .opt .epochs )
0 commit comments