Skip to content

Commit 2b5b033

Browse files
committed
remove unused imports + minor type hints update
1 parent 4a34864 commit 2b5b033

File tree

7 files changed

+18
-21
lines changed

7 files changed

+18
-21
lines changed

chebai/callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from lightning.pytorch.callbacks import BasePredictionWriter
55
import torch
6-
from typing import Any, Dict, List, Union
6+
from typing import Any, Dict, List, Union, Literal
77

88

99
class ChebaiPredictionWriter(BasePredictionWriter):
@@ -22,7 +22,7 @@ class ChebaiPredictionWriter(BasePredictionWriter):
2222
def __init__(
2323
self,
2424
output_dir: str,
25-
write_interval: str,
25+
write_interval: Literal["batch", "epoch", "batch_and_epoch"],
2626
target_file: str = "predictions.json",
2727
) -> None:
2828
super().__init__(write_interval)

chebai/callbacks/prediction_callback.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import os
55
import pickle
6-
from typing import Sequence, Any
6+
from typing import Sequence, Any, Literal
77

88

99
class PredictionWriter(BasePredictionWriter):
@@ -15,7 +15,11 @@ class PredictionWriter(BasePredictionWriter):
1515
write_interval (str): When to write predictions. Options are "batch" or "epoch".
1616
"""
1717

18-
def __init__(self, output_dir: str, write_interval: str):
18+
def __init__(
19+
self,
20+
output_dir: str,
21+
write_interval: Literal["batch", "epoch", "batch_and_epoch"],
22+
):
1923
super().__init__(write_interval)
2024
self.output_dir = output_dir
2125
self.prediction_file_name = "predictions.pkl"

chebai/loss/bce_weighted.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from typing import Optional
2+
13
import torch
24
from chebai.preprocessing.datasets.base import XYBaseDataModule
35
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
46
import pandas as pd
57
import os
6-
import pickle
78

89

910
class BCEWeighted(torch.nn.BCEWithLogitsLoss):

chebai/molecule.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import networkx as nx
1212
import numpy as np
13-
import six
1413
import torch
1514

1615
logger = logging.getLogger(__name__)
@@ -131,6 +130,9 @@ def create_directed_graphs(self):
131130
def create_feature_vectors(self):
132131
"""
133132
Creates feature vectors based on local environments of atoms.
133+
134+
Note:
135+
create a three-dimensional matrix I, such that I_{i,j} is the local input vector for jth vertex in ith DAG
134136
"""
135137
length_of_bond_features = Molecule.num_bond_features()
136138
length_of_atom_features = Molecule.num_atom_features()

chebai/preprocessing/datasets/tox21.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,17 @@
33
import csv
44
import gzip
55
import os
6-
import random
76
import shutil
87
import zipfile
9-
from typing import List, Dict, Generator
8+
from typing import List, Dict, Generator, Optional
109

1110
from rdkit import Chem
1211
from sklearn.model_selection import GroupShuffleSplit, train_test_split
1312
import numpy as np
1413
import torch
15-
import pysmiles
1614

1715
from chebai.preprocessing import reader as dr
18-
from chebai.preprocessing.datasets.base import MergedDataset, XYBaseDataModule
19-
from chebai.preprocessing.datasets.chebi import JCIExtendedTokenData
20-
from chebai.preprocessing.datasets.pubchem import Hazardous
16+
from chebai.preprocessing.datasets.base import XYBaseDataModule
2117

2218

2319
class Tox21MolNet(XYBaseDataModule):
@@ -218,7 +214,7 @@ def download(self) -> None:
218214
)
219215

220216
def _retrieve_file(
221-
self, url: str, target_file: str, compression: str = None
217+
self, url: str, target_file: str, compression: Optional[str] = None
222218
) -> None:
223219
"""Retrieves a file from a URL and saves it locally.
224220

chebai/preprocessing/structures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Tuple, Union
1+
from typing import Any, Tuple, Union
22
from torch.utils.data.dataset import T_co
33
import networkx as nx
44
import torch

chebai/result/classification.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
1-
import os
2-
from typing import List, Optional, Tuple
1+
from typing import List
32

43
import matplotlib.pyplot as plt
54
import pandas as pd
65
import seaborn as sns
7-
import torch
86
from torch import Tensor
97
from torchmetrics.classification import (
108
MultilabelF1Score,
119
MultilabelPrecision,
1210
MultilabelRecall,
1311
)
14-
import tqdm
1512

1613
from chebai.callbacks.epoch_metrics import MacroF1
17-
from chebai.models import ChebaiBaseNet
18-
from chebai.models.electra import Electra
19-
from chebai.preprocessing.datasets import XYBaseDataModule
2014
from chebai.result.utils import *
2115

2216

0 commit comments

Comments
 (0)