diff --git a/chebai/cli.py b/chebai/cli.py index f2ad1072..36245aa0 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -50,6 +50,10 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): parser.link_arguments( "data", "model.init_args.criterion.init_args.data_extractor" ) + parser.link_arguments( + "data.init_args.chebi_version", + "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version", + ) @staticmethod def subcommands() -> Dict[str, Set[str]]: diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index c00756e6..9ff7917e 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -5,14 +5,16 @@ import torch from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed class BCEWeighted(torch.nn.BCEWithLogitsLoss): """ BCEWithLogitsLoss with weights automatically computed according to the beta parameter. + If beta is None or data_extractor is None, the loss is unweighted. - This class computes weights based on the formula from the paper: + This class computes weights based on the formula from the paper by Cui et al. (2019): https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf Args: @@ -24,13 +26,17 @@ def __init__( self, beta: Optional[float] = None, data_extractor: Optional[XYBaseDataModule] = None, + **kwargs, ): self.beta = beta if isinstance(data_extractor, LabeledUnlabeledMixed): data_extractor = data_extractor.labeled self.data_extractor = data_extractor - - super().__init__() + assert ( + isinstance(self.data_extractor, _ChEBIDataExtractor) + or self.data_extractor is None + ) + super().__init__(**kwargs) def set_pos_weight(self, input: torch.Tensor) -> None: """ @@ -50,6 +56,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None: ) and self.pos_weight is None ): + print( + f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})" + ) complete_data = pd.concat( [ pd.read_pickle( @@ -75,7 +84,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None: [w / mean for w in weights], device=input.device ) - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward( + self, input: torch.Tensor, target: torch.Tensor, **kwargs + ) -> torch.Tensor: """ Forward pass for the loss calculation. diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 78938d22..271c3124 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -7,6 +7,7 @@ import torch from chebai.loss.bce_weighted import BCEWeighted +from chebai.preprocessing.datasets import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed @@ -16,49 +17,175 @@ class ImplicationLoss(torch.nn.Module): Implication Loss module. Args: - data_extractor (Union[_ChEBIDataExtractor, LabeledUnlabeledMixed]): Data extractor for labels. + data_extractor _ChEBIDataExtractor: Data extractor for labels. base_loss (torch.nn.Module, optional): Base loss function. Defaults to None. - tnorm (Literal["product", "lukasiewicz", "xu19"], optional): T-norm type. Defaults to "product". + fuzzy_implication (Literal["product", "lukasiewicz", "xu19"], optional): T-norm type. Defaults to "product". impl_loss_weight (float, optional): Weight of implication loss relative to base loss. Defaults to 0.1. pos_scalar (int, optional): Positive scalar exponent. Defaults to 1. pos_epsilon (float, optional): Epsilon value for numerical stability. Defaults to 0.01. multiply_by_softmax (bool, optional): Whether to multiply by softmax. Defaults to False. + use_sigmoidal_implication (bool, optional): Whether to use the sigmoidal fuzzy implication based on the + specified fuzzy_implication (as defined by van Krieken et al., 2022: Analyzing Differentiable Fuzzy Logic + Operators). Defaults to False. + weight_epoch_dependent (Union[bool, tuple[int, int]], optional): Whether to weight the implication loss + depending on the current epoch with the sigmoid function sigmoid((epoch-c)/s). If True, c=50 and s=10, + otherwise, a tuple of integers (c,s) can be supplied. Defaults to False. + start_at_epoch (int, optional): Epoch at which to start applying the loss. Defaults to 0. + violations_per_cls_aggregator (Literal["sum", "max"], optional): How to aggregate violations for each class. + If a class is involved in several implications / disjointnesses, the loss value for this class will be + aggregated with this method. Defaults to "sum". """ def __init__( self, - data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed], + data_extractor: XYBaseDataModule, base_loss: torch.nn.Module = None, - tnorm: Literal["product", "lukasiewicz", "xu19"] = "product", + fuzzy_implication: Literal[ + "reichenbach", + "rc", + "lukasiewicz", + "lk", + "xu19", + "kleene_dienes", + "kd", + "goedel", + "g", + "reverse-goedel", + "rg", + "binary", + "b", + ] = "reichenbach", impl_loss_weight: float = 0.1, - pos_scalar: int = 1, + pos_scalar: Union[int, float] = 1, pos_epsilon: float = 0.01, multiply_by_softmax: bool = False, + use_sigmoidal_implication: bool = False, + weight_epoch_dependent: Union[bool | tuple[int, int]] = False, + start_at_epoch: int = 0, + violations_per_cls_aggregator: Literal[ + "sum", "max", "mean", "log-sum", "log-max", "log-mean" + ] = "sum", + multiply_with_base_loss: bool = True, + no_grads: bool = False, ): super().__init__() # automatically choose labeled subset for implication filter in case of mixed dataset if isinstance(data_extractor, LabeledUnlabeledMixed): data_extractor = data_extractor.labeled + assert isinstance(data_extractor, _ChEBIDataExtractor) self.data_extractor = data_extractor # propagate data_extractor to base loss if isinstance(base_loss, BCEWeighted): base_loss.data_extractor = self.data_extractor + base_loss.reduction = ( + "none" # needed to multiply fuzzy loss with base loss for each sample + ) self.base_loss = base_loss self.implication_cache_file = f"implications_{self.data_extractor.name}.cache" self.label_names = _load_label_names( - os.path.join(data_extractor.raw_dir, "classes.txt") + os.path.join(data_extractor.processed_dir_main, "classes.txt") ) self.hierarchy = self._load_implications( os.path.join(data_extractor.raw_dir, "chebi.obo") ) - implication_filter = _build_implication_filter(self.label_names, self.hierarchy) - self.implication_filter_l = implication_filter[:, 0] - self.implication_filter_r = implication_filter[:, 1] - self.tnorm = tnorm + implication_filter_dense = _build_dense_filter( + _build_implication_filter(self.label_names, self.hierarchy), + len(self.label_names), + ) + self.implication_filter_l = implication_filter_dense + self.implication_filter_r = self.implication_filter_l.transpose(0, 1) + self.fuzzy_implication = fuzzy_implication self.impl_weight = impl_loss_weight self.pos_scalar = pos_scalar self.eps = pos_epsilon self.multiply_by_softmax = multiply_by_softmax + self.use_sigmoidal_implication = use_sigmoidal_implication + self.weight_epoch_dependent = weight_epoch_dependent + self.start_at_epoch = start_at_epoch + self.violations_per_cls_aggregator = violations_per_cls_aggregator + self.multiply_with_base_loss = multiply_with_base_loss + self.no_grads = no_grads + + def _calculate_unaggregated_fuzzy_loss( + self, + pred, + target: torch.Tensor, + weight, + filter_l, + filter_r, + mode="impl", + **kwargs, + ): + # for each batch, get all pairwise losses: [a1, a2, a3] -> [[a1*a1, a1*a2, a1*a3],[a2*a1,...],[a3*a1,...]] + preds_expanded1 = pred.unsqueeze(1).expand(-1, pred.shape[1], -1) + preds_expanded2 = pred.unsqueeze(2).expand(-1, -1, pred.shape[1]) + # filter by implication relations and labels + + label_filter = target.unsqueeze(2).expand(-1, -1, pred.shape[1]) + filter_l = filter_l.to(pred.device).unsqueeze(0).expand(pred.shape[0], -1, -1) + filter_r = filter_r.to(pred.device).unsqueeze(0).expand(pred.shape[0], -1, -1) + if mode == "impl": + all_implications = self._calculate_implication_loss( + preds_expanded2, preds_expanded1 + ) + else: + all_implications = self._calculate_implication_loss( + preds_expanded2, 1 - preds_expanded1 + ) + loss_impl_l = all_implications * filter_l * (1 - label_filter) + if mode == "impl": + loss_impl_r = all_implications.transpose(1, 2) * filter_r * label_filter + loss_impl_sum = loss_impl_l + loss_impl_r + else: + loss_impl_sum = loss_impl_l + + if self.violations_per_cls_aggregator.startswith("log-"): + loss_impl_sum = -torch.log(1 - loss_impl_sum) + violations_per_cls_aggregator = self.violations_per_cls_aggregator[4:] + else: + violations_per_cls_aggregator = self.violations_per_cls_aggregator + if violations_per_cls_aggregator == "sum": + loss_by_cls = loss_impl_sum.sum(dim=-1) + elif violations_per_cls_aggregator == "max": + loss_by_cls = loss_impl_sum.max(dim=-1).values + elif violations_per_cls_aggregator == "mean": + loss_by_cls = loss_impl_sum.mean(dim=-1) + else: + raise NotImplementedError( + f"Unknown violations_per_cls_aggregator {self.violations_per_cls_aggregator}" + ) + + unweighted_mean = loss_by_cls.mean() + implication_loss_weighted = loss_by_cls + if "current_epoch" in kwargs and self.weight_epoch_dependent: + sigmoid_center = ( + self.weight_epoch_dependent[0] + if isinstance(self.weight_epoch_dependent, tuple) + else 50 + ) + sigmoid_spread = ( + self.weight_epoch_dependent[1] + if isinstance(self.weight_epoch_dependent, tuple) + else 10 + ) + # sigmoid function centered around epoch 50 + implication_loss_weighted = implication_loss_weighted / ( + 1 + + math.exp(-(kwargs["current_epoch"] - sigmoid_center) / sigmoid_spread) + ) + implication_loss_weighted *= weight + weighted_mean = implication_loss_weighted.mean() + + return implication_loss_weighted, unweighted_mean, weighted_mean + + def _calculate_unaggregated_base_loss(self, input, target, **kwargs): + nnl = kwargs.pop("non_null_labels", None) + labeled_input = input[nnl] if nnl else input + + if target is not None and self.base_loss is not None: + return self.base_loss(labeled_input, target.float()) + else: + return torch.zeros(input.shape, device=input.device) def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: """ @@ -72,26 +199,34 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: Returns: tuple: Tuple containing total loss, base loss, and implication loss. """ - nnl = kwargs.pop("non_null_labels", None) - if nnl: - labeled_input = input[nnl] - else: - labeled_input = input - if target is not None: - base_loss = self.base_loss(labeled_input, target.float()) - else: - base_loss = 0 + base_loss = self._calculate_unaggregated_base_loss(input, target, **kwargs) + loss_components = {"base_loss": base_loss.mean()} + + if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: + return base_loss.mean(), loss_components + pred = torch.sigmoid(input) - l = pred[:, self.implication_filter_l] - r = pred[:, self.implication_filter_r] - # implication_loss = torch.sqrt(torch.mean(torch.sum(l*(1-r), dim=-1), dim=0)) - implication_loss = self._calculate_implication_loss(l, r) - - return ( - base_loss + self.impl_weight * implication_loss, - base_loss, - implication_loss, + fuzzy_loss, unweighted_fuzzy_mean, weighted_fuzzy_mean = ( + self._calculate_unaggregated_fuzzy_loss( + pred, + target, + self.impl_weight, + self.implication_filter_l, + self.implication_filter_r, + **kwargs, + ) ) + if self.no_grads: + fuzzy_loss = fuzzy_loss.detach() + loss_components["unweighted_fuzzy_loss"] = unweighted_fuzzy_mean + loss_components["weighted_fuzzy_loss"] = weighted_fuzzy_mean + if self.base_loss is None or target is None: + total_loss = self.impl_weight * fuzzy_loss + elif self.multiply_with_base_loss: + total_loss = base_loss * (1 + self.impl_weight * fuzzy_loss) + else: + total_loss = base_loss + self.impl_weight * fuzzy_loss + return total_loss.mean(), loss_components def _calculate_implication_loss( self, l: torch.Tensor, r: torch.Tensor @@ -106,8 +241,14 @@ def _calculate_implication_loss( Returns: torch.Tensor: Calculated implication loss. """ - assert not l.isnan().any() - assert not r.isnan().any() + assert not l.isnan().any(), ( + f"l contains NaN values - l.shape: {l.shape}, l.isnan().sum(): {l.isnan().sum()}, " + f"l: {l}" + ) + assert not r.isnan().any(), ( + f"r contains NaN values - r.shape: {r.shape}, r.isnan().sum(): {r.isnan().sum()}, " + f"r: {r}" + ) if self.pos_scalar != 1: l = ( torch.pow(l + self.eps, 1 / self.pos_scalar) @@ -116,24 +257,51 @@ def _calculate_implication_loss( math.pow(1 + self.eps, 1 / self.pos_scalar) - math.pow(self.eps, 1 / self.pos_scalar) ) - one_min_r = torch.pow(1 - r, self.pos_scalar) + one_min_r = ( + torch.pow(1 - r + self.eps, 1 / self.pos_scalar) + - math.pow(self.eps, 1 / self.pos_scalar) + ) / ( + math.pow(1 + self.eps, 1 / self.pos_scalar) + - math.pow(self.eps, 1 / self.pos_scalar) + ) else: one_min_r = 1 - r - if self.tnorm == "product": + # for each implication I, calculate 1 - I(l, 1-one_min_r) + # for S-implications, this is equivalent to the t-norm + if self.fuzzy_implication in ["reichenbach", "rc"]: individual_loss = l * one_min_r - elif self.tnorm == "xu19": + # xu19 (from Xu et al., 2019: Semantic loss) is not a fuzzy implication, but behaves similar to the Reichenbach + # implication + elif self.fuzzy_implication == "xu19": individual_loss = -torch.log(1 - l * one_min_r) - elif self.tnorm == "lukasiewicz": + elif self.fuzzy_implication in ["lukasiewicz", "lk"]: individual_loss = torch.relu(l + one_min_r - 1) + elif self.fuzzy_implication in ["kleene_dienes", "kd"]: + individual_loss = torch.min(l, 1 - r) + elif self.fuzzy_implication in ["goedel", "g"]: + individual_loss = torch.where(l <= r, 0, one_min_r) + elif self.fuzzy_implication in ["reverse-goedel", "rg"]: + individual_loss = torch.where(l <= r, 0, l) + elif self.fuzzy_implication in ["binary", "b"]: + individual_loss = torch.where(l <= r, 0, 1).to(dtype=l.dtype) else: - raise NotImplementedError(f"Unknown tnorm {self.tnorm}") + raise NotImplementedError( + f"Unknown fuzzy implication {self.fuzzy_implication}" + ) + + if self.use_sigmoidal_implication: + # formula by van Krieken, 2022, applied to fuzzy implication with default parameters: b_0 = 0.5, s = 9 + # parts that only depend on b_0 and s are pre-calculated + implication = 1 - individual_loss + sigmoidal_implication = 0.01123379 * ( + 91.0171 * torch.sigmoid(9 * (implication - 0.5)) - 1 + ) + individual_loss = 1 - sigmoidal_implication if self.multiply_by_softmax: individual_loss = individual_loss * individual_loss.softmax(dim=-1) - return torch.mean( - torch.sum(individual_loss, dim=-1), - dim=0, - ) + + return individual_loss def _load_implications(self, path_to_chebi: str) -> dict: """ @@ -193,17 +361,57 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: Returns: tuple: Tuple containing total loss, base loss, implication loss, and disjointness loss. """ - loss, base_loss, impl_loss = super().forward(input, target, **kwargs) + base_loss = self._calculate_unaggregated_base_loss(input, target, **kwargs) + loss_components = {"base_loss": base_loss.mean()} + + if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: + return base_loss.mean(), loss_components + pred = torch.sigmoid(input) - l = pred[:, self.disjoint_filter_l] - r = pred[:, self.disjoint_filter_r] - disjointness_loss = self._calculate_implication_loss(l, 1 - r) - return ( - loss + self.disjoint_weight * disjointness_loss, - base_loss, - impl_loss, - disjointness_loss, + impl_loss, unweighted_impl_mean, weighted_impl_mean = ( + self._calculate_unaggregated_fuzzy_loss( + pred, + target, + self.impl_weight, + self.implication_filter_l, + self.implication_filter_r, + **kwargs, + ) ) + if self.no_grads: + impl_loss = impl_loss.detach() + loss_components["unweighted_implication_loss"] = unweighted_impl_mean + loss_components["weighted_implication_loss"] = weighted_impl_mean + + disj_loss, unweighted_disj_mean, weighted_disj_mean = ( + self._calculate_unaggregated_fuzzy_loss( + pred, + target, + self.disjoint_weight, + self.disjoint_filter_l, + self.disjoint_filter_r, + mode="disj", + **kwargs, + ) + ) + if self.no_grads: + disj_loss = disj_loss.detach() + loss_components["unweighted_disjointness_loss"] = unweighted_disj_mean + loss_components["weighted_disjointness_loss"] = weighted_disj_mean + + if self.base_loss is None or target is None: + total_loss = self.impl_weight * impl_loss + self.disjoint_weight * disj_loss + elif self.multiply_with_base_loss: + total_loss = base_loss * ( + 1 + self.impl_weight * impl_loss + self.disjoint_weight * disj_loss + ) + else: + total_loss = ( + base_loss + + self.impl_weight * impl_loss + + self.disjoint_weight * disj_loss + ) + return total_loss.mean(), loss_components def _load_label_names(path_to_label_names: str) -> List: @@ -223,7 +431,8 @@ def _load_label_names(path_to_label_names: str) -> List: def _build_implication_filter(label_names: List, hierarchy: dict) -> torch.Tensor: """ - Build implication filter based on label names and hierarchy. + Build implication filter based on label names and hierarchy. Results in list of pairs (A,B) for each implication + A->B (including indirect implications). Args: label_names (list): List of label names. @@ -242,6 +451,13 @@ def _build_implication_filter(label_names: List, hierarchy: dict) -> torch.Tenso ) +def _build_dense_filter(sparse_filter: torch.Tensor, n_labels: int) -> torch.Tensor: + res = torch.zeros((n_labels, n_labels), dtype=torch.bool) + for l, r in sparse_filter: + res[l, r] = True + return res + + def _build_disjointness_filter( path_to_disjointness: str, label_names: List, hierarchy: dict ) -> tuple: @@ -278,10 +494,39 @@ def _build_disjointness_filter( ) dis_filter = torch.tensor(list(disjoints)) - return dis_filter[:, 0], dis_filter[:, 1] + dense = _build_dense_filter(dis_filter, len(label_names)) + dense_r = dense.transpose(0, 1) + return dense, dense_r if __name__ == "__main__": loss = DisjointLoss( - os.path.join("data", "disjoint.csv"), ChEBIOver100(chebi_version=227) + os.path.join("data", "disjoint.csv"), + ChEBIOver100(chebi_version=231), + base_loss=BCEWeighted(), + impl_loss_weight=1, + disjoint_loss_weight=1, + ) + random_preds = torch.randn(10, 997) + random_labels = torch.randint(0, 2, (10, 997)) + for agg in ["sum", "max", "mean", "log-mean"]: + loss.violations_per_cls_aggregator = agg + l = loss(random_preds, random_labels) + print(f"Loss with {agg} aggregation for random input:", l) + + # simplified example for ontology with 4 classes, A -> B, B -> C, D -> C, B and D disjoint + loss.implication_filter_l = torch.tensor( + [[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0], [0, 0, 1, 0]] + ) + loss.implication_filter_r = loss.implication_filter_l.transpose(0, 1) + loss.disjoint_filter_l = torch.tensor( + [[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 0], [0, 1, 0, 0]] ) + loss.disjoint_filter_r = loss.disjoint_filter_l.transpose(0, 1) + # expected result: first sample: moderately high loss for B disj D, otherwise low, second sample: high loss for A -> B (applied to A), otherwise low + preds = torch.tensor([[0.1, 0.3, 0.7, 0.4], [0.5, 0.2, 0.9, 0.1]]) + labels = [[0, 1, 1, 0], [0, 0, 1, 1]] + for agg in ["sum", "max", "mean", "log-mean"]: + loss.violations_per_cls_aggregator = agg + l = loss(preds, torch.tensor(labels)) + print(f"Loss with {agg} aggregation for simple input:", l) diff --git a/chebai/models/base.py b/chebai/models/base.py index b881c671..4ba27bbc 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -67,6 +67,14 @@ def __init__( self.test_metrics = test_metrics self.pass_loss_kwargs = pass_loss_kwargs + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a + # different loss) + if "criterion.base_loss.pos_weight" in checkpoint["state_dict"]: + del checkpoint["state_dict"]["criterion.base_loss.pos_weight"] + if "criterion.pos_weight" in checkpoint["state_dict"]: + del checkpoint["state_dict"]["criterion.pos_weight"] + def __init_subclass__(cls, **kwargs): """ Automatically registers subclasses in the model registry to prevent duplicates. @@ -251,16 +259,31 @@ def _execute( loss_kwargs = dict() if self.pass_loss_kwargs: loss_kwargs = loss_kwargs_candidates + loss_kwargs["current_epoch"] = self.trainer.current_epoch loss = self.criterion(loss_data, loss_labels, **loss_kwargs) if isinstance(loss, tuple): - loss_additional = loss[1:] + unnamed_loss_index = 1 + if isinstance(loss[1], dict): + unnamed_loss_index = 2 + for key, value in loss[1].items(): + self.log( + key, + value if isinstance(value, int) else value.item(), + batch_size=len(batch), + on_step=True, + on_epoch=True, + prog_bar=False, + logger=True, + sync_dist=sync_dist, + ) + loss_additional = loss[unnamed_loss_index:] for i, loss_add in enumerate(loss_additional): self.log( f"{prefix}loss_{i}", loss_add if isinstance(loss_add, int) else loss_add.item(), batch_size=len(batch), on_step=True, - on_epoch=False, + on_epoch=True, prog_bar=False, logger=True, sync_dist=sync_dist, diff --git a/chebai/preprocessing/bin/smiles_token/tokens.txt b/chebai/preprocessing/bin/smiles_token/tokens.txt index 6f4c338e..15d5af14 100644 --- a/chebai/preprocessing/bin/smiles_token/tokens.txt +++ b/chebai/preprocessing/bin/smiles_token/tokens.txt @@ -769,3 +769,44 @@ p [14CH3] [HH] [CH3-] +[PH+] +[Zr+4] +[Zr+2] +[Zr+3] +[Th+4] +[Sn+2] +[ClH+] +[Ti+] +[Ir+2] +[Si@] +[Pd+] +[16NH] +[SH2+] +[Hf+2] +[Hf+4] +[Ti+2] +[Ru+2] +[BH-] +[Sc+3] +[Sb+5] +[Pd+2] +[Si@@] +[Cr+] +[W+2] +[PH-] +[BH] +[PH2+] +[3HH] +[CH2+] +[BiH2] +[SH-] +[2HH] +[GeH] +[PH2-] +[PbH2] +[Ru+] +[U+2] +[SiH-] +[AlH2] +[Fe+5] +[Rh+2] diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index c703ae1d..857a5862 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -177,7 +177,8 @@ def load_processed_data( self, kind: Optional[str] = None, filename: Optional[str] = None ) -> List: """ - Load processed data from a file. + Load processed data from a file. Either the kind or the filename has to be provided. If both are provided, the + filename is used. Args: kind (str, optional): The kind of dataset to load such as "train", "val" or "test". Defaults to None. @@ -1147,6 +1148,16 @@ def processed_main_file_names_dict(self) -> dict: """ return {"data": "data.pkl"} + @property + def raw_file_names(self) -> List[str]: + """ + Returns a list of raw file names. + + Returns: + List[str]: A list of file names corresponding to the raw data. + """ + return list(self.raw_file_names_dict.values()) + @property def processed_file_names_dict(self) -> dict: """ diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index c82ea42f..9e43302a 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -290,13 +290,14 @@ def download(self): class PubChemKMeans(PubChem): """ Dataset class representing a subset of PubChem dataset clustered using K-Means algorithm. + The idea is to create distinct distributions where pretraining and test sets are formed from dissimilar data. """ def __init__( self, *args, - n_clusters: int = 1e4, - random_size: int = 1e6, + n_clusters: int = 10000, + random_size: int = 1000000, exclude_data_from: _ChEBIDataExtractor = None, validation_size_limit: int = 4000, include_min_n_clusters: int = 100, @@ -306,9 +307,11 @@ def __init__( Args: n_clusters (int): Number of clusters to create using K-Means. random_size (int): Size of random dataset to download. - exclude_data_from (_ChEBIDataExtractor): Dataset extractor to exclude data clusters from. - validation_size_limit (int): Size limit for validation dataset. - include_min_n_clusters (int): Minimum number of clusters to include after exclusion. + exclude_data_from (_ChEBIDataExtractor): Dataset which should not overlap with selected clusters + (remove all clusters that contain data from this dataset). + validation_size_limit (int): Validation set will contain at most this number of instances. + include_min_n_clusters (int): Minimum number of clusters to keep if there are not enough clusters that don't + overlap with the `exclude_data_from` dataset. *args: Additional arguments for superclass initialization. **kwargs: Additional keyword arguments for superclass initialization. @@ -346,14 +349,19 @@ def split_label(self) -> str: @property def raw_file_names(self) -> List[str]: """ + Clusters generated by K-Means, sorted by size (cluster0 is the largest). + cluster0 is the training cluster (will be split into train/val/test in processed, used for pretraining) Returns: List[str]: List of raw file names expected in the raw directory. """ - return ["train.txt", "validation.txt", "test.txt"] + return ["cluster0.txt", "cluster1.txt", "cluster2.txt"] @property def fingerprints(self) -> pd.DataFrame: """ + Creates random dataset, sanitises, creates Mol objects, generates fingerprints (RDKit) + Saves `fingerprints_df` to `fingerprints.pkl` + Returns: pd.DataFrame: DataFrame containing SMILES and corresponding fingerprints. """ @@ -424,7 +432,13 @@ def _build_clusters(self) -> tuple[pd.DataFrame, pd.DataFrame]: def _exclude_clusters(self, cluster_centers: pd.DataFrame) -> pd.DataFrame: """ - Excludes clusters based on data from an exclusion dataset. + Excludes clusters based on data from an exclusion dataset (in a training setup, this is the labeled dataset, + usually ChEBI). The goal is to avoid having similar data in the labeled training and the PubChem evaluation. + + Loads data from `exclude_data_from` dataset, generates mols, fingerprints, finds closest cluster centre for + each fingerprint, saves data to `exclusion_data_clustered.pkl`, returns all clusters with no instances from the + exclusion data (or the n clusters with the lowest number of instances if there are less than n clusters with no + instances, n being the minimum number of clusters to include) Args: cluster_centers (pd.DataFrame): DataFrame of cluster centers. @@ -491,6 +505,7 @@ def _exclude_clusters(self, cluster_centers: pd.DataFrame) -> pd.DataFrame: @property def cluster_centers(self) -> pd.DataFrame: """ + Loads cluster centers from file if possible, otherwise calls `self._build_clusters()`. Returns: pd.DataFrame: DataFrame of cluster centers. """ @@ -505,6 +520,7 @@ def cluster_centers(self) -> pd.DataFrame: @property def fingerprints_clustered(self) -> pd.DataFrame: """ + Loads fingerprints with assigned clusters from file if possible, otherwise calls `self._build_clusters()`. Returns: pd.DataFrame: DataFrame of clustered fingerprints. """ @@ -521,6 +537,10 @@ def fingerprints_clustered(self) -> pd.DataFrame: @property def cluster_centers_superclustered(self) -> pd.DataFrame: """ + Calls `_exclude_clusters()` which removes all clusters that contain data from the exclusion set (usually the + ChEBI, i.e., the labeled dataset). + Runs KMeans with 3 clusters on remaining data, saves cluster centres with assigned supercluster-labels to + `cluster_centers_superclustered.pkl` Returns: pd.DataFrame: DataFrame of superclustered cluster centers. """ @@ -559,7 +579,12 @@ def cluster_centers_superclustered(self) -> pd.DataFrame: def download(self): """ - Downloads the PubChemKMeans dataset. + Downloads the PubChemKMeans dataset. This function creates the complete dataset (including train, test, and + validation splits). Most of the steps are hidden in properties (e.g., `self.fingerprints_clustered` triggers + the download of a random dataset, the calculation of fingerprints for it and the KMeans clustering) + The final splits are created by assigning all fingerprints that belong to a cluster of a certain supercluster + to a dataset. This creates 3 datasets (for each of the 3 superclusters), the datasets are saved as validation, + test and train based on their size. The validation set is limited to `self.validation_size_limit` entries. """ if self._k == PubChem.FULL: super().download() @@ -580,7 +605,7 @@ def download(self): splits = [fp_grouped.get_group(g) for g in fp_grouped.groups if g != -1] splits[0] = splits[0][: self.validation_size_limit] splits.sort(key=lambda x: len(x)) - for i, name in enumerate(["validation", "test", "train"]): + for i, name in enumerate(["cluster2", "cluster1", "cluster0"]): if not os.path.exists(os.path.join(self.raw_dir, f"{name}.txt")): open(os.path.join(self.raw_dir, f"{name}.txt"), "x").close() with open(os.path.join(self.raw_dir, f"{name}.txt"), "w") as f: diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index 6adb1066..51a1fb2b 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -1,21 +1,22 @@ import gc -import os import sys import traceback from datetime import datetime -from typing import List, Union +from typing import List, LiteralString -import pandas as pd -import torch -import wandb -from torchmetrics.functional.classification import multilabel_auroc, multilabel_f1_score +from torchmetrics.functional.classification import ( + multilabel_auroc, + multilabel_average_precision, + multilabel_f1_score, +) from utils import * from chebai.loss.semantic import DisjointLoss +from chebai.preprocessing.datasets.base import _DynamicDataset from chebai.preprocessing.datasets.chebi import ChEBIOver100 -from chebai.preprocessing.datasets.pubchem import Hazardous +from chebai.preprocessing.datasets.pubchem import PubChemKMeans -DEVICE = "cpu" # torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def binary(left, right): @@ -42,6 +43,9 @@ def apply_metric(metric, left, right): return torch.sum(metric(left, right), dim=0) +ALL_CONSISTENCY_METRICS = [product, lukasiewicz, weak, strict, binary] + + def _filter_to_one_hot(preds, idx_filter): """Takes list of indices (e.g. [1, 3, 0]) and returns a one-hot filter with these indices (e.g. [[0,1,0,0], [0,0,0,1], [1,0,0,0]])""" @@ -52,7 +56,7 @@ def _filter_to_one_hot(preds, idx_filter): def _sort_results_by_label(n_labels, results, filter): - by_label = torch.zeros(n_labels, device=DEVICE) + by_label = torch.zeros(n_labels, device=DEVICE, dtype=torch.int) for r, filter_l in zip(results, filter): by_label[filter_l] += r return by_label @@ -69,83 +73,59 @@ def get_best_epoch(run): best_ep = int(file.name.split("=")[1].split("_")[0]) best_micro_f1 = micro_f1 if best_ep is None: - raise Exception(f"Could not find any 'best' checkpoint for run {run.name}") + raise Exception(f"Could not find any 'best' checkpoint for run {run.id}") else: - print(f"Best epoch for run {run.name}: {best_ep}") + print(f"Best epoch for run {run.id}: {best_ep}") return best_ep -def load_preds_labels_from_wandb( - run, - epoch, - chebi_version, - test_on_data_cls=ChEBIOver100, # use data from this class - kind="test", # specify segment of test_on_data_cls +def download_model_from_wandb( + run_id, base_dir=os.path.join("logs", "downloaded_ckpts") ): - data_module = test_on_data_cls(chebi_version=chebi_version) - - buffer_dir = os.path.join( - "results_buffer", - f"{run.name}_ep{epoch}", - f"{data_module.__class__.__name__}_{kind}", - ) - - model = get_checkpoint_from_wandb(epoch, run, map_device_to="cuda:0") - print(f"Calculating predictions...") - evaluate_model( - model, - data_module, - buffer_dir=buffer_dir, - filename=f"{kind}.pt", - skip_existing_preds=True, - batch_size=1, + api = wandb.Api() + run = api.run(f"chebai/chebai/{run_id}") + epoch = get_best_epoch(run) + return ( + get_checkpoint_from_wandb(epoch, run, root=base_dir), + epoch, ) - preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE) - del model - gc.collect() - return preds, labels - -def load_preds_labels_from_nonwandb( - name, epoch, chebi_version, test_on_data_cls=ChEBIOver100, kind="test" +def load_preds_labels( + ckpt_path: LiteralString, data_module, data_subset_key="test", buffer_dir=None ): - data_module = test_on_data_cls(chebi_version=chebi_version) - - buffer_dir = os.path.join( - "results_buffer", - f"{name}_ep{epoch}", - f"{data_module.__class__.__name__}_{kind}", - ) - ckpt_path = None - for file in os.listdir(os.path.join("logs", "downloaded_ckpts", name)): - if file.startswith(f"best_epoch={epoch}"): - ckpt_path = os.path.join( - os.path.join("logs", "downloaded_ckpts", name, file) - ) - assert ( - ckpt_path is not None - ), f"Could not find ckpt for epoch {epoch} in directory {os.path.join('logs', 'downloaded_ckpts', name)}" + if buffer_dir is None: + buffer_dir = os.path.join( + "results_buffer", + *ckpt_path.split(os.path.sep)[-2:], + f"{data_module.__class__.__name__}_{data_subset_key}", + ) model = Electra.load_from_checkpoint(ckpt_path, map_location="cuda:0", strict=False) - print(f"Calculating predictions...") + print( + f"Calculating predictions on {data_module.__class__.__name__} ({data_subset_key})..." + ) evaluate_model( model, data_module, buffer_dir=buffer_dir, - filename=f"{kind}.pt", + # for chebi, use kinds, otherwise use file names + filename=( + data_subset_key if not isinstance(buffer_dir, _DynamicDataset) else None + ), + kind=data_subset_key, skip_existing_preds=True, + batch_size=1, ) - preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE) - del model - gc.collect() - - return preds, labels + return load_results_from_buffer(buffer_dir, device=torch.device("cpu")) def get_label_names(data_module): - if os.path.exists(os.path.join(data_module.raw_dir, "classes.txt")): - with open(os.path.join(data_module.raw_dir, "classes.txt")) as fin: + if os.path.exists(os.path.join(data_module.processed_dir_main, "classes.txt")): + with open(os.path.join(data_module.processed_dir_main, "classes.txt")) as fin: return [int(line.strip()) for line in fin] + print( + f"Failed to retrieve label names, {os.path.join(data_module.processed_dir_main, 'classes.txt')} not found" + ) return None @@ -155,6 +135,9 @@ def get_chebi_graph(data_module, label_names): os.path.join(data_module.raw_dir, "chebi.obo") ) return chebi_graph.subgraph(label_names) + print( + f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found" + ) return None @@ -186,120 +169,175 @@ def get_disjoint_groups(): return disjoint_all -def smooth_preds(preds, label_names, chebi_graph, disjoint_groups): - preds_sum_orig = torch.sum(preds) - print(f"Preds sum: {preds_sum_orig}") - # eliminate implication violations by setting each prediction to maximum of its successors - for i, label in enumerate(label_names): - succs = [label_names.index(p) for p in chebi_graph.successors(label)] + [i] - if len(succs) > 0: - preds[:, i] = torch.max(preds[:, succs], dim=1).values - print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") - preds_sum_orig = torch.sum(preds) - # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower) - preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49) - for disj_group in disjoint_groups: - disj_group = [label_names.index(g) for g in disj_group if g in label_names] - if len(disj_group) > 1: - old_preds = preds[:, disj_group] - disj_max = torch.max(preds[:, disj_group], dim=1) - for i, row in enumerate(preds): - for l in range(len(preds[i])): - if l in disj_group and l != disj_group[disj_max.indices[i]]: - preds[i, l] = preds_bounded[i, l] - samples_changed = 0 - for i, row in enumerate(preds[:, disj_group]): - if any(r != o for r, o in zip(row, old_preds[i])): - samples_changed += 1 - if samples_changed != 0: - print( - f"disjointness group {[label_names[d] for d in disj_group]} changed {samples_changed} samples" - ) - print( - f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}" - ) - preds_sum_orig = torch.sum(preds) - # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors - for i, label in enumerate(label_names): - predecessors = [i] + [ - label_names.index(p) for p in chebi_graph.predecessors(label) - ] - lowest_predecessors = torch.min(preds[:, predecessors], dim=1) - preds[:, i] = lowest_predecessors.values - for idx_idx, idx in enumerate(lowest_predecessors.indices): - if idx > 0: +class PredictionSmoother: + """Removes implication and disjointness violations from predictions""" + + def __init__(self, dataset): + self.label_names = get_label_names(dataset) + self.chebi_graph = get_chebi_graph(dataset, self.label_names) + self.disjoint_groups = get_disjoint_groups() + + def __call__(self, preds): + + preds_sum_orig = torch.sum(preds) + print(f"Preds sum: {preds_sum_orig}") + # eliminate implication violations by setting each prediction to maximum of its successors + for i, label in enumerate(self.label_names): + succs = [ + self.label_names.index(p) for p in self.chebi_graph.successors(label) + ] + [i] + if len(succs) > 0: + preds[:, i] = torch.max(preds[:, succs], dim=1).values + print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") + preds_sum_orig = torch.sum(preds) + # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower) + preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49) + for disj_group in self.disjoint_groups: + disj_group = [ + self.label_names.index(g) for g in disj_group if g in self.label_names + ] + if len(disj_group) > 1: + old_preds = preds[:, disj_group] + disj_max = torch.max(preds[:, disj_group], dim=1) + for i, row in enumerate(preds): + for l in range(len(preds[i])): + if l in disj_group and l != disj_group[disj_max.indices[i]]: + preds[i, l] = preds_bounded[i, l] + samples_changed = 0 + for i, row in enumerate(preds[:, disj_group]): + if any(r != o for r, o in zip(row, old_preds[i])): + samples_changed += 1 + if samples_changed != 0: + print( + f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples" + ) + print( + f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}" + ) + preds_sum_orig = torch.sum(preds) + # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors + for i, label in enumerate(self.label_names): + predecessors = [i] + [ + self.label_names.index(p) for p in self.chebi_graph.predecessors(label) + ] + lowest_predecessors = torch.min(preds[:, predecessors], dim=1) + preds[:, i] = lowest_predecessors.values + for idx_idx, idx in enumerate(lowest_predecessors.indices): + if idx > 0: + print( + f"class {label}: changed prediction of sample {idx_idx} to value of class " + f"{self.label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})" + ) + if torch.sum(preds) != preds_sum_orig: print( - f"class {label}: changed prediction of sample {idx_idx} to value of class " - f"{label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})" + f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}" ) - if torch.sum(preds) != preds_sum_orig: - print( - f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}" - ) - preds_sum_orig = torch.sum(preds) - return preds + preds_sum_orig = torch.sum(preds) + return preds -def analyse_run( - preds, - labels, - df_hyperparams, # parameters that are the independent of the semantic loss function used - labeled_data_cls=ChEBIOver100, # use labels from this dataset for violations - chebi_version=231, - results_path=os.path.join("_semantic", "eval_results.csv"), - violation_metrics: Union[str, List[callable]] = "all", - verbose_violation_output=False, -): - """Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided), - saves results to csv""" - if violation_metrics == "all": - violation_metrics = [product, lukasiewicz, weak, strict, binary] - data_module_labeled = labeled_data_cls(chebi_version=chebi_version) - n_labels = preds.size(1) - print(f"Found {preds.shape[0]} predictions ({n_labels} classes)") +def _filter_to_dense(filter): + filter_dense = [] + for i in range(filter.shape[0]): + for j in range(filter.shape[1]): + if filter[i, j] > 0: + filter_dense.append([i, j]) + return torch.tensor(filter_dense) - df_new = [] +def build_prediction_filter(data_module_labeled=None): + if data_module_labeled is None: + data_module_labeled = ChEBIOver100(chebi_version=231) # prepare filters - print(f"Loading & rescaling implication / disjointness filters...") + print(f"Loading implication / disjointness filters...") dl = DisjointLoss( path_to_disjointness=os.path.join("data", "disjoint.csv"), data_extractor=data_module_labeled, ) - for dl_filter_l, dl_filter_r, filter_type in [ - (dl.implication_filter_l, dl.implication_filter_r, "impl"), - (dl.disjoint_filter_l, dl.disjoint_filter_r, "disj"), - ]: - # prepare predictions - n_loss_terms = dl_filter_l.shape[0] - preds_exp = preds.unsqueeze(2).expand((-1, -1, n_loss_terms)).swapaxes(1, 2) - l_preds = preds_exp[:, _filter_to_one_hot(preds, dl_filter_l)] - r_preds = preds_exp[:, _filter_to_one_hot(preds, dl_filter_r)] - del preds_exp - gc.collect() + impl = _filter_to_dense(dl.implication_filter_l) + disj = _filter_to_dense(dl.disjoint_filter_l) + + return [ + (impl[:, 0], impl[:, 1], "impl"), + (disj[:, 0], disj[:, 1], "disj"), + ] - for i, metric in enumerate(violation_metrics): - if filter_type == "impl": - df_new.append(df_hyperparams.copy()) - df_new[-1]["metric"] = metric.__name__ - print( - f"Calculating metric {metric.__name__ if metric is not None else 'supervised'} on {filter_type}" - ) - m = {} - m["tps"] = apply_metric( - metric, l_preds, r_preds if filter_type == "impl" else 1 - r_preds +def run_consistency_metrics( + preds, + consistency_filters, + data_module_labeled=None, # use labels from this dataset for violations + violation_metrics=None, + verbose_violation_output=False, + save_details_to=None, +): + """Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided)""" + if violation_metrics is None: + violation_metrics = ALL_CONSISTENCY_METRICS + if data_module_labeled is None: + data_module_labeled = ChEBIOver100(chebi_version=231) + if save_details_to is not None: + os.makedirs(save_details_to, exist_ok=True) + + preds.to("cpu") + + n_labels = preds.size(1) + print(f"Found {preds.shape[0]} predictions ({n_labels} classes)") + + results = {} + + for dl_filter_l, dl_filter_r, filter_type in consistency_filters: + l_preds = preds[:, dl_filter_l] + r_preds = preds[:, dl_filter_r] + for i, metric in enumerate(violation_metrics): + if metric.__name__ not in results: + results[metric.__name__] = {} + print(f"Calculating metrics {metric.__name__} on {filter_type}") + + metric_results = {} + metric_results["tps"] = torch.sum( + torch.stack( + [ + apply_metric( + metric, + l_preds[i : i + 1000], + ( + r_preds[i : i + 1000] + if filter_type == "impl" + else 1 - r_preds[i : i + 1000] + ), + ) + for i in range(0, r_preds.shape[0], 1000) + ] + ), + dim=0, ) - m["fns"] = apply_metric( - metric, l_preds, 1 - r_preds if filter_type == "impl" else r_preds + metric_results["fns"] = torch.sum( + torch.stack( + [ + apply_metric( + metric, + l_preds[i : i + 1000], + ( + 1 - r_preds[i : i + 1000] + if filter_type == "impl" + else r_preds[i : i + 1000] + ), + ) + for i in range(0, r_preds.shape[0], 1000) + ] + ), + dim=0, ) if verbose_violation_output: label_names = get_label_names(data_module_labeled) - print(f"Found {torch.sum(m['fns'])} {filter_type}-violations") - # for k, fn_cls in enumerate(m['fns']): + print( + f"Found {torch.sum(metric_results['fns'])} {filter_type}-violations" + ) + # for k, fn_cls in enumerate(metric_results['fns']): # if fn_cls > 0: # print(f"\tThereof, {fn_cls.item()} belong to class {label_names[k]}") - if torch.sum(m["fns"]) != 0: + if torch.sum(metric_results["fns"]) != 0: fns = metric( l_preds, 1 - r_preds if filter_type == "impl" else r_preds ) @@ -314,266 +352,370 @@ def analyse_run( f", {label_names[dl_filter_r[j]]} -> {preds[k, dl_filter_r[j]]:.3f})" ) - m_cls = {} - for key, value in m.items(): - m_cls[key] = _sort_results_by_label( + m_l_agg = {} + for key, value in metric_results.items(): + m_l_agg[key] = _sort_results_by_label( + n_labels, + value, + dl_filter_l, + ) + m_r_agg = {} + for key, value in metric_results.items(): + m_r_agg[key] = _sort_results_by_label( n_labels, value, - (dl_filter_l), + dl_filter_r, ) - df_new[i][f"micro-sem-recall-{filter_type}"] = ( - torch.sum(m["tps"]) / (torch.sum(m[f"tps"]) + torch.sum(m[f"fns"])) + if save_details_to is not None: + with open( + os.path.join( + save_details_to, f"{metric.__name__}_{filter_type}_all.csv" + ), + "w+", + ) as f: + f.write("left,right,tps,fns\n") + for left, right, tps, fns in zip( + dl_filter_l, + dl_filter_r, + metric_results["tps"], + metric_results["fns"], + ): + f.write(f"{left},{right},{tps},{fns}\n") + with open( + os.path.join( + save_details_to, f"{metric.__name__}_{filter_type}_l.csv" + ), + "w+", + ) as f: + f.write("left,tps,fns\n") + for left in range(n_labels): + f.write( + f"{left},{m_l_agg['tps'][left].item()},{m_l_agg['fns'][left].item()}\n" + ) + with open( + os.path.join( + save_details_to, f"{metric.__name__}_{filter_type}_r.csv" + ), + "w+", + ) as f: + f.write("right,tps,fns\n") + for right in range(n_labels): + f.write( + f"{right},{m_r_agg['tps'][right].item()},{m_r_agg['fns'][right].item()}\n" + ) + print( + f"Saved unaggregated consistency metrics ({metric.__name__}, {filter_type}) to {save_details_to}" + ) + + fns_sum = torch.sum(metric_results["fns"]).item() + results[metric.__name__][f"micro-fnr-{filter_type}"] = ( + 0 + if fns_sum == 0 + else ( + torch.sum(metric_results["fns"]) + / ( + torch.sum(metric_results[f"tps"]) + + torch.sum(metric_results[f"fns"]) + ) + ).item() + ) + macro_fnr_l = m_l_agg[f"fns"] / (m_l_agg[f"tps"] + m_l_agg[f"fns"]) + results[metric.__name__][f"lmacro-fnr-{filter_type}"] = ( + 0 + if fns_sum == 0 + else torch.mean(macro_fnr_l[~macro_fnr_l.isnan()]).item() + ) + macro_fnr_r = m_r_agg[f"fns"] / (m_r_agg[f"tps"] + m_r_agg[f"fns"]) + results[metric.__name__][f"rmacro-fnr-{filter_type}"] = ( + 0 + if fns_sum == 0 + else torch.mean(macro_fnr_r[~macro_fnr_r.isnan()]).item() + ) + results[metric.__name__][f"fn-sum-{filter_type}"] = torch.sum( + metric_results["fns"] ).item() - macro_recall = m_cls[f"tps"] / (m_cls[f"tps"] + m_cls[f"fns"]) - df_new[i][f"macro-sem-recall-{filter_type}"] = torch.mean( - macro_recall[~macro_recall.isnan()] + results[metric.__name__][f"tp-sum-{filter_type}"] = torch.sum( + metric_results["tps"] ).item() - df_new[i][f"fn-sum-{filter_type}"] = torch.sum(m["fns"]).item() - df_new[i][f"tp-sum-{filter_type}"] = torch.sum(m["tps"]).item() - del m - del m_cls + del metric_results + del m_l_agg + del m_r_agg gc.collect() - df_new[i] = pd.DataFrame(df_new[i], index=[0]) del l_preds del r_preds gc.collect() + return results + + +def run_supervised_metrics(preds, labels, save_details_to=None): # calculate supervised metrics + results = {} if labels is not None: - df_supervised = df_hyperparams.copy() - df_supervised["micro-f1"] = multilabel_f1_score( + results["micro-f1"] = multilabel_f1_score( preds, labels, num_labels=preds.size(1), average="micro" ).item() - df_supervised["macro-f1"] = multilabel_f1_score( + results["macro-f1"] = multilabel_f1_score( preds, labels, num_labels=preds.size(1), average="macro" ).item() - df_supervised["micro-roc-auc"] = multilabel_auroc( + results["micro-roc-auc"] = multilabel_auroc( preds, labels, num_labels=preds.size(1), average="micro" ).item() - df_supervised["macro-roc-auc"] = multilabel_auroc( + results["macro-roc-auc"] = multilabel_auroc( preds, labels, num_labels=preds.size(1), average="macro" ).item() - df_new.append(pd.DataFrame(df_supervised, index=[0])) - - if os.path.exists(results_path): - df_previous = pd.read_csv(results_path) - else: - df_previous = None - if df_previous is not None: - df_new = [df_previous] + df_new - del df_previous + results["micro-ap"] = multilabel_average_precision( + preds, labels, num_labels=preds.size(1), average="micro" + ).item() + results["macro-ap"] = multilabel_average_precision( + preds, labels, num_labels=preds.size(1), average="macro" + ).item() - df_new = pd.concat(df_new, ignore_index=True) - print(f"Saving results to {results_path}") - df_new.to_csv(results_path, index=False) + if save_details_to is not None: + f1_by_label = multilabel_f1_score( + preds, labels, num_labels=preds.size(1), average=None + ) + roc_by_label = multilabel_auroc( + preds, labels, num_labels=preds.size(1), average=None + ) + ap_by_label = multilabel_average_precision( + preds, labels, num_labels=preds.size(1), average=None + ) + with open(os.path.join(save_details_to, f"supervised.csv"), "w+") as f: + f.write("label,f1,roc-auc,ap\n") + for right in range(preds.size(1)): + f.write( + f"{right},{f1_by_label[right].item()},{roc_by_label[right].item()},{ap_by_label[right].item()}\n" + ) + print(f"Saved class-wise supervised metrics to {save_details_to}") - del df_new del preds del labels - del dl gc.collect() + return results + + +# run predictions / metrics calculations for semantic loss paper runs (NeSy 2024 submission) +def run_semloss_eval(): + # runs from wandb + non_wandb_runs = [] + api = wandb.Api() + runs = api.runs("chebai/chebai", filters={"tags": "eval_semloss_paper"}) + print(f"Found {len(runs)} tagged wandb runs") + ids_wandb = [run.id for run in runs] + + # ids used in the NeSy submission + prod = ["tk15yznc", "uke62a8m", "w0h3zr5s"] + xu19 = ["5ko8knb4", "061fd85t", "r50ioujs"] + prod_mixed = ["hk8555ff", "e0lxw8py", "lig23cmg"] + luka = ["0c0s48nh", "lfg384bp", "qeghvubh"] + baseline = ["i4wtz1k4", "zd020wkv", "rc1q3t49"] + prodk2 = ["ng3usn0p", "rp0wwzjv", "8fma1q7r"] + ids = baseline + prod + prodk2 + xu19 + luka + prod_mixed + # ids = ids_wandb + run_all( + ids, + non_wandb_runs, + prediction_datasets=[(ChEBIOver100(chebi_version=231), "test")], + consistency_metrics=[binary], + ) def run_all( - run_ids, - datasets=None, - chebi_version=231, - skip_analyse=False, - skip_preds=False, - nonwandb_runs=None, - violation_metrics="all", - remove_violations=False, + wandb_ids=None, + local_ckpts: List[Tuple] = None, + consistency_metrics: Optional[List[callable]] = None, + prediction_datasets: List[Tuple] = None, + remove_violations: bool = False, + results_dir="_fuzzy_loss_eval", + check_consistency_on=None, + verbose_violation_output=False, ): - # evaluate a list of runs on Hazardous and ChEBIOver100 datasets - if datasets is None: - datasets = [(Hazardous, "all"), (ChEBIOver100, "test")] - timestamp = datetime.now().strftime("%y%m%d-%H%M") - results_path = os.path.join( - "_semloss_eval", - f"semloss_results_pc-dis-200k_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", - ) + if wandb_ids is None: + wandb_ids = [] + if local_ckpts is None: + local_ckpts = [] + if consistency_metrics is None: + consistency_metrics = ALL_CONSISTENCY_METRICS + if prediction_datasets is None: + prediction_datasets = [ + (ChEBIOver100(chebi_version=231), "test"), + ] + if check_consistency_on is None: + check_consistency_on = ChEBIOver100(chebi_version=231) if remove_violations: - label_names = get_label_names(ChEBIOver100(chebi_version=chebi_version)) - chebi_graph = get_chebi_graph( - ChEBIOver100(chebi_version=chebi_version), label_names - ) - disjoint_groups = get_disjoint_groups() + smooth_preds = PredictionSmoother(check_consistency_on) + else: + smooth_preds = lambda x: x - api = wandb.Api() - for run_id in run_ids: + timestamp = datetime.now().strftime("%y%m%d-%H%M%S") + prediction_filters = build_prediction_filter(check_consistency_on) + + results_path_consistency = os.path.join( + results_dir, + f"consistency_metrics_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", + ) + consistency_keys = [ + "micro-fnr-impl", + "lmacro-fnr-impl", + "rmacro-fnr-impl", + "fn-sum-impl", + "tp-sum-impl", + "micro-fnr-disj", + "lmacro-fnr-disj", + "rmacro-fnr-disj", + "fn-sum-disj", + "tp-sum-disj", + ] + with open(results_path_consistency, "x") as f: + f.write( + "run-id,epoch,datamodule,data_key,metric," + + ",".join(consistency_keys) + + "\n" + ) + results_path_supervised = os.path.join( + results_dir, + f"supervised_metrics_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", + ) + supervised_keys = [ + "micro-f1", + "macro-f1", + "micro-roc-auc", + "macro-roc-auc", + "micro-ap", + "macro-ap", + ] + with open(results_path_supervised, "x") as f: + f.write("run-id,epoch,datamodule,data_key," + ",".join(supervised_keys) + "\n") + + ckpts = [(run_name, ep, None) for run_name, ep in local_ckpts] + [ + (None, None, wandb_id) for wandb_id in wandb_ids + ] + + for run_name, epoch, wandb_id in ckpts: try: - run = api.run(f"chebai/chebai/{run_id}") - epoch = get_best_epoch(run) - for test_on, kind in datasets: - df = { - "run-id": run_id, - "epoch": int(epoch), - "kind": kind, - "data_module": test_on.__name__, - "chebi_version": chebi_version, - } - buffer_dir_smoothed = os.path.join( + ckpt_dir = os.path.join("logs", "downloaded_ckpts") + # for wandb runs, use short id as name, otherwise use ckpt dir name + if wandb_id is not None: + run_name = wandb_id + ckpt_path, epoch = download_model_from_wandb(run_name, ckpt_dir) + else: + ckpt_path = None + for file in os.listdir(os.path.join(ckpt_dir, run_name)): + if f"epoch={epoch}_" in file or f"epoch={epoch}." in file: + ckpt_path = os.path.join(os.path.join(ckpt_dir, run_name, file)) + assert ( + ckpt_path is not None + ), f"Failed to find checkpoint for epoch {epoch} in {os.path.join(ckpt_dir, run_name)}" + print(f"Starting run {run_name} (epoch {epoch})") + + for dataset, dataset_key in prediction_datasets: + # copy data from legacy buffer dir if possible + old_buffer_dir = os.path.join( "results_buffer", - "smoothed3step", - f"{run.name}_ep{epoch}", - f"{test_on.__name__}_{kind}", + *ckpt_path.split(os.path.sep)[-2:], + f"{dataset.__class__.__name__}_{dataset_key}", ) - if remove_violations and os.path.exists( - os.path.join(buffer_dir_smoothed, "preds000.pt") - ): - preds = torch.load( - os.path.join(buffer_dir_smoothed, "preds000.pt"), - DEVICE, - weights_only=False, - ) - labels = None - else: - if not skip_preds: - preds, labels = load_preds_labels_from_wandb( - run, epoch, chebi_version, test_on, kind - ) - else: - buffer_dir = os.path.join( - "results_buffer", - f"{run.name}_ep{epoch}", - f"{test_on.__name__}_{kind}", + buffer_dir = os.path.join( + "results_buffer", + run_name, + f"epoch={epoch}", + f"{dataset.__class__.__name__}_{dataset_key}", + ) + print("Checking for buffer dir", old_buffer_dir) + if os.path.isdir(old_buffer_dir): + from distutils.dir_util import copy_tree, remove_tree + + os.makedirs(buffer_dir, exist_ok=True) + copy_tree(old_buffer_dir, buffer_dir) + remove_tree(old_buffer_dir, dry_run=True) + print(f"Moved buffer from {old_buffer_dir} to {buffer_dir}") + print(f"Using buffer_dir {buffer_dir}") + preds, labels = load_preds_labels( + ckpt_path, dataset, dataset_key, buffer_dir + ) + # identity function if remove_violations is False + smooth_preds(preds) + + details_path = None # os.path.join( + # results_dir, + # f"{run_name}_ep{epoch}_{dataset.__class__.__name__}_{dataset_key}", + # ) + metrics_dict = run_consistency_metrics( + preds, + prediction_filters, + check_consistency_on, + consistency_metrics, + verbose_violation_output, + save_details_to=details_path, + ) + with open(results_path_consistency, "a") as f: + for metric in metrics_dict: + values = metrics_dict[metric] + f.write( + f"{run_name},{epoch},{dataset.__class__.__name__},{dataset_key},{metric}," + f"{','.join([str(values[k]) for k in consistency_keys])}\n" ) - preds, labels = load_results_from_buffer( - buffer_dir, device=DEVICE + print( + f"Consistency metrics have been written to {results_path_consistency}" + ) + if labels is not None: + metrics_dict = run_supervised_metrics( + preds, labels, save_details_to=details_path + ) + with open(results_path_supervised, "a") as f: + f.write( + f"{run_name},{epoch},{dataset.__class__.__name__},{dataset_key}," + f"{','.join([str(metrics_dict[k]) for k in supervised_keys])}\n" ) - assert ( - preds is not None - ), f"Did not find predictions in dir {buffer_dir}" - if remove_violations: - preds = smooth_preds( - preds, label_names, chebi_graph, disjoint_groups - ) - buffer_dir_smoothed = os.path.join( - "results_buffer", - "smoothed3step", - f"{run.name}_ep{epoch}", - f"{test_on.__name__}_{kind}", - ) - os.makedirs(buffer_dir_smoothed, exist_ok=True) - torch.save( - preds, os.path.join(buffer_dir_smoothed, "preds000.pt") - ) - if not skip_analyse: print( - f"Calculating metrics for run {run.name} on {test_on.__name__} ({kind})" - ) - analyse_run( - preds, - labels, - df_hyperparams=df, - chebi_version=chebi_version, - results_path=results_path, - violation_metrics=violation_metrics, - verbose_violation_output=True, + f"Supervised metrics have been written to {results_path_supervised}" ) except Exception as e: - print(f"Failed for run {run_id}: {e}") - # print(traceback.format_exc()) - - if nonwandb_runs: - for run_name, epoch in nonwandb_runs: - try: - for test_on, kind in datasets: - df = { - "run-id": run_name, - "epoch": int(epoch), - "kind": kind, - "data_module": test_on.__name__, - "chebi_version": chebi_version, - } - if not skip_preds: - preds, labels = load_preds_labels_from_nonwandb( - run_name, epoch, chebi_version, test_on, kind - ) - else: - buffer_dir = os.path.join( - "results_buffer", - f"{run_name}_ep{epoch}", - f"{test_on.__name__}_{kind}", - ) - preds, labels = load_results_from_buffer( - buffer_dir, device=DEVICE - ) - assert ( - preds is not None - ), f"Did not find predictions in dir {buffer_dir}" - if remove_violations: - preds = smooth_preds( - preds, label_names, chebi_graph, disjoint_groups - ) - if not skip_analyse: - print( - f"Calculating metrics for run {run_name} on {test_on.__name__} ({kind})" - ) - analyse_run( - preds, - labels, - df_hyperparams=df, - chebi_version=chebi_version, - results_path=results_path, - violation_metrics=violation_metrics, - ) - except Exception as e: - print(f"Failed for run {run_name}: {e}") - print(traceback.format_exc()) + print( + f"Error during run {wandb_id if wandb_id is not None else run_name}: {e}" + ) + print(traceback.format_exc()) -# run predictions / metrics calculations for semantic loss paper runs (NeSy 2024 submission) -def run_semloss_eval(mode="eval"): - non_wandb_runs = [] - if mode == "preds": - api = wandb.Api() - runs = api.runs("chebai/chebai", filters={"tags": "eval_semloss_paper"}) - print(f"Found {len(runs)} tagged wandb runs") - ids = [run.id for run in runs] - run_all(ids, skip_analyse=True, nonwandb_runs=non_wandb_runs) - - if mode == "eval": - prod = [ - "tk15yznc", - "uke62a8m", - "w0h3zr5s", - ] - xu19 = [ - "5ko8knb4", - "061fd85t", - "r50ioujs", - ] - prod_mixed = [ - "hk8555ff", - "e0lxw8py", - "lig23cmg", - ] - luka = [ - "0c0s48nh", - "lfg384bp", - "qeghvubh", - ] - baseline = ["i4wtz1k4", "zd020wkv", "rc1q3t49"] - prodk2 = ["ng3usn0p", "rp0wwzjv", "8fma1q7r"] - ids = baseline + prod + prodk2 + xu19 + luka + prod_mixed - run_all( - ids, - skip_preds=True, - nonwandb_runs=non_wandb_runs, - datasets=[(ChEBIOver100, "test")], - violation_metrics=[binary], - remove_violations=True, - ) +# follow-up to NeSy submission +def run_fuzzy_loss(tag="fuzzy_loss", skip_first_n=0): + api = wandb.Api() + runs = api.runs("chebai/chebai", filters={"tags": tag}) + print(f"Found {len(runs)} wandb runs tagged with '{tag}'") + ids = [run.id for run in runs] + chebi100 = ChEBIOver100( + chebi_version=231, + splits_file_path=os.path.join( + "data", "chebi_v231", "ChEBI100", "fuzzy_loss_splits.csv" + ), + ) + local_ckpts = [][skip_first_n:] + pubchem_kmeans = PubChemKMeans() + run_all( + ids[max(0, skip_first_n - len(local_ckpts)) :], # ids, + local_ckpts, + consistency_metrics=[binary], + check_consistency_on=chebi100, + prediction_datasets=[ + (chebi100, "test"), + # (pubchem_kmeans, "cluster1_cutoff2k.pt"), + # (pubchem_kmeans, "cluster2.pt"), + # (pubchem_kmeans, "ten_from_each_cluster.pt"), + # (pubchem_kmeans, "chebi_close.pt"), + ], + ) if __name__ == "__main__": - if len(sys.argv) > 1: - run_semloss_eval(sys.argv[1]) + if len(sys.argv) > 2: + run_fuzzy_loss(sys.argv[1], int(sys.argv[2])) + elif len(sys.argv) > 1: + run_fuzzy_loss(sys.argv[1]) else: - run_semloss_eval() + run_fuzzy_loss() diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 80bf56e2..35dbc319 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -1,4 +1,5 @@ import os +import shutil from typing import Optional, Tuple, Union import torch @@ -16,9 +17,7 @@ def get_checkpoint_from_wandb( epoch: int, run: wandb.apis.public.Run, root: str = os.path.join("logs", "downloaded_ckpts"), - model_class: Optional[Union[Electra, ChebaiBaseNet]] = None, - map_device_to: Optional[torch.device] = None, -) -> Optional[ChebaiBaseNet]: +): """ Gets a wandb checkpoint based on run and epoch, downloads it if necessary. @@ -26,28 +25,31 @@ def get_checkpoint_from_wandb( epoch: The epoch number of the checkpoint to retrieve. run: The wandb run object. root: The root directory to save the downloaded checkpoint. - model_class: The class of the model to load. - map_device_to: The device to map the model to. Returns: - The loaded model or None if no checkpoint is found. + The location of the downloaded checkpoint. """ api = wandb.Api() - if model_class is None: - model_class = Electra files = run.files() for file in files: if file.name.startswith( f"checkpoints/per_epoch={epoch}" ) or file.name.startswith(f"checkpoints/best_epoch={epoch}"): - dest_path = os.path.join(root, run.name, file.name.split("/")[-1]) - if not os.path.isfile(dest_path): - print(f"Downloading checkpoint to {dest_path}") - wandb_util.download_file_from_url(dest_path, file.url, api.api_key) - return model_class.load_from_checkpoint( - dest_path, strict=False, map_location=map_device_to + dest_path = os.path.join( + root, run.id, file.name.split("/")[-1].split("_")[1] + ".ckpt" ) + # legacy: also look for ckpts in the old format + old_dest_path = os.path.join(root, run.name, file.name.split("/")[-1]) + if not os.path.isfile(dest_path): + if os.path.isfile(old_dest_path): + print(f"Copying checkpoint from {old_dest_path} to {dest_path}") + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + shutil.copy2(old_dest_path, dest_path) + else: + print(f"Downloading checkpoint to {dest_path}") + wandb_util.download_file_from_url(dest_path, file.url, api.api_key) + return dest_path print(f"No model found for epoch {epoch}") return None @@ -151,10 +153,9 @@ def evaluate_model( test_preds = _concat_tuple(preds_list) if labels_list is not None: test_labels = _concat_tuple(labels_list) - return test_preds, test_labels return test_preds, None - else: + elif len(preds_list) < 0: torch.save( _concat_tuple(preds_list), os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), diff --git a/configs/data/pubchem_kmeans.yml b/configs/data/pubchem_kmeans.yml new file mode 100644 index 00000000..05a0171d --- /dev/null +++ b/configs/data/pubchem_kmeans.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.pubchem.PubChemKMeans diff --git a/configs/loss/semantic_loss.yml b/configs/loss/semantic_loss.yml index cc254d4f..5434084b 100644 --- a/configs/loss/semantic_loss.yml +++ b/configs/loss/semantic_loss.yml @@ -5,5 +5,6 @@ init_args: class_path: chebai.loss.bce_weighted.BCEWeighted init_args: beta: 0.99 - tnorm: product - impl_loss_weight: 0.01 + multiply_by_softmax: true + impl_loss_weight: 100 + disjoint_loss_weight: 1000000