Skip to content

Commit 38c83db

Browse files
committed
change proto fieldname
1 parent c7df6e4 commit 38c83db

File tree

3 files changed

+22
-19
lines changed

3 files changed

+22
-19
lines changed

cusim/culda/pyculda.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,21 @@ def preprocess_data(self):
4444
if self.opt.skip_preprocess:
4545
return
4646
iou = IoUtils()
47-
if not self.opt.data_dir:
48-
self.opt.data_dir = tempfile.TemporaryDirectory().name
47+
if not self.opt.processed_data_dir:
48+
self.opt.processed_data_dir = tempfile.TemporaryDirectory().name
4949
iou.convert_stream_to_h5(self.opt.data_path, self.opt.word_min_count,
50-
self.opt.data_dir)
50+
self.opt.processed_data_dir)
5151

5252
def init_model(self):
5353
# load voca
54-
self.logger.info("load key from %s", pjoin(self.opt.data_dir, "keys.txt"))
55-
with open(pjoin(self.opt.data_dir, "keys.txt"), "rb") as fin:
54+
data_dir = self.opt.processed_data_dir
55+
self.logger.info("load key from %s", pjoin(data_dir, "keys.txt"))
56+
with open(pjoin(data_dir, "keys.txt"), "rb") as fin:
5657
self.words = [line.strip() for line in fin]
5758
self.num_words = len(self.words)
5859

5960
# count number of docs
60-
h5f = h5py.File(pjoin(self.opt.data_dir, "token.h5"), "r")
61+
h5f = h5py.File(pjoin(data_dir, "token.h5"), "r")
6162
self.num_docs = h5f["indptr"].shape[0] - 1
6263
h5f.close()
6364

@@ -88,7 +89,7 @@ def init_model(self):
8889
def train_model(self):
8990
self.preprocess_data()
9091
self.init_model()
91-
h5f = h5py.File(pjoin(self.opt.data_dir, "token.h5"), "r")
92+
h5f = h5py.File(pjoin(self.opt.processed_data_dir, "token.h5"), "r")
9293
for epoch in range(1, self.opt.epochs + 1):
9394
self.logger.info("Epoch %d / %d", epoch, self.opt.epochs)
9495
self._train_e_step(h5f)

cusim/proto/config.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ message CuLDAConfigProto {
2222
optional int32 num_topics = 3 [default = 10];
2323
optional int32 block_dim = 4 [default = 32];
2424
optional int32 hyper_threads = 5 [default = 10];
25-
optional string data_dir = 6;
25+
optional string processed_data_dir = 6;
2626
optional bool skip_preprocess = 8;
2727
optional int32 word_min_count = 9 [default = 5];
2828
optional int32 batch_size = 10 [default = 100000];

examples/example1.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
LOGGER = aux.get_logger()
1818
DOWNLOAD_PATH = "./res"
1919
# DATASET = "wiki-english-20171001"
20-
DATASET = "fake-news"
20+
DATASET = "quora-duplicate-questions"
2121
DATA_PATH = f"./res/{DATASET}.stream.txt"
22-
DATA_PATH2 = f"./res/{DATASET}-converted"
22+
LDA_PATH = f"./res/{DATASET}-lda.h5"
23+
PROCESSED_DATA_DIR = f"./res/{DATASET}-converted"
2324
MIN_COUNT = 5
25+
TOPK = 10
2426

2527
def download():
2628
if os.path.exists(DATA_PATH):
@@ -37,28 +39,28 @@ def download():
3739
def run_io():
3840
download()
3941
iou = IoUtils(opt={"chunk_lines": 10000, "num_threads": 8})
40-
iou.convert_stream_to_h5(DATA_PATH, 5, DATA_PATH2)
42+
iou.convert_stream_to_h5(DATA_PATH, 5, PROCESSED_DATA_DIR)
4143

4244

4345
def run_lda():
4446
opt = {
4547
"data_path": DATA_PATH,
46-
"data_dir": DATA_PATH2,
47-
# "skip_preprocess": True,
48-
# "c_log_level": 3,
48+
"processed_data_dir": PROCESSED_DATA_DIR,
4949
}
5050
lda = CuLDA(opt)
5151
lda.train_model()
52-
lda.save_model("res/lda.h5")
53-
h5f = h5py.File("res/lda.h5", "r")
52+
lda.save_model(LDA_PATH)
53+
h5f = h5py.File(LDA_PATH, "r")
5454
beta = h5f["beta"][:]
55-
for i in range(lda.opt.num_topics):
55+
word_list = h5f["keys"][:]
56+
num_topics = h5f["alpha"].shape[0]
57+
for i in range(num_topics):
5658
print("=" * 50)
5759
print(f"topic {i + 1}")
5860
words = np.argsort(-beta.T[i])[:10]
5961
print("-" * 50)
60-
for j in range(10):
61-
word = lda.words[words[j]].decode("utf8")
62+
for j in range(TOPK):
63+
word = word_list[words[j]].decode("utf8")
6264
prob = beta[words[j], i]
6365
print(f"rank {j + 1}. word: {word}, prob: {prob}")
6466
h5f.close()

0 commit comments

Comments
 (0)