Skip to content

Commit 64d7623

Browse files
committed
change loss module for protein data
1 parent 2422518 commit 64d7623

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

chebai/loss/bce_weighted.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import os
22
from typing import Optional
33

4-
import pandas as pd
54
import torch
65

76
from chebai.preprocessing.datasets.base import XYBaseDataModule
8-
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
9-
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
7+
from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtDataExtractor
108

119

1210
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
@@ -29,11 +27,9 @@ def __init__(
2927
**kwargs,
3028
):
3129
self.beta = beta
32-
if isinstance(data_extractor, LabeledUnlabeledMixed):
33-
data_extractor = data_extractor.labeled
3430
self.data_extractor = data_extractor
3531
assert (
36-
isinstance(self.data_extractor, _ChEBIDataExtractor)
32+
isinstance(self.data_extractor, _GOUniProtDataExtractor)
3733
or self.data_extractor is None
3834
)
3935
super().__init__(**kwargs)

chebai/loss/semantic.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,24 @@
22
import math
33
import os
44
import pickle
5-
from typing import List, Literal, Union
5+
from typing import List, Literal, Type, Union
66

77
import torch
88

99
from chebai.loss.bce_weighted import BCEWeighted
1010
from chebai.preprocessing.datasets import XYBaseDataModule
11-
from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor
12-
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
11+
from chebai.preprocessing.datasets.deepGO.go_uniprot import (
12+
GOUniProtOver250,
13+
_GOUniProtDataExtractor,
14+
)
1315

1416

1517
class ImplicationLoss(torch.nn.Module):
1618
"""
1719
Implication Loss module.
1820
1921
Args:
20-
data_extractor _ChEBIDataExtractor: Data extractor for labels.
22+
data_extractor _GOUniProtDataExtractor: Data extractor for labels.
2123
base_loss (torch.nn.Module, optional): Base loss function. Defaults to None.
2224
fuzzy_implication (Literal["product", "lukasiewicz", "xu19"], optional): T-norm type. Defaults to "product".
2325
impl_loss_weight (float, optional): Weight of implication loss relative to base loss. Defaults to 0.1.
@@ -70,9 +72,7 @@ def __init__(
7072
):
7173
super().__init__()
7274
# automatically choose labeled subset for implication filter in case of mixed dataset
73-
if isinstance(data_extractor, LabeledUnlabeledMixed):
74-
data_extractor = data_extractor.labeled
75-
assert isinstance(data_extractor, _ChEBIDataExtractor)
75+
assert isinstance(data_extractor, _GOUniProtDataExtractor)
7676
self.data_extractor = data_extractor
7777
# propagate data_extractor to base loss
7878
if isinstance(base_loss, BCEWeighted):
@@ -329,7 +329,7 @@ class DisjointLoss(ImplicationLoss):
329329
330330
Args:
331331
path_to_disjointness (str): Path to the disjointness data file (a csv file containing pairs of disjoint classes)
332-
data_extractor (Union[_ChEBIDataExtractor, LabeledUnlabeledMixed]): Data extractor for labels.
332+
data_extractor (_GOUniProtDataExtractor): Data extractor for labels.
333333
base_loss (torch.nn.Module, optional): Base loss function. Defaults to None.
334334
disjoint_loss_weight (float, optional): Weight of disjointness loss. Defaults to 100.
335335
**kwargs: Additional arguments.
@@ -338,7 +338,7 @@ class DisjointLoss(ImplicationLoss):
338338
def __init__(
339339
self,
340340
path_to_disjointness: str,
341-
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
341+
data_extractor: _GOUniProtDataExtractor,
342342
base_loss: torch.nn.Module = None,
343343
disjoint_loss_weight: float = 100,
344344
**kwargs,
@@ -502,7 +502,7 @@ def _build_disjointness_filter(
502502
if __name__ == "__main__":
503503
loss = DisjointLoss(
504504
os.path.join("data", "disjoint.csv"),
505-
ChEBIOver100(chebi_version=231),
505+
GOUniProtOver250(),
506506
base_loss=BCEWeighted(),
507507
impl_loss_weight=1,
508508
disjoint_loss_weight=1,

0 commit comments

Comments
 (0)