22import math
33import os
44import pickle
5- from typing import List , Literal , Union
5+ from typing import List , Literal , Type , Union
66
77import torch
88
99from chebai .loss .bce_weighted import BCEWeighted
1010from chebai .preprocessing .datasets import XYBaseDataModule
11- from chebai .preprocessing .datasets .chebi import ChEBIOver100 , _ChEBIDataExtractor
12- from chebai .preprocessing .datasets .pubchem import LabeledUnlabeledMixed
11+ from chebai .preprocessing .datasets .deepGO .go_uniprot import (
12+ GOUniProtOver250 ,
13+ _GOUniProtDataExtractor ,
14+ )
1315
1416
1517class ImplicationLoss (torch .nn .Module ):
1618 """
1719 Implication Loss module.
1820
1921 Args:
20- data_extractor _ChEBIDataExtractor : Data extractor for labels.
22+ data_extractor _GOUniProtDataExtractor : Data extractor for labels.
2123 base_loss (torch.nn.Module, optional): Base loss function. Defaults to None.
2224 fuzzy_implication (Literal["product", "lukasiewicz", "xu19"], optional): T-norm type. Defaults to "product".
2325 impl_loss_weight (float, optional): Weight of implication loss relative to base loss. Defaults to 0.1.
@@ -70,9 +72,7 @@ def __init__(
7072 ):
7173 super ().__init__ ()
7274 # automatically choose labeled subset for implication filter in case of mixed dataset
73- if isinstance (data_extractor , LabeledUnlabeledMixed ):
74- data_extractor = data_extractor .labeled
75- assert isinstance (data_extractor , _ChEBIDataExtractor )
75+ assert isinstance (data_extractor , _GOUniProtDataExtractor )
7676 self .data_extractor = data_extractor
7777 # propagate data_extractor to base loss
7878 if isinstance (base_loss , BCEWeighted ):
@@ -329,7 +329,7 @@ class DisjointLoss(ImplicationLoss):
329329
330330 Args:
331331 path_to_disjointness (str): Path to the disjointness data file (a csv file containing pairs of disjoint classes)
332- data_extractor (Union[_ChEBIDataExtractor, LabeledUnlabeledMixed] ): Data extractor for labels.
332+ data_extractor (_GOUniProtDataExtractor ): Data extractor for labels.
333333 base_loss (torch.nn.Module, optional): Base loss function. Defaults to None.
334334 disjoint_loss_weight (float, optional): Weight of disjointness loss. Defaults to 100.
335335 **kwargs: Additional arguments.
@@ -338,7 +338,7 @@ class DisjointLoss(ImplicationLoss):
338338 def __init__ (
339339 self ,
340340 path_to_disjointness : str ,
341- data_extractor : Union [ _ChEBIDataExtractor , LabeledUnlabeledMixed ] ,
341+ data_extractor : _GOUniProtDataExtractor ,
342342 base_loss : torch .nn .Module = None ,
343343 disjoint_loss_weight : float = 100 ,
344344 ** kwargs ,
@@ -502,7 +502,7 @@ def _build_disjointness_filter(
502502if __name__ == "__main__" :
503503 loss = DisjointLoss (
504504 os .path .join ("data" , "disjoint.csv" ),
505- ChEBIOver100 ( chebi_version = 231 ),
505+ GOUniProtOver250 ( ),
506506 base_loss = BCEWeighted (),
507507 impl_loss_weight = 1 ,
508508 disjoint_loss_weight = 1 ,
0 commit comments