From a91746f2731a6ac5a5ca84f202451a2d809b1129 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 18 Sep 2024 11:54:51 +0200 Subject: [PATCH 01/47] add documentation for pubchem kmeans, update tokens.txt --- chebai/preprocessing/datasets/pubchem.py | 37 +++++++++++++++++++----- configs/data/pubchem_kmeans.yml | 1 + 2 files changed, 31 insertions(+), 7 deletions(-) create mode 100644 configs/data/pubchem_kmeans.yml diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 5ba76cc4..e5366563 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -290,13 +290,14 @@ def download(self): class PubChemKMeans(PubChem): """ Dataset class representing a subset of PubChem dataset clustered using K-Means algorithm. + The idea is to create distinct distributions where pretraining and test sets are formed from dissimilar data. """ def __init__( self, *args, - n_clusters: int = 1e4, - random_size: int = 1e6, + n_clusters: int = 10000, + random_size: int = 1000000, exclude_data_from: _ChEBIDataExtractor = None, validation_size_limit: int = 4000, include_min_n_clusters: int = 100, @@ -306,9 +307,11 @@ def __init__( Args: n_clusters (int): Number of clusters to create using K-Means. random_size (int): Size of random dataset to download. - exclude_data_from (_ChEBIDataExtractor): Dataset extractor to exclude data clusters from. - validation_size_limit (int): Size limit for validation dataset. - include_min_n_clusters (int): Minimum number of clusters to include after exclusion. + exclude_data_from (_ChEBIDataExtractor): Dataset which should not overlap with selected clusters + (remove all clusters that contain data from this dataset). + validation_size_limit (int): Validation set will contain at most this number of instances. + include_min_n_clusters (int): Minimum number of clusters to keep if there are not enough clusters that don't + overlap with the `exclude_data_from` dataset. *args: Additional arguments for superclass initialization. **kwargs: Additional keyword arguments for superclass initialization. @@ -354,6 +357,9 @@ def raw_file_names(self) -> List[str]: @property def fingerprints(self) -> pd.DataFrame: """ + Creates random dataset, sanitises, creates Mol objects, generates fingerprints (RDKit) + Saves `fingerprints_df` to `fingerprints.pkl` + Returns: pd.DataFrame: DataFrame containing SMILES and corresponding fingerprints. """ @@ -424,7 +430,13 @@ def _build_clusters(self) -> tuple[pd.DataFrame, pd.DataFrame]: def _exclude_clusters(self, cluster_centers: pd.DataFrame) -> pd.DataFrame: """ - Excludes clusters based on data from an exclusion dataset. + Excludes clusters based on data from an exclusion dataset (in a training setup, this is the labeled dataset, + usually ChEBI). The goal is to avoid having similar data in the labeled training and the PubChem evaluation. + + Loads data from `exclude_data_from` dataset, generates mols, fingerprints, finds closest cluster centre for + each fingerprint, saves data to `exclusion_data_clustered.pkl`, returns all clusters with no instances from the + exclusion data (or the n clusters with the lowest number of instances if there are less than n clusters with no + instances, n being the minimum number of clusters to include) Args: cluster_centers (pd.DataFrame): DataFrame of cluster centers. @@ -491,6 +503,7 @@ def _exclude_clusters(self, cluster_centers: pd.DataFrame) -> pd.DataFrame: @property def cluster_centers(self) -> pd.DataFrame: """ + Loads cluster centers from file if possible, otherwise calls `self._build_clusters()`. Returns: pd.DataFrame: DataFrame of cluster centers. """ @@ -505,6 +518,7 @@ def cluster_centers(self) -> pd.DataFrame: @property def fingerprints_clustered(self) -> pd.DataFrame: """ + Loads fingerprints with assigned clusters from file if possible, otherwise calls `self._build_clusters()`. Returns: pd.DataFrame: DataFrame of clustered fingerprints. """ @@ -521,6 +535,10 @@ def fingerprints_clustered(self) -> pd.DataFrame: @property def cluster_centers_superclustered(self) -> pd.DataFrame: """ + Calls `_exclude_clusters()` which removes all clusters that contain data from the exclusion set (usually the + ChEBI, i.e., the labeled dataset). + Runs KMeans with 3 clusters on remaining data, saves cluster centres with assigned supercluster-labels to + `cluster_centers_superclustered.pkl` Returns: pd.DataFrame: DataFrame of superclustered cluster centers. """ @@ -559,7 +577,12 @@ def cluster_centers_superclustered(self) -> pd.DataFrame: def download(self): """ - Downloads the PubChemKMeans dataset. + Downloads the PubChemKMeans dataset. This function creates the complete dataset (including train, test, and + validation splits). Most of the steps are hidden in properties (e.g., `self.fingerprints_clustered` triggers + the download of a random dataset, the calculation of fingerprints for it and the KMeans clustering) + The final splits are created by assigning all fingerprints that belong to a cluster of a certain supercluster + to a dataset. This creates 3 datasets (for each of the 3 superclusters), the datasets are saved as validation, + test and train based on their size. The validation set is limited to `self.validation_size_limit` entries. """ if self._k == PubChem.FULL: super().download() diff --git a/configs/data/pubchem_kmeans.yml b/configs/data/pubchem_kmeans.yml new file mode 100644 index 00000000..05a0171d --- /dev/null +++ b/configs/data/pubchem_kmeans.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.pubchem.PubChemKMeans From 30d194cebd970bed67049ae4b5f97e7153c3e2bb Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 18 Sep 2024 11:55:10 +0200 Subject: [PATCH 02/47] update SMILES tokens --- .../preprocessing/bin/smiles_token/tokens.txt | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/chebai/preprocessing/bin/smiles_token/tokens.txt b/chebai/preprocessing/bin/smiles_token/tokens.txt index 6f4c338e..15d5af14 100644 --- a/chebai/preprocessing/bin/smiles_token/tokens.txt +++ b/chebai/preprocessing/bin/smiles_token/tokens.txt @@ -769,3 +769,44 @@ p [14CH3] [HH] [CH3-] +[PH+] +[Zr+4] +[Zr+2] +[Zr+3] +[Th+4] +[Sn+2] +[ClH+] +[Ti+] +[Ir+2] +[Si@] +[Pd+] +[16NH] +[SH2+] +[Hf+2] +[Hf+4] +[Ti+2] +[Ru+2] +[BH-] +[Sc+3] +[Sb+5] +[Pd+2] +[Si@@] +[Cr+] +[W+2] +[PH-] +[BH] +[PH2+] +[3HH] +[CH2+] +[BiH2] +[SH-] +[2HH] +[GeH] +[PH2-] +[PbH2] +[Ru+] +[U+2] +[SiH-] +[AlH2] +[Fe+5] +[Rh+2] From af54a11308905ebd90e7dcd9705e4f4df9e29a2c Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 19 Sep 2024 11:43:17 +0200 Subject: [PATCH 03/47] replace tnorm by fuzzy implication names --- chebai/cli.py | 8 ++++++++ chebai/loss/semantic.py | 23 +++++++++++++++-------- configs/loss/semantic_loss.yml | 1 + 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/chebai/cli.py b/chebai/cli.py index f2ad1072..7dae6f11 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -50,6 +50,14 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): parser.link_arguments( "data", "model.init_args.criterion.init_args.data_extractor" ) + parser.link_arguments( + "data.init_args.chebi_version", + "model.criterion.init_args.data_extractor.init_args.chebi_version", + ) + parser.link_arguments( + "data.init_args.chebi_version", + "model.criterion.init_args.base_loss.init_args.data_extractor.init_args.chebi_version", + ) @staticmethod def subcommands() -> Dict[str, Set[str]]: diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 78938d22..8d50bf21 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -16,7 +16,7 @@ class ImplicationLoss(torch.nn.Module): Implication Loss module. Args: - data_extractor (Union[_ChEBIDataExtractor, LabeledUnlabeledMixed]): Data extractor for labels. + data_extractor _ChEBIDataExtractor: Data extractor for labels. base_loss (torch.nn.Module, optional): Base loss function. Defaults to None. tnorm (Literal["product", "lukasiewicz", "xu19"], optional): T-norm type. Defaults to "product". impl_loss_weight (float, optional): Weight of implication loss relative to base loss. Defaults to 0.1. @@ -27,9 +27,11 @@ class ImplicationLoss(torch.nn.Module): def __init__( self, - data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed], + data_extractor: _ChEBIDataExtractor, base_loss: torch.nn.Module = None, - tnorm: Literal["product", "lukasiewicz", "xu19"] = "product", + fuzzy_implication: Literal[ + "reichenbach", "rb", "lukasiewicz", "lk", "xu19" + ] = "reichenbach", impl_loss_weight: float = 0.1, pos_scalar: int = 1, pos_epsilon: float = 0.01, @@ -39,6 +41,7 @@ def __init__( # automatically choose labeled subset for implication filter in case of mixed dataset if isinstance(data_extractor, LabeledUnlabeledMixed): data_extractor = data_extractor.labeled + assert isinstance(data_extractor, _ChEBIDataExtractor) self.data_extractor = data_extractor # propagate data_extractor to base loss if isinstance(base_loss, BCEWeighted): @@ -54,7 +57,7 @@ def __init__( implication_filter = _build_implication_filter(self.label_names, self.hierarchy) self.implication_filter_l = implication_filter[:, 0] self.implication_filter_r = implication_filter[:, 1] - self.tnorm = tnorm + self.fuzzy_implication = fuzzy_implication self.impl_weight = impl_loss_weight self.pos_scalar = pos_scalar self.eps = pos_epsilon @@ -119,14 +122,18 @@ def _calculate_implication_loss( one_min_r = torch.pow(1 - r, self.pos_scalar) else: one_min_r = 1 - r - if self.tnorm == "product": + if self.fuzzy_implication in ["reichenbach", "rb"]: individual_loss = l * one_min_r - elif self.tnorm == "xu19": + # xu19 (from Xu et al., 2019: Semantic loss) is not a fuzzy implication, but behaves similar to the Reichenbach + # implication + elif self.fuzzy_implication == "xu19": individual_loss = -torch.log(1 - l * one_min_r) - elif self.tnorm == "lukasiewicz": + elif self.fuzzy_implication in ["lukasiewicz", "lk"]: individual_loss = torch.relu(l + one_min_r - 1) else: - raise NotImplementedError(f"Unknown tnorm {self.tnorm}") + raise NotImplementedError( + f"Unknown fuzzy implication {self.fuzzy_implication}" + ) if self.multiply_by_softmax: individual_loss = individual_loss * individual_loss.softmax(dim=-1) diff --git a/configs/loss/semantic_loss.yml b/configs/loss/semantic_loss.yml index cc254d4f..8f619f17 100644 --- a/configs/loss/semantic_loss.yml +++ b/configs/loss/semantic_loss.yml @@ -5,5 +5,6 @@ init_args: class_path: chebai.loss.bce_weighted.BCEWeighted init_args: beta: 0.99 + multiply_by_softmax: true tnorm: product impl_loss_weight: 0.01 From 31c8107c31a7e8ffd8ce0e27359f687e69475536 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 19 Sep 2024 12:00:17 +0200 Subject: [PATCH 04/47] fix argument linking --- chebai/cli.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/chebai/cli.py b/chebai/cli.py index 7dae6f11..36245aa0 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -52,11 +52,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): ) parser.link_arguments( "data.init_args.chebi_version", - "model.criterion.init_args.data_extractor.init_args.chebi_version", - ) - parser.link_arguments( - "data.init_args.chebi_version", - "model.criterion.init_args.base_loss.init_args.data_extractor.init_args.chebi_version", + "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version", ) @staticmethod From 04226a52bcdbc5a273308b5962fe115bf2ad9e02 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 19 Sep 2024 15:13:31 +0200 Subject: [PATCH 05/47] fix semantic loss --- chebai/loss/bce_weighted.py | 18 ++++++++++++++---- chebai/loss/semantic.py | 7 ++++--- configs/loss/semantic_loss.yml | 1 - 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index b69fff43..b1b66995 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -5,14 +5,16 @@ import torch from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed class BCEWeighted(torch.nn.BCEWithLogitsLoss): """ BCEWithLogitsLoss with weights automatically computed according to the beta parameter. + If beta is None or data_extractor is None, the loss is unweighted. - This class computes weights based on the formula from the paper: + This class computes weights based on the formula from the paper by Cui et al. (2019): https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf Args: @@ -29,7 +31,10 @@ def __init__( if isinstance(data_extractor, LabeledUnlabeledMixed): data_extractor = data_extractor.labeled self.data_extractor = data_extractor - + assert ( + isinstance(self.data_extractor, _ChEBIDataExtractor) + or self.data_extractor is None + ) super().__init__() def set_pos_weight(self, input: torch.Tensor) -> None: @@ -43,17 +48,22 @@ def set_pos_weight(self, input: torch.Tensor) -> None: self.beta is not None and self.data_extractor is not None and all( - os.path.exists(os.path.join(self.data_extractor.raw_dir, raw_file)) + os.path.exists( + os.path.join(self.data_extractor.processed_dir_main, raw_file) + ) for raw_file in self.data_extractor.raw_file_names ) and self.pos_weight is None ): + print( + f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})" + ) complete_data = pd.concat( [ pd.read_pickle( open( os.path.join( - self.data_extractor.raw_dir, + self.data_extractor.processed_dir_main, raw_file_name, ), "rb", diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 8d50bf21..c4b1abac 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -7,6 +7,7 @@ import torch from chebai.loss.bce_weighted import BCEWeighted +from chebai.preprocessing.datasets import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed @@ -27,7 +28,7 @@ class ImplicationLoss(torch.nn.Module): def __init__( self, - data_extractor: _ChEBIDataExtractor, + data_extractor: XYBaseDataModule, base_loss: torch.nn.Module = None, fuzzy_implication: Literal[ "reichenbach", "rb", "lukasiewicz", "lk", "xu19" @@ -41,7 +42,7 @@ def __init__( # automatically choose labeled subset for implication filter in case of mixed dataset if isinstance(data_extractor, LabeledUnlabeledMixed): data_extractor = data_extractor.labeled - assert isinstance(data_extractor, _ChEBIDataExtractor) + assert isinstance(data_extractor, _ChEBIDataExtractor) self.data_extractor = data_extractor # propagate data_extractor to base loss if isinstance(base_loss, BCEWeighted): @@ -49,7 +50,7 @@ def __init__( self.base_loss = base_loss self.implication_cache_file = f"implications_{self.data_extractor.name}.cache" self.label_names = _load_label_names( - os.path.join(data_extractor.raw_dir, "classes.txt") + os.path.join(data_extractor.processed_dir_main, "classes.txt") ) self.hierarchy = self._load_implications( os.path.join(data_extractor.raw_dir, "chebi.obo") diff --git a/configs/loss/semantic_loss.yml b/configs/loss/semantic_loss.yml index 8f619f17..015f6619 100644 --- a/configs/loss/semantic_loss.yml +++ b/configs/loss/semantic_loss.yml @@ -6,5 +6,4 @@ init_args: init_args: beta: 0.99 multiply_by_softmax: true - tnorm: product impl_loss_weight: 0.01 From 094ade2bee9ab6b8f1834ad6c1f5b0785ab72686 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 20 Sep 2024 16:55:53 +0200 Subject: [PATCH 06/47] add kleene-dienes implication and sigmoidal implications --- chebai/loss/semantic.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index c4b1abac..970deeb8 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -19,11 +19,14 @@ class ImplicationLoss(torch.nn.Module): Args: data_extractor _ChEBIDataExtractor: Data extractor for labels. base_loss (torch.nn.Module, optional): Base loss function. Defaults to None. - tnorm (Literal["product", "lukasiewicz", "xu19"], optional): T-norm type. Defaults to "product". + fuzzy_implication (Literal["product", "lukasiewicz", "xu19"], optional): T-norm type. Defaults to "product". impl_loss_weight (float, optional): Weight of implication loss relative to base loss. Defaults to 0.1. pos_scalar (int, optional): Positive scalar exponent. Defaults to 1. pos_epsilon (float, optional): Epsilon value for numerical stability. Defaults to 0.01. multiply_by_softmax (bool, optional): Whether to multiply by softmax. Defaults to False. + use_sigmoidal_implication (bool, optional): Whether to use the sigmoidal fuzzy implication based on the + specified fuzzy_implication (as defined by van Krieken et al., 2022: Analyzing Differentiable Fuzzy Logic + Operators). Defaults to False. """ def __init__( @@ -31,12 +34,13 @@ def __init__( data_extractor: XYBaseDataModule, base_loss: torch.nn.Module = None, fuzzy_implication: Literal[ - "reichenbach", "rb", "lukasiewicz", "lk", "xu19" + "reichenbach", "rb", "lukasiewicz", "lk", "xu19", "kleene_dienes", "kd" ] = "reichenbach", impl_loss_weight: float = 0.1, pos_scalar: int = 1, pos_epsilon: float = 0.01, multiply_by_softmax: bool = False, + use_sigmoidal_implication: bool = False, ): super().__init__() # automatically choose labeled subset for implication filter in case of mixed dataset @@ -63,6 +67,7 @@ def __init__( self.pos_scalar = pos_scalar self.eps = pos_epsilon self.multiply_by_softmax = multiply_by_softmax + self.use_sigmoidal_implication = use_sigmoidal_implication def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: """ @@ -123,6 +128,8 @@ def _calculate_implication_loss( one_min_r = torch.pow(1 - r, self.pos_scalar) else: one_min_r = 1 - r + # for each implication I, calculate 1 - I(l, 1-one_min_r) + # for S-implications, this is equivalent to the t-norm if self.fuzzy_implication in ["reichenbach", "rb"]: individual_loss = l * one_min_r # xu19 (from Xu et al., 2019: Semantic loss) is not a fuzzy implication, but behaves similar to the Reichenbach @@ -131,11 +138,22 @@ def _calculate_implication_loss( individual_loss = -torch.log(1 - l * one_min_r) elif self.fuzzy_implication in ["lukasiewicz", "lk"]: individual_loss = torch.relu(l + one_min_r - 1) + elif self.fuzzy_implication in ["kleene_dienes", "kd"]: + individual_loss = torch.min(l, 1 - r) else: raise NotImplementedError( f"Unknown fuzzy implication {self.fuzzy_implication}" ) + if self.use_sigmoidal_implication: + # formula by van Krieken, 2022, applied to fuzzy implication with default parameters: b_0 = 0.5, s = 9 + # parts that only depend on b_0 and s are pre-calculated + implication = 1 - individual_loss + sigmoidal_implication = 90.028 * ( + 1.011 * torch.sigmoid(9 * (implication + 0.5)) - 1 + ) + individual_loss = 1 - sigmoidal_implication + if self.multiply_by_softmax: individual_loss = individual_loss * individual_loss.softmax(dim=-1) return torch.mean( From 74a4e121bb51241eae9cbfb9646d9ef536904880 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 20 Sep 2024 17:01:34 +0200 Subject: [PATCH 07/47] fix reichenbach abbreviation rb->rc --- chebai/loss/semantic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 970deeb8..e4dede08 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -34,7 +34,7 @@ def __init__( data_extractor: XYBaseDataModule, base_loss: torch.nn.Module = None, fuzzy_implication: Literal[ - "reichenbach", "rb", "lukasiewicz", "lk", "xu19", "kleene_dienes", "kd" + "reichenbach", "rc", "lukasiewicz", "lk", "xu19", "kleene_dienes", "kd" ] = "reichenbach", impl_loss_weight: float = 0.1, pos_scalar: int = 1, @@ -130,7 +130,7 @@ def _calculate_implication_loss( one_min_r = 1 - r # for each implication I, calculate 1 - I(l, 1-one_min_r) # for S-implications, this is equivalent to the t-norm - if self.fuzzy_implication in ["reichenbach", "rb"]: + if self.fuzzy_implication in ["reichenbach", "rc"]: individual_loss = l * one_min_r # xu19 (from Xu et al., 2019: Semantic loss) is not a fuzzy implication, but behaves similar to the Reichenbach # implication From f751bc5e4e9ec0057c2c26eafdce51c578b0c95f Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 24 Sep 2024 15:21:29 +0200 Subject: [PATCH 08/47] fix fuzzy loss evaluation --- chebai/result/analyse_sem.py | 73 ++++++++++++++++++++++++------------ chebai/result/utils.py | 27 +++++++------ 2 files changed, 63 insertions(+), 37 deletions(-) diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index 64ac87a1..aab6324b 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -15,7 +15,7 @@ from chebai.preprocessing.datasets.chebi import ChEBIOver100 from chebai.preprocessing.datasets.pubchem import Hazardous -DEVICE = "cpu" # torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def binary(left, right): @@ -78,11 +78,9 @@ def get_best_epoch(run): def load_preds_labels_from_wandb( run, epoch, - chebi_version, - test_on_data_cls=ChEBIOver100, # use data from this class + data_module, # use data from this class kind="test", # specify segment of test_on_data_cls ): - data_module = test_on_data_cls(chebi_version=chebi_version) buffer_dir = os.path.join( "results_buffer", @@ -107,10 +105,7 @@ def load_preds_labels_from_wandb( return preds, labels -def load_preds_labels_from_nonwandb( - name, epoch, chebi_version, test_on_data_cls=ChEBIOver100, kind="test" -): - data_module = test_on_data_cls(chebi_version=chebi_version) +def load_preds_labels_from_nonwandb(name, epoch, data_module, kind="test"): buffer_dir = os.path.join( "results_buffer", @@ -390,11 +385,14 @@ def run_all( ): # evaluate a list of runs on Hazardous and ChEBIOver100 datasets if datasets is None: - datasets = [(Hazardous, "all"), (ChEBIOver100, "test")] + datasets = [ + (Hazardous(), "all"), + (ChEBIOver100(chebi_version=chebi_version), "test"), + ] timestamp = datetime.now().strftime("%y%m%d-%H%M") results_path = os.path.join( - "_semloss_eval", - f"semloss_results_pc-dis-200k_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", + "_fuzzy_loss_eval", + f"results_pc-kmeans_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", ) if remove_violations: @@ -414,14 +412,14 @@ def run_all( "run-id": run_id, "epoch": int(epoch), "kind": kind, - "data_module": test_on.__name__, + "data_module": test_on.__class__.__name__, "chebi_version": chebi_version, } buffer_dir_smoothed = os.path.join( "results_buffer", "smoothed3step", f"{run.name}_ep{epoch}", - f"{test_on.__name__}_{kind}", + f"{test_on.__class__.__name__}_{kind}", ) if remove_violations and os.path.exists( os.path.join(buffer_dir_smoothed, "preds000.pt") @@ -433,13 +431,13 @@ def run_all( else: if not skip_preds: preds, labels = load_preds_labels_from_wandb( - run, epoch, chebi_version, test_on, kind + run, epoch, test_on, kind ) else: buffer_dir = os.path.join( "results_buffer", f"{run.name}_ep{epoch}", - f"{test_on.__name__}_{kind}", + f"{test_on.__class__.__name__}_{kind}", ) preds, labels = load_results_from_buffer( buffer_dir, device=DEVICE @@ -455,7 +453,7 @@ def run_all( "results_buffer", "smoothed3step", f"{run.name}_ep{epoch}", - f"{test_on.__name__}_{kind}", + f"{test_on.__class__.__name__}_{kind}", ) os.makedirs(buffer_dir_smoothed, exist_ok=True) torch.save( @@ -463,7 +461,7 @@ def run_all( ) if not skip_analyse: print( - f"Calculating metrics for run {run.name} on {test_on.__name__} ({kind})" + f"Calculating metrics for run {run.name} on {test_on.__class__.__name__} ({kind})" ) analyse_run( preds, @@ -472,11 +470,10 @@ def run_all( chebi_version=chebi_version, results_path=results_path, violation_metrics=violation_metrics, - verbose_violation_output=True, ) except Exception as e: print(f"Failed for run {run_id}: {e}") - # print(traceback.format_exc()) + print(traceback.format_exc()) if nonwandb_runs: for run_name, epoch in nonwandb_runs: @@ -486,18 +483,18 @@ def run_all( "run-id": run_name, "epoch": int(epoch), "kind": kind, - "data_module": test_on.__name__, + "data_module": test_on.__class__.__name__, "chebi_version": chebi_version, } if not skip_preds: preds, labels = load_preds_labels_from_nonwandb( - run_name, epoch, chebi_version, test_on, kind + run_name, epoch, test_on, kind ) else: buffer_dir = os.path.join( "results_buffer", f"{run_name}_ep{epoch}", - f"{test_on.__name__}_{kind}", + f"{test_on.__class__.__name__}_{kind}", ) preds, labels = load_results_from_buffer( buffer_dir, device=DEVICE @@ -511,7 +508,7 @@ def run_all( ) if not skip_analyse: print( - f"Calculating metrics for run {run_name} on {test_on.__name__} ({kind})" + f"Calculating metrics for run {run_name} on {test_on.__class__.__name__} ({kind})" ) analyse_run( preds, @@ -570,8 +567,34 @@ def run_semloss_eval(mode="eval"): ) +# follow-up to NeSy submission +def run_fuzzy_loss(tag="fuzzy_loss"): + api = wandb.Api() + runs = api.runs("chebai/chebai", filters={"tags": tag}) + print(f"Found {len(runs)} wandb runs tagged with '{tag}'") + ids = [run.id for run in runs] + chebi_version = 231 + run_all( + ids, + violation_metrics=[binary], + chebi_version=chebi_version, + datasets=[ + ( + ChEBIOver100( + chebi_version=chebi_version, + splits_file_path=os.path.join( + ChEBIOver100(chebi_version=chebi_version).processed_dir_main, + "splits.csv", + ), + ), + "test", + ) + ], + ) + + if __name__ == "__main__": if len(sys.argv) > 1: - run_semloss_eval(sys.argv[1]) + run_fuzzy_loss(sys.argv[1]) else: - run_semloss_eval() + run_fuzzy_loss() diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 31063747..6fad80ab 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -140,21 +140,24 @@ def evaluate_model( n_saved = 0 n_saved += 1 - if buffer_dir is None: - test_preds = torch.cat(preds_list) - if labels_list is not None: - test_labels = torch.cat(labels_list) + concat_preds = None + if preds_list is not None and len(preds_list) > 0: + concat_preds = torch.cat(preds_list) + concat_labels = None + if labels_list is not None and len(labels_list) > 0 and labels_list[0] is not None: + concat_labels = torch.cat(labels_list) - return test_preds, test_labels - return test_preds, None + if buffer_dir is None: + return concat_preds, concat_labels else: - torch.save( - torch.cat(preds_list), - os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), - ) - if labels_list[0] is not None: + if concat_preds is not None: + torch.save( + concat_preds, + os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), + ) + if concat_labels is not None: torch.save( - torch.cat(labels_list), + concat_labels, os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), ) From 4e4a8fb84eacf4f21c76775322e4da8120db321f Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 25 Sep 2024 16:51:15 +0200 Subject: [PATCH 09/47] add epoch-dependent weighting of semantic terms --- chebai/loss/semantic.py | 32 +++++++++++++++++++------------- chebai/models/base.py | 19 +++++++++++++++++-- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index e4dede08..67f8437d 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -41,6 +41,7 @@ def __init__( pos_epsilon: float = 0.01, multiply_by_softmax: bool = False, use_sigmoidal_implication: bool = False, + weight_epoch_dependent: bool = False, ): super().__init__() # automatically choose labeled subset for implication filter in case of mixed dataset @@ -68,6 +69,7 @@ def __init__( self.eps = pos_epsilon self.multiply_by_softmax = multiply_by_softmax self.use_sigmoidal_implication = use_sigmoidal_implication + self.weight_epoch_dependent = weight_epoch_dependent def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: """ @@ -95,12 +97,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: r = pred[:, self.implication_filter_r] # implication_loss = torch.sqrt(torch.mean(torch.sum(l*(1-r), dim=-1), dim=0)) implication_loss = self._calculate_implication_loss(l, r) - - return ( - base_loss + self.impl_weight * implication_loss, - base_loss, - implication_loss, - ) + loss_components = { + "base_loss": base_loss, + "unweighted_implication_loss": implication_loss, + } + if "current_epoch" in kwargs and self.weight_epoch_dependent: + # sigmoid function centered around epoch 50 + implication_loss *= torch.sigmoid((kwargs["current_epoch"] - 50) / 10) + implication_loss *= self.impl_weight + loss_components["weighted_implication_loss"] = implication_loss + return (base_loss + implication_loss, loss_components) def _calculate_implication_loss( self, l: torch.Tensor, r: torch.Tensor @@ -219,17 +225,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: Returns: tuple: Tuple containing total loss, base loss, implication loss, and disjointness loss. """ - loss, base_loss, impl_loss = super().forward(input, target, **kwargs) + loss, loss_components = super().forward(input, target, **kwargs) pred = torch.sigmoid(input) l = pred[:, self.disjoint_filter_l] r = pred[:, self.disjoint_filter_r] disjointness_loss = self._calculate_implication_loss(l, 1 - r) - return ( - loss + self.disjoint_weight * disjointness_loss, - base_loss, - impl_loss, - disjointness_loss, - ) + loss_components["unweighted_disjointness_loss"] = disjointness_loss + if "current_epoch" in kwargs and self.weight_epoch_dependent: + disjointness_loss *= torch.sigmoid((kwargs["current_epoch"] - 50) / 10) + disjointness_loss *= self.disjoint_weight + loss_components["weighted_disjointness_loss"] = disjointness_loss + return (loss + disjointness_loss, loss_components) def _load_label_names(path_to_label_names: str) -> List: diff --git a/chebai/models/base.py b/chebai/models/base.py index 362731df..34d32134 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -242,16 +242,31 @@ def _execute( loss_kwargs = dict() if self.pass_loss_kwargs: loss_kwargs = loss_kwargs_candidates + loss_kwargs["current_epoch"] = self.trainer.current_epoch loss = self.criterion(loss_data, loss_labels, **loss_kwargs) if isinstance(loss, tuple): - loss_additional = loss[1:] + unnamed_loss_index = 1 + if isinstance(loss[1], dict): + unnamed_loss_index = 2 + for key, value in loss[1].items(): + self.log( + key, + value if isinstance(value, int) else value.item(), + batch_size=len(batch), + on_step=True, + on_epoch=True, + prog_bar=False, + logger=True, + sync_dist=sync_dist, + ) + loss_additional = loss[unnamed_loss_index:] for i, loss_add in enumerate(loss_additional): self.log( f"{prefix}loss_{i}", loss_add if isinstance(loss_add, int) else loss_add.item(), batch_size=len(batch), on_step=True, - on_epoch=False, + on_epoch=True, prog_bar=False, logger=True, sync_dist=sync_dist, From 5fa3ac5a2a31594e40404850f36eafa0ac365b4c Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 25 Sep 2024 18:08:28 +0200 Subject: [PATCH 10/47] fix tensor handling --- chebai/loss/semantic.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 67f8437d..928bdec3 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -103,7 +103,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: } if "current_epoch" in kwargs and self.weight_epoch_dependent: # sigmoid function centered around epoch 50 - implication_loss *= torch.sigmoid((kwargs["current_epoch"] - 50) / 10) + implication_loss *= torch.sigmoid( + (torch.tensor([kwargs["current_epoch"]]) - 50) / 10 + ) implication_loss *= self.impl_weight loss_components["weighted_implication_loss"] = implication_loss return (base_loss + implication_loss, loss_components) @@ -235,7 +237,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: disjointness_loss *= torch.sigmoid((kwargs["current_epoch"] - 50) / 10) disjointness_loss *= self.disjoint_weight loss_components["weighted_disjointness_loss"] = disjointness_loss - return (loss + disjointness_loss, loss_components) + return loss + disjointness_loss, loss_components def _load_label_names(path_to_label_names: str) -> List: From 8acf6baf6f1cccebef881d1632d24b89c9b6d6be Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 25 Sep 2024 22:27:01 +0200 Subject: [PATCH 11/47] fix dynamic loss weights, log weighted and unweighted loss components --- chebai/loss/semantic.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 928bdec3..3b6c99fa 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -101,14 +101,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: "base_loss": base_loss, "unweighted_implication_loss": implication_loss, } + implication_loss_weighted = implication_loss if "current_epoch" in kwargs and self.weight_epoch_dependent: # sigmoid function centered around epoch 50 - implication_loss *= torch.sigmoid( - (torch.tensor([kwargs["current_epoch"]]) - 50) / 10 + implication_loss_weighted = implication_loss_weighted / ( + 1 + math.exp(-(kwargs["current_epoch"] - 50) / 10) ) - implication_loss *= self.impl_weight - loss_components["weighted_implication_loss"] = implication_loss - return (base_loss + implication_loss, loss_components) + implication_loss_weighted *= self.impl_weight + loss_components["weighted_implication_loss"] = implication_loss_weighted + return base_loss + implication_loss, loss_components def _calculate_implication_loss( self, l: torch.Tensor, r: torch.Tensor @@ -233,10 +234,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: r = pred[:, self.disjoint_filter_r] disjointness_loss = self._calculate_implication_loss(l, 1 - r) loss_components["unweighted_disjointness_loss"] = disjointness_loss + disjointness_loss_weighted = disjointness_loss if "current_epoch" in kwargs and self.weight_epoch_dependent: - disjointness_loss *= torch.sigmoid((kwargs["current_epoch"] - 50) / 10) - disjointness_loss *= self.disjoint_weight - loss_components["weighted_disjointness_loss"] = disjointness_loss + disjointness_loss_weighted = disjointness_loss_weighted / ( + 1 + math.exp(-(kwargs["current_epoch"] - 50) / 10) + ) + disjointness_loss_weighted *= self.disjoint_weight + loss_components["weighted_disjointness_loss"] = disjointness_loss_weighted return loss + disjointness_loss, loss_components From 72b001124bda236e4d1dcfea8dbeac1713d3635a Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 26 Sep 2024 10:47:04 +0200 Subject: [PATCH 12/47] no fuzzy loss for epoch<=10 --- chebai/loss/semantic.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 3b6c99fa..4a0e814d 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -102,7 +102,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: "unweighted_implication_loss": implication_loss, } implication_loss_weighted = implication_loss - if "current_epoch" in kwargs and self.weight_epoch_dependent: + if ( + "current_epoch" in kwargs + and self.weight_epoch_dependent + and kwargs["current_epoch"] > 10 + ): # sigmoid function centered around epoch 50 implication_loss_weighted = implication_loss_weighted / ( 1 + math.exp(-(kwargs["current_epoch"] - 50) / 10) @@ -235,7 +239,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: disjointness_loss = self._calculate_implication_loss(l, 1 - r) loss_components["unweighted_disjointness_loss"] = disjointness_loss disjointness_loss_weighted = disjointness_loss - if "current_epoch" in kwargs and self.weight_epoch_dependent: + if ( + "current_epoch" in kwargs + and self.weight_epoch_dependent + and kwargs["current_epoch"] > 10 + ): disjointness_loss_weighted = disjointness_loss_weighted / ( 1 + math.exp(-(kwargs["current_epoch"] - 50) / 10) ) From 5d80272b8380f082af5c7ce14d36a6738c2f3578 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 26 Sep 2024 12:59:49 +0200 Subject: [PATCH 13/47] actually set fuzzy loss to 0 for epoch<=10 --- chebai/loss/semantic.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 4a0e814d..b042c70d 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -92,6 +92,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: base_loss = self.base_loss(labeled_input, target.float()) else: base_loss = 0 + if "current_epoch" in kwargs and kwargs["current_epoch"] < 10: + return base_loss, {"base_loss": base_loss} pred = torch.sigmoid(input) l = pred[:, self.implication_filter_l] r = pred[:, self.implication_filter_r] @@ -102,11 +104,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: "unweighted_implication_loss": implication_loss, } implication_loss_weighted = implication_loss - if ( - "current_epoch" in kwargs - and self.weight_epoch_dependent - and kwargs["current_epoch"] > 10 - ): + if "current_epoch" in kwargs and self.weight_epoch_dependent: # sigmoid function centered around epoch 50 implication_loss_weighted = implication_loss_weighted / ( 1 + math.exp(-(kwargs["current_epoch"] - 50) / 10) @@ -233,17 +231,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: tuple: Tuple containing total loss, base loss, implication loss, and disjointness loss. """ loss, loss_components = super().forward(input, target, **kwargs) + if "current_epoch" in kwargs and kwargs["current_epoch"] < 10: + return loss, loss_components pred = torch.sigmoid(input) l = pred[:, self.disjoint_filter_l] r = pred[:, self.disjoint_filter_r] disjointness_loss = self._calculate_implication_loss(l, 1 - r) loss_components["unweighted_disjointness_loss"] = disjointness_loss disjointness_loss_weighted = disjointness_loss - if ( - "current_epoch" in kwargs - and self.weight_epoch_dependent - and kwargs["current_epoch"] > 10 - ): + if "current_epoch" in kwargs and self.weight_epoch_dependent: disjointness_loss_weighted = disjointness_loss_weighted / ( 1 + math.exp(-(kwargs["current_epoch"] - 50) / 10) ) From e57aa8deae74b7f6d753786deff2d54864d056a3 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 27 Sep 2024 11:36:53 +0200 Subject: [PATCH 14/47] fix bce loss --- chebai/loss/bce_weighted.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index b1b66995..6cde8bb6 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -83,7 +83,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None: [w / mean for w in weights], device=input.device ) - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward( + self, input: torch.Tensor, target: torch.Tensor, **kwargs + ) -> torch.Tensor: """ Forward pass for the loss calculation. From 94039c0d85b6603f0860a359e596c3963ec848e8 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 30 Sep 2024 10:06:37 +0200 Subject: [PATCH 15/47] remove skipping first 10 epochs --- chebai/loss/semantic.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index b042c70d..3b6c99fa 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -92,8 +92,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: base_loss = self.base_loss(labeled_input, target.float()) else: base_loss = 0 - if "current_epoch" in kwargs and kwargs["current_epoch"] < 10: - return base_loss, {"base_loss": base_loss} pred = torch.sigmoid(input) l = pred[:, self.implication_filter_l] r = pred[:, self.implication_filter_r] @@ -231,8 +229,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: tuple: Tuple containing total loss, base loss, implication loss, and disjointness loss. """ loss, loss_components = super().forward(input, target, **kwargs) - if "current_epoch" in kwargs and kwargs["current_epoch"] < 10: - return loss, loss_components pred = torch.sigmoid(input) l = pred[:, self.disjoint_filter_l] r = pred[:, self.disjoint_filter_r] From bdef653cbdd227729c346516a9270214a140281d Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 1 Oct 2024 10:34:14 +0200 Subject: [PATCH 16/47] fix fuzzy loss (now passing the weighted loss components) --- chebai/loss/semantic.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 3b6c99fa..3284c8f0 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -42,6 +42,7 @@ def __init__( multiply_by_softmax: bool = False, use_sigmoidal_implication: bool = False, weight_epoch_dependent: bool = False, + start_at_epoch: int = 0, ): super().__init__() # automatically choose labeled subset for implication filter in case of mixed dataset @@ -70,6 +71,7 @@ def __init__( self.multiply_by_softmax = multiply_by_softmax self.use_sigmoidal_implication = use_sigmoidal_implication self.weight_epoch_dependent = weight_epoch_dependent + self.start_at_epoch = start_at_epoch def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: """ @@ -92,10 +94,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: base_loss = self.base_loss(labeled_input, target.float()) else: base_loss = 0 + if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: + return base_loss, {"base_loss": base_loss} pred = torch.sigmoid(input) l = pred[:, self.implication_filter_l] r = pred[:, self.implication_filter_r] - # implication_loss = torch.sqrt(torch.mean(torch.sum(l*(1-r), dim=-1), dim=0)) implication_loss = self._calculate_implication_loss(l, r) loss_components = { "base_loss": base_loss, @@ -109,7 +112,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: ) implication_loss_weighted *= self.impl_weight loss_components["weighted_implication_loss"] = implication_loss_weighted - return base_loss + implication_loss, loss_components + return base_loss + implication_loss_weighted, loss_components def _calculate_implication_loss( self, l: torch.Tensor, r: torch.Tensor @@ -229,6 +232,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: tuple: Tuple containing total loss, base loss, implication loss, and disjointness loss. """ loss, loss_components = super().forward(input, target, **kwargs) + if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: + return loss, loss_components pred = torch.sigmoid(input) l = pred[:, self.disjoint_filter_l] r = pred[:, self.disjoint_filter_r] @@ -241,7 +246,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: ) disjointness_loss_weighted *= self.disjoint_weight loss_components["weighted_disjointness_loss"] = disjointness_loss_weighted - return loss + disjointness_loss, loss_components + return loss + disjointness_loss_weighted, loss_components def _load_label_names(path_to_label_names: str) -> List: From 7f1c4686acd447f93f2e7cf799ae5dcf6169081a Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 1 Oct 2024 13:13:12 +0200 Subject: [PATCH 17/47] add epoch to analyse_sem output --- chebai/result/analyse_sem.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index aab6324b..64d1b83f 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -461,7 +461,7 @@ def run_all( ) if not skip_analyse: print( - f"Calculating metrics for run {run.name} on {test_on.__class__.__name__} ({kind})" + f"Calculating metrics for run {run.name} on {test_on.__class__.__name__} (epoch {epoch}, {kind})" ) analyse_run( preds, @@ -576,6 +576,8 @@ def run_fuzzy_loss(tag="fuzzy_loss"): chebi_version = 231 run_all( ids, + # nonwandb_runs=[("chebi100_semrc_epoch-dependent100-100k_weighted_v231_pc_kmeans_240927-1219", 97), + # ("chebi100_semrc_epoch-dependent100-100k_weighted_v231_pc_kmeans_240927-1220", 99)], violation_metrics=[binary], chebi_version=chebi_version, datasets=[ From 0264c8539046352eea809dcd33b6c0fbe8c73778 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 1 Oct 2024 18:20:23 +0200 Subject: [PATCH 18/47] add right-aggregated macro-FNR --- chebai/result/analyse_sem.py | 37 +++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index 64d1b83f..24d0683e 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -309,26 +309,41 @@ def analyse_run( f", {label_names[dl_filter_r[j]]} -> {preds[k, dl_filter_r[j]]:.3f})" ) - m_cls = {} + m_l_agg = {} for key, value in m.items(): - m_cls[key] = _sort_results_by_label( + m_l_agg[key] = _sort_results_by_label( n_labels, value, - (dl_filter_l), + dl_filter_l, + ) + m_r_agg = {} + for key, value in m.items(): + m_r_agg[key] = _sort_results_by_label( + n_labels, + value, + dl_filter_r, ) - df_new[i][f"micro-sem-recall-{filter_type}"] = ( - torch.sum(m["tps"]) / (torch.sum(m[f"tps"]) + torch.sum(m[f"fns"])) - ).item() - macro_recall = m_cls[f"tps"] / (m_cls[f"tps"] + m_cls[f"fns"]) - df_new[i][f"macro-sem-recall-{filter_type}"] = torch.mean( - macro_recall[~macro_recall.isnan()] - ).item() + df_new[i][f"micro-fnr-{filter_type}"] = ( + 1 + - ( + torch.sum(m["tps"]) / (torch.sum(m[f"tps"]) + torch.sum(m[f"fns"])) + ).item() + ) + macro_recall_l = m_l_agg[f"tps"] / (m_l_agg[f"tps"] + m_l_agg[f"fns"]) + df_new[i][f"lmacro-fnr-{filter_type}"] = ( + 1 - torch.mean(macro_recall_l[~macro_recall_l.isnan()]).item() + ) + macro_recall_r = m_r_agg[f"tps"] / (m_r_agg[f"tps"] + m_r_agg[f"fns"]) + df_new[i][f"rmacro-fnr-{filter_type}"] = ( + 1 - torch.mean(macro_recall_r[~macro_recall_r.isnan()]).item() + ) df_new[i][f"fn-sum-{filter_type}"] = torch.sum(m["fns"]).item() df_new[i][f"tp-sum-{filter_type}"] = torch.sum(m["tps"]).item() del m - del m_cls + del m_l_agg + del m_r_agg gc.collect() df_new[i] = pd.DataFrame(df_new[i], index=[0]) From cfda8c68def270c2ec6a5e9a15a4d7ee14db3811 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 2 Oct 2024 16:36:13 +0200 Subject: [PATCH 19/47] download ckpt without returning model (return ckpt path instead), improve structure of analyse_sem --- chebai/result/analyse_sem.py | 639 +++++++++++++++-------------------- chebai/result/utils.py | 8 +- 2 files changed, 283 insertions(+), 364 deletions(-) diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index 24d0683e..6b81eaf3 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -1,19 +1,13 @@ import gc -import os import sys -import traceback from datetime import datetime -from typing import List, Union +from typing import List, LiteralString -import pandas as pd -import torch -import wandb from torchmetrics.functional.classification import multilabel_auroc, multilabel_f1_score from utils import * from chebai.loss.semantic import DisjointLoss from chebai.preprocessing.datasets.chebi import ChEBIOver100 -from chebai.preprocessing.datasets.pubchem import Hazardous DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -42,6 +36,9 @@ def apply_metric(metric, left, right): return torch.sum(metric(left, right), dim=0) +ALL_CONSISTENCY_METRICS = [product, lukasiewicz, weak, strict, binary] + + def _filter_to_one_hot(preds, idx_filter): """Takes list of indices (e.g. [1, 3, 0]) and returns a one-hot filter with these indices (e.g. [[0,1,0,0], [0,0,0,1], [1,0,0,0]])""" @@ -75,72 +72,46 @@ def get_best_epoch(run): return best_ep -def load_preds_labels_from_wandb( - run, - epoch, - data_module, # use data from this class - kind="test", # specify segment of test_on_data_cls +def download_model_from_wandb( + run_id, base_dir=os.path.join("logs", "downloaded_ckpts") ): - - buffer_dir = os.path.join( - "results_buffer", - f"{run.name}_ep{epoch}", - f"{data_module.__class__.__name__}_{kind}", - ) - - model = get_checkpoint_from_wandb(epoch, run, map_device_to="cuda:0") - print(f"Calculating predictions...") - evaluate_model( - model, - data_module, - buffer_dir=buffer_dir, - filename=f"{kind}.pt", - skip_existing_preds=True, - batch_size=1, + api = wandb.Api() + run = api.run(f"chebai/chebai/{run_id}") + epoch = get_best_epoch(run) + return ( + get_checkpoint_from_wandb(epoch, run, root=base_dir, map_device_to="cuda:0"), + epoch, ) - preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE) - del model - gc.collect() - - return preds, labels -def load_preds_labels_from_nonwandb(name, epoch, data_module, kind="test"): - - buffer_dir = os.path.join( - "results_buffer", - f"{name}_ep{epoch}", - f"{data_module.__class__.__name__}_{kind}", - ) - ckpt_path = None - for file in os.listdir(os.path.join("logs", "downloaded_ckpts", name)): - if file.startswith(f"best_epoch={epoch}"): - ckpt_path = os.path.join( - os.path.join("logs", "downloaded_ckpts", name, file) - ) - assert ( - ckpt_path is not None - ), f"Could not find ckpt for epoch {epoch} in directory {os.path.join('logs', 'downloaded_ckpts', name)}" +def load_preds_labels( + ckpt_path: LiteralString, data_module, data_subset_key="test", buffer_dir=None +): + if buffer_dir is None: + buffer_dir = os.path.join( + "results_buffer", + *ckpt_path.split(os.path.sep)[-2:], + f"{data_module.__class__.__name__}_{data_subset_key}", + ) model = Electra.load_from_checkpoint(ckpt_path, map_location="cuda:0", strict=False) print(f"Calculating predictions...") evaluate_model( model, data_module, buffer_dir=buffer_dir, - filename=f"{kind}.pt", + kind=data_subset_key, skip_existing_preds=True, ) - preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE) - del model - gc.collect() - - return preds, labels + return load_results_from_buffer(buffer_dir, device=DEVICE) def get_label_names(data_module): - if os.path.exists(os.path.join(data_module.raw_dir, "classes.txt")): - with open(os.path.join(data_module.raw_dir, "classes.txt")) as fin: + if os.path.exists(os.path.join(data_module.processed_dir_main, "classes.txt")): + with open(os.path.join(data_module.processed_dir_main, "classes.txt")) as fin: return [int(line.strip()) for line in fin] + print( + f"Failed to retrieve label names, {os.path.join(data_module.processed_dir_main, 'classes.txt')} not found" + ) return None @@ -150,6 +121,9 @@ def get_chebi_graph(data_module, label_names): os.path.join(data_module.raw_dir, "chebi.obo") ) return chebi_graph.subgraph(label_names) + print( + f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found" + ) return None @@ -181,79 +155,89 @@ def get_disjoint_groups(): return disjoint_all -def smooth_preds(preds, label_names, chebi_graph, disjoint_groups): - preds_sum_orig = torch.sum(preds) - print(f"Preds sum: {preds_sum_orig}") - # eliminate implication violations by setting each prediction to maximum of its successors - for i, label in enumerate(label_names): - succs = [label_names.index(p) for p in chebi_graph.successors(label)] + [i] - if len(succs) > 0: - preds[:, i] = torch.max(preds[:, succs], dim=1).values - print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") - preds_sum_orig = torch.sum(preds) - # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower) - preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49) - for disj_group in disjoint_groups: - disj_group = [label_names.index(g) for g in disj_group if g in label_names] - if len(disj_group) > 1: - old_preds = preds[:, disj_group] - disj_max = torch.max(preds[:, disj_group], dim=1) - for i, row in enumerate(preds): - for l in range(len(preds[i])): - if l in disj_group and l != disj_group[disj_max.indices[i]]: - preds[i, l] = preds_bounded[i, l] - samples_changed = 0 - for i, row in enumerate(preds[:, disj_group]): - if any(r != o for r, o in zip(row, old_preds[i])): - samples_changed += 1 - if samples_changed != 0: - print( - f"disjointness group {[label_names[d] for d in disj_group]} changed {samples_changed} samples" - ) - print( - f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}" - ) - preds_sum_orig = torch.sum(preds) - # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors - for i, label in enumerate(label_names): - predecessors = [i] + [ - label_names.index(p) for p in chebi_graph.predecessors(label) - ] - lowest_predecessors = torch.min(preds[:, predecessors], dim=1) - preds[:, i] = lowest_predecessors.values - for idx_idx, idx in enumerate(lowest_predecessors.indices): - if idx > 0: +class PredictionSmoother: + """Removes implication and disjointness violations from predictions""" + + def __init__(self, dataset): + self.label_names = get_label_names(dataset) + self.chebi_graph = get_chebi_graph(dataset, self.label_names) + self.disjoint_groups = get_disjoint_groups() + + def __call__(self, preds): + + preds_sum_orig = torch.sum(preds) + print(f"Preds sum: {preds_sum_orig}") + # eliminate implication violations by setting each prediction to maximum of its successors + for i, label in enumerate(self.label_names): + succs = [ + self.label_names.index(p) for p in self.chebi_graph.successors(label) + ] + [i] + if len(succs) > 0: + preds[:, i] = torch.max(preds[:, succs], dim=1).values + print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") + preds_sum_orig = torch.sum(preds) + # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower) + preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49) + for disj_group in self.disjoint_groups: + disj_group = [ + self.label_names.index(g) for g in disj_group if g in self.label_names + ] + if len(disj_group) > 1: + old_preds = preds[:, disj_group] + disj_max = torch.max(preds[:, disj_group], dim=1) + for i, row in enumerate(preds): + for l in range(len(preds[i])): + if l in disj_group and l != disj_group[disj_max.indices[i]]: + preds[i, l] = preds_bounded[i, l] + samples_changed = 0 + for i, row in enumerate(preds[:, disj_group]): + if any(r != o for r, o in zip(row, old_preds[i])): + samples_changed += 1 + if samples_changed != 0: + print( + f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples" + ) + print( + f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}" + ) + preds_sum_orig = torch.sum(preds) + # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors + for i, label in enumerate(self.label_names): + predecessors = [i] + [ + self.label_names.index(p) for p in self.chebi_graph.predecessors(label) + ] + lowest_predecessors = torch.min(preds[:, predecessors], dim=1) + preds[:, i] = lowest_predecessors.values + for idx_idx, idx in enumerate(lowest_predecessors.indices): + if idx > 0: + print( + f"class {label}: changed prediction of sample {idx_idx} to value of class " + f"{self.label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})" + ) + if torch.sum(preds) != preds_sum_orig: print( - f"class {label}: changed prediction of sample {idx_idx} to value of class " - f"{label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})" + f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}" ) - if torch.sum(preds) != preds_sum_orig: - print( - f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}" - ) - preds_sum_orig = torch.sum(preds) - return preds + preds_sum_orig = torch.sum(preds) + return preds -def analyse_run( +def run_consistency_metrics( preds, - labels, - df_hyperparams, # parameters that are the independent of the semantic loss function used - labeled_data_cls=ChEBIOver100, # use labels from this dataset for violations - chebi_version=231, - results_path=os.path.join("_semantic", "eval_results.csv"), - violation_metrics: Union[str, List[callable]] = "all", + data_module_labeled=None, # use labels from this dataset for violations + violation_metrics=None, verbose_violation_output=False, ): - """Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided), - saves results to csv""" - if violation_metrics == "all": - violation_metrics = [product, lukasiewicz, weak, strict, binary] - data_module_labeled = labeled_data_cls(chebi_version=chebi_version) + """Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided)""" + if violation_metrics is None: + violation_metrics = ALL_CONSISTENCY_METRICS + if data_module_labeled is None: + data_module_labeled = ChEBIOver100(chebi_version=231) + n_labels = preds.size(1) print(f"Found {preds.shape[0]} predictions ({n_labels} classes)") - df_new = [] + results = {} # prepare filters print(f"Loading & rescaling implication / disjointness filters...") @@ -274,27 +258,26 @@ def analyse_run( gc.collect() for i, metric in enumerate(violation_metrics): - if filter_type == "impl": - df_new.append(df_hyperparams.copy()) - df_new[-1]["metric"] = metric.__name__ - print( - f"Calculating metric {metric.__name__ if metric is not None else 'supervised'} on {filter_type}" - ) + if metric.__name__ not in results: + results[metric.__name__] = {} + print(f"Calculating metrics {metric.__name__} on {filter_type}") - m = {} - m["tps"] = apply_metric( + metric_results = {} + metric_results["tps"] = apply_metric( metric, l_preds, r_preds if filter_type == "impl" else 1 - r_preds ) - m["fns"] = apply_metric( + metric_results["fns"] = apply_metric( metric, l_preds, 1 - r_preds if filter_type == "impl" else r_preds ) if verbose_violation_output: label_names = get_label_names(data_module_labeled) - print(f"Found {torch.sum(m['fns'])} {filter_type}-violations") - # for k, fn_cls in enumerate(m['fns']): + print( + f"Found {torch.sum(metric_results['fns'])} {filter_type}-violations" + ) + # for k, fn_cls in enumerate(metric_results['fns']): # if fn_cls > 0: # print(f"\tThereof, {fn_cls.item()} belong to class {label_names[k]}") - if torch.sum(m["fns"]) != 0: + if torch.sum(metric_results["fns"]) != 0: fns = metric( l_preds, 1 - r_preds if filter_type == "impl" else r_preds ) @@ -310,276 +293,209 @@ def analyse_run( ) m_l_agg = {} - for key, value in m.items(): + for key, value in metric_results.items(): m_l_agg[key] = _sort_results_by_label( n_labels, value, dl_filter_l, ) m_r_agg = {} - for key, value in m.items(): + for key, value in metric_results.items(): m_r_agg[key] = _sort_results_by_label( n_labels, value, dl_filter_r, ) - df_new[i][f"micro-fnr-{filter_type}"] = ( + results[metric.__name__][f"micro-fnr-{filter_type}"] = ( 1 - ( - torch.sum(m["tps"]) / (torch.sum(m[f"tps"]) + torch.sum(m[f"fns"])) + torch.sum(metric_results["tps"]) + / ( + torch.sum(metric_results[f"tps"]) + + torch.sum(metric_results[f"fns"]) + ) ).item() ) macro_recall_l = m_l_agg[f"tps"] / (m_l_agg[f"tps"] + m_l_agg[f"fns"]) - df_new[i][f"lmacro-fnr-{filter_type}"] = ( + results[metric.__name__][f"lmacro-fnr-{filter_type}"] = ( 1 - torch.mean(macro_recall_l[~macro_recall_l.isnan()]).item() ) macro_recall_r = m_r_agg[f"tps"] / (m_r_agg[f"tps"] + m_r_agg[f"fns"]) - df_new[i][f"rmacro-fnr-{filter_type}"] = ( + results[metric.__name__][f"rmacro-fnr-{filter_type}"] = ( 1 - torch.mean(macro_recall_r[~macro_recall_r.isnan()]).item() ) - df_new[i][f"fn-sum-{filter_type}"] = torch.sum(m["fns"]).item() - df_new[i][f"tp-sum-{filter_type}"] = torch.sum(m["tps"]).item() - - del m + results[metric.__name__][f"fn-sum-{filter_type}"] = torch.sum( + metric_results["fns"] + ).item() + results[metric.__name__][f"tp-sum-{filter_type}"] = torch.sum( + metric_results["tps"] + ).item() + + del metric_results del m_l_agg del m_r_agg gc.collect() - df_new[i] = pd.DataFrame(df_new[i], index=[0]) del l_preds del r_preds gc.collect() + return results + + +def run_supervised_metrics(preds, labels): # calculate supervised metrics + results = {} if labels is not None: - df_supervised = df_hyperparams.copy() - df_supervised["micro-f1"] = multilabel_f1_score( + results["micro-f1"] = multilabel_f1_score( preds, labels, num_labels=preds.size(1), average="micro" ).item() - df_supervised["macro-f1"] = multilabel_f1_score( + results["macro-f1"] = multilabel_f1_score( preds, labels, num_labels=preds.size(1), average="macro" ).item() - df_supervised["micro-roc-auc"] = multilabel_auroc( + results["micro-roc-auc"] = multilabel_auroc( preds, labels, num_labels=preds.size(1), average="micro" ).item() - df_supervised["macro-roc-auc"] = multilabel_auroc( + results["macro-roc-auc"] = multilabel_auroc( preds, labels, num_labels=preds.size(1), average="macro" ).item() - df_new.append(pd.DataFrame(df_supervised, index=[0])) - - if os.path.exists(results_path): - df_previous = pd.read_csv(results_path) - else: - df_previous = None - if df_previous is not None: - df_new = [df_previous] + df_new - del df_previous - - df_new = pd.concat(df_new, ignore_index=True) - print(f"Saving results to {results_path}") - df_new.to_csv(results_path, index=False) - - del df_new del preds del labels - del dl gc.collect() + return results + + +# run predictions / metrics calculations for semantic loss paper runs (NeSy 2024 submission) +def run_semloss_eval(): + # runs from wandb + non_wandb_runs = [] + api = wandb.Api() + runs = api.runs("chebai/chebai", filters={"tags": "eval_semloss_paper"}) + print(f"Found {len(runs)} tagged wandb runs") + ids_wandb = [run.id for run in runs] + + # ids used in the NeSy submission + prod = ["tk15yznc", "uke62a8m", "w0h3zr5s"] + xu19 = ["5ko8knb4", "061fd85t", "r50ioujs"] + prod_mixed = ["hk8555ff", "e0lxw8py", "lig23cmg"] + luka = ["0c0s48nh", "lfg384bp", "qeghvubh"] + baseline = ["i4wtz1k4", "zd020wkv", "rc1q3t49"] + prodk2 = ["ng3usn0p", "rp0wwzjv", "8fma1q7r"] + ids = baseline + prod + prodk2 + xu19 + luka + prod_mixed + # ids = ids_wandb + run_all( + ids, + non_wandb_runs, + prediction_datasets=[(ChEBIOver100(chebi_version=231), "test")], + consistency_metrics=[binary], + ) def run_all( - run_ids, - datasets=None, - chebi_version=231, - skip_analyse=False, - skip_preds=False, - nonwandb_runs=None, - violation_metrics="all", - remove_violations=False, + wandb_ids=None, + local_ckpts: List[Tuple] = None, + consistency_metrics: Optional[List[callable]] = None, + prediction_datasets: List[Tuple] = None, + remove_violations: bool = False, + results_dir="_fuzzy_loss_eval", + check_consistency_on=None, + verbose_violation_output=False, ): - # evaluate a list of runs on Hazardous and ChEBIOver100 datasets - if datasets is None: - datasets = [ - (Hazardous(), "all"), - (ChEBIOver100(chebi_version=chebi_version), "test"), + if wandb_ids is None: + wandb_ids = [] + if local_ckpts is None: + local_ckpts = [] + if consistency_metrics is None: + consistency_metrics = ALL_CONSISTENCY_METRICS + if prediction_datasets is None: + prediction_datasets = [ + (ChEBIOver100(chebi_version=231), "test"), ] - timestamp = datetime.now().strftime("%y%m%d-%H%M") - results_path = os.path.join( - "_fuzzy_loss_eval", - f"results_pc-kmeans_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", - ) + if check_consistency_on is None: + check_consistency_on = ChEBIOver100(chebi_version=231) if remove_violations: - label_names = get_label_names(ChEBIOver100(chebi_version=chebi_version)) - chebi_graph = get_chebi_graph( - ChEBIOver100(chebi_version=chebi_version), label_names - ) - disjoint_groups = get_disjoint_groups() + smooth_preds = PredictionSmoother(check_consistency_on) + else: + smooth_preds = lambda x: x - api = wandb.Api() - for run_id in run_ids: - try: - run = api.run(f"chebai/chebai/{run_id}") - epoch = get_best_epoch(run) - for test_on, kind in datasets: - df = { - "run-id": run_id, - "epoch": int(epoch), - "kind": kind, - "data_module": test_on.__class__.__name__, - "chebi_version": chebi_version, - } - buffer_dir_smoothed = os.path.join( - "results_buffer", - "smoothed3step", - f"{run.name}_ep{epoch}", - f"{test_on.__class__.__name__}_{kind}", - ) - if remove_violations and os.path.exists( - os.path.join(buffer_dir_smoothed, "preds000.pt") + timestamp = datetime.now().strftime("%y%m%d-%H%M%S") + results_path_consistency = os.path.join( + results_dir, + f"consistency_metrics_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", + ) + consistency_keys = [ + "micro-fnr-impl", + "lmacro-fnr-impl", + "rmacro-fnr-impl", + "fn-sum-impl", + "tp-sum-impl", + "micro-fnr-disj", + "lmacro-fnr-disj", + "rmacro-fnr-disj", + "fn-sum-disj", + "tp-sum-disj", + ] + with open(results_path_consistency, "x") as f: + f.write("run-id,epoch,datamodule,metric," + ",".join(consistency_keys) + "\n") + results_path_supervised = os.path.join( + results_dir, + f"supervised_metrics_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", + ) + supervised_keys = ["micro-f1", "macro-f1", "micro-roc-auc", "macro-roc-auc"] + with open(results_path_supervised, "x") as f: + f.write("run-id,epoch,datamodule," + ",".join(supervised_keys) + "\n") + + ckpts = [(run_name, ep, None) for run_name, ep in local_ckpts] + [ + (None, None, wandb_id) for wandb_id in wandb_ids + ] + + for run_name, epoch, wandb_id in ckpts: + ckpt_dir = os.path.join("logs", "downloaded_ckpts") + if wandb_id is not None: + ckpt_path, epoch = download_model_from_wandb(wandb_id, ckpt_dir) + else: + ckpt_path = None + for file in os.listdir(os.path.join(ckpt_dir, run_name)): + if file.startswith(f"best_epoch={epoch}_") or file.startswith( + f"per_epoch={epoch}_" ): - preds = torch.load( - os.path.join(buffer_dir_smoothed, "preds000.pt"), DEVICE - ) - labels = None - else: - if not skip_preds: - preds, labels = load_preds_labels_from_wandb( - run, epoch, test_on, kind - ) - else: - buffer_dir = os.path.join( - "results_buffer", - f"{run.name}_ep{epoch}", - f"{test_on.__class__.__name__}_{kind}", - ) - preds, labels = load_results_from_buffer( - buffer_dir, device=DEVICE - ) - assert ( - preds is not None - ), f"Did not find predictions in dir {buffer_dir}" - if remove_violations: - preds = smooth_preds( - preds, label_names, chebi_graph, disjoint_groups - ) - buffer_dir_smoothed = os.path.join( - "results_buffer", - "smoothed3step", - f"{run.name}_ep{epoch}", - f"{test_on.__class__.__name__}_{kind}", - ) - os.makedirs(buffer_dir_smoothed, exist_ok=True) - torch.save( - preds, os.path.join(buffer_dir_smoothed, "preds000.pt") - ) - if not skip_analyse: - print( - f"Calculating metrics for run {run.name} on {test_on.__class__.__name__} (epoch {epoch}, {kind})" - ) - analyse_run( - preds, - labels, - df_hyperparams=df, - chebi_version=chebi_version, - results_path=results_path, - violation_metrics=violation_metrics, + ckpt_path = os.path.join(os.path.join(ckpt_dir, run_name, file)) + assert ( + ckpt_path is not None + ), f"Failed to find checkpoint for epoch {epoch} in {os.path.join(ckpt_dir, run_name)}" + + for dataset, dataset_key in prediction_datasets: + preds, labels = load_preds_labels(ckpt_path, dataset, dataset_key) + + # identity function if remove_violations is False + smooth_preds(preds) + + # for wandb runs, use short id as name, otherwise use ckpt dir name + if wandb_id is not None: + run_name = wandb_id + metrics_dict = run_consistency_metrics( + preds, + check_consistency_on, + consistency_metrics, + verbose_violation_output, + ) + with open(results_path_consistency, "a") as f: + for metric in metrics_dict: + values = metrics_dict[metric] + f.write( + f"{run_name},{epoch},{dataset.__class__.__name__},{metric}," + f"{','.join([str(values[k]) for k in consistency_keys])}\n" ) - except Exception as e: - print(f"Failed for run {run_id}: {e}") - print(traceback.format_exc()) - - if nonwandb_runs: - for run_name, epoch in nonwandb_runs: - try: - for test_on, kind in datasets: - df = { - "run-id": run_name, - "epoch": int(epoch), - "kind": kind, - "data_module": test_on.__class__.__name__, - "chebi_version": chebi_version, - } - if not skip_preds: - preds, labels = load_preds_labels_from_nonwandb( - run_name, epoch, test_on, kind - ) - else: - buffer_dir = os.path.join( - "results_buffer", - f"{run_name}_ep{epoch}", - f"{test_on.__class__.__name__}_{kind}", - ) - preds, labels = load_results_from_buffer( - buffer_dir, device=DEVICE - ) - assert ( - preds is not None - ), f"Did not find predictions in dir {buffer_dir}" - if remove_violations: - preds = smooth_preds( - preds, label_names, chebi_graph, disjoint_groups - ) - if not skip_analyse: - print( - f"Calculating metrics for run {run_name} on {test_on.__class__.__name__} ({kind})" - ) - analyse_run( - preds, - labels, - df_hyperparams=df, - chebi_version=chebi_version, - results_path=results_path, - violation_metrics=violation_metrics, - ) - except Exception as e: - print(f"Failed for run {run_name}: {e}") - print(traceback.format_exc()) - -# run predictions / metrics calculations for semantic loss paper runs (NeSy 2024 submission) -def run_semloss_eval(mode="eval"): - non_wandb_runs = [] - if mode == "preds": - api = wandb.Api() - runs = api.runs("chebai/chebai", filters={"tags": "eval_semloss_paper"}) - print(f"Found {len(runs)} tagged wandb runs") - ids = [run.id for run in runs] - run_all(ids, skip_analyse=True, nonwandb_runs=non_wandb_runs) - - if mode == "eval": - prod = [ - "tk15yznc", - "uke62a8m", - "w0h3zr5s", - ] - xu19 = [ - "5ko8knb4", - "061fd85t", - "r50ioujs", - ] - prod_mixed = [ - "hk8555ff", - "e0lxw8py", - "lig23cmg", - ] - luka = [ - "0c0s48nh", - "lfg384bp", - "qeghvubh", - ] - baseline = ["i4wtz1k4", "zd020wkv", "rc1q3t49"] - prodk2 = ["ng3usn0p", "rp0wwzjv", "8fma1q7r"] - ids = baseline + prod + prodk2 + xu19 + luka + prod_mixed - run_all( - ids, - skip_preds=True, - nonwandb_runs=non_wandb_runs, - datasets=[(ChEBIOver100, "test")], - violation_metrics=[binary], - remove_violations=True, - ) + metrics_dict = run_supervised_metrics(preds, labels) + with open(results_path_supervised, "a") as f: + f.write( + f"{run_name},{epoch},{dataset.__class__.__name__},{metric}," + f"{','.join([str(metrics_dict[k]) for k in supervised_keys])}\n" + ) # follow-up to NeSy submission @@ -588,22 +504,27 @@ def run_fuzzy_loss(tag="fuzzy_loss"): runs = api.runs("chebai/chebai", filters={"tags": tag}) print(f"Found {len(runs)} wandb runs tagged with '{tag}'") ids = [run.id for run in runs] - chebi_version = 231 + chebi100 = ChEBIOver100( + chebi_version=231, + splits_file_path=os.path.join( + "data", "chebi_v231", "ChEBI100", "fuzzy_loss_splits.csv" + ), + ) run_all( ids, - # nonwandb_runs=[("chebi100_semrc_epoch-dependent100-100k_weighted_v231_pc_kmeans_240927-1219", 97), + [ + ( + "chebi100_semrc_epoch-dependent100-1m_start-at10_weighted_v231_pc_kmeans_241001-0836", + 196, + ) + ], + # [("chebi100_semrc_epoch-dependent100-100k_weighted_v231_pc_kmeans_240927-1219", 97), # ("chebi100_semrc_epoch-dependent100-100k_weighted_v231_pc_kmeans_240927-1220", 99)], - violation_metrics=[binary], - chebi_version=chebi_version, - datasets=[ + consistency_metrics=[binary], + check_consistency_on=chebi100, + prediction_datasets=[ ( - ChEBIOver100( - chebi_version=chebi_version, - splits_file_path=os.path.join( - ChEBIOver100(chebi_version=chebi_version).processed_dir_main, - "splits.csv", - ), - ), + chebi100, "test", ) ], diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 6fad80ab..e640bc11 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -18,7 +18,7 @@ def get_checkpoint_from_wandb( root: str = os.path.join("logs", "downloaded_ckpts"), model_class: Optional[Union[Electra, ChebaiBaseNet]] = None, map_device_to: Optional[torch.device] = None, -) -> Optional[ChebaiBaseNet]: +): """ Gets a wandb checkpoint based on run and epoch, downloads it if necessary. @@ -30,7 +30,7 @@ def get_checkpoint_from_wandb( map_device_to: The device to map the model to. Returns: - The loaded model or None if no checkpoint is found. + The location of the downloaded checkpoint. """ api = wandb.Api() if model_class is None: @@ -45,9 +45,7 @@ def get_checkpoint_from_wandb( if not os.path.isfile(dest_path): print(f"Downloading checkpoint to {dest_path}") wandb_util.download_file_from_url(dest_path, file.url, api.api_key) - return model_class.load_from_checkpoint( - dest_path, strict=False, map_location=map_device_to - ) + return dest_path print(f"No model found for epoch {epoch}") return None From 2a7a10fb5de0dc6d5fe55ece4734907cc76f0b36 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 4 Oct 2024 10:48:17 +0200 Subject: [PATCH 20/47] fix sigmoidal implication --- chebai/loss/semantic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 3284c8f0..9aac98f6 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -161,8 +161,8 @@ def _calculate_implication_loss( # formula by van Krieken, 2022, applied to fuzzy implication with default parameters: b_0 = 0.5, s = 9 # parts that only depend on b_0 and s are pre-calculated implication = 1 - individual_loss - sigmoidal_implication = 90.028 * ( - 1.011 * torch.sigmoid(9 * (implication + 0.5)) - 1 + sigmoidal_implication = 0.0112338 * ( + 91.0171 * torch.sigmoid(9 * (implication - 0.5)) - 1 ) individual_loss = 1 - sigmoidal_implication From a85498c81ed665d1a4945bd1c51cec6d05cea4b3 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 4 Oct 2024 11:16:30 +0200 Subject: [PATCH 21/47] add goedel implication --- chebai/loss/semantic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 9aac98f6..2bb8e186 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -152,6 +152,8 @@ def _calculate_implication_loss( individual_loss = torch.relu(l + one_min_r - 1) elif self.fuzzy_implication in ["kleene_dienes", "kd"]: individual_loss = torch.min(l, 1 - r) + elif self.fuzzy_implication in ["goedel", "g"]: + individual_loss = 0 if l <= r else one_min_r else: raise NotImplementedError( f"Unknown fuzzy implication {self.fuzzy_implication}" From eef32d2480a81f858d83e1379ee95fed53ba6ae6 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 7 Oct 2024 11:29:40 +0200 Subject: [PATCH 22/47] fix implication loss signature --- chebai/loss/semantic.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 2bb8e186..98b4c19d 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -34,7 +34,15 @@ def __init__( data_extractor: XYBaseDataModule, base_loss: torch.nn.Module = None, fuzzy_implication: Literal[ - "reichenbach", "rc", "lukasiewicz", "lk", "xu19", "kleene_dienes", "kd" + "reichenbach", + "rc", + "lukasiewicz", + "lk", + "xu19", + "kleene_dienes", + "kd", + "goedel", + "g", ] = "reichenbach", impl_loss_weight: float = 0.1, pos_scalar: int = 1, From 6f6c6a0c37f56b9c75d912cb2df9002ff7b87af9 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 7 Oct 2024 13:25:03 +0200 Subject: [PATCH 23/47] fix goedel loss --- chebai/loss/semantic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 98b4c19d..bdb4ab4c 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -161,7 +161,9 @@ def _calculate_implication_loss( elif self.fuzzy_implication in ["kleene_dienes", "kd"]: individual_loss = torch.min(l, 1 - r) elif self.fuzzy_implication in ["goedel", "g"]: - individual_loss = 0 if l <= r else one_min_r + individual_loss = ( + torch.relu(l - r) / (l - r) * one_min_r + ) # 0 if l <= r else one_min_r else: raise NotImplementedError( f"Unknown fuzzy implication {self.fuzzy_implication}" From d83957cd8e24bd6035beffb654c3fd551aedada1 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 7 Oct 2024 17:20:25 +0200 Subject: [PATCH 24/47] add error messages --- chebai/loss/semantic.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index bdb4ab4c..8550fe80 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -135,8 +135,14 @@ def _calculate_implication_loss( Returns: torch.Tensor: Calculated implication loss. """ - assert not l.isnan().any() - assert not r.isnan().any() + assert not l.isnan().any(), ( + f"l contains NaN values - l.shape: {l.shape}, l.isnan().sum(): {l.isnan().sum()}, " + f"l: {l}" + ) + assert not r.isnan().any(), ( + f"r contains NaN values - r.shape: {r.shape}, r.isnan().sum(): {r.isnan().sum()}, " + f"r: {r}" + ) if self.pos_scalar != 1: l = ( torch.pow(l + self.eps, 1 / self.pos_scalar) From e4b707648a569d16607617320764310b7674097e Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 8 Oct 2024 11:12:42 +0200 Subject: [PATCH 25/47] fix sigmoidal implication --- chebai/loss/semantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 8550fe80..bc6df01c 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -179,7 +179,7 @@ def _calculate_implication_loss( # formula by van Krieken, 2022, applied to fuzzy implication with default parameters: b_0 = 0.5, s = 9 # parts that only depend on b_0 and s are pre-calculated implication = 1 - individual_loss - sigmoidal_implication = 0.0112338 * ( + sigmoidal_implication = 0.01123379 * ( 91.0171 * torch.sigmoid(9 * (implication - 0.5)) - 1 ) individual_loss = 1 - sigmoidal_implication From 9d619a71ea191cf3ff5d06f78e7539096a9bc29c Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 9 Oct 2024 14:11:18 +0200 Subject: [PATCH 26/47] add ap metric, results by class to analyse_sem --- chebai/result/analyse_sem.py | 113 ++++++++++++++++++++++++++++++----- chebai/result/utils.py | 3 - 2 files changed, 99 insertions(+), 17 deletions(-) diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index 6b81eaf3..d2ac8ad4 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -3,7 +3,11 @@ from datetime import datetime from typing import List, LiteralString -from torchmetrics.functional.classification import multilabel_auroc, multilabel_f1_score +from torchmetrics.functional.classification import ( + multilabel_auroc, + multilabel_average_precision, + multilabel_f1_score, +) from utils import * from chebai.loss.semantic import DisjointLoss @@ -49,7 +53,7 @@ def _filter_to_one_hot(preds, idx_filter): def _sort_results_by_label(n_labels, results, filter): - by_label = torch.zeros(n_labels, device=DEVICE) + by_label = torch.zeros(n_labels, device=DEVICE, dtype=torch.int) for r, filter_l in zip(results, filter): by_label[filter_l] += r return by_label @@ -66,9 +70,9 @@ def get_best_epoch(run): best_ep = int(file.name.split("=")[1].split("_")[0]) best_micro_f1 = micro_f1 if best_ep is None: - raise Exception(f"Could not find any 'best' checkpoint for run {run.name}") + raise Exception(f"Could not find any 'best' checkpoint for run {run.id}") else: - print(f"Best epoch for run {run.name}: {best_ep}") + print(f"Best epoch for run {run.id}: {best_ep}") return best_ep @@ -227,12 +231,15 @@ def run_consistency_metrics( data_module_labeled=None, # use labels from this dataset for violations violation_metrics=None, verbose_violation_output=False, + save_details_to=None, ): """Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided)""" if violation_metrics is None: violation_metrics = ALL_CONSISTENCY_METRICS if data_module_labeled is None: data_module_labeled = ChEBIOver100(chebi_version=231) + if save_details_to is not None: + os.makedirs(save_details_to, exist_ok=True) n_labels = preds.size(1) print(f"Found {preds.shape[0]} predictions ({n_labels} classes)") @@ -307,6 +314,47 @@ def run_consistency_metrics( dl_filter_r, ) + if save_details_to is not None: + with open( + os.path.join( + save_details_to, f"{metric.__name__}_{filter_type}_all.csv" + ), + "w+", + ) as f: + f.write("left,right,tps,fns\n") + for left, right, tps, fns in zip( + dl_filter_l, + dl_filter_r, + metric_results["tps"], + metric_results["fns"], + ): + f.write(f"{left},{right},{tps},{fns}\n") + with open( + os.path.join( + save_details_to, f"{metric.__name__}_{filter_type}_l.csv" + ), + "w+", + ) as f: + f.write("left,tps,fns\n") + for left in range(n_labels): + f.write( + f"{left},{m_l_agg['tps'][left].item()},{m_l_agg['fns'][left].item()}\n" + ) + with open( + os.path.join( + save_details_to, f"{metric.__name__}_{filter_type}_r.csv" + ), + "w+", + ) as f: + f.write("right,tps,fns\n") + for right in range(n_labels): + f.write( + f"{right},{m_r_agg['tps'][right].item()},{m_r_agg['fns'][right].item()}\n" + ) + print( + f"Saved unaggregated consistency metrics ({metric.__name__}, {filter_type}) to {save_details_to}" + ) + results[metric.__name__][f"micro-fnr-{filter_type}"] = ( 1 - ( @@ -344,7 +392,7 @@ def run_consistency_metrics( return results -def run_supervised_metrics(preds, labels): +def run_supervised_metrics(preds, labels, save_details_to=None): # calculate supervised metrics results = {} if labels is not None: @@ -361,6 +409,31 @@ def run_supervised_metrics(preds, labels): preds, labels, num_labels=preds.size(1), average="macro" ).item() + results["micro-ap"] = multilabel_average_precision( + preds, labels, num_labels=preds.size(1), average="micro" + ).item() + results["macro-ap"] = multilabel_average_precision( + preds, labels, num_labels=preds.size(1), average="macro" + ).item() + + if save_details_to is not None: + f1_by_label = multilabel_f1_score( + preds, labels, num_labels=preds.size(1), average=None + ) + roc_by_label = multilabel_auroc( + preds, labels, num_labels=preds.size(1), average=None + ) + ap_by_label = multilabel_average_precision( + preds, labels, num_labels=preds.size(1), average=None + ) + with open(os.path.join(save_details_to, f"supervised.csv"), "w+") as f: + f.write("label,f1,roc-auc,ap\n") + for right in range(preds.size(1)): + f.write( + f"{right},{f1_by_label[right].item()},{roc_by_label[right].item()},{ap_by_label[right].item()}\n" + ) + print(f"Saved class-wise supervised metrics to {save_details_to}") + del preds del labels gc.collect() @@ -444,7 +517,14 @@ def run_all( results_dir, f"supervised_metrics_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", ) - supervised_keys = ["micro-f1", "macro-f1", "micro-roc-auc", "macro-roc-auc"] + supervised_keys = [ + "micro-f1", + "macro-f1", + "micro-roc-auc", + "macro-roc-auc", + "micro-ap", + "macro-ap", + ] with open(results_path_supervised, "x") as f: f.write("run-id,epoch,datamodule," + ",".join(supervised_keys) + "\n") @@ -476,11 +556,16 @@ def run_all( # for wandb runs, use short id as name, otherwise use ckpt dir name if wandb_id is not None: run_name = wandb_id + details_path = os.path.join( + results_dir, + f"{run_name}_ep{epoch}_{dataset.__class__.__name__}_{dataset_key}", + ) metrics_dict = run_consistency_metrics( preds, check_consistency_on, consistency_metrics, verbose_violation_output, + save_details_to=details_path, ) with open(results_path_consistency, "a") as f: for metric in metrics_dict: @@ -489,13 +574,19 @@ def run_all( f"{run_name},{epoch},{dataset.__class__.__name__},{metric}," f"{','.join([str(values[k]) for k in consistency_keys])}\n" ) + print( + f"Consistency metrics have been written to {results_path_consistency}" + ) - metrics_dict = run_supervised_metrics(preds, labels) + metrics_dict = run_supervised_metrics( + preds, labels, save_details_to=details_path + ) with open(results_path_supervised, "a") as f: f.write( - f"{run_name},{epoch},{dataset.__class__.__name__},{metric}," + f"{run_name},{epoch},{dataset.__class__.__name__}," f"{','.join([str(metrics_dict[k]) for k in supervised_keys])}\n" ) + print(f"Supervised metrics have been written to {results_path_supervised}") # follow-up to NeSy submission @@ -512,12 +603,6 @@ def run_fuzzy_loss(tag="fuzzy_loss"): ) run_all( ids, - [ - ( - "chebi100_semrc_epoch-dependent100-1m_start-at10_weighted_v231_pc_kmeans_241001-0836", - 196, - ) - ], # [("chebi100_semrc_epoch-dependent100-100k_weighted_v231_pc_kmeans_240927-1219", 97), # ("chebi100_semrc_epoch-dependent100-100k_weighted_v231_pc_kmeans_240927-1220", 99)], consistency_metrics=[binary], diff --git a/chebai/result/utils.py b/chebai/result/utils.py index e640bc11..49147a61 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -16,7 +16,6 @@ def get_checkpoint_from_wandb( epoch: int, run: wandb.apis.public.Run, root: str = os.path.join("logs", "downloaded_ckpts"), - model_class: Optional[Union[Electra, ChebaiBaseNet]] = None, map_device_to: Optional[torch.device] = None, ): """ @@ -33,8 +32,6 @@ def get_checkpoint_from_wandb( The location of the downloaded checkpoint. """ api = wandb.Api() - if model_class is None: - model_class = Electra files = run.files() for file in files: From cc1bba7257635b76f392c773dddaba1ee242e626 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 9 Oct 2024 14:11:34 +0200 Subject: [PATCH 27/47] fix goedel loss --- chebai/loss/semantic.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index bc6df01c..8a78a55f 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -167,9 +167,10 @@ def _calculate_implication_loss( elif self.fuzzy_implication in ["kleene_dienes", "kd"]: individual_loss = torch.min(l, 1 - r) elif self.fuzzy_implication in ["goedel", "g"]: - individual_loss = ( - torch.relu(l - r) / (l - r) * one_min_r - ) # 0 if l <= r else one_min_r + individual_loss = torch.where(l <= r, 0, one_min_r) + # individual_loss = ( + # torch.relu(l - r) / (l - r) * one_min_r + # ) # 0 if l <= r else one_min_r else: raise NotImplementedError( f"Unknown fuzzy implication {self.fuzzy_implication}" From 235abdbcfb7e4e082d3664f9f425bfd675103df3 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 10 Oct 2024 08:53:51 +0200 Subject: [PATCH 28/47] fix pos scalar typehint --- chebai/loss/semantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 8a78a55f..9aa64af0 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -45,7 +45,7 @@ def __init__( "g", ] = "reichenbach", impl_loss_weight: float = 0.1, - pos_scalar: int = 1, + pos_scalar: Union[int, float] = 1, pos_epsilon: float = 0.01, multiply_by_softmax: bool = False, use_sigmoidal_implication: bool = False, From e892967336599b2125d5d45ecb2c73610fb20794 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 10 Oct 2024 10:47:08 +0200 Subject: [PATCH 29/47] add epsilon to consequent (for balanced loss with k < 1) --- chebai/loss/semantic.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 9aa64af0..13874469 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -151,7 +151,13 @@ def _calculate_implication_loss( math.pow(1 + self.eps, 1 / self.pos_scalar) - math.pow(self.eps, 1 / self.pos_scalar) ) - one_min_r = torch.pow(1 - r, self.pos_scalar) + one_min_r = ( + torch.pow(1 - r + self.eps, 1 / self.pos_scalar) + - math.pow(self.eps, 1 / self.pos_scalar) + ) / ( + math.pow(1 + self.eps, 1 / self.pos_scalar) + - math.pow(self.eps, 1 / self.pos_scalar) + ) else: one_min_r = 1 - r # for each implication I, calculate 1 - I(l, 1-one_min_r) @@ -168,9 +174,8 @@ def _calculate_implication_loss( individual_loss = torch.min(l, 1 - r) elif self.fuzzy_implication in ["goedel", "g"]: individual_loss = torch.where(l <= r, 0, one_min_r) - # individual_loss = ( - # torch.relu(l - r) / (l - r) * one_min_r - # ) # 0 if l <= r else one_min_r + elif self.fuzzy_implication in ["reverse-goedel", "rg"]: + individual_loss = torch.where(l <= r, 0, l) else: raise NotImplementedError( f"Unknown fuzzy implication {self.fuzzy_implication}" From f34a5ad83696d13cb81bd97a62a925434f1523d4 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 10 Oct 2024 15:30:14 +0200 Subject: [PATCH 30/47] add parameters to epoch-dependent weighting --- chebai/loss/semantic.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 13874469..56a758e9 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -43,13 +43,15 @@ def __init__( "kd", "goedel", "g", + "reverse-goedel", + "rg", ] = "reichenbach", impl_loss_weight: float = 0.1, pos_scalar: Union[int, float] = 1, pos_epsilon: float = 0.01, multiply_by_softmax: bool = False, use_sigmoidal_implication: bool = False, - weight_epoch_dependent: bool = False, + weight_epoch_dependent: Union[bool | tuple[int, int]] = False, start_at_epoch: int = 0, ): super().__init__() @@ -114,9 +116,20 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: } implication_loss_weighted = implication_loss if "current_epoch" in kwargs and self.weight_epoch_dependent: + sigmoid_center = ( + self.weight_epoch_dependent[0] + if isinstance(self.weight_epoch_dependent, tuple) + else 50 + ) + sigmoid_spread = ( + self.weight_epoch_dependent[1] + if isinstance(self.weight_epoch_dependent, tuple) + else 10 + ) # sigmoid function centered around epoch 50 implication_loss_weighted = implication_loss_weighted / ( - 1 + math.exp(-(kwargs["current_epoch"] - 50) / 10) + 1 + + math.exp(-(kwargs["current_epoch"] - sigmoid_center) / sigmoid_spread) ) implication_loss_weighted *= self.impl_weight loss_components["weighted_implication_loss"] = implication_loss_weighted From 436c8977dfd45e653a4e1b0fca6a72d263073d7c Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 10 Oct 2024 16:03:03 +0200 Subject: [PATCH 31/47] disable strict checkpoint loading --- chebai/models/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/chebai/models/base.py b/chebai/models/base.py index 34d32134..a0e46d73 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -58,6 +58,10 @@ def __init__( self.test_metrics = test_metrics self.pass_loss_kwargs = pass_loss_kwargs + # allows resuming training without strict loading (e.g., ignoring loss weights), + # see https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#resume-from-a-partial-checkpoint + self.strict_loading = False + def __init_subclass__(cls, **kwargs): """ Automatically registers subclasses in the model registry to prevent duplicates. From e7c9ff899f4f676de8bec157ca140e3f29cb76db Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 10 Oct 2024 16:32:14 +0200 Subject: [PATCH 32/47] fix checkpoint loading --- chebai/models/base.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index a0e46d73..288bdbf5 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -58,9 +58,13 @@ def __init__( self.test_metrics = test_metrics self.pass_loss_kwargs = pass_loss_kwargs - # allows resuming training without strict loading (e.g., ignoring loss weights), - # see https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#resume-from-a-partial-checkpoint - self.strict_loading = False + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a + # different loss) + if "criterion.base_loss.pos_weight" in checkpoint["state_dict"]: + del checkpoint["state_dict"]["criterion.base_loss.pos_weight"] + if "criterion.pos_weight" in checkpoint["state_dict"]: + del checkpoint["state_dict"]["criterion.pos_weight"] def __init_subclass__(cls, **kwargs): """ From 7f7695d6a4777d7e12f2ef7c73d6e3b453aa463e Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 14 Oct 2024 14:47:01 +0200 Subject: [PATCH 33/47] add raw file names (temporary fix), add pubchem data to fuzzy eval, --- chebai/preprocessing/datasets/base.py | 19 +++- chebai/preprocessing/datasets/chebi.py | 4 +- chebai/preprocessing/datasets/pubchem.py | 6 +- chebai/result/analyse_sem.py | 138 +++++++++++++---------- 4 files changed, 101 insertions(+), 66 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index f163a9e6..fa2bcf10 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -172,7 +172,8 @@ def load_processed_data( self, kind: Optional[str] = None, filename: Optional[str] = None ) -> List: """ - Load processed data from a file. + Load processed data from a file. Either the kind or the filename has to be provided. If both are provided, the + filename is used. Args: kind (str, optional): The kind of dataset to load such as "train", "val" or "test". Defaults to None. @@ -705,7 +706,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: """ print("Checking for processed data in", self.processed_dir_main) - processed_name = self.processed_dir_main_file_names_dict["data"] + processed_name = self.raw_file_names_dict["data"] if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): print("Missing processed data file (`data.pkl` file)") os.makedirs(self.processed_dir_main, exist_ok=True) @@ -796,7 +797,7 @@ def setup_processed(self) -> None: self._load_data_from_file( os.path.join( self.processed_dir_main, - self.processed_dir_main_file_names_dict["data"], + self.raw_file_names_dict["data"], ) ), os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), @@ -1131,7 +1132,7 @@ def processed_dir(self) -> str: ) @property - def processed_dir_main_file_names_dict(self) -> dict: + def raw_file_names_dict(self) -> dict: """ Returns a dictionary mapping processed data file names, processed by `prepare_data` method. @@ -1141,6 +1142,16 @@ def processed_dir_main_file_names_dict(self) -> dict: """ return {"data": "data.pkl"} + @property + def raw_file_names(self) -> List[str]: + """ + Returns a list of raw file names. + + Returns: + List[str]: A list of file names corresponding to the raw data. + """ + return list(self.raw_file_names_dict.values()) + @property def processed_file_names_dict(self) -> dict: """ diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 9d80929a..885332b0 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -185,9 +185,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: if not os.path.isfile( os.path.join( self._chebi_version_train_obj.processed_dir_main, - self._chebi_version_train_obj.processed_dir_main_file_names_dict[ - "data" - ], + self._chebi_version_train_obj.raw_file_names_dict["data"], ) ): print( diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index fafa2561..9e43302a 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -349,10 +349,12 @@ def split_label(self) -> str: @property def raw_file_names(self) -> List[str]: """ + Clusters generated by K-Means, sorted by size (cluster0 is the largest). + cluster0 is the training cluster (will be split into train/val/test in processed, used for pretraining) Returns: List[str]: List of raw file names expected in the raw directory. """ - return ["train.txt", "validation.txt", "test.txt"] + return ["cluster0.txt", "cluster1.txt", "cluster2.txt"] @property def fingerprints(self) -> pd.DataFrame: @@ -603,7 +605,7 @@ def download(self): splits = [fp_grouped.get_group(g) for g in fp_grouped.groups if g != -1] splits[0] = splits[0][: self.validation_size_limit] splits.sort(key=lambda x: len(x)) - for i, name in enumerate(["validation", "test", "train"]): + for i, name in enumerate(["cluster2", "cluster1", "cluster0"]): if not os.path.exists(os.path.join(self.raw_dir, f"{name}.txt")): open(os.path.join(self.raw_dir, f"{name}.txt"), "x").close() with open(os.path.join(self.raw_dir, f"{name}.txt"), "w") as f: diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index d2ac8ad4..8d83c097 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -11,7 +11,9 @@ from utils import * from chebai.loss.semantic import DisjointLoss +from chebai.preprocessing.datasets.base import _DynamicDataset from chebai.preprocessing.datasets.chebi import ChEBIOver100 +from chebai.preprocessing.datasets.pubchem import PubChemKMeans DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -98,11 +100,17 @@ def load_preds_labels( f"{data_module.__class__.__name__}_{data_subset_key}", ) model = Electra.load_from_checkpoint(ckpt_path, map_location="cuda:0", strict=False) - print(f"Calculating predictions...") + print( + f"Calculating predictions on {data_module.__class__.__name__} ({data_subset_key})..." + ) evaluate_model( model, data_module, buffer_dir=buffer_dir, + # for chebi, use kinds, otherwise use file names + filename=( + data_subset_key if not isinstance(buffer_dir, _DynamicDataset) else None + ), kind=data_subset_key, skip_existing_preds=True, ) @@ -512,7 +520,11 @@ def run_all( "tp-sum-disj", ] with open(results_path_consistency, "x") as f: - f.write("run-id,epoch,datamodule,metric," + ",".join(consistency_keys) + "\n") + f.write( + "run-id,epoch,datamodule,data_key,metric," + + ",".join(consistency_keys) + + "\n" + ) results_path_supervised = os.path.join( results_dir, f"supervised_metrics_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", @@ -526,68 +538,74 @@ def run_all( "macro-ap", ] with open(results_path_supervised, "x") as f: - f.write("run-id,epoch,datamodule," + ",".join(supervised_keys) + "\n") + f.write("run-id,epoch,datamodule,data_key" + ",".join(supervised_keys) + "\n") ckpts = [(run_name, ep, None) for run_name, ep in local_ckpts] + [ (None, None, wandb_id) for wandb_id in wandb_ids ] for run_name, epoch, wandb_id in ckpts: - ckpt_dir = os.path.join("logs", "downloaded_ckpts") - if wandb_id is not None: - ckpt_path, epoch = download_model_from_wandb(wandb_id, ckpt_dir) - else: - ckpt_path = None - for file in os.listdir(os.path.join(ckpt_dir, run_name)): - if file.startswith(f"best_epoch={epoch}_") or file.startswith( - f"per_epoch={epoch}_" - ): - ckpt_path = os.path.join(os.path.join(ckpt_dir, run_name, file)) - assert ( - ckpt_path is not None - ), f"Failed to find checkpoint for epoch {epoch} in {os.path.join(ckpt_dir, run_name)}" - - for dataset, dataset_key in prediction_datasets: - preds, labels = load_preds_labels(ckpt_path, dataset, dataset_key) - - # identity function if remove_violations is False - smooth_preds(preds) - + try: + ckpt_dir = os.path.join("logs", "downloaded_ckpts") # for wandb runs, use short id as name, otherwise use ckpt dir name if wandb_id is not None: run_name = wandb_id - details_path = os.path.join( - results_dir, - f"{run_name}_ep{epoch}_{dataset.__class__.__name__}_{dataset_key}", - ) - metrics_dict = run_consistency_metrics( - preds, - check_consistency_on, - consistency_metrics, - verbose_violation_output, - save_details_to=details_path, - ) - with open(results_path_consistency, "a") as f: - for metric in metrics_dict: - values = metrics_dict[metric] + ckpt_path, epoch = download_model_from_wandb(run_name, ckpt_dir) + else: + ckpt_path = None + for file in os.listdir(os.path.join(ckpt_dir, run_name)): + if file.startswith(f"best_epoch={epoch}_") or file.startswith( + f"per_epoch={epoch}_" + ): + ckpt_path = os.path.join(os.path.join(ckpt_dir, run_name, file)) + assert ( + ckpt_path is not None + ), f"Failed to find checkpoint for epoch {epoch} in {os.path.join(ckpt_dir, run_name)}" + print(f"Starting run {run_name} (epoch {epoch})") + for dataset, dataset_key in prediction_datasets: + preds, labels = load_preds_labels(ckpt_path, dataset, dataset_key) + + # identity function if remove_violations is False + smooth_preds(preds) + + details_path = os.path.join( + results_dir, + f"{run_name}_ep{epoch}_{dataset.__class__.__name__}_{dataset_key}", + ) + metrics_dict = run_consistency_metrics( + preds, + check_consistency_on, + consistency_metrics, + verbose_violation_output, + save_details_to=details_path, + ) + with open(results_path_consistency, "a") as f: + for metric in metrics_dict: + values = metrics_dict[metric] + f.write( + f"{run_name},{epoch},{dataset.__class__.__name__},{dataset_key},{metric}," + f"{','.join([str(values[k]) for k in consistency_keys])}\n" + ) + print( + f"Consistency metrics have been written to {results_path_consistency}" + ) + + metrics_dict = run_supervised_metrics( + preds, labels, save_details_to=details_path + ) + with open(results_path_supervised, "a") as f: f.write( - f"{run_name},{epoch},{dataset.__class__.__name__},{metric}," - f"{','.join([str(values[k]) for k in consistency_keys])}\n" + f"{run_name},{epoch},{dataset.__class__.__name__},{dataset_key}," + f"{','.join([str(metrics_dict[k]) for k in supervised_keys])}\n" ) + print( + f"Supervised metrics have been written to {results_path_supervised}" + ) + except Exception as e: print( - f"Consistency metrics have been written to {results_path_consistency}" + f"Error during run {wandb_id if wandb_id is not None else run_name}: {e}" ) - metrics_dict = run_supervised_metrics( - preds, labels, save_details_to=details_path - ) - with open(results_path_supervised, "a") as f: - f.write( - f"{run_name},{epoch},{dataset.__class__.__name__}," - f"{','.join([str(metrics_dict[k]) for k in supervised_keys])}\n" - ) - print(f"Supervised metrics have been written to {results_path_supervised}") - # follow-up to NeSy submission def run_fuzzy_loss(tag="fuzzy_loss"): @@ -601,17 +619,23 @@ def run_fuzzy_loss(tag="fuzzy_loss"): "data", "chebi_v231", "ChEBI100", "fuzzy_loss_splits.csv" ), ) + pubchem_kmeans = PubChemKMeans() run_all( - ids, - # [("chebi100_semrc_epoch-dependent100-100k_weighted_v231_pc_kmeans_240927-1219", 97), - # ("chebi100_semrc_epoch-dependent100-100k_weighted_v231_pc_kmeans_240927-1220", 99)], + [], # ids, + [ + ( + "chebi100_semg_epoch-dependent1-1k_start-at=10_batch3_weighted_v231_pc_kmeans_241010-0814", + 199, + ) + ], consistency_metrics=[binary], check_consistency_on=chebi100, prediction_datasets=[ - ( - chebi100, - "test", - ) + (chebi100, "test"), + (pubchem_kmeans, "cluster1.pt"), + (pubchem_kmeans, "cluster2.pt"), + (pubchem_kmeans, "chebi_close.pt"), + (pubchem_kmeans, "ten_from_each_cluster.pt"), ], ) From 33c2d64830871fd253db46aca66955afc7948bac Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 17 Oct 2024 13:13:02 +0200 Subject: [PATCH 34/47] efficiency, minor fixes, changed paths --- chebai/loss/semantic.py | 7 +- chebai/result/analyse_sem.py | 214 +++++++++++++++++++++++------------ chebai/result/utils.py | 19 +++- 3 files changed, 161 insertions(+), 79 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 56a758e9..8975ee38 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -136,7 +136,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: return base_loss + implication_loss_weighted, loss_components def _calculate_implication_loss( - self, l: torch.Tensor, r: torch.Tensor + self, l: torch.Tensor, r: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: """ Calculate implication loss based on T-norm and other parameters. @@ -205,6 +205,11 @@ def _calculate_implication_loss( if self.multiply_by_softmax: individual_loss = individual_loss * individual_loss.softmax(dim=-1) + + # aggregate for classes, mask with ground truth labels + target_l = target[:, self.implication_filter_l] + target_r = target[:, self.implication_filter_r] + return torch.mean( torch.sum(individual_loss, dim=-1), dim=0, diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index 8d83c097..71d23c89 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -1,5 +1,6 @@ import gc import sys +import traceback from datetime import datetime from typing import List, LiteralString @@ -85,7 +86,7 @@ def download_model_from_wandb( run = api.run(f"chebai/chebai/{run_id}") epoch = get_best_epoch(run) return ( - get_checkpoint_from_wandb(epoch, run, root=base_dir, map_device_to="cuda:0"), + get_checkpoint_from_wandb(epoch, run, root=base_dir), epoch, ) @@ -113,8 +114,9 @@ def load_preds_labels( ), kind=data_subset_key, skip_existing_preds=True, + batch_size=1, ) - return load_results_from_buffer(buffer_dir, device=DEVICE) + return load_results_from_buffer(buffer_dir, device=torch.device("cpu")) def get_label_names(data_module): @@ -234,8 +236,24 @@ def __call__(self, preds): return preds +def build_prediction_filter(data_module_labeled=None): + if data_module_labeled is None: + data_module_labeled = ChEBIOver100(chebi_version=231) + # prepare filters + print(f"Loading implication / disjointness filters...") + dl = DisjointLoss( + path_to_disjointness=os.path.join("data", "disjoint.csv"), + data_extractor=data_module_labeled, + ) + return [ + (dl.implication_filter_l, dl.implication_filter_r, "impl"), + (dl.disjoint_filter_l, dl.disjoint_filter_r, "disj"), + ] + + def run_consistency_metrics( preds, + consistency_filters, data_module_labeled=None, # use labels from this dataset for violations violation_metrics=None, verbose_violation_output=False, @@ -249,40 +267,55 @@ def run_consistency_metrics( if save_details_to is not None: os.makedirs(save_details_to, exist_ok=True) + preds.to("cpu") + n_labels = preds.size(1) print(f"Found {preds.shape[0]} predictions ({n_labels} classes)") results = {} - # prepare filters - print(f"Loading & rescaling implication / disjointness filters...") - dl = DisjointLoss( - path_to_disjointness=os.path.join("data", "disjoint.csv"), - data_extractor=data_module_labeled, - ) - for dl_filter_l, dl_filter_r, filter_type in [ - (dl.implication_filter_l, dl.implication_filter_r, "impl"), - (dl.disjoint_filter_l, dl.disjoint_filter_r, "disj"), - ]: - # prepare predictions - n_loss_terms = dl_filter_l.shape[0] - preds_exp = preds.unsqueeze(2).expand((-1, -1, n_loss_terms)).swapaxes(1, 2) - l_preds = preds_exp[:, _filter_to_one_hot(preds, dl_filter_l)] - r_preds = preds_exp[:, _filter_to_one_hot(preds, dl_filter_r)] - del preds_exp - gc.collect() - + for dl_filter_l, dl_filter_r, filter_type in consistency_filters: + l_preds = preds[:, dl_filter_l] + r_preds = preds[:, dl_filter_r] for i, metric in enumerate(violation_metrics): if metric.__name__ not in results: results[metric.__name__] = {} print(f"Calculating metrics {metric.__name__} on {filter_type}") metric_results = {} - metric_results["tps"] = apply_metric( - metric, l_preds, r_preds if filter_type == "impl" else 1 - r_preds + metric_results["tps"] = torch.sum( + torch.stack( + [ + apply_metric( + metric, + l_preds[i : i + 1000], + ( + r_preds[i : i + 1000] + if filter_type == "impl" + else 1 - r_preds[i : i + 1000] + ), + ) + for i in range(0, r_preds.shape[0], 1000) + ] + ), + dim=0, ) - metric_results["fns"] = apply_metric( - metric, l_preds, 1 - r_preds if filter_type == "impl" else r_preds + metric_results["fns"] = torch.sum( + torch.stack( + [ + apply_metric( + metric, + l_preds[i : i + 1000], + ( + 1 - r_preds[i : i + 1000] + if filter_type == "impl" + else r_preds[i : i + 1000] + ), + ) + for i in range(0, r_preds.shape[0], 1000) + ] + ), + dim=0, ) if verbose_violation_output: label_names = get_label_names(data_module_labeled) @@ -363,23 +396,29 @@ def run_consistency_metrics( f"Saved unaggregated consistency metrics ({metric.__name__}, {filter_type}) to {save_details_to}" ) + fns_sum = torch.sum(metric_results["fns"]).item() results[metric.__name__][f"micro-fnr-{filter_type}"] = ( - 1 - - ( - torch.sum(metric_results["tps"]) + 0 + if fns_sum == 0 + else ( + torch.sum(metric_results["fns"]) / ( torch.sum(metric_results[f"tps"]) + torch.sum(metric_results[f"fns"]) ) ).item() ) - macro_recall_l = m_l_agg[f"tps"] / (m_l_agg[f"tps"] + m_l_agg[f"fns"]) + macro_fnr_l = m_l_agg[f"fns"] / (m_l_agg[f"tps"] + m_l_agg[f"fns"]) results[metric.__name__][f"lmacro-fnr-{filter_type}"] = ( - 1 - torch.mean(macro_recall_l[~macro_recall_l.isnan()]).item() + 0 + if fns_sum == 0 + else torch.mean(macro_fnr_l[~macro_fnr_l.isnan()]).item() ) - macro_recall_r = m_r_agg[f"tps"] / (m_r_agg[f"tps"] + m_r_agg[f"fns"]) + macro_fnr_r = m_r_agg[f"fns"] / (m_r_agg[f"tps"] + m_r_agg[f"fns"]) results[metric.__name__][f"rmacro-fnr-{filter_type}"] = ( - 1 - torch.mean(macro_recall_r[~macro_recall_r.isnan()]).item() + 0 + if fns_sum == 0 + else torch.mean(macro_fnr_r[~macro_fnr_r.isnan()]).item() ) results[metric.__name__][f"fn-sum-{filter_type}"] = torch.sum( metric_results["fns"] @@ -424,23 +463,23 @@ def run_supervised_metrics(preds, labels, save_details_to=None): preds, labels, num_labels=preds.size(1), average="macro" ).item() - if save_details_to is not None: - f1_by_label = multilabel_f1_score( - preds, labels, num_labels=preds.size(1), average=None - ) - roc_by_label = multilabel_auroc( - preds, labels, num_labels=preds.size(1), average=None - ) - ap_by_label = multilabel_average_precision( - preds, labels, num_labels=preds.size(1), average=None - ) - with open(os.path.join(save_details_to, f"supervised.csv"), "w+") as f: - f.write("label,f1,roc-auc,ap\n") - for right in range(preds.size(1)): - f.write( - f"{right},{f1_by_label[right].item()},{roc_by_label[right].item()},{ap_by_label[right].item()}\n" - ) - print(f"Saved class-wise supervised metrics to {save_details_to}") + if save_details_to is not None: + f1_by_label = multilabel_f1_score( + preds, labels, num_labels=preds.size(1), average=None + ) + roc_by_label = multilabel_auroc( + preds, labels, num_labels=preds.size(1), average=None + ) + ap_by_label = multilabel_average_precision( + preds, labels, num_labels=preds.size(1), average=None + ) + with open(os.path.join(save_details_to, f"supervised.csv"), "w+") as f: + f.write("label,f1,roc-auc,ap\n") + for right in range(preds.size(1)): + f.write( + f"{right},{f1_by_label[right].item()},{roc_by_label[right].item()},{ap_by_label[right].item()}\n" + ) + print(f"Saved class-wise supervised metrics to {save_details_to}") del preds del labels @@ -503,6 +542,8 @@ def run_all( smooth_preds = lambda x: x timestamp = datetime.now().strftime("%y%m%d-%H%M%S") + prediction_filters = build_prediction_filter(check_consistency_on) + results_path_consistency = os.path.join( results_dir, f"consistency_metrics_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", @@ -538,7 +579,7 @@ def run_all( "macro-ap", ] with open(results_path_supervised, "x") as f: - f.write("run-id,epoch,datamodule,data_key" + ",".join(supervised_keys) + "\n") + f.write("run-id,epoch,datamodule,data_key," + ",".join(supervised_keys) + "\n") ckpts = [(run_name, ep, None) for run_name, ep in local_ckpts] + [ (None, None, wandb_id) for wandb_id in wandb_ids @@ -562,9 +603,32 @@ def run_all( ckpt_path is not None ), f"Failed to find checkpoint for epoch {epoch} in {os.path.join(ckpt_dir, run_name)}" print(f"Starting run {run_name} (epoch {epoch})") - for dataset, dataset_key in prediction_datasets: - preds, labels = load_preds_labels(ckpt_path, dataset, dataset_key) + for dataset, dataset_key in prediction_datasets: + # copy data from legacy buffer dir if possible + old_buffer_dir = os.path.join( + "results_buffer", + *ckpt_path.split(os.path.sep)[-2:], + f"{dataset.__class__.__name__}_{dataset_key}", + ) + buffer_dir = os.path.join( + "results_buffer", + run_name, + f"epoch={epoch}", + f"{dataset.__class__.__name__}_{dataset_key}", + ) + print("Checking for buffer dir", old_buffer_dir) + if os.path.isdir(old_buffer_dir): + from distutils.dir_util import copy_tree, remove_tree + + os.makedirs(buffer_dir, exist_ok=True) + copy_tree(old_buffer_dir, buffer_dir) + remove_tree(old_buffer_dir, dry_run=True) + print(f"Moved buffer from {old_buffer_dir} to {buffer_dir}") + print(f"Using buffer_dir {buffer_dir}") + preds, labels = load_preds_labels( + ckpt_path, dataset, dataset_key, buffer_dir + ) # identity function if remove_violations is False smooth_preds(preds) @@ -574,6 +638,7 @@ def run_all( ) metrics_dict = run_consistency_metrics( preds, + prediction_filters, check_consistency_on, consistency_metrics, verbose_violation_output, @@ -589,26 +654,27 @@ def run_all( print( f"Consistency metrics have been written to {results_path_consistency}" ) - - metrics_dict = run_supervised_metrics( - preds, labels, save_details_to=details_path - ) - with open(results_path_supervised, "a") as f: - f.write( - f"{run_name},{epoch},{dataset.__class__.__name__},{dataset_key}," - f"{','.join([str(metrics_dict[k]) for k in supervised_keys])}\n" + if labels is not None: + metrics_dict = run_supervised_metrics( + preds, labels, save_details_to=details_path + ) + with open(results_path_supervised, "a") as f: + f.write( + f"{run_name},{epoch},{dataset.__class__.__name__},{dataset_key}," + f"{','.join([str(metrics_dict[k]) for k in supervised_keys])}\n" + ) + print( + f"Supervised metrics have been written to {results_path_supervised}" ) - print( - f"Supervised metrics have been written to {results_path_supervised}" - ) except Exception as e: print( f"Error during run {wandb_id if wandb_id is not None else run_name}: {e}" ) + print(traceback.format_exc()) # follow-up to NeSy submission -def run_fuzzy_loss(tag="fuzzy_loss"): +def run_fuzzy_loss(tag="fuzzy_loss", skip_first_n=0): api = wandb.Api() runs = api.runs("chebai/chebai", filters={"tags": tag}) print(f"Found {len(runs)} wandb runs tagged with '{tag}'") @@ -619,29 +685,33 @@ def run_fuzzy_loss(tag="fuzzy_loss"): "data", "chebi_v231", "ChEBI100", "fuzzy_loss_splits.csv" ), ) + local_ckpts = [("dd1r2kfb", 179)][skip_first_n:] pubchem_kmeans = PubChemKMeans() run_all( - [], # ids, - [ - ( - "chebi100_semg_epoch-dependent1-1k_start-at=10_batch3_weighted_v231_pc_kmeans_241010-0814", - 199, - ) + [], # ids[max(0, skip_first_n-len(local_ckpts)):], # ids, + local_ckpts + + [ + # ( + # "chebi100_semg_epoch-dependent1-1k_start-at=10_batch3_weighted_v231_pc_kmeans_241010-0814", + # 199, + # ) ], consistency_metrics=[binary], check_consistency_on=chebi100, prediction_datasets=[ (chebi100, "test"), - (pubchem_kmeans, "cluster1.pt"), + (pubchem_kmeans, "cluster1_cutoff2k.pt"), (pubchem_kmeans, "cluster2.pt"), - (pubchem_kmeans, "chebi_close.pt"), (pubchem_kmeans, "ten_from_each_cluster.pt"), + (pubchem_kmeans, "chebi_close.pt"), ], ) if __name__ == "__main__": - if len(sys.argv) > 1: + if len(sys.argv) > 2: + run_fuzzy_loss(sys.argv[1], int(sys.argv[2])) + elif len(sys.argv) > 1: run_fuzzy_loss(sys.argv[1]) else: run_fuzzy_loss() diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 50956810..a5da8501 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -1,4 +1,5 @@ import os +import shutil from typing import Optional, Tuple, Union import torch @@ -16,7 +17,6 @@ def get_checkpoint_from_wandb( epoch: int, run: wandb.apis.public.Run, root: str = os.path.join("logs", "downloaded_ckpts"), - map_device_to: Optional[torch.device] = None, ): """ Gets a wandb checkpoint based on run and epoch, downloads it if necessary. @@ -25,8 +25,6 @@ def get_checkpoint_from_wandb( epoch: The epoch number of the checkpoint to retrieve. run: The wandb run object. root: The root directory to save the downloaded checkpoint. - model_class: The class of the model to load. - map_device_to: The device to map the model to. Returns: The location of the downloaded checkpoint. @@ -38,10 +36,19 @@ def get_checkpoint_from_wandb( if file.name.startswith( f"checkpoints/per_epoch={epoch}" ) or file.name.startswith(f"checkpoints/best_epoch={epoch}"): - dest_path = os.path.join(root, run.name, file.name.split("/")[-1]) + dest_path = os.path.join( + root, run.id, file.name.split("/")[-1].split("_")[1] + ".ckpt" + ) + # legacy: also look for ckpts in the old format + old_dest_path = os.path.join(root, run.name, file.name.split("/")[-1]) if not os.path.isfile(dest_path): - print(f"Downloading checkpoint to {dest_path}") - wandb_util.download_file_from_url(dest_path, file.url, api.api_key) + if os.path.isfile(old_dest_path): + print(f"Copying checkpoint from {old_dest_path} to {dest_path}") + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + shutil.copy2(old_dest_path, dest_path) + else: + print(f"Downloading checkpoint to {dest_path}") + wandb_util.download_file_from_url(dest_path, file.url, api.api_key) return dest_path print(f"No model found for epoch {epoch}") return None From 0e32978bf74c0e0660b6d619d91efc660701d017 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 21 Oct 2024 12:08:28 +0200 Subject: [PATCH 35/47] add elementwise multiplicative fuzzy loss --- chebai/loss/bce_weighted.py | 3 +- chebai/loss/semantic.py | 193 +++++++++++++++++++++++---------- configs/loss/semantic_loss.yml | 3 +- 3 files changed, 141 insertions(+), 58 deletions(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 6cde8bb6..10f607dc 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -26,6 +26,7 @@ def __init__( self, beta: Optional[float] = None, data_extractor: Optional[XYBaseDataModule] = None, + **kwargs, ): self.beta = beta if isinstance(data_extractor, LabeledUnlabeledMixed): @@ -35,7 +36,7 @@ def __init__( isinstance(self.data_extractor, _ChEBIDataExtractor) or self.data_extractor is None ) - super().__init__() + super().__init__(**kwargs) def set_pos_weight(self, input: torch.Tensor) -> None: """ diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 8975ee38..b2433bf7 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -12,6 +12,25 @@ from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed +def _filter_by_ground_truth(individual_loss, target, filter_l, filter_r): + # mask of ground truth labels for implications, shape (batch_size, num_implications/num_disjointnesses) + target_l = target[:, filter_l] + target_r = target[:, filter_r] + + # filter individual loss: for an implication A->B, only apply loss to A if A is labeled false, + # only apply loss to B if B is labeled true + applicable_individual_loss_l = individual_loss * (1 - target_l) + applicable_individual_loss_r = individual_loss * target_r + class_loss = torch.zeros(target.shape, device=target.device) + # for each class, sum up the losses for all implication antecedents and consequents + for cls in range(class_loss.shape[1]): + class_loss[:, cls] = applicable_individual_loss_l[:, filter_l == cls].sum(dim=1) + class_loss[:, cls] += applicable_individual_loss_r[:, filter_r == cls].sum( + dim=1 + ) + return class_loss + + class ImplicationLoss(torch.nn.Module): """ Implication Loss module. @@ -63,6 +82,9 @@ def __init__( # propagate data_extractor to base loss if isinstance(base_loss, BCEWeighted): base_loss.data_extractor = self.data_extractor + base_loss.reduction = ( + "none" # needed to multiply fuzzy loss with base loss for each sample + ) self.base_loss = base_loss self.implication_cache_file = f"implications_{self.data_extractor.name}.cache" self.label_names = _load_label_names( @@ -83,37 +105,19 @@ def __init__( self.weight_epoch_dependent = weight_epoch_dependent self.start_at_epoch = start_at_epoch - def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: - """ - Forward pass of the implication loss module. - - Args: - input (torch.Tensor): Input tensor. - target (torch.Tensor): Target tensor. - **kwargs: Additional arguments. - - Returns: - tuple: Tuple containing total loss, base loss, and implication loss. - """ - nnl = kwargs.pop("non_null_labels", None) - if nnl: - labeled_input = input[nnl] - else: - labeled_input = input - if target is not None: - base_loss = self.base_loss(labeled_input, target.float()) - else: - base_loss = 0 - if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: - return base_loss, {"base_loss": base_loss} - pred = torch.sigmoid(input) - l = pred[:, self.implication_filter_l] - r = pred[:, self.implication_filter_r] - implication_loss = self._calculate_implication_loss(l, r) - loss_components = { - "base_loss": base_loss, - "unweighted_implication_loss": implication_loss, - } + def _calculate_unaggregated_fuzzy_loss( + self, pred, target: torch.Tensor, weight, filter_l, filter_r, **kwargs + ): + l = pred[:, filter_l] + r = pred[:, filter_r] + individual_loss = self._calculate_implication_loss(l, r, target) + implication_loss = _filter_by_ground_truth( + individual_loss, + target, + filter_l, + filter_r, + ) + unweighted_mean = implication_loss.sum(dim=-1).mean() implication_loss_weighted = implication_loss if "current_epoch" in kwargs and self.weight_epoch_dependent: sigmoid_center = ( @@ -131,9 +135,56 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: 1 + math.exp(-(kwargs["current_epoch"] - sigmoid_center) / sigmoid_spread) ) - implication_loss_weighted *= self.impl_weight - loss_components["weighted_implication_loss"] = implication_loss_weighted - return base_loss + implication_loss_weighted, loss_components + implication_loss_weighted *= weight + weighted_mean = implication_loss_weighted.sum(dim=-1).mean() + + return implication_loss_weighted, unweighted_mean, weighted_mean + + def _calculate_unaggregated_base_loss(self, input, target, **kwargs): + nnl = kwargs.pop("non_null_labels", None) + labeled_input = input[nnl] if nnl else input + + if target is not None and self.base_loss is not None: + return self.base_loss(labeled_input, target.float()) + else: + return torch.zeros(input.shape, device=input.device) + + def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: + """ + Forward pass of the implication loss module. + + Args: + input (torch.Tensor): Input tensor. + target (torch.Tensor): Target tensor. + **kwargs: Additional arguments. + + Returns: + tuple: Tuple containing total loss, base loss, and implication loss. + """ + base_loss = self._calculate_unaggregated_base_loss(input, target, **kwargs) + loss_components = {"base_loss": base_loss.sum(dim=-1).mean()} + + if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: + return base_loss, loss_components + + pred = torch.sigmoid(input) + fuzzy_loss, unweighted_fuzzy_mean, weighted_fuzzy_mean = ( + self._calculate_unaggregated_fuzzy_loss( + pred, + target, + self.impl_weight, + self.implication_filter_l, + self.implication_filter_r, + **kwargs, + ) + ) + loss_components["unweighted_fuzzy_loss"] = unweighted_fuzzy_mean + loss_components["weighted_fuzzy_loss"] = weighted_fuzzy_mean + if self.base_loss is None or target is None: + return self.impl_weight * fuzzy_loss, loss_components + else: + total_loss = base_loss * (1 + self.impl_weight * fuzzy_loss) + return total_loss.sum(dim=-1).mean(), loss_components def _calculate_implication_loss( self, l: torch.Tensor, r: torch.Tensor, target: torch.Tensor @@ -206,14 +257,7 @@ def _calculate_implication_loss( if self.multiply_by_softmax: individual_loss = individual_loss * individual_loss.softmax(dim=-1) - # aggregate for classes, mask with ground truth labels - target_l = target[:, self.implication_filter_l] - target_r = target[:, self.implication_filter_r] - - return torch.mean( - torch.sum(individual_loss, dim=-1), - dim=0, - ) + return individual_loss def _load_implications(self, path_to_chebi: str) -> dict: """ @@ -273,22 +317,49 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: Returns: tuple: Tuple containing total loss, base loss, implication loss, and disjointness loss. """ - loss, loss_components = super().forward(input, target, **kwargs) + base_loss = self._calculate_unaggregated_base_loss(input, target, **kwargs) + loss_components = {"base_loss": base_loss.sum(dim=-1).mean()} + if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: - return loss, loss_components + return base_loss, loss_components + pred = torch.sigmoid(input) - l = pred[:, self.disjoint_filter_l] - r = pred[:, self.disjoint_filter_r] - disjointness_loss = self._calculate_implication_loss(l, 1 - r) - loss_components["unweighted_disjointness_loss"] = disjointness_loss - disjointness_loss_weighted = disjointness_loss - if "current_epoch" in kwargs and self.weight_epoch_dependent: - disjointness_loss_weighted = disjointness_loss_weighted / ( - 1 + math.exp(-(kwargs["current_epoch"] - 50) / 10) + impl_loss, unweighted_impl_mean, weighted_impl_mean = ( + self._calculate_unaggregated_fuzzy_loss( + pred, + target, + self.impl_weight, + self.implication_filter_l, + self.implication_filter_r, + **kwargs, + ) + ) + loss_components["unweighted_implication_loss"] = unweighted_impl_mean + loss_components["weighted_implication_loss"] = weighted_impl_mean + + disj_loss, unweighted_disj_mean, weighted_disj_mean = ( + self._calculate_unaggregated_fuzzy_loss( + pred, + target, + self.disjoint_weight, + self.disjoint_filter_l, + self.disjoint_filter_r, + **kwargs, ) - disjointness_loss_weighted *= self.disjoint_weight - loss_components["weighted_disjointness_loss"] = disjointness_loss_weighted - return loss + disjointness_loss_weighted, loss_components + ) + loss_components["unweighted_disjointness_loss"] = unweighted_disj_mean + loss_components["weighted_disjointness_loss"] = weighted_disj_mean + + if self.base_loss is None or target is None: + return ( + self.impl_weight * impl_loss + self.disjoint_weight * disj_loss, + loss_components, + ) + else: + total_loss = base_loss + ( + 1 + self.impl_weight * impl_loss + self.disjoint_weight * disj_loss + ) + return total_loss.sum(dim=-1).mean(), loss_components def _load_label_names(path_to_label_names: str) -> List: @@ -368,5 +439,15 @@ def _build_disjointness_filter( if __name__ == "__main__": loss = DisjointLoss( - os.path.join("data", "disjoint.csv"), ChEBIOver100(chebi_version=227) + os.path.join("data", "disjoint.csv"), ChEBIOver100(chebi_version=231) + ) + l = loss(torch.randn(10, 997), torch.randn(10, 997)) + + loss_with_base = DisjointLoss( + os.path.join("data", "disjoint.csv"), + ChEBIOver100(chebi_version=231), + base_loss=BCEWeighted(beta=0.99), ) + lb = loss_with_base(torch.randn(10, 997), torch.randn(10, 997)) + print(l) + print(lb) diff --git a/configs/loss/semantic_loss.yml b/configs/loss/semantic_loss.yml index 015f6619..5434084b 100644 --- a/configs/loss/semantic_loss.yml +++ b/configs/loss/semantic_loss.yml @@ -6,4 +6,5 @@ init_args: init_args: beta: 0.99 multiply_by_softmax: true - impl_loss_weight: 0.01 + impl_loss_weight: 100 + disjoint_loss_weight: 1000000 From 2037c20f1384c06fccf12755c8cff2d720619de8 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 23 Oct 2024 09:48:23 +0200 Subject: [PATCH 36/47] clean up analyse_sem.py --- chebai/result/analyse_sem.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index 71d23c89..4046df87 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -685,17 +685,11 @@ def run_fuzzy_loss(tag="fuzzy_loss", skip_first_n=0): "data", "chebi_v231", "ChEBI100", "fuzzy_loss_splits.csv" ), ) - local_ckpts = [("dd1r2kfb", 179)][skip_first_n:] + local_ckpts = [][skip_first_n:] pubchem_kmeans = PubChemKMeans() run_all( - [], # ids[max(0, skip_first_n-len(local_ckpts)):], # ids, - local_ckpts - + [ - # ( - # "chebi100_semg_epoch-dependent1-1k_start-at=10_batch3_weighted_v231_pc_kmeans_241010-0814", - # 199, - # ) - ], + ids[max(0, skip_first_n - len(local_ckpts)) :], # ids, + local_ckpts, consistency_metrics=[binary], check_consistency_on=chebi100, prediction_datasets=[ From a16a065f6d7c48833c755a239cfd1bc55e3b0922 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 23 Oct 2024 10:02:19 +0200 Subject: [PATCH 37/47] fix fuzzy loss mean aggregation --- chebai/loss/semantic.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index b2433bf7..49b1a6d4 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -117,7 +117,7 @@ def _calculate_unaggregated_fuzzy_loss( filter_l, filter_r, ) - unweighted_mean = implication_loss.sum(dim=-1).mean() + unweighted_mean = implication_loss.mean() implication_loss_weighted = implication_loss if "current_epoch" in kwargs and self.weight_epoch_dependent: sigmoid_center = ( @@ -136,7 +136,7 @@ def _calculate_unaggregated_fuzzy_loss( + math.exp(-(kwargs["current_epoch"] - sigmoid_center) / sigmoid_spread) ) implication_loss_weighted *= weight - weighted_mean = implication_loss_weighted.sum(dim=-1).mean() + weighted_mean = implication_loss_weighted.mean() return implication_loss_weighted, unweighted_mean, weighted_mean @@ -162,10 +162,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: tuple: Tuple containing total loss, base loss, and implication loss. """ base_loss = self._calculate_unaggregated_base_loss(input, target, **kwargs) - loss_components = {"base_loss": base_loss.sum(dim=-1).mean()} + loss_components = {"base_loss": base_loss.mean()} if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: - return base_loss, loss_components + return base_loss.mean(), loss_components pred = torch.sigmoid(input) fuzzy_loss, unweighted_fuzzy_mean, weighted_fuzzy_mean = ( @@ -181,10 +181,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: loss_components["unweighted_fuzzy_loss"] = unweighted_fuzzy_mean loss_components["weighted_fuzzy_loss"] = weighted_fuzzy_mean if self.base_loss is None or target is None: - return self.impl_weight * fuzzy_loss, loss_components + total_loss = self.impl_weight * fuzzy_loss else: total_loss = base_loss * (1 + self.impl_weight * fuzzy_loss) - return total_loss.sum(dim=-1).mean(), loss_components + return total_loss.mean(), loss_components def _calculate_implication_loss( self, l: torch.Tensor, r: torch.Tensor, target: torch.Tensor @@ -318,10 +318,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: tuple: Tuple containing total loss, base loss, implication loss, and disjointness loss. """ base_loss = self._calculate_unaggregated_base_loss(input, target, **kwargs) - loss_components = {"base_loss": base_loss.sum(dim=-1).mean()} + loss_components = {"base_loss": base_loss.mean()} if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: - return base_loss, loss_components + return base_loss.mean(), loss_components pred = torch.sigmoid(input) impl_loss, unweighted_impl_mean, weighted_impl_mean = ( @@ -351,15 +351,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: loss_components["weighted_disjointness_loss"] = weighted_disj_mean if self.base_loss is None or target is None: - return ( - self.impl_weight * impl_loss + self.disjoint_weight * disj_loss, - loss_components, - ) + total_loss = self.impl_weight * impl_loss + self.disjoint_weight * disj_loss else: - total_loss = base_loss + ( + total_loss = base_loss * ( 1 + self.impl_weight * impl_loss + self.disjoint_weight * disj_loss ) - return total_loss.sum(dim=-1).mean(), loss_components + return total_loss.mean(), loss_components def _load_label_names(path_to_label_names: str) -> List: From 70d0f29e74272ca3d59b2a765d035c58a7361f45 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 23 Oct 2024 16:30:47 +0200 Subject: [PATCH 38/47] make fuzzy loss implementation more efficient --- chebai/loss/semantic.py | 59 ++++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 49b1a6d4..b5001c3a 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -93,9 +93,12 @@ def __init__( self.hierarchy = self._load_implications( os.path.join(data_extractor.raw_dir, "chebi.obo") ) - implication_filter = _build_implication_filter(self.label_names, self.hierarchy) - self.implication_filter_l = implication_filter[:, 0] - self.implication_filter_r = implication_filter[:, 1] + implication_filter_dense = _build_dense_filter( + _build_implication_filter(self.label_names, self.hierarchy), + len(self.label_names), + ) + self.implication_filter_l = implication_filter_dense + self.implication_filter_r = self.implication_filter_l.transpose(0, 1) self.fuzzy_implication = fuzzy_implication self.impl_weight = impl_loss_weight self.pos_scalar = pos_scalar @@ -108,17 +111,30 @@ def __init__( def _calculate_unaggregated_fuzzy_loss( self, pred, target: torch.Tensor, weight, filter_l, filter_r, **kwargs ): - l = pred[:, filter_l] - r = pred[:, filter_r] - individual_loss = self._calculate_implication_loss(l, r, target) - implication_loss = _filter_by_ground_truth( - individual_loss, - target, - filter_l, - filter_r, + # for each batch, get all pairwise losses: [a1, a2, a3] -> [[a1*a1, a1*a2, a1*a3],[a2*a1,...],[a3*a1,...]] + preds_expanded1 = pred.unsqueeze(1).expand(-1, pred.shape[1], -1) + preds_expanded2 = pred.unsqueeze(2).expand(-1, -1, pred.shape[1]) + # filter by implication relations and labels + + label_filter = target.unsqueeze(2).expand(-1, -1, pred.shape[1]) + filter_l = filter_l.unsqueeze(0).expand(pred.shape[0], -1, -1) + filter_r = filter_r.unsqueeze(0).expand(pred.shape[0], -1, -1) + + loss_impl_l = ( + self._calculate_implication_loss(preds_expanded2, preds_expanded1, target) + * filter_l + * (1 - label_filter) + ) + loss_impl_r = ( + self._calculate_implication_loss(preds_expanded1, preds_expanded2, target) + * filter_r + * label_filter ) - unweighted_mean = implication_loss.mean() - implication_loss_weighted = implication_loss + + loss_by_cls = (loss_impl_l + loss_impl_r).sum(dim=-1) + + unweighted_mean = loss_by_cls.mean() + implication_loss_weighted = loss_by_cls if "current_epoch" in kwargs and self.weight_epoch_dependent: sigmoid_center = ( self.weight_epoch_dependent[0] @@ -376,7 +392,8 @@ def _load_label_names(path_to_label_names: str) -> List: def _build_implication_filter(label_names: List, hierarchy: dict) -> torch.Tensor: """ - Build implication filter based on label names and hierarchy. + Build implication filter based on label names and hierarchy. Results in list of pairs (A,B) for each implication + A->B (including indirect implications). Args: label_names (list): List of label names. @@ -395,6 +412,13 @@ def _build_implication_filter(label_names: List, hierarchy: dict) -> torch.Tenso ) +def _build_dense_filter(sparse_filter: torch.Tensor, n_labels: int) -> torch.Tensor: + res = torch.zeros((n_labels, n_labels), dtype=torch.bool) + for l, r in sparse_filter: + res[l, r] = True + return res + + def _build_disjointness_filter( path_to_disjointness: str, label_names: List, hierarchy: dict ) -> tuple: @@ -431,14 +455,16 @@ def _build_disjointness_filter( ) dis_filter = torch.tensor(list(disjoints)) - return dis_filter[:, 0], dis_filter[:, 1] + dense = _build_dense_filter(dis_filter, len(label_names)) + dense_r = dense.transpose(0, 1) + return dense, dense_r if __name__ == "__main__": loss = DisjointLoss( os.path.join("data", "disjoint.csv"), ChEBIOver100(chebi_version=231) ) - l = loss(torch.randn(10, 997), torch.randn(10, 997)) + l = loss(torch.randn(10, 997), torch.randint(0, 2, (10, 997))) loss_with_base = DisjointLoss( os.path.join("data", "disjoint.csv"), @@ -448,3 +474,4 @@ def _build_disjointness_filter( lb = loss_with_base(torch.randn(10, 997), torch.randn(10, 997)) print(l) print(lb) + print() From 93a4aae15aed75c44dee52aa0b76c1bcd399eaa1 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 23 Oct 2024 16:51:54 +0200 Subject: [PATCH 39/47] fix device --- chebai/loss/semantic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index b5001c3a..0dba832b 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -117,8 +117,8 @@ def _calculate_unaggregated_fuzzy_loss( # filter by implication relations and labels label_filter = target.unsqueeze(2).expand(-1, -1, pred.shape[1]) - filter_l = filter_l.unsqueeze(0).expand(pred.shape[0], -1, -1) - filter_r = filter_r.unsqueeze(0).expand(pred.shape[0], -1, -1) + filter_l = filter_l.to(pred.device).unsqueeze(0).expand(pred.shape[0], -1, -1) + filter_r = filter_r.to(pred.device).unsqueeze(0).expand(pred.shape[0], -1, -1) loss_impl_l = ( self._calculate_implication_loss(preds_expanded2, preds_expanded1, target) From fbd2c653600771497259c355154396588055d5ee Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 24 Oct 2024 14:37:36 +0200 Subject: [PATCH 40/47] add max aggregation for fuzzy loss --- chebai/loss/semantic.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 0dba832b..7f5f1810 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -46,6 +46,13 @@ class ImplicationLoss(torch.nn.Module): use_sigmoidal_implication (bool, optional): Whether to use the sigmoidal fuzzy implication based on the specified fuzzy_implication (as defined by van Krieken et al., 2022: Analyzing Differentiable Fuzzy Logic Operators). Defaults to False. + weight_epoch_dependent (Union[bool, tuple[int, int]], optional): Whether to weight the implication loss + depending on the current epoch with the sigmoid function sigmoid((epoch-c)/s). If True, c=50 and s=10, + otherwise, a tuple of integers (c,s) can be supplied. Defaults to False. + start_at_epoch (int, optional): Epoch at which to start applying the loss. Defaults to 0. + violations_per_cls_aggregator (Literal["sum", "max"], optional): How to aggregate violations for each class. + If a class is involved in several implications / disjointnesses, the loss value for this class will be + aggregated with this method. Defaults to "sum". """ def __init__( @@ -72,6 +79,7 @@ def __init__( use_sigmoidal_implication: bool = False, weight_epoch_dependent: Union[bool | tuple[int, int]] = False, start_at_epoch: int = 0, + violations_per_cls_aggregator: Literal["sum", "max"] = "sum", ): super().__init__() # automatically choose labeled subset for implication filter in case of mixed dataset @@ -107,6 +115,7 @@ def __init__( self.use_sigmoidal_implication = use_sigmoidal_implication self.weight_epoch_dependent = weight_epoch_dependent self.start_at_epoch = start_at_epoch + self.violations_per_cls_aggregator = violations_per_cls_aggregator def _calculate_unaggregated_fuzzy_loss( self, pred, target: torch.Tensor, weight, filter_l, filter_r, **kwargs @@ -131,7 +140,10 @@ def _calculate_unaggregated_fuzzy_loss( * label_filter ) - loss_by_cls = (loss_impl_l + loss_impl_r).sum(dim=-1) + if self.violations_per_cls_aggregator == "sum": + loss_by_cls = (loss_impl_l + loss_impl_r).sum(dim=-1) + else: + loss_by_cls = torch.max(loss_impl_l + loss_impl_r, dim=-1).values unweighted_mean = loss_by_cls.mean() implication_loss_weighted = loss_by_cls @@ -461,17 +473,18 @@ def _build_disjointness_filter( if __name__ == "__main__": - loss = DisjointLoss( - os.path.join("data", "disjoint.csv"), ChEBIOver100(chebi_version=231) - ) - l = loss(torch.randn(10, 997), torch.randint(0, 2, (10, 997))) - loss_with_base = DisjointLoss( os.path.join("data", "disjoint.csv"), ChEBIOver100(chebi_version=231), base_loss=BCEWeighted(beta=0.99), ) - lb = loss_with_base(torch.randn(10, 997), torch.randn(10, 997)) - print(l) + lb = loss_with_base(torch.randn(10, 997), torch.randint(0, 2, (10, 997))) print(lb) - print() + loss_max = DisjointLoss( + os.path.join("data", "disjoint.csv"), + ChEBIOver100(chebi_version=231), + base_loss=BCEWeighted(beta=0.99), + violations_per_cls_aggregator="max", + ) + lm = loss_max(torch.randn(10, 997), torch.randint(0, 2, (10, 997))) + print(lm) From acb05d77e16653da7209860a5ec4e8a2d46d3375 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 25 Oct 2024 11:10:02 +0200 Subject: [PATCH 41/47] adapt evaluation to new fuzzy loss --- chebai/loss/semantic.py | 19 ------------------- chebai/result/analyse_sem.py | 36 +++++++++++++++++++++++------------- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 7f5f1810..118b496e 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -12,25 +12,6 @@ from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed -def _filter_by_ground_truth(individual_loss, target, filter_l, filter_r): - # mask of ground truth labels for implications, shape (batch_size, num_implications/num_disjointnesses) - target_l = target[:, filter_l] - target_r = target[:, filter_r] - - # filter individual loss: for an implication A->B, only apply loss to A if A is labeled false, - # only apply loss to B if B is labeled true - applicable_individual_loss_l = individual_loss * (1 - target_l) - applicable_individual_loss_r = individual_loss * target_r - class_loss = torch.zeros(target.shape, device=target.device) - # for each class, sum up the losses for all implication antecedents and consequents - for cls in range(class_loss.shape[1]): - class_loss[:, cls] = applicable_individual_loss_l[:, filter_l == cls].sum(dim=1) - class_loss[:, cls] += applicable_individual_loss_r[:, filter_r == cls].sum( - dim=1 - ) - return class_loss - - class ImplicationLoss(torch.nn.Module): """ Implication Loss module. diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index 4046df87..51a1fb2b 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -236,6 +236,15 @@ def __call__(self, preds): return preds +def _filter_to_dense(filter): + filter_dense = [] + for i in range(filter.shape[0]): + for j in range(filter.shape[1]): + if filter[i, j] > 0: + filter_dense.append([i, j]) + return torch.tensor(filter_dense) + + def build_prediction_filter(data_module_labeled=None): if data_module_labeled is None: data_module_labeled = ChEBIOver100(chebi_version=231) @@ -245,9 +254,12 @@ def build_prediction_filter(data_module_labeled=None): path_to_disjointness=os.path.join("data", "disjoint.csv"), data_extractor=data_module_labeled, ) + impl = _filter_to_dense(dl.implication_filter_l) + disj = _filter_to_dense(dl.disjoint_filter_l) + return [ - (dl.implication_filter_l, dl.implication_filter_r, "impl"), - (dl.disjoint_filter_l, dl.disjoint_filter_r, "disj"), + (impl[:, 0], impl[:, 1], "impl"), + (disj[:, 0], disj[:, 1], "disj"), ] @@ -595,9 +607,7 @@ def run_all( else: ckpt_path = None for file in os.listdir(os.path.join(ckpt_dir, run_name)): - if file.startswith(f"best_epoch={epoch}_") or file.startswith( - f"per_epoch={epoch}_" - ): + if f"epoch={epoch}_" in file or f"epoch={epoch}." in file: ckpt_path = os.path.join(os.path.join(ckpt_dir, run_name, file)) assert ( ckpt_path is not None @@ -632,10 +642,10 @@ def run_all( # identity function if remove_violations is False smooth_preds(preds) - details_path = os.path.join( - results_dir, - f"{run_name}_ep{epoch}_{dataset.__class__.__name__}_{dataset_key}", - ) + details_path = None # os.path.join( + # results_dir, + # f"{run_name}_ep{epoch}_{dataset.__class__.__name__}_{dataset_key}", + # ) metrics_dict = run_consistency_metrics( preds, prediction_filters, @@ -694,10 +704,10 @@ def run_fuzzy_loss(tag="fuzzy_loss", skip_first_n=0): check_consistency_on=chebi100, prediction_datasets=[ (chebi100, "test"), - (pubchem_kmeans, "cluster1_cutoff2k.pt"), - (pubchem_kmeans, "cluster2.pt"), - (pubchem_kmeans, "ten_from_each_cluster.pt"), - (pubchem_kmeans, "chebi_close.pt"), + # (pubchem_kmeans, "cluster1_cutoff2k.pt"), + # (pubchem_kmeans, "cluster2.pt"), + # (pubchem_kmeans, "ten_from_each_cluster.pt"), + # (pubchem_kmeans, "chebi_close.pt"), ], ) From ebda76c7aba34ac1124fb475b13d08b734729594 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 25 Oct 2024 16:20:30 +0200 Subject: [PATCH 42/47] add binary implication --- chebai/loss/semantic.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 118b496e..6b6d02c0 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -52,6 +52,8 @@ def __init__( "g", "reverse-goedel", "rg", + "binary", + "b", ] = "reichenbach", impl_loss_weight: float = 0.1, pos_scalar: Union[int, float] = 1, @@ -249,6 +251,8 @@ def _calculate_implication_loss( individual_loss = torch.where(l <= r, 0, one_min_r) elif self.fuzzy_implication in ["reverse-goedel", "rg"]: individual_loss = torch.where(l <= r, 0, l) + elif self.fuzzy_implication in ["binary", "b"]: + individual_loss = torch.where(l <= r, 0, 1) else: raise NotImplementedError( f"Unknown fuzzy implication {self.fuzzy_implication}" @@ -454,18 +458,16 @@ def _build_disjointness_filter( if __name__ == "__main__": - loss_with_base = DisjointLoss( + loss = DisjointLoss( os.path.join("data", "disjoint.csv"), ChEBIOver100(chebi_version=231), - base_loss=BCEWeighted(beta=0.99), + base_loss=BCEWeighted(), ) - lb = loss_with_base(torch.randn(10, 997), torch.randint(0, 2, (10, 997))) - print(lb) - loss_max = DisjointLoss( - os.path.join("data", "disjoint.csv"), - ChEBIOver100(chebi_version=231), - base_loss=BCEWeighted(beta=0.99), - violations_per_cls_aggregator="max", - ) - lm = loss_max(torch.randn(10, 997), torch.randint(0, 2, (10, 997))) + lm = loss(torch.randn(10, 997), torch.randint(0, 2, (10, 997))) + print(lm) + loss.disjoint_filter_l = torch.tensor( + [[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0], [0, 0, 1, 0]] + ) + loss.disjoint_filter_r = loss.disjoint_filter_l.transpose(0, 1) + # todo From 71a3f301c9d7206e29d7fe13ead28721828b9402 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 28 Oct 2024 10:42:32 +0100 Subject: [PATCH 43/47] fix disjointness, binary implication --- chebai/loss/semantic.py | 63 +++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 6b6d02c0..2acf3431 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -101,7 +101,14 @@ def __init__( self.violations_per_cls_aggregator = violations_per_cls_aggregator def _calculate_unaggregated_fuzzy_loss( - self, pred, target: torch.Tensor, weight, filter_l, filter_r, **kwargs + self, + pred, + target: torch.Tensor, + weight, + filter_l, + filter_r, + mode="impl", + **kwargs, ): # for each batch, get all pairwise losses: [a1, a2, a3] -> [[a1*a1, a1*a2, a1*a3],[a2*a1,...],[a3*a1,...]] preds_expanded1 = pred.unsqueeze(1).expand(-1, pred.shape[1], -1) @@ -111,22 +118,25 @@ def _calculate_unaggregated_fuzzy_loss( label_filter = target.unsqueeze(2).expand(-1, -1, pred.shape[1]) filter_l = filter_l.to(pred.device).unsqueeze(0).expand(pred.shape[0], -1, -1) filter_r = filter_r.to(pred.device).unsqueeze(0).expand(pred.shape[0], -1, -1) - - loss_impl_l = ( - self._calculate_implication_loss(preds_expanded2, preds_expanded1, target) - * filter_l - * (1 - label_filter) - ) - loss_impl_r = ( - self._calculate_implication_loss(preds_expanded1, preds_expanded2, target) - * filter_r - * label_filter - ) + if mode == "impl": + all_implications = self._calculate_implication_loss( + preds_expanded2, preds_expanded1 + ) + else: + all_implications = self._calculate_implication_loss( + preds_expanded2, 1 - preds_expanded1 + ) + loss_impl_l = all_implications * filter_l * (1 - label_filter) + if mode == "impl": + loss_impl_r = all_implications.transpose(1, 2) * filter_r * label_filter + loss_impl_sum = loss_impl_l + loss_impl_r + else: + loss_impl_sum = loss_impl_l if self.violations_per_cls_aggregator == "sum": - loss_by_cls = (loss_impl_l + loss_impl_r).sum(dim=-1) + loss_by_cls = (loss_impl_sum).sum(dim=-1) else: - loss_by_cls = torch.max(loss_impl_l + loss_impl_r, dim=-1).values + loss_by_cls = torch.max(loss_impl_sum, dim=-1).values unweighted_mean = loss_by_cls.mean() implication_loss_weighted = loss_by_cls @@ -198,7 +208,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: return total_loss.mean(), loss_components def _calculate_implication_loss( - self, l: torch.Tensor, r: torch.Tensor, target: torch.Tensor + self, l: torch.Tensor, r: torch.Tensor ) -> torch.Tensor: """ Calculate implication loss based on T-norm and other parameters. @@ -252,7 +262,7 @@ def _calculate_implication_loss( elif self.fuzzy_implication in ["reverse-goedel", "rg"]: individual_loss = torch.where(l <= r, 0, l) elif self.fuzzy_implication in ["binary", "b"]: - individual_loss = torch.where(l <= r, 0, 1) + individual_loss = torch.where(l <= r, 0, 1).to(dtype=l.dtype) else: raise NotImplementedError( f"Unknown fuzzy implication {self.fuzzy_implication}" @@ -336,7 +346,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: return base_loss.mean(), loss_components - pred = torch.sigmoid(input) + pred = input # torch.sigmoid(input) impl_loss, unweighted_impl_mean, weighted_impl_mean = ( self._calculate_unaggregated_fuzzy_loss( pred, @@ -357,6 +367,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: self.disjoint_weight, self.disjoint_filter_l, self.disjoint_filter_r, + mode="disj", **kwargs, ) ) @@ -462,12 +473,22 @@ def _build_disjointness_filter( os.path.join("data", "disjoint.csv"), ChEBIOver100(chebi_version=231), base_loss=BCEWeighted(), + impl_loss_weight=1, + disjoint_loss_weight=1, ) - lm = loss(torch.randn(10, 997), torch.randint(0, 2, (10, 997))) - - print(lm) - loss.disjoint_filter_l = torch.tensor( + # lm = loss(torch.randn(10, 997), torch.randint(0, 2, (10, 997))) + # print(lm) + loss.implication_filter_l = torch.tensor( [[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0], [0, 0, 1, 0]] ) + loss.implication_filter_r = loss.implication_filter_l.transpose(0, 1) + loss.disjoint_filter_l = torch.tensor( + [[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 0], [0, 1, 0, 0]] + ) loss.disjoint_filter_r = loss.disjoint_filter_l.transpose(0, 1) + preds = torch.tensor([[0.1, 0.3, 0.7, 0.4], [0.5, 0.2, 0.9, 0.1]]) + labels = [[0, 1, 1, 0], [0, 0, 1, 1]] + lm = loss(preds, torch.tensor(labels)) + print(lm) + print() # todo From 4b354d3dedf79b4ee704d587c3f87205b6e2b175 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 28 Oct 2024 10:46:37 +0100 Subject: [PATCH 44/47] add sigmoid --- chebai/loss/semantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 2acf3431..f4e67e75 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -346,7 +346,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: return base_loss.mean(), loss_components - pred = input # torch.sigmoid(input) + pred = torch.sigmoid(input) impl_loss, unweighted_impl_mean, weighted_impl_mean = ( self._calculate_unaggregated_fuzzy_loss( pred, From 45c54f41a2187607744b302e12d67eb0fc70cb59 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 4 Nov 2024 14:53:53 +0100 Subject: [PATCH 45/47] add log- and mean aggregation --- chebai/loss/semantic.py | 54 ++++++++++++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index f4e67e75..5c4d06df 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -62,7 +62,10 @@ def __init__( use_sigmoidal_implication: bool = False, weight_epoch_dependent: Union[bool | tuple[int, int]] = False, start_at_epoch: int = 0, - violations_per_cls_aggregator: Literal["sum", "max"] = "sum", + violations_per_cls_aggregator: Literal[ + "sum", "max", "mean", "log-sum", "log-max", "log-mean" + ] = "sum", + multiply_with_base_loss: bool = True, ): super().__init__() # automatically choose labeled subset for implication filter in case of mixed dataset @@ -99,6 +102,7 @@ def __init__( self.weight_epoch_dependent = weight_epoch_dependent self.start_at_epoch = start_at_epoch self.violations_per_cls_aggregator = violations_per_cls_aggregator + self.multiply_with_base_loss = multiply_with_base_loss def _calculate_unaggregated_fuzzy_loss( self, @@ -133,10 +137,21 @@ def _calculate_unaggregated_fuzzy_loss( else: loss_impl_sum = loss_impl_l - if self.violations_per_cls_aggregator == "sum": - loss_by_cls = (loss_impl_sum).sum(dim=-1) + if self.violations_per_cls_aggregator.startswith("log-"): + loss_impl_sum = -torch.log(1 - loss_impl_sum) + violations_per_cls_aggregator = self.violations_per_cls_aggregator[4:] else: - loss_by_cls = torch.max(loss_impl_sum, dim=-1).values + violations_per_cls_aggregator = self.violations_per_cls_aggregator + if violations_per_cls_aggregator == "sum": + loss_by_cls = loss_impl_sum.sum(dim=-1) + elif violations_per_cls_aggregator == "max": + loss_by_cls = loss_impl_sum.max(dim=-1).values + elif violations_per_cls_aggregator == "mean": + loss_by_cls = loss_impl_sum.mean(dim=-1) + else: + raise NotImplementedError( + f"Unknown violations_per_cls_aggregator {self.violations_per_cls_aggregator}" + ) unweighted_mean = loss_by_cls.mean() implication_loss_weighted = loss_by_cls @@ -203,8 +218,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: loss_components["weighted_fuzzy_loss"] = weighted_fuzzy_mean if self.base_loss is None or target is None: total_loss = self.impl_weight * fuzzy_loss - else: + elif self.multiply_with_base_loss: total_loss = base_loss * (1 + self.impl_weight * fuzzy_loss) + else: + total_loss = base_loss + self.impl_weight * fuzzy_loss return total_loss.mean(), loss_components def _calculate_implication_loss( @@ -376,10 +393,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: if self.base_loss is None or target is None: total_loss = self.impl_weight * impl_loss + self.disjoint_weight * disj_loss - else: + elif self.multiply_with_base_loss: total_loss = base_loss * ( 1 + self.impl_weight * impl_loss + self.disjoint_weight * disj_loss ) + else: + total_loss = ( + base_loss + + self.impl_weight * impl_loss + + self.disjoint_weight * disj_loss + ) return total_loss.mean(), loss_components @@ -476,8 +499,14 @@ def _build_disjointness_filter( impl_loss_weight=1, disjoint_loss_weight=1, ) - # lm = loss(torch.randn(10, 997), torch.randint(0, 2, (10, 997))) - # print(lm) + random_preds = torch.randn(10, 997) + random_labels = torch.randint(0, 2, (10, 997)) + for agg in ["sum", "max", "mean", "log-mean"]: + loss.violations_per_cls_aggregator = agg + l = loss(random_preds, random_labels) + print(f"Loss with {agg} aggregation for random input:", l) + + # simplified example for ontology with 4 classes, A -> B, B -> C, D -> C, B and D disjoint loss.implication_filter_l = torch.tensor( [[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0], [0, 0, 1, 0]] ) @@ -486,9 +515,10 @@ def _build_disjointness_filter( [[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 0], [0, 1, 0, 0]] ) loss.disjoint_filter_r = loss.disjoint_filter_l.transpose(0, 1) + # expected result: first sample: moderately high loss for B disj D, otherwise low, second sample: high loss for A -> B (applied to A), otherwise low preds = torch.tensor([[0.1, 0.3, 0.7, 0.4], [0.5, 0.2, 0.9, 0.1]]) labels = [[0, 1, 1, 0], [0, 0, 1, 1]] - lm = loss(preds, torch.tensor(labels)) - print(lm) - print() - # todo + for agg in ["sum", "max", "mean", "log-mean"]: + loss.violations_per_cls_aggregator = agg + l = loss(preds, torch.tensor(labels)) + print(f"Loss with {agg} aggregation for simple input:", l) From 62a49ebea9c26d3d06a242ad9bf7b67b80629c6e Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 4 Nov 2024 15:05:40 +0100 Subject: [PATCH 46/47] add optional detach from gradients --- chebai/loss/semantic.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 5c4d06df..271c3124 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -66,6 +66,7 @@ def __init__( "sum", "max", "mean", "log-sum", "log-max", "log-mean" ] = "sum", multiply_with_base_loss: bool = True, + no_grads: bool = False, ): super().__init__() # automatically choose labeled subset for implication filter in case of mixed dataset @@ -103,6 +104,7 @@ def __init__( self.start_at_epoch = start_at_epoch self.violations_per_cls_aggregator = violations_per_cls_aggregator self.multiply_with_base_loss = multiply_with_base_loss + self.no_grads = no_grads def _calculate_unaggregated_fuzzy_loss( self, @@ -214,6 +216,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: **kwargs, ) ) + if self.no_grads: + fuzzy_loss = fuzzy_loss.detach() loss_components["unweighted_fuzzy_loss"] = unweighted_fuzzy_mean loss_components["weighted_fuzzy_loss"] = weighted_fuzzy_mean if self.base_loss is None or target is None: @@ -374,6 +378,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: **kwargs, ) ) + if self.no_grads: + impl_loss = impl_loss.detach() loss_components["unweighted_implication_loss"] = unweighted_impl_mean loss_components["weighted_implication_loss"] = weighted_impl_mean @@ -388,6 +394,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: **kwargs, ) ) + if self.no_grads: + disj_loss = disj_loss.detach() loss_components["unweighted_disjointness_loss"] = unweighted_disj_mean loss_components["weighted_disjointness_loss"] = weighted_disj_mean From e7bd80a6dba01e7755f4383361a343ba349ee7d5 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 19 Dec 2024 15:38:17 +0100 Subject: [PATCH 47/47] fix: ignore empty list in evaluation --- chebai/result/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 2a97e5ed..35dbc319 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -155,7 +155,7 @@ def evaluate_model( test_labels = _concat_tuple(labels_list) return test_preds, test_labels return test_preds, None - elif preds_list i + elif len(preds_list) < 0: torch.save( _concat_tuple(preds_list), os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),