Skip to content

Commit b87129d

Browse files
committed
to avoid access to pubchem file: dynamic import
1 parent 109723c commit b87129d

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

chebai/loss/bce_weighted.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from chebai.preprocessing.datasets.base import XYBaseDataModule
77
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
8-
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
98

109

1110
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
@@ -27,6 +26,8 @@ def __init__(
2726
data_extractor: Optional[XYBaseDataModule] = None,
2827
**kwargs,
2928
):
29+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
30+
3031
self.beta = beta
3132
if isinstance(data_extractor, LabeledUnlabeledMixed):
3233
data_extractor = data_extractor.labeled

chebai/loss/semantic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
import math
33
import os
44
import pickle
5-
from typing import List, Literal, Union
5+
from typing import TYPE_CHECKING, List, Literal, Union
66

77
import torch
88

99
from chebai.loss.bce_weighted import BCEWeighted
1010
from chebai.preprocessing.datasets.base import XYBaseDataModule
1111
from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor
12-
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
12+
13+
if TYPE_CHECKING:
14+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
1315

1416

1517
class ImplicationLoss(torch.nn.Module):
@@ -68,6 +70,8 @@ def __init__(
6870
multiply_with_base_loss: bool = True,
6971
no_grads: bool = False,
7072
):
73+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
74+
7175
super().__init__()
7276
# automatically choose labeled subset for implication filter in case of mixed dataset
7377
if isinstance(data_extractor, LabeledUnlabeledMixed):

0 commit comments

Comments
 (0)