Skip to content

Commit 1b517b0

Browse files
authored
ProdLDA example with Haiku + Flax (#1056)
* Add ProdLDA example * Format source * Fix docs build and add local image * Fixes * Add command line configuration * Add Flax * Address further comments * Add prodLDA to docs and tests * Explicitly set momentum / decay_rate in batch norm layers * Add wordcloud dependency to examples requirements * Fix tests * Test specification * Use Dirichlet distribution * Minor fixes * Add dirichlet-dirichlet KL test * Fix test >_<
1 parent f368d98 commit 1b517b0

File tree

7 files changed

+384
-2
lines changed

7 files changed

+384
-2
lines changed
367 KB
Loading

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ NumPyro documentation
5959
examples/neutra
6060
examples/covtype
6161
examples/thompson_sampling
62+
examples/prodlda
6263

6364

6465
Indices and tables

examples/prodlda.py

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
# Copyright Contributors to the Pyro project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Example: ProdLDA
6+
================
7+
In this example, we will follow [1] to implement the ProdLDA topic model from
8+
Autoencoding Variational Inference For Topic Models by Akash Srivastava and Charles
9+
Sutton [2]. This model returns consistently better topics than vanilla LDA and trains
10+
much more quickly. Furthermore, it does not require a custom inference algorithm that
11+
relies on complex mathematical derivations. This example also serves as an
12+
introduction to Flax and Haiku modules in NumPyro.
13+
14+
Note that unlike [1, 2], this implementation uses a Dirichlet prior directly rather
15+
than approximating it with a softmax-normal distribution.
16+
17+
For the interested reader, a nice extension of this model is the CombinedTM model [3]
18+
which utilizes a pre-trained sentence transformer (like https://www.sbert.net/) to
19+
generate a better representation of the encoded latent vector.
20+
21+
**References:**
22+
1. http://pyro.ai/examples/prodlda.html
23+
2. Akash Srivastava, & Charles Sutton. (2017). Autoencoding Variational Inference
24+
For Topic Models.
25+
3. Federico Bianchi, Silvia Terragni, and Dirk Hovy (2021), "Pre-training is a Hot
26+
Topic: Contextualized Document Embeddings Improve Topic Coherence"
27+
(https://arxiv.org/abs/2004.03974)
28+
29+
.. image:: ../_static/img/examples/prodlda.png
30+
:align: center
31+
"""
32+
import argparse
33+
34+
import matplotlib.pyplot as plt
35+
import pandas as pd
36+
from sklearn.datasets import fetch_20newsgroups
37+
from sklearn.feature_extraction.text import CountVectorizer
38+
from wordcloud import WordCloud
39+
40+
import flax.linen as nn
41+
import haiku as hk
42+
import jax
43+
from jax import device_put, random
44+
import jax.numpy as jnp
45+
46+
import numpyro
47+
from numpyro.contrib.module import flax_module, haiku_module
48+
import numpyro.distributions as dist
49+
from numpyro.infer import SVI, TraceMeanField_ELBO
50+
51+
52+
class HaikuEncoder:
53+
def __init__(self, vocab_size, num_topics, hidden, dropout_rate):
54+
self._vocab_size = vocab_size
55+
self._num_topics = num_topics
56+
self._hidden = hidden
57+
self._dropout_rate = dropout_rate
58+
59+
def __call__(self, inputs, is_training):
60+
dropout_rate = self._dropout_rate if is_training else 0.0
61+
62+
h = jax.nn.softplus(hk.Linear(self._hidden)(inputs))
63+
h = jax.nn.softplus(hk.Linear(self._hidden)(h))
64+
h = hk.dropout(hk.next_rng_key(), dropout_rate, h)
65+
h = hk.Linear(self._num_topics)(h)
66+
67+
# NB: here we set `create_scale=False` and `create_offset=False` to reduce
68+
# the number of learning parameters
69+
log_concentration = hk.BatchNorm(
70+
create_scale=False, create_offset=False, decay_rate=0.9
71+
)(h, is_training)
72+
return jnp.exp(log_concentration)
73+
74+
75+
class HaikuDecoder:
76+
def __init__(self, vocab_size, dropout_rate):
77+
self._vocab_size = vocab_size
78+
self._dropout_rate = dropout_rate
79+
80+
def __call__(self, inputs, is_training):
81+
dropout_rate = self._dropout_rate if is_training else 0.0
82+
h = hk.dropout(hk.next_rng_key(), dropout_rate, inputs)
83+
h = hk.Linear(self._vocab_size, with_bias=False)(h)
84+
return hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)(
85+
h, is_training
86+
)
87+
88+
89+
class FlaxEncoder(nn.Module):
90+
vocab_size: int
91+
num_topics: int
92+
hidden: int
93+
dropout_rate: float
94+
95+
@nn.compact
96+
def __call__(self, inputs, is_training):
97+
h = nn.softplus(nn.Dense(self.hidden)(inputs))
98+
h = nn.softplus(nn.Dense(self.hidden)(h))
99+
h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(h)
100+
h = nn.Dense(self.num_topics)(h)
101+
102+
log_concentration = nn.BatchNorm(
103+
use_bias=False,
104+
use_scale=False,
105+
momentum=0.9,
106+
use_running_average=not is_training,
107+
)(h)
108+
return jnp.exp(log_concentration)
109+
110+
111+
class FlaxDecoder(nn.Module):
112+
vocab_size: int
113+
dropout_rate: float
114+
115+
@nn.compact
116+
def __call__(self, inputs, is_training):
117+
h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(inputs)
118+
h = nn.Dense(self.vocab_size, use_bias=False)(h)
119+
return nn.BatchNorm(
120+
use_bias=False,
121+
use_scale=False,
122+
momentum=0.9,
123+
use_running_average=not is_training,
124+
)(h)
125+
126+
127+
def model(docs, hyperparams, is_training=False, nn_framework="flax"):
128+
if nn_framework == "flax":
129+
decoder = flax_module(
130+
"decoder",
131+
FlaxDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"]),
132+
input_shape=(1, hyperparams["num_topics"]),
133+
# ensure PRNGKey is made available to dropout layers
134+
apply_rng=["dropout"],
135+
# indicate mutable state due to BatchNorm layers
136+
mutable=["batch_stats"],
137+
# to ensure proper initialisation of BatchNorm we must
138+
# initialise with is_training=True
139+
is_training=True,
140+
)
141+
elif nn_framework == "haiku":
142+
decoder = haiku_module(
143+
"decoder",
144+
# use `transform_with_state` for BatchNorm
145+
hk.transform_with_state(
146+
HaikuDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"])
147+
),
148+
input_shape=(1, hyperparams["num_topics"]),
149+
apply_rng=True,
150+
# to ensure proper initialisation of BatchNorm we must
151+
# initialise with is_training=True
152+
is_training=True,
153+
)
154+
else:
155+
raise ValueError(f"Invalid choice {nn_framework} for argument nn_framework")
156+
157+
with numpyro.plate(
158+
"documents", docs.shape[0], subsample_size=hyperparams["batch_size"]
159+
):
160+
batch_docs = numpyro.subsample(docs, event_dim=1)
161+
theta = numpyro.sample(
162+
"theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"]))
163+
)
164+
165+
if nn_framework == "flax":
166+
logits = decoder(theta, is_training, rngs={"dropout": numpyro.prng_key()})
167+
elif nn_framework == "haiku":
168+
logits = decoder(numpyro.prng_key(), theta, is_training)
169+
170+
total_count = batch_docs.sum(-1)
171+
numpyro.sample(
172+
"obs", dist.Multinomial(total_count, logits=logits), obs=batch_docs
173+
)
174+
175+
176+
def guide(docs, hyperparams, is_training=False, nn_framework="flax"):
177+
if nn_framework == "flax":
178+
encoder = flax_module(
179+
"encoder",
180+
FlaxEncoder(
181+
hyperparams["vocab_size"],
182+
hyperparams["num_topics"],
183+
hyperparams["hidden"],
184+
hyperparams["dropout_rate"],
185+
),
186+
input_shape=(1, hyperparams["vocab_size"]),
187+
# ensure PRNGKey is made available to dropout layers
188+
apply_rng=["dropout"],
189+
# indicate mutable state due to BatchNorm layers
190+
mutable=["batch_stats"],
191+
# to ensure proper initialisation of BatchNorm we must
192+
# initialise with is_training=True
193+
is_training=True,
194+
)
195+
elif nn_framework == "haiku":
196+
encoder = haiku_module(
197+
"encoder",
198+
# use `transform_with_state` for BatchNorm
199+
hk.transform_with_state(
200+
HaikuEncoder(
201+
hyperparams["vocab_size"],
202+
hyperparams["num_topics"],
203+
hyperparams["hidden"],
204+
hyperparams["dropout_rate"],
205+
)
206+
),
207+
input_shape=(1, hyperparams["vocab_size"]),
208+
apply_rng=True,
209+
# to ensure proper initialisation of BatchNorm we must
210+
# initialise with is_training=True
211+
is_training=True,
212+
)
213+
else:
214+
raise ValueError(f"Invalid choice {nn_framework} for argument nn_framework")
215+
216+
with numpyro.plate(
217+
"documents", docs.shape[0], subsample_size=hyperparams["batch_size"]
218+
):
219+
batch_docs = numpyro.subsample(docs, event_dim=1)
220+
221+
if nn_framework == "flax":
222+
concentration = encoder(
223+
batch_docs, is_training, rngs={"dropout": numpyro.prng_key()}
224+
)
225+
elif nn_framework == "haiku":
226+
concentration = encoder(numpyro.prng_key(), batch_docs, is_training)
227+
228+
numpyro.sample("theta", dist.Dirichlet(concentration))
229+
230+
231+
def load_data():
232+
news = fetch_20newsgroups(subset="all")
233+
vectorizer = CountVectorizer(max_df=0.5, min_df=20, stop_words="english")
234+
docs = jnp.array(vectorizer.fit_transform(news["data"]).toarray())
235+
236+
vocab = pd.DataFrame(columns=["word", "index"])
237+
vocab["word"] = vectorizer.get_feature_names()
238+
vocab["index"] = vocab.index
239+
240+
return docs, vocab
241+
242+
243+
def run_inference(docs, args):
244+
rng_key = random.PRNGKey(0)
245+
docs = device_put(docs)
246+
247+
hyperparams = dict(
248+
vocab_size=docs.shape[1],
249+
num_topics=args.num_topics,
250+
hidden=args.hidden,
251+
dropout_rate=args.dropout_rate,
252+
batch_size=args.batch_size,
253+
)
254+
255+
optimizer = numpyro.optim.Adam(args.learning_rate)
256+
svi = SVI(model, guide, optimizer, loss=TraceMeanField_ELBO())
257+
258+
return svi.run(
259+
rng_key,
260+
args.num_steps,
261+
docs,
262+
hyperparams,
263+
is_training=True,
264+
progress_bar=not args.disable_progbar,
265+
nn_framework=args.nn_framework,
266+
)
267+
268+
269+
def plot_word_cloud(b, ax, vocab, n):
270+
indices = jnp.argsort(b)[::-1]
271+
top20 = indices[:20]
272+
df = pd.DataFrame(top20, columns=["index"])
273+
words = pd.merge(df, vocab[["index", "word"]], how="left", on="index")[
274+
"word"
275+
].values.tolist()
276+
sizes = b[top20].tolist()
277+
freqs = {words[i]: sizes[i] for i in range(len(words))}
278+
wc = WordCloud(background_color="white", width=800, height=500)
279+
wc = wc.generate_from_frequencies(freqs)
280+
ax.set_title(f"Topic {n + 1}")
281+
ax.imshow(wc, interpolation="bilinear")
282+
ax.axis("off")
283+
284+
285+
def main(args):
286+
docs, vocab = load_data()
287+
print(f"Dictionary size: {len(vocab)}")
288+
print(f"Corpus size: {docs.shape}")
289+
290+
svi_result = run_inference(docs, args)
291+
292+
if args.nn_framework == "flax":
293+
beta = svi_result.params["decoder$params"]["Dense_0"]["kernel"]
294+
elif args.nn_framework == "haiku":
295+
beta = svi_result.params["decoder$params"]["linear"]["w"]
296+
297+
beta = jax.nn.softmax(beta)
298+
299+
# the number of plots depends on the chosen number of topics.
300+
# add 2 to num topics to ensure we create a row for any remainder after division
301+
nrows = (args.num_topics + 2) // 3
302+
fig, axs = plt.subplots(nrows, 3, figsize=(14, 3 + 3 * nrows))
303+
axs = axs.flatten()
304+
305+
for n in range(beta.shape[0]):
306+
plot_word_cloud(beta[n], axs[n], vocab, n)
307+
308+
# hide any unused axes
309+
for i in range(n, len(axs)):
310+
axs[i].axis("off")
311+
312+
fig.savefig("wordclouds.png")
313+
314+
315+
if __name__ == "__main__":
316+
assert numpyro.__version__.startswith("0.6.0")
317+
parser = argparse.ArgumentParser(
318+
description="Probabilistic topic modelling with Flax and Haiku"
319+
)
320+
parser.add_argument("-n", "--num-steps", nargs="?", default=30_000, type=int)
321+
parser.add_argument("-t", "--num-topics", nargs="?", default=12, type=int)
322+
parser.add_argument("--batch-size", nargs="?", default=32, type=int)
323+
parser.add_argument("--learning-rate", nargs="?", default=1e-3, type=float)
324+
parser.add_argument("--hidden", nargs="?", default=200, type=int)
325+
parser.add_argument("--dropout-rate", nargs="?", default=0.2, type=float)
326+
parser.add_argument(
327+
"-dp",
328+
"--disable-progbar",
329+
action="store_true",
330+
default=False,
331+
help="Whether to disable progress bar",
332+
)
333+
parser.add_argument(
334+
"--device", default="cpu", type=str, help='use "cpu", "gpu" or "tpu".'
335+
)
336+
parser.add_argument(
337+
"--nn-framework",
338+
nargs="?",
339+
default="flax",
340+
help=(
341+
"The framework to use for constructing encoder / decoder. Options are "
342+
'"flax" or "haiku".'
343+
),
344+
)
345+
args = parser.parse_args()
346+
347+
numpyro.set_platform(args.device)
348+
main(args)

numpyro/distributions/kl.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@
3030

3131
from jax import lax
3232
import jax.numpy as jnp
33+
from jax.scipy.special import digamma, gammaln
3334

34-
from numpyro.distributions.continuous import Normal
35+
from numpyro.distributions.continuous import Dirichlet, Normal
3536
from numpyro.distributions.distribution import (
3637
Delta,
3738
Distribution,
@@ -196,3 +197,15 @@ def _kl_normal_normal(p, q):
196197
var_ratio = jnp.square(p.scale / q.scale)
197198
t1 = jnp.square((p.loc - q.loc) / q.scale)
198199
return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio))
200+
201+
202+
@register_kl(Dirichlet, Dirichlet)
203+
def _kl_dirichlet_dirichlet(p, q):
204+
# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
205+
sum_p_concentration = p.concentration.sum(-1)
206+
sum_q_concentration = q.concentration.sum(-1)
207+
t1 = gammaln(sum_p_concentration) - gammaln(sum_q_concentration)
208+
t2 = (gammaln(p.concentration) - gammaln(q.concentration)).sum(-1)
209+
t3 = p.concentration - q.concentration
210+
t4 = digamma(p.concentration) - digamma(sum_p_concentration)[..., None]
211+
return t1 - t2 + (t3 * t4).sum(-1)

0 commit comments

Comments
 (0)