1010
1111class TNTMBase (CTMBase ):
1212
13- #@override
13+ # @override
1414 def __init__ (
15- self ,
16- dataset ,
17- mus_init : torch .Tensor ,
18- L_lower_init : torch .Tensor ,
19- log_diag_init : torch .Tensor ,
20- word_embeddings_projected : torch .Tensor ,
21- n_topics : int = 50 ,
22- encoder_dim : int = 128 ,
23- inference_type = "zeroshot" ,
24- dropout : float = 0.1 ,
25- inference_activation = nn .Softplus (),
26- n_layers_inference_network : int = 1 ,
15+ self ,
16+ dataset ,
17+ mus_init : torch .Tensor ,
18+ L_lower_init : torch .Tensor ,
19+ log_diag_init : torch .Tensor ,
20+ word_embeddings_projected : torch .Tensor ,
21+ n_topics : int = 50 ,
22+ encoder_dim : int = 128 ,
23+ inference_type = "zeroshot" ,
24+ dropout : float = 0.1 ,
25+ inference_activation = nn .Softplus (),
26+ n_layers_inference_network : int = 1 ,
2727 ):
2828 """
29- Initialize the topic model parameters.
30-
31- Parameters
32- ----------
33- dataset : object
34- The dataset containing bag-of-words (BoW) and embeddings.
35- mus_init : torch.Tensor
36- Initial value for the topic means. Shape: (n_topics, vocab_size).
37- L_lower_init : torch.Tensor
38- Initial value for the lower triangular matrix. Shape: (n_topics, vocab_size, vocab_size).
39- log_diag_init : torch.Tensor
40- Initial value for the diagonal of the covariance matrix (log of the diagonal). Shape: (n_topics, vocab_size).
41- word_embeddings_projected : torch.Tensor
42- Projected word embeddings. Shape: (vocab_size, encoder_dim).
43- n_topics : int, optional
44- Number of topics, by default 50.
45- encoder_dim : int, optional
46- Dimension of the encoder, by default 200.
47- inference_type : str, optional
48- Type of inference, either "combined", "zeroshot", or "avitm". By default "zeroshot".
49- dropout : float, optional
50- Dropout rate, by default 0.1.
51- inference_activation : nn.Module, optional
52- Activation function for inference, by default nn.Softplus().
53- n_layers_inference_network : int, optional
54- Number of layers in the inference network, by default 3.
29+ Initialize the topic model parameters.
30+
31+ Parameters
32+ ----------
33+ dataset : object
34+ The dataset containing bag-of-words (BoW) and embeddings.
35+ mus_init : torch.Tensor
36+ Initial value for the topic means. Shape: (n_topics, vocab_size).
37+ L_lower_init : torch.Tensor
38+ Initial value for the lower triangular matrix. Shape: (n_topics, vocab_size, vocab_size).
39+ log_diag_init : torch.Tensor
40+ Initial value for the diagonal of the covariance matrix (log of the diagonal). Shape: (n_topics, vocab_size).
41+ word_embeddings_projected : torch.Tensor
42+ Projected word embeddings. Shape: (vocab_size, encoder_dim).
43+ n_topics : int, optional
44+ Number of topics, by default 50.
45+ encoder_dim : int, optional
46+ Dimension of the encoder, by default 200.
47+ inference_type : str, optional
48+ Type of inference, either "combined", "zeroshot", or "avitm". By default "zeroshot".
49+ dropout : float, optional
50+ Dropout rate, by default 0.1.
51+ inference_activation : nn.Module, optional
52+ Activation function for inference, by default nn.Softplus().
53+ n_layers_inference_network : int, optional
54+ Number of layers in the inference network, by default 3.
5555 """
56- super ().__init__ (dataset = dataset , n_topics = n_topics , encoder_dim = encoder_dim , dropout = dropout )
56+ super ().__init__ (
57+ dataset = dataset , n_topics = n_topics , encoder_dim = encoder_dim , dropout = dropout
58+ )
5759
58- self .mus = nn .Parameter (mus_init ) #create topic means as learnable paramter
59- self .L_lower = nn .Parameter (L_lower_init ) # factor of covariance per topic
60- self .log_diag = nn .Parameter (log_diag_init ) # summand for diagonal of covariance
60+ self .mus = nn .Parameter (mus_init ) # create topic means as learnable paramter
61+ self .L_lower = nn .Parameter (L_lower_init ) # factor of covariance per topic
62+ self .log_diag = nn .Parameter (
63+ log_diag_init
64+ ) # summand for diagonal of covariance
6165 self .word_embeddings_projected = torch .tensor (word_embeddings_projected )
6266
6367 emb_dim = word_embeddings_projected .shape [1 ]
@@ -69,10 +73,23 @@ def __init__(
6973 self .inference_type = inference_type
7074 self .dropout = dropout
7175
72- assert self .mus .shape == (n_topics , emb_dim ), f"Shape of mus is { self .mus .shape } but expected { (n_topics , emb_dim )} "
73- assert self .L_lower .shape == (n_topics , emb_dim , emb_dim ), f"Shape of L_lower is { self .L_lower .shape } but expected { (n_topics , emb_dim , emb_dim )} "
74- assert self .log_diag .shape == (n_topics , emb_dim ), f"Shape of log_diag is { self .log_diag .shape } but expected { (n_topics , emb_dim )} "
75- assert word_embeddings_projected .shape == (self .vocab_size , emb_dim ), f"Shape of word_embeddings_projected is { word_embeddings_projected .shape } but expected { (self .vocab_size , emb_dim )} "
76+ assert self .mus .shape == (
77+ n_topics ,
78+ emb_dim ,
79+ ), f"Shape of mus is { self .mus .shape } but expected { (n_topics , emb_dim )} "
80+ assert self .L_lower .shape == (
81+ n_topics ,
82+ emb_dim ,
83+ emb_dim ,
84+ ), f"Shape of L_lower is { self .L_lower .shape } but expected { (n_topics , emb_dim , emb_dim )} "
85+ assert self .log_diag .shape == (
86+ n_topics ,
87+ emb_dim ,
88+ ), f"Shape of log_diag is { self .log_diag .shape } but expected { (n_topics , emb_dim )} "
89+ assert word_embeddings_projected .shape == (
90+ self .vocab_size ,
91+ emb_dim ,
92+ ), f"Shape of word_embeddings_projected is { word_embeddings_projected .shape } but expected { (self .vocab_size , emb_dim )} "
7693
7794 contextual_embed_size = dataset .embeddings .shape [1 ]
7895
@@ -92,7 +109,7 @@ def __init__(
92109 input_size = self .vocab_size ,
93110 bert_size = contextual_embed_size ,
94111 output_size = n_topics ,
95- hidden_sizes = [encoder_dim ]* n_layers_inference_network ,
112+ hidden_sizes = [encoder_dim ] * n_layers_inference_network ,
96113 activation = inference_activation ,
97114 dropout = dropout ,
98115 inference_type = inference_type ,
@@ -105,7 +122,10 @@ def calc_log_beta(self):
105122
106123 diag = torch .exp (self .log_diag )
107124
108- normal_dis_lis = [LowRankMultivariateNormal (mu , cov_factor = lower , cov_diag = D ) for mu , lower , D in zip (self .mus , self .L_lower , diag )]
125+ normal_dis_lis = [
126+ LowRankMultivariateNormal (mu , cov_factor = lower , cov_diag = D )
127+ for mu , lower , D in zip (self .mus , self .L_lower , diag )
128+ ]
109129 log_probs = torch .zeros (self .n_topics , self .vocab_size )
110130
111131 for i , dis in enumerate (normal_dis_lis ):
@@ -119,7 +139,8 @@ def get_beta(self):
119139
120140 log_beta = self .calc_log_beta ()
121141 return torch .exp (log_beta )
122- #@override
142+
143+ # @override
123144 def forward (self , x ):
124145 """
125146 Forward pass through the network.
@@ -142,17 +163,18 @@ def forward(self, x):
142163
143164 log_beta = self .calc_log_beta ()
144165
145-
146-
147166 # prodLDA vs LDA
148167 # use numerical trick to compute log(beta @ theta )
149- log_theta = torch .nn .LogSoftmax (dim = - 1 )(theta ) #calculate log theta = log_softmax(theta_hat)
150- A = log_beta + log_theta .unsqueeze (- 1 ) #calculate (log (beta @ theta))[i] = (log (exp(log_beta) @ exp(log_theta)))[i] = log(\sum_k exp (log_beta[i,k] + log_theta[k]))
151- log_recon = torch .logsumexp (A , dim = 1 )
168+ log_theta = torch .nn .LogSoftmax (dim = - 1 )(
169+ theta
170+ ) # calculate log theta = log_softmax(theta_hat)
171+ A = log_beta + log_theta .unsqueeze (
172+ - 1
173+ ) # calculate (log (beta @ theta))[i] = (log (exp(log_beta) @ exp(log_theta)))[i] = log(\sum_k exp (log_beta[i,k] + log_theta[k]))
174+ log_recon = torch .logsumexp (A , dim = 1 )
152175
153176 return log_recon , posterior_mean , posterior_logvar
154177
155-
156178 def loss_function (self , x_bow , log_recon , posterior_mean , posterior_logvar ):
157179 """
158180 Computes the reconstruction and KL divergence loss.
@@ -173,13 +195,13 @@ def loss_function(self, x_bow, log_recon, posterior_mean, posterior_logvar):
173195 torch.Tensor
174196 The computed loss.
175197 """
176- # Negative log-likelihood: - (u^d)^T @ log(beta @ \theta^d)
198+ # Negative log-likelihood: - (u^d)^T @ log(beta @ \theta^d)
177199 NL = - (x_bow * log_recon ).sum (1 )
178200
179201 prior_mean = self .mu2
180202 prior_var = self .var2
181203
182- #KLD between variational posterior p(\theta|d) and prior p(\theta)
204+ # KLD between variational posterior p(\theta|d) and prior p(\theta)
183205 posterior_var = posterior_logvar .exp ()
184206 prior_mean = prior_mean .expand_as (posterior_mean )
185207 prior_var = prior_var .expand_as (posterior_mean )
@@ -188,11 +210,12 @@ def loss_function(self, x_bow, log_recon, posterior_mean, posterior_logvar):
188210 var_division = posterior_var / prior_var
189211
190212 diff = posterior_mean - prior_mean
191- diff_term = diff * diff / prior_var
213+ diff_term = diff * diff / prior_var
192214 logvar_division = prior_logvar - posterior_logvar
193215
194-
195- KLD = 0.5 * ( (var_division + diff_term + logvar_division ).sum (1 ) - self .n_topics )
216+ KLD = 0.5 * (
217+ (var_division + diff_term + logvar_division ).sum (1 ) - self .n_topics
218+ )
196219
197220 loss = (NL + KLD ).mean ()
198221 return loss
@@ -211,14 +234,7 @@ def compute_loss(self, x):
211234 torch.Tensor
212235 The computed loss.
213236 """
214- x_bow = x [' bow' ]
237+ x_bow = x [" bow" ]
215238 log_recon , posterior_mean , posterior_logvar = self .forward (x )
216239 loss = self .loss_function (x_bow , log_recon , posterior_mean , posterior_logvar )
217240 return loss
218-
219-
220-
221-
222-
223-
224-
0 commit comments