|
| 1 | +# Copyright Contributors to the Pyro project. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +""" |
| 5 | +Example: ProdLDA |
| 6 | +================ |
| 7 | +In this example, we will follow [1] to implement the ProdLDA topic model from |
| 8 | +Autoencoding Variational Inference For Topic Models by Akash Srivastava and Charles |
| 9 | +Sutton [2]. This model returns consistently better topics than vanilla LDA and trains |
| 10 | +much more quickly. Furthermore, it does not require a custom inference algorithm that |
| 11 | +relies on complex mathematical derivations. This example also serves as an |
| 12 | +introduction to Flax and Haiku modules in NumPyro. |
| 13 | +
|
| 14 | +Note that unlike [1, 2], this implementation uses a Dirichlet prior directly rather |
| 15 | +than approximating it with a softmax-normal distribution. |
| 16 | +
|
| 17 | +For the interested reader, a nice extension of this model is the CombinedTM model [3] |
| 18 | +which utilizes a pre-trained sentence transformer (like https://www.sbert.net/) to |
| 19 | +generate a better representation of the encoded latent vector. |
| 20 | +
|
| 21 | +**References:** |
| 22 | + 1. http://pyro.ai/examples/prodlda.html |
| 23 | + 2. Akash Srivastava, & Charles Sutton. (2017). Autoencoding Variational Inference |
| 24 | + For Topic Models. |
| 25 | + 3. Federico Bianchi, Silvia Terragni, and Dirk Hovy (2021), "Pre-training is a Hot |
| 26 | + Topic: Contextualized Document Embeddings Improve Topic Coherence" |
| 27 | + (https://arxiv.org/abs/2004.03974) |
| 28 | +
|
| 29 | +.. image:: ../_static/img/examples/prodlda.png |
| 30 | + :align: center |
| 31 | +""" |
| 32 | +import argparse |
| 33 | + |
| 34 | +import matplotlib.pyplot as plt |
| 35 | +import pandas as pd |
| 36 | +from sklearn.datasets import fetch_20newsgroups |
| 37 | +from sklearn.feature_extraction.text import CountVectorizer |
| 38 | +from wordcloud import WordCloud |
| 39 | + |
| 40 | +import flax.linen as nn |
| 41 | +import haiku as hk |
| 42 | +import jax |
| 43 | +from jax import device_put, random |
| 44 | +import jax.numpy as jnp |
| 45 | + |
| 46 | +import numpyro |
| 47 | +from numpyro.contrib.module import flax_module, haiku_module |
| 48 | +import numpyro.distributions as dist |
| 49 | +from numpyro.infer import SVI, TraceMeanField_ELBO |
| 50 | + |
| 51 | + |
| 52 | +class HaikuEncoder: |
| 53 | + def __init__(self, vocab_size, num_topics, hidden, dropout_rate): |
| 54 | + self._vocab_size = vocab_size |
| 55 | + self._num_topics = num_topics |
| 56 | + self._hidden = hidden |
| 57 | + self._dropout_rate = dropout_rate |
| 58 | + |
| 59 | + def __call__(self, inputs, is_training): |
| 60 | + dropout_rate = self._dropout_rate if is_training else 0.0 |
| 61 | + |
| 62 | + h = jax.nn.softplus(hk.Linear(self._hidden)(inputs)) |
| 63 | + h = jax.nn.softplus(hk.Linear(self._hidden)(h)) |
| 64 | + h = hk.dropout(hk.next_rng_key(), dropout_rate, h) |
| 65 | + h = hk.Linear(self._num_topics)(h) |
| 66 | + |
| 67 | + # NB: here we set `create_scale=False` and `create_offset=False` to reduce |
| 68 | + # the number of learning parameters |
| 69 | + log_concentration = hk.BatchNorm( |
| 70 | + create_scale=False, create_offset=False, decay_rate=0.9 |
| 71 | + )(h, is_training) |
| 72 | + return jnp.exp(log_concentration) |
| 73 | + |
| 74 | + |
| 75 | +class HaikuDecoder: |
| 76 | + def __init__(self, vocab_size, dropout_rate): |
| 77 | + self._vocab_size = vocab_size |
| 78 | + self._dropout_rate = dropout_rate |
| 79 | + |
| 80 | + def __call__(self, inputs, is_training): |
| 81 | + dropout_rate = self._dropout_rate if is_training else 0.0 |
| 82 | + h = hk.dropout(hk.next_rng_key(), dropout_rate, inputs) |
| 83 | + h = hk.Linear(self._vocab_size, with_bias=False)(h) |
| 84 | + return hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)( |
| 85 | + h, is_training |
| 86 | + ) |
| 87 | + |
| 88 | + |
| 89 | +class FlaxEncoder(nn.Module): |
| 90 | + vocab_size: int |
| 91 | + num_topics: int |
| 92 | + hidden: int |
| 93 | + dropout_rate: float |
| 94 | + |
| 95 | + @nn.compact |
| 96 | + def __call__(self, inputs, is_training): |
| 97 | + h = nn.softplus(nn.Dense(self.hidden)(inputs)) |
| 98 | + h = nn.softplus(nn.Dense(self.hidden)(h)) |
| 99 | + h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(h) |
| 100 | + h = nn.Dense(self.num_topics)(h) |
| 101 | + |
| 102 | + log_concentration = nn.BatchNorm( |
| 103 | + use_bias=False, |
| 104 | + use_scale=False, |
| 105 | + momentum=0.9, |
| 106 | + use_running_average=not is_training, |
| 107 | + )(h) |
| 108 | + return jnp.exp(log_concentration) |
| 109 | + |
| 110 | + |
| 111 | +class FlaxDecoder(nn.Module): |
| 112 | + vocab_size: int |
| 113 | + dropout_rate: float |
| 114 | + |
| 115 | + @nn.compact |
| 116 | + def __call__(self, inputs, is_training): |
| 117 | + h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(inputs) |
| 118 | + h = nn.Dense(self.vocab_size, use_bias=False)(h) |
| 119 | + return nn.BatchNorm( |
| 120 | + use_bias=False, |
| 121 | + use_scale=False, |
| 122 | + momentum=0.9, |
| 123 | + use_running_average=not is_training, |
| 124 | + )(h) |
| 125 | + |
| 126 | + |
| 127 | +def model(docs, hyperparams, is_training=False, nn_framework="flax"): |
| 128 | + if nn_framework == "flax": |
| 129 | + decoder = flax_module( |
| 130 | + "decoder", |
| 131 | + FlaxDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"]), |
| 132 | + input_shape=(1, hyperparams["num_topics"]), |
| 133 | + # ensure PRNGKey is made available to dropout layers |
| 134 | + apply_rng=["dropout"], |
| 135 | + # indicate mutable state due to BatchNorm layers |
| 136 | + mutable=["batch_stats"], |
| 137 | + # to ensure proper initialisation of BatchNorm we must |
| 138 | + # initialise with is_training=True |
| 139 | + is_training=True, |
| 140 | + ) |
| 141 | + elif nn_framework == "haiku": |
| 142 | + decoder = haiku_module( |
| 143 | + "decoder", |
| 144 | + # use `transform_with_state` for BatchNorm |
| 145 | + hk.transform_with_state( |
| 146 | + HaikuDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"]) |
| 147 | + ), |
| 148 | + input_shape=(1, hyperparams["num_topics"]), |
| 149 | + apply_rng=True, |
| 150 | + # to ensure proper initialisation of BatchNorm we must |
| 151 | + # initialise with is_training=True |
| 152 | + is_training=True, |
| 153 | + ) |
| 154 | + else: |
| 155 | + raise ValueError(f"Invalid choice {nn_framework} for argument nn_framework") |
| 156 | + |
| 157 | + with numpyro.plate( |
| 158 | + "documents", docs.shape[0], subsample_size=hyperparams["batch_size"] |
| 159 | + ): |
| 160 | + batch_docs = numpyro.subsample(docs, event_dim=1) |
| 161 | + theta = numpyro.sample( |
| 162 | + "theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"])) |
| 163 | + ) |
| 164 | + |
| 165 | + if nn_framework == "flax": |
| 166 | + logits = decoder(theta, is_training, rngs={"dropout": numpyro.prng_key()}) |
| 167 | + elif nn_framework == "haiku": |
| 168 | + logits = decoder(numpyro.prng_key(), theta, is_training) |
| 169 | + |
| 170 | + total_count = batch_docs.sum(-1) |
| 171 | + numpyro.sample( |
| 172 | + "obs", dist.Multinomial(total_count, logits=logits), obs=batch_docs |
| 173 | + ) |
| 174 | + |
| 175 | + |
| 176 | +def guide(docs, hyperparams, is_training=False, nn_framework="flax"): |
| 177 | + if nn_framework == "flax": |
| 178 | + encoder = flax_module( |
| 179 | + "encoder", |
| 180 | + FlaxEncoder( |
| 181 | + hyperparams["vocab_size"], |
| 182 | + hyperparams["num_topics"], |
| 183 | + hyperparams["hidden"], |
| 184 | + hyperparams["dropout_rate"], |
| 185 | + ), |
| 186 | + input_shape=(1, hyperparams["vocab_size"]), |
| 187 | + # ensure PRNGKey is made available to dropout layers |
| 188 | + apply_rng=["dropout"], |
| 189 | + # indicate mutable state due to BatchNorm layers |
| 190 | + mutable=["batch_stats"], |
| 191 | + # to ensure proper initialisation of BatchNorm we must |
| 192 | + # initialise with is_training=True |
| 193 | + is_training=True, |
| 194 | + ) |
| 195 | + elif nn_framework == "haiku": |
| 196 | + encoder = haiku_module( |
| 197 | + "encoder", |
| 198 | + # use `transform_with_state` for BatchNorm |
| 199 | + hk.transform_with_state( |
| 200 | + HaikuEncoder( |
| 201 | + hyperparams["vocab_size"], |
| 202 | + hyperparams["num_topics"], |
| 203 | + hyperparams["hidden"], |
| 204 | + hyperparams["dropout_rate"], |
| 205 | + ) |
| 206 | + ), |
| 207 | + input_shape=(1, hyperparams["vocab_size"]), |
| 208 | + apply_rng=True, |
| 209 | + # to ensure proper initialisation of BatchNorm we must |
| 210 | + # initialise with is_training=True |
| 211 | + is_training=True, |
| 212 | + ) |
| 213 | + else: |
| 214 | + raise ValueError(f"Invalid choice {nn_framework} for argument nn_framework") |
| 215 | + |
| 216 | + with numpyro.plate( |
| 217 | + "documents", docs.shape[0], subsample_size=hyperparams["batch_size"] |
| 218 | + ): |
| 219 | + batch_docs = numpyro.subsample(docs, event_dim=1) |
| 220 | + |
| 221 | + if nn_framework == "flax": |
| 222 | + concentration = encoder( |
| 223 | + batch_docs, is_training, rngs={"dropout": numpyro.prng_key()} |
| 224 | + ) |
| 225 | + elif nn_framework == "haiku": |
| 226 | + concentration = encoder(numpyro.prng_key(), batch_docs, is_training) |
| 227 | + |
| 228 | + numpyro.sample("theta", dist.Dirichlet(concentration)) |
| 229 | + |
| 230 | + |
| 231 | +def load_data(): |
| 232 | + news = fetch_20newsgroups(subset="all") |
| 233 | + vectorizer = CountVectorizer(max_df=0.5, min_df=20, stop_words="english") |
| 234 | + docs = jnp.array(vectorizer.fit_transform(news["data"]).toarray()) |
| 235 | + |
| 236 | + vocab = pd.DataFrame(columns=["word", "index"]) |
| 237 | + vocab["word"] = vectorizer.get_feature_names() |
| 238 | + vocab["index"] = vocab.index |
| 239 | + |
| 240 | + return docs, vocab |
| 241 | + |
| 242 | + |
| 243 | +def run_inference(docs, args): |
| 244 | + rng_key = random.PRNGKey(0) |
| 245 | + docs = device_put(docs) |
| 246 | + |
| 247 | + hyperparams = dict( |
| 248 | + vocab_size=docs.shape[1], |
| 249 | + num_topics=args.num_topics, |
| 250 | + hidden=args.hidden, |
| 251 | + dropout_rate=args.dropout_rate, |
| 252 | + batch_size=args.batch_size, |
| 253 | + ) |
| 254 | + |
| 255 | + optimizer = numpyro.optim.Adam(args.learning_rate) |
| 256 | + svi = SVI(model, guide, optimizer, loss=TraceMeanField_ELBO()) |
| 257 | + |
| 258 | + return svi.run( |
| 259 | + rng_key, |
| 260 | + args.num_steps, |
| 261 | + docs, |
| 262 | + hyperparams, |
| 263 | + is_training=True, |
| 264 | + progress_bar=not args.disable_progbar, |
| 265 | + nn_framework=args.nn_framework, |
| 266 | + ) |
| 267 | + |
| 268 | + |
| 269 | +def plot_word_cloud(b, ax, vocab, n): |
| 270 | + indices = jnp.argsort(b)[::-1] |
| 271 | + top20 = indices[:20] |
| 272 | + df = pd.DataFrame(top20, columns=["index"]) |
| 273 | + words = pd.merge(df, vocab[["index", "word"]], how="left", on="index")[ |
| 274 | + "word" |
| 275 | + ].values.tolist() |
| 276 | + sizes = b[top20].tolist() |
| 277 | + freqs = {words[i]: sizes[i] for i in range(len(words))} |
| 278 | + wc = WordCloud(background_color="white", width=800, height=500) |
| 279 | + wc = wc.generate_from_frequencies(freqs) |
| 280 | + ax.set_title(f"Topic {n + 1}") |
| 281 | + ax.imshow(wc, interpolation="bilinear") |
| 282 | + ax.axis("off") |
| 283 | + |
| 284 | + |
| 285 | +def main(args): |
| 286 | + docs, vocab = load_data() |
| 287 | + print(f"Dictionary size: {len(vocab)}") |
| 288 | + print(f"Corpus size: {docs.shape}") |
| 289 | + |
| 290 | + svi_result = run_inference(docs, args) |
| 291 | + |
| 292 | + if args.nn_framework == "flax": |
| 293 | + beta = svi_result.params["decoder$params"]["Dense_0"]["kernel"] |
| 294 | + elif args.nn_framework == "haiku": |
| 295 | + beta = svi_result.params["decoder$params"]["linear"]["w"] |
| 296 | + |
| 297 | + beta = jax.nn.softmax(beta) |
| 298 | + |
| 299 | + # the number of plots depends on the chosen number of topics. |
| 300 | + # add 2 to num topics to ensure we create a row for any remainder after division |
| 301 | + nrows = (args.num_topics + 2) // 3 |
| 302 | + fig, axs = plt.subplots(nrows, 3, figsize=(14, 3 + 3 * nrows)) |
| 303 | + axs = axs.flatten() |
| 304 | + |
| 305 | + for n in range(beta.shape[0]): |
| 306 | + plot_word_cloud(beta[n], axs[n], vocab, n) |
| 307 | + |
| 308 | + # hide any unused axes |
| 309 | + for i in range(n, len(axs)): |
| 310 | + axs[i].axis("off") |
| 311 | + |
| 312 | + fig.savefig("wordclouds.png") |
| 313 | + |
| 314 | + |
| 315 | +if __name__ == "__main__": |
| 316 | + assert numpyro.__version__.startswith("0.6.0") |
| 317 | + parser = argparse.ArgumentParser( |
| 318 | + description="Probabilistic topic modelling with Flax and Haiku" |
| 319 | + ) |
| 320 | + parser.add_argument("-n", "--num-steps", nargs="?", default=30_000, type=int) |
| 321 | + parser.add_argument("-t", "--num-topics", nargs="?", default=12, type=int) |
| 322 | + parser.add_argument("--batch-size", nargs="?", default=32, type=int) |
| 323 | + parser.add_argument("--learning-rate", nargs="?", default=1e-3, type=float) |
| 324 | + parser.add_argument("--hidden", nargs="?", default=200, type=int) |
| 325 | + parser.add_argument("--dropout-rate", nargs="?", default=0.2, type=float) |
| 326 | + parser.add_argument( |
| 327 | + "-dp", |
| 328 | + "--disable-progbar", |
| 329 | + action="store_true", |
| 330 | + default=False, |
| 331 | + help="Whether to disable progress bar", |
| 332 | + ) |
| 333 | + parser.add_argument( |
| 334 | + "--device", default="cpu", type=str, help='use "cpu", "gpu" or "tpu".' |
| 335 | + ) |
| 336 | + parser.add_argument( |
| 337 | + "--nn-framework", |
| 338 | + nargs="?", |
| 339 | + default="flax", |
| 340 | + help=( |
| 341 | + "The framework to use for constructing encoder / decoder. Options are " |
| 342 | + '"flax" or "haiku".' |
| 343 | + ), |
| 344 | + ) |
| 345 | + args = parser.parse_args() |
| 346 | + |
| 347 | + numpyro.set_platform(args.device) |
| 348 | + main(args) |
0 commit comments