File tree Expand file tree Collapse file tree 2 files changed +8
-3
lines changed Expand file tree Collapse file tree 2 files changed +8
-3
lines changed Original file line number Diff line number Diff line change 55
66from chebai .preprocessing .datasets .base import XYBaseDataModule
77from chebai .preprocessing .datasets .chebi import _ChEBIDataExtractor
8- from chebai .preprocessing .datasets .pubchem import LabeledUnlabeledMixed
98
109
1110class BCEWeighted (torch .nn .BCEWithLogitsLoss ):
@@ -27,6 +26,8 @@ def __init__(
2726 data_extractor : Optional [XYBaseDataModule ] = None ,
2827 ** kwargs ,
2928 ):
29+ from chebai .preprocessing .datasets .pubchem import LabeledUnlabeledMixed
30+
3031 self .beta = beta
3132 if isinstance (data_extractor , LabeledUnlabeledMixed ):
3233 data_extractor = data_extractor .labeled
Original file line number Diff line number Diff line change 22import math
33import os
44import pickle
5- from typing import List , Literal , Union
5+ from typing import TYPE_CHECKING , List , Literal , Union
66
77import torch
88
99from chebai .loss .bce_weighted import BCEWeighted
1010from chebai .preprocessing .datasets .base import XYBaseDataModule
1111from chebai .preprocessing .datasets .chebi import ChEBIOver100 , _ChEBIDataExtractor
12- from chebai .preprocessing .datasets .pubchem import LabeledUnlabeledMixed
12+
13+ if TYPE_CHECKING :
14+ from chebai .preprocessing .datasets .pubchem import LabeledUnlabeledMixed
1315
1416
1517class ImplicationLoss (torch .nn .Module ):
@@ -68,6 +70,8 @@ def __init__(
6870 multiply_with_base_loss : bool = True ,
6971 no_grads : bool = False ,
7072 ):
73+ from chebai .preprocessing .datasets .pubchem import LabeledUnlabeledMixed
74+
7175 super ().__init__ ()
7276 # automatically choose labeled subset for implication filter in case of mixed dataset
7377 if isinstance (data_extractor , LabeledUnlabeledMixed ):
You can’t perform that action at this time.
0 commit comments