From 1f2d45a0fefc876616dacaaf1fb6473e86f5f9ef Mon Sep 17 00:00:00 2001 From: XavierSpycy Date: Sat, 15 Jun 2024 02:20:14 +1000 Subject: [PATCH 1/2] fix_none_check_and_device_assignment --- octis/models/ETM.py | 11 +++++++---- octis/models/LSI.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/octis/models/ETM.py b/octis/models/ETM.py index fe0116c8..af9a3ca1 100644 --- a/octis/models/ETM.py +++ b/octis/models/ETM.py @@ -16,7 +16,7 @@ def __init__( self, num_topics=10, num_epochs=100, t_hidden_size=800, rho_size=300, embedding_size=300, activation='relu', dropout=0.5, lr=0.005, optimizer='adam', batch_size=128, clip=0.0, wdecay=1.2e-6, bow_norm=1, - device='cpu', train_embeddings=True, embeddings_path=None, + device='cuda', train_embeddings=True, embeddings_path=None, embeddings_type='pickle', binary_embeddings=True, headerless_embeddings=False, use_partitions=True): """ @@ -131,9 +131,12 @@ def set_model(self, dataset, hyperparameters): self.train_tokens, self.train_counts = self.preprocess( vocab2id, data_corpus, None) - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") - + if isinstance(self.device, str): + self.device = torch.device(self.device) + + if (self.device.type == 'cuda' and not torch.cuda.is_available()) or (self.device.type == 'mps' and torch.backends.mps.is_available()): + self.device = torch.device('cpu') + self.set_default_hyperparameters(hyperparameters) self.load_embeddings() # define model and optimizer diff --git a/octis/models/LSI.py b/octis/models/LSI.py index 1d6229ec..1c74def4 100644 --- a/octis/models/LSI.py +++ b/octis/models/LSI.py @@ -106,10 +106,10 @@ def train_model(self, dataset, hyperparameters={}, top_words=10): else: partition = [dataset.get_corpus(), []] - if self.id2word == None: + if self.id2word is None: self.id2word = corpora.Dictionary(dataset.get_corpus()) - if self.id_corpus == None: + if self.id_corpus is None: self.id_corpus = [self.id2word.doc2bow( document) for document in partition[0]] From 25293e5f329c419e8f36d835700b0191663eacbb Mon Sep 17 00:00:00 2001 From: XavierSpycy Date: Sat, 15 Jun 2024 18:11:38 +1000 Subject: [PATCH 2/2] fix_none_check_and_device_assignment --- octis/models/ETM.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/octis/models/ETM.py b/octis/models/ETM.py index af9a3ca1..9f200e9c 100644 --- a/octis/models/ETM.py +++ b/octis/models/ETM.py @@ -134,7 +134,7 @@ def set_model(self, dataset, hyperparameters): if isinstance(self.device, str): self.device = torch.device(self.device) - if (self.device.type == 'cuda' and not torch.cuda.is_available()) or (self.device.type == 'mps' and torch.backends.mps.is_available()): + if (self.device.type == 'cuda' and not torch.cuda.is_available()) or (self.device.type == 'mps' and not torch.backends.mps.is_available()): self.device = torch.device('cpu') self.set_default_hyperparameters(hyperparameters)