diff --git a/octis/models/CTM.py b/octis/models/CTM.py index 49c25516..f2d44819 100644 --- a/octis/models/CTM.py +++ b/octis/models/CTM.py @@ -134,7 +134,7 @@ def train_model(self, dataset, hyperparameters=None, top_words=10): topic_prior_variance=self.hyperparameters["prior_variance"], top_words=top_words) - self.model.fit(x_train, x_valid, verbose=False) + self.model.fit(x_train, x_valid, verbose=self.hyperparameters["verbose"], save_dir=self.hyperparameters["save_dir"]) result = self.inference(x_test) return result @@ -161,7 +161,7 @@ def train_model(self, dataset, hyperparameters=None, top_words=10): topic_prior_variance=self.hyperparameters["prior_variance"], top_words=top_words) - self.model.fit(x_train, None, verbose=False) + self.model.fit(x_train, None, verbose=self.hyperparameters["verbose"], save_dir=self.hyperparameters["save_dir"]) result = self.model.get_info() return result diff --git a/octis/models/contextualized_topic_models/models/ctm.py b/octis/models/contextualized_topic_models/models/ctm.py index dc93f7c5..6e9fb90f 100644 --- a/octis/models/contextualized_topic_models/models/ctm.py +++ b/octis/models/contextualized_topic_models/models/ctm.py @@ -273,7 +273,7 @@ def fit(self, train_dataset, validation_dataset=None, train_loader = DataLoader( self.train_data, batch_size=self.batch_size, shuffle=True, - num_workers=self.num_data_loader_workers) + num_workers=self.num_data_loader_workers, drop_last=True) # init training variables train_loss = 0 @@ -301,7 +301,7 @@ def fit(self, train_dataset, validation_dataset=None, if self.validation_data is not None: validation_loader = DataLoader( self.validation_data, batch_size=self.batch_size, - shuffle=True, num_workers=self.num_data_loader_workers) + shuffle=True, num_workers=self.num_data_loader_workers, drop_last=True) # train epoch s = datetime.datetime.now() val_samples_processed, val_loss = self._validation( diff --git a/octis/models/pytorchavitm/AVITM.py b/octis/models/pytorchavitm/AVITM.py index 5602e12b..954906c6 100644 --- a/octis/models/pytorchavitm/AVITM.py +++ b/octis/models/pytorchavitm/AVITM.py @@ -98,14 +98,14 @@ def train_model(self, dataset, hyperparameters=None, top_words=10): solver=self.hyperparameters['solver'], num_epochs=self.hyperparameters['num_epochs'], reduce_on_plateau=self.hyperparameters['reduce_on_plateau'], num_samples=self.hyperparameters[ 'num_samples'], topic_prior_mean=self.hyperparameters["prior_mean"], - topic_prior_variance=self.hyperparameters["prior_variance"] + topic_prior_variance=self.hyperparameters["prior_variance"], verbose=self.hyperparameters["verbose"], top_words=top_words, ) if self.use_partitions: - self.model.fit(x_train, x_valid) + self.model.fit(x_train, x_valid, save_dir=self.hyperparameters["save_dir"]) result = self.inference(x_test) else: - self.model.fit(x_train, None) + self.model.fit(x_train, None, save_dir=self.hyperparameters["save_dir"]) result = self.model.get_info() return result diff --git a/octis/models/pytorchavitm/avitm/avitm_model.py b/octis/models/pytorchavitm/avitm/avitm_model.py index 2b9e4f21..056a54d3 100644 --- a/octis/models/pytorchavitm/avitm/avitm_model.py +++ b/octis/models/pytorchavitm/avitm/avitm_model.py @@ -19,7 +19,7 @@ class AVITM_model(object): def __init__(self, input_size, num_topics=10, model_type='prodLDA', hidden_sizes=(100, 100), activation='softplus', dropout=0.2, learn_priors=True, batch_size=64, lr=2e-3, momentum=0.99, solver='adam', num_epochs=100, reduce_on_plateau=False, topic_prior_mean=0.0, - topic_prior_variance=None, num_samples=10, num_data_loader_workers=0, verbose=False): + topic_prior_variance=None, num_samples=10, num_data_loader_workers=0, verbose=False, top_words=10): """ Initialize AVITM model. @@ -68,6 +68,7 @@ def __init__(self, input_size, num_topics=10, model_type='prodLDA', hidden_sizes # assert isinstance(topic_prior_variance, float), \ # "topic prior_variance must be type float" + self.top_words = top_words self.input_size = input_size self.num_topics = num_topics self.verbose = verbose @@ -240,7 +241,7 @@ def fit(self, train_dataset, validation_dataset, save_dir=None): self.validation_data = validation_dataset train_loader = DataLoader( self.train_data, batch_size=self.batch_size, shuffle=True, - num_workers=self.num_data_loader_workers) + num_workers=self.num_data_loader_workers, drop_last=True) # init training variables train_loss = 0 @@ -267,7 +268,7 @@ def fit(self, train_dataset, validation_dataset, save_dir=None): if self.validation_data is not None: validation_loader = DataLoader( self.validation_data, batch_size=self.batch_size, shuffle=True, - num_workers=self.num_data_loader_workers) + num_workers=self.num_data_loader_workers, drop_last=True) # train epoch s = datetime.datetime.now() val_samples_processed, val_loss = self._validation(validation_loader) @@ -347,7 +348,7 @@ def get_topics(self, k=10): def get_info(self): info = {} - topic_word = self.get_topics() + topic_word = self.get_topics(k=self.top_words) # or self.input_size topic_word_dist = self.get_topic_word_mat() # topic_document_dist = self.get_topic_document_mat() info['topics'] = topic_word