Skip to content

Commit 81f8025

Browse files
committed
ruff fixes
1 parent fb6fdb7 commit 81f8025

File tree

6 files changed

+31
-64
lines changed

6 files changed

+31
-64
lines changed

chebai/preprocessing/datasets/molecule_classification.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,20 @@
1-
from tempfile import NamedTemporaryFile, TemporaryDirectory
1+
from tempfile import NamedTemporaryFile
22
from urllib import request
33
import csv
44
import gzip
55
import os
6-
import random
76
import shutil
8-
import zipfile
9-
from typing import Dict, Generator, List, Optional
7+
from typing import Dict, List
108

11-
from rdkit import Chem
129
from sklearn.model_selection import (
1310
GroupShuffleSplit,
1411
train_test_split,
15-
StratifiedShuffleSplit,
1612
)
1713
import numpy as np
18-
import pysmiles
1914
import torch
20-
from sklearn.preprocessing import LabelBinarizer
2115

2216
from chebai.preprocessing import reader as dr
23-
from chebai.preprocessing.datasets.base import MergedDataset, XYBaseDataModule
24-
from chebai.preprocessing.datasets.chebi import JCIExtendedTokenData
25-
from chebai.preprocessing.datasets.pubchem import Hazardous
17+
from chebai.preprocessing.datasets.base import XYBaseDataModule
2618

2719

2820
class ClinTox(XYBaseDataModule):
@@ -76,7 +68,7 @@ def setup_processed(self) -> None:
7668
"""Processes and splits the dataset."""
7769
print("Create splits")
7870
data = list(
79-
self._load_data_from_file(os.path.join(self.raw_dir, f"clintox.csv"))
71+
self._load_data_from_file(os.path.join(self.raw_dir, "clintox.csv"))
8072
)
8173
groups = np.array([d["group"] for d in data])
8274
if not all(g is None for g in groups):
@@ -229,14 +221,14 @@ def download(self) -> None:
229221
"""Downloads and extracts the dataset."""
230222
with open(os.path.join(self.raw_dir, "bbbp.csv"), "ab") as dst:
231223
with request.urlopen(
232-
f"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv",
224+
"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv",
233225
) as src:
234226
shutil.copyfileobj(src, dst)
235227

236228
def setup_processed(self) -> None:
237229
"""Processes and splits the dataset."""
238230
print("Create splits")
239-
data = list(self._load_data_from_file(os.path.join(self.raw_dir, f"bbbp.csv")))
231+
data = list(self._load_data_from_file(os.path.join(self.raw_dir, "bbbp.csv")))
240232
groups = np.array([d["group"] for d in data])
241233
if not all(g is None for g in groups):
242234
print("Group shuffled")
@@ -426,7 +418,7 @@ def download(self) -> None:
426418
def setup_processed(self) -> None:
427419
"""Processes and splits the dataset."""
428420
print("Create splits")
429-
data = list(self._load_data_from_file(os.path.join(self.raw_dir, f"sider.csv")))
421+
data = list(self._load_data_from_file(os.path.join(self.raw_dir, "sider.csv")))
430422
groups = np.array([d["group"] for d in data])
431423
if not all(g is None for g in groups):
432424
split_size = int(
@@ -581,14 +573,14 @@ def download(self) -> None:
581573
"""Downloads and extracts the dataset."""
582574
with open(os.path.join(self.raw_dir, "bace.csv"), "ab") as dst:
583575
with request.urlopen(
584-
f"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv",
576+
"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv",
585577
) as src:
586578
shutil.copyfileobj(src, dst)
587579

588580
def setup_processed(self) -> None:
589581
"""Processes and splits the dataset."""
590582
print("Create splits")
591-
data = list(self._load_data_from_file(os.path.join(self.raw_dir, f"bace.csv")))
583+
data = list(self._load_data_from_file(os.path.join(self.raw_dir, "bace.csv")))
592584
# groups = np.array([d.get("group") for d in data])
593585

594586
# if not all(g is None for g in groups):
@@ -729,14 +721,14 @@ def download(self) -> None:
729721
"""Downloads and extracts the dataset."""
730722
with open(os.path.join(self.raw_dir, "hiv.csv"), "ab") as dst:
731723
with request.urlopen(
732-
f"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv",
724+
"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv",
733725
) as src:
734726
shutil.copyfileobj(src, dst)
735727

736728
def setup_processed(self) -> None:
737729
"""Processes and splits the dataset."""
738730
print("Create splits")
739-
data = list(self._load_data_from_file(os.path.join(self.raw_dir, f"hiv.csv")))
731+
data = list(self._load_data_from_file(os.path.join(self.raw_dir, "hiv.csv")))
740732
groups = np.array([d["group"] for d in data])
741733
if not all(g is None for g in groups):
742734
print("Group shuffled")
@@ -913,7 +905,7 @@ def download(self) -> None:
913905
def setup_processed(self) -> None:
914906
"""Processes and splits the dataset."""
915907
print("Create splits")
916-
data = list(self._load_data_from_file(os.path.join(self.raw_dir, f"muv.csv")))
908+
data = list(self._load_data_from_file(os.path.join(self.raw_dir, "muv.csv")))
917909
groups = np.array([d["group"] for d in data])
918910
if not all(g is None for g in groups):
919911
split_size = int(

chebai/preprocessing/datasets/molecule_regression.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,14 @@
1-
from tempfile import NamedTemporaryFile, TemporaryDirectory
21
from urllib import request
32
import csv
4-
import gzip
53
import os
6-
import random
74
import shutil
8-
import zipfile
9-
from typing import Dict, Generator, List, Optional
5+
from typing import Dict, List
106

11-
from rdkit import Chem
12-
from sklearn.model_selection import GroupShuffleSplit, train_test_split
13-
import numpy as np
14-
import pysmiles
7+
from sklearn.model_selection import train_test_split
158
import torch
16-
from sklearn.preprocessing import LabelBinarizer
179

1810
from chebai.preprocessing import reader as dr
19-
from chebai.preprocessing.datasets.base import MergedDataset, XYBaseDataModule
20-
from chebai.preprocessing.datasets.chebi import JCIExtendedTokenData
21-
from chebai.preprocessing.datasets.pubchem import Hazardous
11+
from chebai.preprocessing.datasets.base import XYBaseDataModule
2212

2313

2414
class Lipo(XYBaseDataModule):
@@ -54,13 +44,13 @@ def download(self):
5444
# download
5545
with open(os.path.join(self.raw_dir, "Lipo.csv"), "ab") as dst:
5646
with request.urlopen(
57-
f"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv",
47+
"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv",
5848
) as src:
5949
shutil.copyfileobj(src, dst)
6050

6151
def setup_processed(self):
6252
print("Create splits")
63-
data = list(self._load_data_from_file(os.path.join(self.raw_dir, f"Lipo.csv")))
53+
data = list(self._load_data_from_file(os.path.join(self.raw_dir, "Lipo.csv")))
6454
print(len(data))
6555

6656
train_split, test_split = train_test_split(
@@ -189,14 +179,14 @@ def download(self):
189179
# download
190180
with open(os.path.join(self.raw_dir, "FreeSolv.csv"), "ab") as dst:
191181
with request.urlopen(
192-
f"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/SAMPL.csv",
182+
"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/SAMPL.csv",
193183
) as src:
194184
shutil.copyfileobj(src, dst)
195185

196186
def setup_processed(self):
197187
print("Create splits")
198188
data = list(
199-
self._load_data_from_file(os.path.join(self.raw_dir, f"FreeSolv.csv"))
189+
self._load_data_from_file(os.path.join(self.raw_dir, "FreeSolv.csv"))
200190
)
201191
print(len(data))
202192
train_split, test_split = train_test_split(

chebai/preprocessing/datasets/solCuration.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,14 @@
1-
from tempfile import NamedTemporaryFile, TemporaryDirectory
21
from urllib import request
32
import csv
4-
import gzip
53
import os
6-
import random
74
import shutil
8-
import zipfile
9-
from typing import Dict, Generator, List, Optional
5+
from typing import Dict, List
106

11-
from rdkit import Chem
12-
from sklearn.model_selection import GroupShuffleSplit, train_test_split
13-
import numpy as np
14-
import pysmiles
7+
from sklearn.model_selection import train_test_split
158
import torch
16-
from sklearn.preprocessing import LabelBinarizer
179

1810
from chebai.preprocessing import reader as dr
19-
from chebai.preprocessing.datasets.base import MergedDataset, XYBaseDataModule
20-
from chebai.preprocessing.datasets.chebi import JCIExtendedTokenData
21-
from chebai.preprocessing.datasets.pubchem import Hazardous
11+
from chebai.preprocessing.datasets.base import XYBaseDataModule
2212

2313

2414
class SolCuration(XYBaseDataModule):
@@ -65,7 +55,7 @@ def download(self):
6555
def setup_processed(self):
6656
print("Create splits")
6757
data = list(
68-
self._load_data_from_file(os.path.join(self.raw_dir, f"solCuration.csv"))
58+
self._load_data_from_file(os.path.join(self.raw_dir, "solCuration.csv"))
6959
)
7060
print(len(data))
7161

@@ -144,7 +134,7 @@ def _load_data_from_file(self, input_file_path: str) -> List[Dict]:
144134
with open(input_file_path, "r") as input_file:
145135
reader = csv.DictReader(input_file)
146136
for row in reader:
147-
if not row["smiles"] in smiles_l:
137+
if row["smiles"] not in smiles_l:
148138
smiles_l.append(row["smiles"])
149139
labels_l.append(float(row["logS"]))
150140
# print(len(smiles_l), len(labels_l))
@@ -204,14 +194,14 @@ def download(self):
204194
# download
205195
with open(os.path.join(self.raw_dir, "solESOL.csv"), "ab") as dst:
206196
with request.urlopen(
207-
f"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv",
197+
"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv",
208198
) as src:
209199
shutil.copyfileobj(src, dst)
210200

211201
def setup_processed(self):
212202
print("Create splits")
213203
data = list(
214-
self._load_data_from_file(os.path.join(self.raw_dir, f"solESOL.csv"))
204+
self._load_data_from_file(os.path.join(self.raw_dir, "solESOL.csv"))
215205
)
216206
print(len(data))
217207

chebai/preprocessing/datasets/tox21.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from sklearn.model_selection import (
1414
GroupShuffleSplit,
1515
train_test_split,
16-
StratifiedShuffleSplit,
1716
)
1817

1918
from chebai.preprocessing import reader as dr
@@ -75,7 +74,7 @@ def download(self) -> None:
7574
def setup_processed(self) -> None:
7675
"""Processes and splits the dataset."""
7776
print("Create splits")
78-
data = list(self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv")))
77+
data = list(self._load_data_from_file(os.path.join(self.raw_dir, "tox21.csv")))
7978
groups = np.array([d.get("group") for d in data])
8079

8180
if not all(g is None for g in groups):

chebai/result/regression.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
from typing import List
21

3-
import matplotlib.pyplot as plt
4-
import pandas as pd
5-
import seaborn as sns
62
from torch import Tensor
73
from torchmetrics.regression import (
84
MeanSquaredError,

chebai/result/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def evaluate_model_regression(
234234
save_batch_size = 128
235235
n_saved = 1
236236

237-
print(f"")
237+
print("")
238238
for i in tqdm.tqdm(range(0, len(data_list), batch_size)):
239239
if not (
240240
skip_existing_preds
@@ -333,7 +333,7 @@ def evaluate_model_regression_attention(
333333
save_batch_size = 128
334334
n_saved = 1
335335

336-
print(f"")
336+
print("")
337337
for i in tqdm.tqdm(range(0, len(data_list), batch_size)):
338338
if not (
339339
skip_existing_preds
@@ -434,7 +434,7 @@ def evaluate_model_regression(
434434
save_batch_size = 128
435435
n_saved = 1
436436

437-
print(f"")
437+
print("")
438438
for i in tqdm.tqdm(range(0, len(data_list), batch_size)):
439439
if not (
440440
skip_existing_preds
@@ -533,7 +533,7 @@ def evaluate_model_regression_attention(
533533
save_batch_size = 128
534534
n_saved = 1
535535

536-
print(f"")
536+
print("")
537537
for i in tqdm.tqdm(range(0, len(data_list), batch_size)):
538538
if not (
539539
skip_existing_preds

0 commit comments

Comments
 (0)