Skip to content

Commit 2911f92

Browse files
committed
fix merge conflict
2 parents 078bfb6 + 8e51a61 commit 2911f92

File tree

15 files changed

+190
-50
lines changed

15 files changed

+190
-50
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
strategy:
1010
fail-fast: false
1111
matrix:
12-
python-version: ["3.9", "3.10", "3.11", "3.12"]
12+
python-version: ["3.10", "3.11", "3.12"]
1313

1414
steps:
1515
- uses: actions/checkout@v4
@@ -24,7 +24,7 @@ jobs:
2424
python -m pip install --upgrade pip
2525
python -m pip install --upgrade pip setuptools wheel
2626
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
27-
python -m pip install -e .
27+
python -m pip install -e .[dev]
2828
2929
- name: Display Python & Installed Packages
3030
run: |

chebai/loggers/custom.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from datetime import datetime
33
from typing import List, Literal, Optional, Union
44

5-
import wandb
65
from lightning.fabric.utilities.types import _PATH
76
from lightning.pytorch.callbacks import ModelCheckpoint
87
from lightning.pytorch.loggers import WandbLogger
@@ -105,6 +104,8 @@ def set_fold(self, fold: int) -> None:
105104
Args:
106105
fold (int): Cross-validation fold number.
107106
"""
107+
import wandb
108+
108109
if fold != self._fold:
109110
self._fold = fold
110111
# Start new experiment

chebai/loss/bce_weighted.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from chebai.preprocessing.datasets.base import XYBaseDataModule
77
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
8-
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
98

109

1110
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
@@ -27,6 +26,8 @@ def __init__(
2726
data_extractor: Optional[XYBaseDataModule] = None,
2827
**kwargs,
2928
):
29+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
30+
3031
self.beta = beta
3132
if isinstance(data_extractor, LabeledUnlabeledMixed):
3233
data_extractor = data_extractor.labeled

chebai/loss/semantic.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
import math
33
import os
44
import pickle
5-
from typing import List, Literal, Union
5+
from typing import TYPE_CHECKING, List, Literal, Union
66

77
import torch
88

99
from chebai.loss.bce_weighted import BCEWeighted
1010
from chebai.preprocessing.datasets.base import XYBaseDataModule
1111
from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor
12-
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
12+
13+
if TYPE_CHECKING:
14+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
1315

1416

1517
class ImplicationLoss(torch.nn.Module):
@@ -68,6 +70,8 @@ def __init__(
6870
multiply_with_base_loss: bool = True,
6971
no_grads: bool = False,
7072
):
73+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
74+
7175
super().__init__()
7276
# automatically choose labeled subset for implication filter in case of mixed dataset
7377
if isinstance(data_extractor, LabeledUnlabeledMixed):
@@ -338,7 +342,7 @@ class DisjointLoss(ImplicationLoss):
338342
def __init__(
339343
self,
340344
path_to_disjointness: str,
341-
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
345+
data_extractor: Union[_ChEBIDataExtractor, "LabeledUnlabeledMixed"],
342346
base_loss: torch.nn.Module = None,
343347
disjoint_loss_weight: float = 100,
344348
**kwargs,

chebai/models/ffn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ class FFN(ChebaiBaseNet):
1111

1212
def __init__(
1313
self,
14-
input_size: int,
1514
hidden_layers: List[int] = [
1615
1024,
1716
],
@@ -20,7 +19,7 @@ def __init__(
2019
super().__init__(**kwargs)
2120

2221
layers = []
23-
current_layer_input_size = input_size
22+
current_layer_input_size = self.input_dim
2423
for hidden_dim in hidden_layers:
2524
layers.append(MLPBlock(current_layer_input_size, hidden_dim))
2625
layers.append(Residual(MLPBlock(hidden_dim, hidden_dim)))

chebai/preprocessing/datasets/base.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
import os
22
import random
33
from abc import ABC, abstractmethod
4-
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
4+
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
55

66
import lightning as pl
7-
import networkx as nx
87
import pandas as pd
98
import torch
109
import tqdm
11-
from iterstrat.ml_stratifiers import (
12-
MultilabelStratifiedKFold,
13-
MultilabelStratifiedShuffleSplit,
14-
)
1510
from lightning.pytorch.core.datamodule import LightningDataModule
1611
from lightning_utilities.core.rank_zero import rank_zero_info
17-
from sklearn.model_selection import StratifiedShuffleSplit
1812
from torch.utils.data import DataLoader
1913

2014
from chebai.preprocessing import reader as dr
2115

16+
if TYPE_CHECKING:
17+
import networkx as nx
18+
2219

2320
class XYBaseDataModule(LightningDataModule):
2421
"""
@@ -822,7 +819,7 @@ def _download_required_data(self) -> str:
822819
pass
823820

824821
@abstractmethod
825-
def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
822+
def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph":
826823
"""
827824
Extracts the class hierarchy from the data.
828825
Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
@@ -837,7 +834,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
837834
pass
838835

839836
@abstractmethod
840-
def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
837+
def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame:
841838
"""
842839
Converts the graph to a raw dataset.
843840
Uses the graph created by `_extract_class_hierarchy` method to extract the
@@ -852,7 +849,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
852849
pass
853850

854851
@abstractmethod
855-
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
852+
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
856853
"""
857854
Selects classes from the dataset based on a specified criteria.
858855
@@ -1027,6 +1024,9 @@ def get_test_split(
10271024
Raises:
10281025
ValueError: If the DataFrame does not contain a column named "labels".
10291026
"""
1027+
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
1028+
from sklearn.model_selection import StratifiedShuffleSplit
1029+
10301030
print("Get test data split")
10311031

10321032
labels_list = df["labels"].tolist()
@@ -1064,6 +1064,12 @@ def get_train_val_splits_given_test(
10641064
and validation DataFrames. The keys are the names of the train and validation sets, and the values
10651065
are the corresponding DataFrames.
10661066
"""
1067+
from iterstrat.ml_stratifiers import (
1068+
MultilabelStratifiedKFold,
1069+
MultilabelStratifiedShuffleSplit,
1070+
)
1071+
from sklearn.model_selection import StratifiedShuffleSplit
1072+
10671073
print("Split dataset into train / val with given test set")
10681074

10691075
test_ids = test_df["ident"].tolist()

chebai/preprocessing/datasets/chebi.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,28 @@
1313
import pickle
1414
from abc import ABC
1515
from collections import OrderedDict
16-
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union
16+
from typing import (
17+
TYPE_CHECKING,
18+
Any,
19+
Dict,
20+
Generator,
21+
List,
22+
Literal,
23+
Optional,
24+
Tuple,
25+
Union,
26+
)
1727

18-
import fastobo
19-
import networkx as nx
2028
import pandas as pd
21-
import requests
2229
import torch
2330

2431
from chebai.preprocessing import reader as dr
2532
from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset
2633

34+
if TYPE_CHECKING:
35+
import fastobo
36+
import networkx as nx
37+
2738
# exclude some entities from the dataset because the violate disjointness axioms
2839
CHEBI_BLACKLIST = [
2940
194026,
@@ -214,6 +225,8 @@ def _load_chebi(self, version: int) -> str:
214225
Returns:
215226
str: The file path of the loaded ChEBI ontology.
216227
"""
228+
import requests
229+
217230
chebi_name = self.raw_file_names_dict["chebi"]
218231
chebi_path = os.path.join(self.raw_dir, chebi_name)
219232
if not os.path.isfile(chebi_path):
@@ -225,7 +238,7 @@ def _load_chebi(self, version: int) -> str:
225238
open(chebi_path, "wb").write(r.content)
226239
return chebi_path
227240

228-
def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
241+
def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph":
229242
"""
230243
Extracts the class hierarchy from the ChEBI ontology.
231244
Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
@@ -237,6 +250,9 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
237250
Returns:
238251
nx.DiGraph: The class hierarchy.
239252
"""
253+
import fastobo
254+
import networkx as nx
255+
240256
with open(data_path, encoding="utf-8") as chebi:
241257
chebi = "\n".join(line for line in chebi if not line.startswith("xref:"))
242258

@@ -266,7 +282,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
266282
print("Compute transitive closure")
267283
return nx.transitive_closure_dag(g)
268284

269-
def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
285+
def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame:
270286
"""
271287
Converts the graph to a raw dataset.
272288
Uses the graph created by `_extract_class_hierarchy` method to extract the
@@ -278,6 +294,8 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
278294
Returns:
279295
pd.DataFrame: The raw dataset created from the graph.
280296
"""
297+
import networkx as nx
298+
281299
smiles = nx.get_node_attributes(g, "smiles")
282300
names = nx.get_node_attributes(g, "name")
283301

@@ -590,7 +608,7 @@ def _name(self) -> str:
590608
"""
591609
return f"ChEBI{self.THRESHOLD}"
592610

593-
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
611+
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
594612
"""
595613
Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold.
596614
@@ -615,6 +633,8 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
615633
- The `THRESHOLD` attribute should be defined in the subclass of this class.
616634
- Nodes without a 'smiles' attribute are ignored in the successor count.
617635
"""
636+
import networkx as nx
637+
618638
smiles = nx.get_node_attributes(g, "smiles")
619639
nodes = list(
620640
sorted(
@@ -753,7 +773,7 @@ def processed_dir_main(self) -> str:
753773
"processed",
754774
)
755775

756-
def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
776+
def _extract_class_hierarchy(self, chebi_path: str) -> "nx.DiGraph":
757777
"""
758778
Extracts a subset of ChEBI based on subclasses of the top class ID.
759779
@@ -791,8 +811,10 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
791811
)
792812
return g
793813

794-
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
814+
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
795815
"""Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
816+
import networkx as nx
817+
796818
smiles = nx.get_node_attributes(g, "smiles")
797819
nodes = list(
798820
sorted(
@@ -868,7 +890,7 @@ def chebi_to_int(s: str) -> int:
868890
return int(s[s.index(":") + 1 :])
869891

870892

871-
def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
893+
def term_callback(doc: "fastobo.term.TermFrame") -> Union[Dict, bool]:
872894
"""
873895
Extracts information from a ChEBI term document.
874896
This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents,
@@ -885,6 +907,8 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
885907
- "name": The name of the ChEBI term.
886908
- "smiles": The SMILES string associated with the ChEBI term, if available.
887909
"""
910+
import fastobo
911+
888912
parts = set()
889913
parents = []
890914
name = None

chebai/preprocessing/reader.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@
55
from itertools import islice
66
from typing import Any, Dict, List, Optional
77

8-
import deepsmiles
9-
import selfies as sf
108
from pysmiles.read_smiles import _tokenize
119
from rdkit import Chem
12-
from transformers import RobertaTokenizerFast
1310

1411
from chebai.preprocessing.collate import DefaultCollator, RaggedCollator
1512

@@ -224,6 +221,8 @@ class DeepChemDataReader(ChemDataReader):
224221
"""
225222

226223
def __init__(self, *args, **kwargs):
224+
import deepsmiles
225+
227226
super().__init__(*args, **kwargs)
228227
self.converter = deepsmiles.Converter(rings=True, branches=True)
229228
self.error_count = 0
@@ -298,6 +297,8 @@ def __init__(
298297
vsize: int = 4000,
299298
**kwargs,
300299
):
300+
from transformers import RobertaTokenizerFast
301+
301302
super().__init__(*args, **kwargs)
302303
self.tokenizer = RobertaTokenizerFast.from_pretrained(
303304
data_path, max_len=max_len
@@ -331,6 +332,8 @@ def __init__(
331332
vsize: int = 4000,
332333
**kwargs,
333334
):
335+
import selfies as sf
336+
334337
super().__init__(*args, **kwargs)
335338
self.error_count = 0
336339
sf.set_semantic_constraints("hypervalent")
@@ -342,6 +345,8 @@ def name(cls) -> str:
342345

343346
def _read_data(self, raw_data: str) -> Optional[List[int]]:
344347
"""Read and tokenize raw data using SELFIES."""
348+
import selfies as sf
349+
345350
try:
346351
tokenized = sf.split_selfies(sf.encoder(raw_data.strip(), strict=True))
347352
tokenized = [self._get_token_index(v) for v in tokenized]

chebai/preprocessing/structures.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from typing import Any, Tuple, Union
1+
from typing import TYPE_CHECKING, Any, Tuple, Union
22

3-
import networkx as nx
43
import torch
54

5+
if TYPE_CHECKING:
6+
import networkx as nx
7+
68

79
class XYData(torch.utils.data.Dataset):
810
"""
@@ -119,7 +121,7 @@ class XYMolData(XYData):
119121
kwargs: Additional fields to store in the dataset.
120122
"""
121123

122-
def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]:
124+
def to_x(self, device: torch.device) -> Tuple["nx.Graph", ...]:
123125
"""
124126
Moves the node attributes of the molecular graphs to the specified device.
125127
@@ -129,6 +131,8 @@ def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]:
129131
Returns:
130132
A tuple of molecular graphs with node attributes on the specified device.
131133
"""
134+
import networkx as nx
135+
132136
l_ = []
133137
for g in self.x:
134138
graph = g.copy()

0 commit comments

Comments
 (0)