From 9dfdf69da537a334eed34ab57d9f90d41265e65c Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 28 Mar 2025 10:52:22 +0100 Subject: [PATCH 1/3] use processed data instead of processed-main for BCE --- chebai/loss/bce_weighted.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 9ff7917e..33df8bb5 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -50,32 +50,25 @@ def set_pos_weight(self, input: torch.Tensor) -> None: and self.data_extractor is not None and all( os.path.exists( - os.path.join(self.data_extractor.processed_dir_main, file_name) + os.path.join(self.data_extractor.processed_dir, file_name) ) - for file_name in self.data_extractor.processed_main_file_names + for file_name in self.data_extractor.processed_file_names ) and self.pos_weight is None ): print( f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})" ) - complete_data = pd.concat( + complete_labels = torch.concat( [ - pd.read_pickle( - open( - os.path.join( - self.data_extractor.processed_dir_main, - file_name, - ), - "rb", - ) - ) - for file_name in self.data_extractor.processed_main_file_names + torch.stack([ + torch.Tensor(row["labels"]) for row in + self.data_extractor.load_processed_data(filename=file_name) + ]) + for file_name in self.data_extractor.processed_file_names ] ) - value_counts = [] - for c in complete_data.columns[3:]: - value_counts.append(len([v for v in complete_data[c] if v])) + value_counts = complete_labels.sum(dim=0) weights = [ (1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts ] From 22a7b795688ff2d84519f424007e4894de878d05 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 28 Mar 2025 10:53:07 +0100 Subject: [PATCH 2/3] add safeguards to evaluate_model, fix typo --- chebai/result/classification.py | 2 +- chebai/result/utils.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/chebai/result/classification.py b/chebai/result/classification.py index c75c7b29..bb23dea1 100644 --- a/chebai/result/classification.py +++ b/chebai/result/classification.py @@ -78,7 +78,7 @@ def print_metrics( print(f"Micro-Recall: {recall_micro(preds, labels):3f}") if markdown_output: print( - f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy" + f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy |" ) print(f"| --- | --- | --- | --- | --- | --- | --- | --- |") print( diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 35dbc319..991960d6 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -156,11 +156,12 @@ def evaluate_model( return test_preds, test_labels return test_preds, None elif len(preds_list) < 0: - torch.save( - _concat_tuple(preds_list), - os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), - ) - if labels_list[0] is not None: + if len(preds_list) > 0 and preds_list[0] is not None: + torch.save( + _concat_tuple(preds_list), + os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), + ) + if len(labels_list) > 0 and labels_list[0] is not None: torch.save( _concat_tuple(labels_list), os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), From 0c2b99d6f65d2ffaf185ccd55d1250256dbc855f Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 28 Mar 2025 10:57:12 +0100 Subject: [PATCH 3/3] reformat using black --- chebai/loss/bce_weighted.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 33df8bb5..b4fb8634 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -61,10 +61,14 @@ def set_pos_weight(self, input: torch.Tensor) -> None: ) complete_labels = torch.concat( [ - torch.stack([ - torch.Tensor(row["labels"]) for row in - self.data_extractor.load_processed_data(filename=file_name) - ]) + torch.stack( + [ + torch.Tensor(row["labels"]) + for row in self.data_extractor.load_processed_data( + filename=file_name + ) + ] + ) for file_name in self.data_extractor.processed_file_names ] )