|
| 1 | +# Copyright (c) 2021 Jisang Yoon |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the Apache 2.0 license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pylint: disable=no-name-in-module,too-few-public-methods,no-member |
| 8 | +import os |
| 9 | +from os.path import join as pjoin |
| 10 | + |
| 11 | +import json |
| 12 | +import tempfile |
| 13 | + |
| 14 | +import h5py |
| 15 | +import numpy as np |
| 16 | +from scipy.special import polygamma as pg |
| 17 | + |
| 18 | +from cusim import aux, IoUtils |
| 19 | +from cusim.culda.culda_bind import CuLDABind |
| 20 | +from cusim.config_pb2 import CuLDAConfigProto |
| 21 | + |
| 22 | +EPS = 1e-10 |
| 23 | + |
| 24 | +class CuLDA: |
| 25 | + def __init__(self, opt=None): |
| 26 | + self.opt = aux.get_opt_as_proto(opt or {}, CuLDAConfigProto) |
| 27 | + self.logger = aux.get_logger("culda", level=self.opt.py_log_level) |
| 28 | + |
| 29 | + tmp = tempfile.NamedTemporaryFile(mode='w', delete=False) |
| 30 | + opt_content = json.dumps(aux.proto_to_dict(self.opt), indent=2) |
| 31 | + tmp.write(opt_content) |
| 32 | + tmp.close() |
| 33 | + |
| 34 | + self.logger.info("opt: %s", opt_content) |
| 35 | + self.obj = CuLDABind() |
| 36 | + assert self.obj.init(bytes(tmp.name, "utf8")), f"failed to load {tmp.name}" |
| 37 | + os.remove(tmp.name) |
| 38 | + |
| 39 | + self.words, self.num_words, self.num_docs = None, None, None |
| 40 | + self.alpha, self.beta, self.grad_alpha, self.new_beta = \ |
| 41 | + None, None, None, None |
| 42 | + |
| 43 | + def preprocess_data(self): |
| 44 | + if self.opt.skip_preprocess: |
| 45 | + return |
| 46 | + iou = IoUtils() |
| 47 | + if not self.opt.processed_data_dir: |
| 48 | + self.opt.processed_data_dir = tempfile.TemporaryDirectory().name |
| 49 | + iou.convert_stream_to_h5(self.opt.data_path, self.opt.word_min_count, |
| 50 | + self.opt.processed_data_dir) |
| 51 | + |
| 52 | + def init_model(self): |
| 53 | + # load voca |
| 54 | + data_dir = self.opt.processed_data_dir |
| 55 | + self.logger.info("load key from %s", pjoin(data_dir, "keys.txt")) |
| 56 | + with open(pjoin(data_dir, "keys.txt"), "rb") as fin: |
| 57 | + self.words = [line.strip() for line in fin] |
| 58 | + self.num_words = len(self.words) |
| 59 | + |
| 60 | + # count number of docs |
| 61 | + h5f = h5py.File(pjoin(data_dir, "token.h5"), "r") |
| 62 | + self.num_docs = h5f["indptr"].shape[0] - 1 |
| 63 | + h5f.close() |
| 64 | + |
| 65 | + self.logger.info("number of words: %d, docs: %d", |
| 66 | + self.num_words, self.num_docs) |
| 67 | + |
| 68 | + # random initialize alpha and beta |
| 69 | + np.random.seed(self.opt.seed) |
| 70 | + self.alpha = np.random.uniform( \ |
| 71 | + size=(self.opt.num_topics,)).astype(np.float32) |
| 72 | + self.beta = np.random.uniform( \ |
| 73 | + size=(self.num_words, self.opt.num_topics)).astype(np.float32) |
| 74 | + self.beta /= np.sum(self.beta, axis=0)[None, :] |
| 75 | + self.logger.info("alpha %s, beta %s initialized", |
| 76 | + self.alpha.shape, self.beta.shape) |
| 77 | + |
| 78 | + # zero initialize grad alpha and new beta |
| 79 | + block_cnt = self.obj.get_block_cnt() |
| 80 | + self.grad_alpha = np.zeros(shape=(block_cnt, self.opt.num_topics), |
| 81 | + dtype=np.float32) |
| 82 | + self.new_beta = np.zeros(shape=self.beta.shape, dtype=np.float32) |
| 83 | + self.logger.info("grad alpha %s, new beta %s initialized", |
| 84 | + self.grad_alpha.shape, self.new_beta.shape) |
| 85 | + |
| 86 | + # push it to gpu |
| 87 | + self.obj.load_model(self.alpha, self.beta, self.grad_alpha, self.new_beta) |
| 88 | + |
| 89 | + def train_model(self): |
| 90 | + self.preprocess_data() |
| 91 | + self.init_model() |
| 92 | + h5f = h5py.File(pjoin(self.opt.processed_data_dir, "token.h5"), "r") |
| 93 | + for epoch in range(1, self.opt.epochs + 1): |
| 94 | + self.logger.info("Epoch %d / %d", epoch, self.opt.epochs) |
| 95 | + self._train_e_step(h5f) |
| 96 | + self._train_m_step() |
| 97 | + h5f.close() |
| 98 | + |
| 99 | + def _train_e_step(self, h5f): |
| 100 | + offset, size = 0, h5f["cols"].shape[0] |
| 101 | + pbar = aux.Progbar(size, stateful_metrics=["train_loss", "vali_loss"]) |
| 102 | + train_loss_nume, train_loss_deno = 0, 0 |
| 103 | + vali_loss_nume, vali_loss_deno = 0, 0 |
| 104 | + while True: |
| 105 | + target = h5f["indptr"][offset] + self.opt.batch_size |
| 106 | + if target < size: |
| 107 | + next_offset = h5f["rows"][target] |
| 108 | + else: |
| 109 | + next_offset = h5f["indptr"].shape[0] - 1 |
| 110 | + indptr = h5f["indptr"][offset:next_offset + 1] |
| 111 | + beg, end = indptr[0], indptr[-1] |
| 112 | + indptr -= beg |
| 113 | + cols = h5f["cols"][beg:end] |
| 114 | + vali = (h5f["vali"][beg:end] < self.opt.vali_p).astype(np.bool) |
| 115 | + offset = next_offset |
| 116 | + |
| 117 | + # call cuda kernel |
| 118 | + train_loss, vali_loss = \ |
| 119 | + self.obj.feed_data(cols, indptr, vali, self.opt.num_iters_in_e_step) |
| 120 | + |
| 121 | + # accumulate loss |
| 122 | + train_loss_nume -= train_loss |
| 123 | + vali_loss_nume -= vali_loss |
| 124 | + vali_cnt = np.count_nonzero(vali) |
| 125 | + train_cnt = len(vali) - vali_cnt |
| 126 | + train_loss_deno += train_cnt |
| 127 | + vali_loss_deno += vali_cnt |
| 128 | + train_loss = train_loss_nume / (train_loss_deno + EPS) |
| 129 | + vali_loss = vali_loss_nume / (vali_loss_deno + EPS) |
| 130 | + |
| 131 | + # update progress bar |
| 132 | + pbar.update(end, values=[("train_loss", train_loss), |
| 133 | + ("vali_loss", vali_loss)]) |
| 134 | + if end == size: |
| 135 | + break |
| 136 | + |
| 137 | + def _train_m_step(self): |
| 138 | + self.obj.pull() |
| 139 | + |
| 140 | + # update beta |
| 141 | + self.new_beta[:, :] = np.maximum(self.new_beta, EPS) |
| 142 | + self.beta[:, :] = self.new_beta / np.sum(self.new_beta, axis=0)[None, :] |
| 143 | + self.new_beta[:, :] = 0 |
| 144 | + |
| 145 | + # update alpha |
| 146 | + alpha_sum = np.sum(self.alpha) |
| 147 | + gvec = np.sum(self.grad_alpha, axis=0) |
| 148 | + gvec += self.num_docs * (pg(0, alpha_sum) - pg(0, self.alpha)) |
| 149 | + hvec = self.num_docs * pg(1, self.alpha) |
| 150 | + z_0 = pg(1, alpha_sum) |
| 151 | + c_nume = np.sum(gvec / hvec) |
| 152 | + c_deno = 1 / z_0 + np.sum(1 / hvec) |
| 153 | + c_0 = c_nume / c_deno |
| 154 | + delta = (gvec - c_0) / hvec |
| 155 | + self.alpha -= delta |
| 156 | + self.alpha[:] = np.maximum(self.alpha, EPS) |
| 157 | + self.grad_alpha[:,:] = 0 |
| 158 | + |
| 159 | + self.obj.push() |
| 160 | + |
| 161 | + def save_model(self, model_path): |
| 162 | + self.logger.info("save model path: %s", model_path) |
| 163 | + h5f = h5py.File(model_path, "w") |
| 164 | + h5f.create_dataset("alpha", data=self.alpha) |
| 165 | + h5f.create_dataset("beta", data=self.beta) |
| 166 | + h5f.create_dataset("keys", data=np.array(self.words)) |
| 167 | + h5f.close() |
0 commit comments