Skip to content

Commit dd7e62b

Browse files
authored
Merge pull request #114 from AnFreTh/tntm_fix
TNTM fix
2 parents 8119e91 + a6ec2c5 commit dd7e62b

File tree

3 files changed

+88
-72
lines changed

3 files changed

+88
-72
lines changed

stream_topic/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Version information."""
22

33
# The following line *must* be the last in the module, exactly as formatted:
4-
__version__ = "0.2.0"
4+
__version__ = "0.2.1"

stream_topic/models/neural_base_models/tntm_base.py

Lines changed: 85 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -10,54 +10,58 @@
1010

1111
class TNTMBase(CTMBase):
1212

13-
#@override
13+
# @override
1414
def __init__(
15-
self,
16-
dataset,
17-
mus_init : torch.Tensor,
18-
L_lower_init: torch.Tensor,
19-
log_diag_init: torch.Tensor,
20-
word_embeddings_projected: torch.Tensor,
21-
n_topics: int = 50,
22-
encoder_dim: int = 128,
23-
inference_type="zeroshot",
24-
dropout: float = 0.1,
25-
inference_activation = nn.Softplus(),
26-
n_layers_inference_network: int = 1,
15+
self,
16+
dataset,
17+
mus_init: torch.Tensor,
18+
L_lower_init: torch.Tensor,
19+
log_diag_init: torch.Tensor,
20+
word_embeddings_projected: torch.Tensor,
21+
n_topics: int = 50,
22+
encoder_dim: int = 128,
23+
inference_type="zeroshot",
24+
dropout: float = 0.1,
25+
inference_activation=nn.Softplus(),
26+
n_layers_inference_network: int = 1,
2727
):
2828
"""
29-
Initialize the topic model parameters.
30-
31-
Parameters
32-
----------
33-
dataset : object
34-
The dataset containing bag-of-words (BoW) and embeddings.
35-
mus_init : torch.Tensor
36-
Initial value for the topic means. Shape: (n_topics, vocab_size).
37-
L_lower_init : torch.Tensor
38-
Initial value for the lower triangular matrix. Shape: (n_topics, vocab_size, vocab_size).
39-
log_diag_init : torch.Tensor
40-
Initial value for the diagonal of the covariance matrix (log of the diagonal). Shape: (n_topics, vocab_size).
41-
word_embeddings_projected : torch.Tensor
42-
Projected word embeddings. Shape: (vocab_size, encoder_dim).
43-
n_topics : int, optional
44-
Number of topics, by default 50.
45-
encoder_dim : int, optional
46-
Dimension of the encoder, by default 200.
47-
inference_type : str, optional
48-
Type of inference, either "combined", "zeroshot", or "avitm". By default "zeroshot".
49-
dropout : float, optional
50-
Dropout rate, by default 0.1.
51-
inference_activation : nn.Module, optional
52-
Activation function for inference, by default nn.Softplus().
53-
n_layers_inference_network : int, optional
54-
Number of layers in the inference network, by default 3.
29+
Initialize the topic model parameters.
30+
31+
Parameters
32+
----------
33+
dataset : object
34+
The dataset containing bag-of-words (BoW) and embeddings.
35+
mus_init : torch.Tensor
36+
Initial value for the topic means. Shape: (n_topics, vocab_size).
37+
L_lower_init : torch.Tensor
38+
Initial value for the lower triangular matrix. Shape: (n_topics, vocab_size, vocab_size).
39+
log_diag_init : torch.Tensor
40+
Initial value for the diagonal of the covariance matrix (log of the diagonal). Shape: (n_topics, vocab_size).
41+
word_embeddings_projected : torch.Tensor
42+
Projected word embeddings. Shape: (vocab_size, encoder_dim).
43+
n_topics : int, optional
44+
Number of topics, by default 50.
45+
encoder_dim : int, optional
46+
Dimension of the encoder, by default 200.
47+
inference_type : str, optional
48+
Type of inference, either "combined", "zeroshot", or "avitm". By default "zeroshot".
49+
dropout : float, optional
50+
Dropout rate, by default 0.1.
51+
inference_activation : nn.Module, optional
52+
Activation function for inference, by default nn.Softplus().
53+
n_layers_inference_network : int, optional
54+
Number of layers in the inference network, by default 3.
5555
"""
56-
super().__init__(dataset = dataset, n_topics = n_topics, encoder_dim = encoder_dim, dropout = dropout)
56+
super().__init__(
57+
dataset=dataset, n_topics=n_topics, encoder_dim=encoder_dim, dropout=dropout
58+
)
5759

58-
self.mus = nn.Parameter(mus_init) #create topic means as learnable paramter
59-
self.L_lower = nn.Parameter(L_lower_init) # factor of covariance per topic
60-
self.log_diag = nn.Parameter(log_diag_init) # summand for diagonal of covariance
60+
self.mus = nn.Parameter(mus_init) # create topic means as learnable paramter
61+
self.L_lower = nn.Parameter(L_lower_init) # factor of covariance per topic
62+
self.log_diag = nn.Parameter(
63+
log_diag_init
64+
) # summand for diagonal of covariance
6165
self.word_embeddings_projected = torch.tensor(word_embeddings_projected)
6266

6367
emb_dim = word_embeddings_projected.shape[1]
@@ -69,10 +73,23 @@ def __init__(
6973
self.inference_type = inference_type
7074
self.dropout = dropout
7175

72-
assert self.mus.shape == (n_topics, emb_dim), f"Shape of mus is {self.mus.shape} but expected {(n_topics, emb_dim)}"
73-
assert self.L_lower.shape == (n_topics, emb_dim, emb_dim), f"Shape of L_lower is {self.L_lower.shape} but expected {(n_topics, emb_dim, emb_dim)}"
74-
assert self.log_diag.shape == (n_topics, emb_dim), f"Shape of log_diag is {self.log_diag.shape} but expected {(n_topics, emb_dim)}"
75-
assert word_embeddings_projected.shape == (self.vocab_size, emb_dim), f"Shape of word_embeddings_projected is {word_embeddings_projected.shape} but expected {(self.vocab_size, emb_dim)}"
76+
assert self.mus.shape == (
77+
n_topics,
78+
emb_dim,
79+
), f"Shape of mus is {self.mus.shape} but expected {(n_topics, emb_dim)}"
80+
assert self.L_lower.shape == (
81+
n_topics,
82+
emb_dim,
83+
emb_dim,
84+
), f"Shape of L_lower is {self.L_lower.shape} but expected {(n_topics, emb_dim, emb_dim)}"
85+
assert self.log_diag.shape == (
86+
n_topics,
87+
emb_dim,
88+
), f"Shape of log_diag is {self.log_diag.shape} but expected {(n_topics, emb_dim)}"
89+
assert word_embeddings_projected.shape == (
90+
self.vocab_size,
91+
emb_dim,
92+
), f"Shape of word_embeddings_projected is {word_embeddings_projected.shape} but expected {(self.vocab_size, emb_dim)}"
7693

7794
contextual_embed_size = dataset.embeddings.shape[1]
7895

@@ -92,7 +109,7 @@ def __init__(
92109
input_size=self.vocab_size,
93110
bert_size=contextual_embed_size,
94111
output_size=n_topics,
95-
hidden_sizes=[encoder_dim]*n_layers_inference_network,
112+
hidden_sizes=[encoder_dim] * n_layers_inference_network,
96113
activation=inference_activation,
97114
dropout=dropout,
98115
inference_type=inference_type,
@@ -105,7 +122,10 @@ def calc_log_beta(self):
105122

106123
diag = torch.exp(self.log_diag)
107124

108-
normal_dis_lis = [LowRankMultivariateNormal(mu, cov_factor= lower, cov_diag = D) for mu, lower, D in zip(self.mus, self.L_lower, diag)]
125+
normal_dis_lis = [
126+
LowRankMultivariateNormal(mu, cov_factor=lower, cov_diag=D)
127+
for mu, lower, D in zip(self.mus, self.L_lower, diag)
128+
]
109129
log_probs = torch.zeros(self.n_topics, self.vocab_size)
110130

111131
for i, dis in enumerate(normal_dis_lis):
@@ -119,7 +139,8 @@ def get_beta(self):
119139

120140
log_beta = self.calc_log_beta()
121141
return torch.exp(log_beta)
122-
#@override
142+
143+
# @override
123144
def forward(self, x):
124145
"""
125146
Forward pass through the network.
@@ -142,17 +163,18 @@ def forward(self, x):
142163

143164
log_beta = self.calc_log_beta()
144165

145-
146-
147166
# prodLDA vs LDA
148167
# use numerical trick to compute log(beta @ theta )
149-
log_theta = torch.nn.LogSoftmax(dim=-1)(theta) #calculate log theta = log_softmax(theta_hat)
150-
A = log_beta + log_theta.unsqueeze(-1) #calculate (log (beta @ theta))[i] = (log (exp(log_beta) @ exp(log_theta)))[i] = log(\sum_k exp (log_beta[i,k] + log_theta[k]))
151-
log_recon = torch.logsumexp(A, dim = 1)
168+
log_theta = torch.nn.LogSoftmax(dim=-1)(
169+
theta
170+
) # calculate log theta = log_softmax(theta_hat)
171+
A = log_beta + log_theta.unsqueeze(
172+
-1
173+
) # calculate (log (beta @ theta))[i] = (log (exp(log_beta) @ exp(log_theta)))[i] = log(\sum_k exp (log_beta[i,k] + log_theta[k]))
174+
log_recon = torch.logsumexp(A, dim=1)
152175

153176
return log_recon, posterior_mean, posterior_logvar
154177

155-
156178
def loss_function(self, x_bow, log_recon, posterior_mean, posterior_logvar):
157179
"""
158180
Computes the reconstruction and KL divergence loss.
@@ -173,13 +195,13 @@ def loss_function(self, x_bow, log_recon, posterior_mean, posterior_logvar):
173195
torch.Tensor
174196
The computed loss.
175197
"""
176-
#Negative log-likelihood: - (u^d)^T @ log(beta @ \theta^d)
198+
# Negative log-likelihood: - (u^d)^T @ log(beta @ \theta^d)
177199
NL = -(x_bow * log_recon).sum(1)
178200

179201
prior_mean = self.mu2
180202
prior_var = self.var2
181203

182-
#KLD between variational posterior p(\theta|d) and prior p(\theta)
204+
# KLD between variational posterior p(\theta|d) and prior p(\theta)
183205
posterior_var = posterior_logvar.exp()
184206
prior_mean = prior_mean.expand_as(posterior_mean)
185207
prior_var = prior_var.expand_as(posterior_mean)
@@ -188,11 +210,12 @@ def loss_function(self, x_bow, log_recon, posterior_mean, posterior_logvar):
188210
var_division = posterior_var / prior_var
189211

190212
diff = posterior_mean - prior_mean
191-
diff_term = diff*diff / prior_var
213+
diff_term = diff * diff / prior_var
192214
logvar_division = prior_logvar - posterior_logvar
193215

194-
195-
KLD = 0.5 * ( (var_division + diff_term + logvar_division).sum(1) - self.n_topics)
216+
KLD = 0.5 * (
217+
(var_division + diff_term + logvar_division).sum(1) - self.n_topics
218+
)
196219

197220
loss = (NL + KLD).mean()
198221
return loss
@@ -211,14 +234,7 @@ def compute_loss(self, x):
211234
torch.Tensor
212235
The computed loss.
213236
"""
214-
x_bow = x['bow']
237+
x_bow = x["bow"]
215238
log_recon, posterior_mean, posterior_logvar = self.forward(x)
216239
loss = self.loss_function(x_bow, log_recon, posterior_mean, posterior_logvar)
217240
return loss
218-
219-
220-
221-
222-
223-
224-

stream_topic/models/tntm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
1212
from loguru import logger
1313
from optuna.integration import PyTorchLightningPruningCallback
14-
from sentence_transformers import SentenceTransformer
14+
from .abstract_helper_models.mixins import SentenceEncodingMixin
1515
from sklearn.mixture import GaussianMixture
1616

1717
from ..utils.datamodule import TMDataModule
@@ -29,7 +29,7 @@
2929
)
3030

3131

32-
class TNTM(BaseModel):
32+
class TNTM(BaseModel, SentenceEncodingMixin):
3333
def __init__(
3434
self,
3535
word_embedding_model_name: str = WORD_EMBEDDING_MODEL_NAME,

0 commit comments

Comments
 (0)