diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 9ff7917e..b4fb8634 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -50,32 +50,29 @@ def set_pos_weight(self, input: torch.Tensor) -> None: and self.data_extractor is not None and all( os.path.exists( - os.path.join(self.data_extractor.processed_dir_main, file_name) + os.path.join(self.data_extractor.processed_dir, file_name) ) - for file_name in self.data_extractor.processed_main_file_names + for file_name in self.data_extractor.processed_file_names ) and self.pos_weight is None ): print( f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})" ) - complete_data = pd.concat( + complete_labels = torch.concat( [ - pd.read_pickle( - open( - os.path.join( - self.data_extractor.processed_dir_main, - file_name, - ), - "rb", - ) + torch.stack( + [ + torch.Tensor(row["labels"]) + for row in self.data_extractor.load_processed_data( + filename=file_name + ) + ] ) - for file_name in self.data_extractor.processed_main_file_names + for file_name in self.data_extractor.processed_file_names ] ) - value_counts = [] - for c in complete_data.columns[3:]: - value_counts.append(len([v for v in complete_data[c] if v])) + value_counts = complete_labels.sum(dim=0) weights = [ (1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts ] diff --git a/chebai/result/classification.py b/chebai/result/classification.py index c75c7b29..bb23dea1 100644 --- a/chebai/result/classification.py +++ b/chebai/result/classification.py @@ -78,7 +78,7 @@ def print_metrics( print(f"Micro-Recall: {recall_micro(preds, labels):3f}") if markdown_output: print( - f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy" + f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy |" ) print(f"| --- | --- | --- | --- | --- | --- | --- | --- |") print( diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 35dbc319..991960d6 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -156,11 +156,12 @@ def evaluate_model( return test_preds, test_labels return test_preds, None elif len(preds_list) < 0: - torch.save( - _concat_tuple(preds_list), - os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), - ) - if labels_list[0] is not None: + if len(preds_list) > 0 and preds_list[0] is not None: + torch.save( + _concat_tuple(preds_list), + os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), + ) + if len(labels_list) > 0 and labels_list[0] is not None: torch.save( _concat_tuple(labels_list), os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),