Skip to content

Commit 25177b3

Browse files
committed
Merge branch 'dev' into protein_prediction
2 parents d7e8097 + 4968b3b commit 25177b3

16 files changed

+92
-37
lines changed

chebai/models/electra.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ def __init__(
256256
# Load pretrained checkpoint if provided
257257
if pretrained_checkpoint:
258258
with open(pretrained_checkpoint, "rb") as fin:
259-
model_dict = torch.load(fin, map_location=self.device)
259+
model_dict = torch.load(
260+
fin, map_location=self.device, weights_only=False
261+
)
260262
if load_prefix:
261263
state_dict = filter_dict(model_dict["state_dict"], load_prefix)
262264
else:
@@ -414,7 +416,9 @@ def __init__(self, cone_dimensions=20, **kwargs):
414416
model_prefix = kwargs.get("load_prefix", None)
415417
if pretrained_checkpoint:
416418
with open(pretrained_checkpoint, "rb") as fin:
417-
model_dict = torch.load(fin, map_location=self.device)
419+
model_dict = torch.load(
420+
fin, map_location=self.device, weights_only=False
421+
)
418422
if model_prefix:
419423
state_dict = {
420424
str(k)[len(model_prefix) :]: v

chebai/preprocessing/datasets/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ def load_processed_data(
200200
filename = self.processed_file_names_dict[kind]
201201
except NotImplementedError:
202202
filename = f"{kind}.pt"
203-
return torch.load(os.path.join(self.processed_dir, filename))
203+
return torch.load(
204+
os.path.join(self.processed_dir, filename), weights_only=False
205+
)
204206

205207
def dataloader(self, kind: str, **kwargs) -> DataLoader:
206208
"""
@@ -519,7 +521,7 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
519521
DataLoader: DataLoader object for the specified subset.
520522
"""
521523
subdatasets = [
522-
torch.load(os.path.join(s.processed_dir, f"{kind}.pt"))
524+
torch.load(os.path.join(s.processed_dir, f"{kind}.pt"), weights_only=False)
523525
for s in self.subsets
524526
]
525527
dataset = [
@@ -1022,7 +1024,9 @@ def _retrieve_splits_from_csv(self) -> None:
10221024
splits_df = pd.read_csv(self.splits_file_path)
10231025

10241026
filename = self.processed_file_names_dict["data"]
1025-
data = torch.load(os.path.join(self.processed_dir, filename))
1027+
data = torch.load(
1028+
os.path.join(self.processed_dir, filename), weights_only=False
1029+
)
10261030
df_data = pd.DataFrame(data)
10271031

10281032
train_ids = splits_df[splits_df["split"] == "train"]["id"]
@@ -1081,7 +1085,9 @@ def load_processed_data(
10811085

10821086
# If filename is provided
10831087
try:
1084-
return torch.load(os.path.join(self.processed_dir, filename))
1088+
return torch.load(
1089+
os.path.join(self.processed_dir, filename), weights_only=False
1090+
)
10851091
except FileNotFoundError:
10861092
raise FileNotFoundError(f"File {filename} doesn't exist")
10871093

chebai/preprocessing/datasets/chebi.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pickle
1414
from abc import ABC
1515
from collections import OrderedDict
16-
from typing import Any, Dict, Generator, List, Optional, Tuple
16+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
1717

1818
import fastobo
1919
import networkx as nx
@@ -244,16 +244,26 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
244244
with open(data_path, encoding="utf-8") as chebi:
245245
chebi = "\n".join(l for l in chebi if not l.startswith("xref:"))
246246

247-
elements = [
248-
term_callback(clause)
249-
for clause in fastobo.loads(chebi)
250-
if clause and ":" in str(clause.id)
251-
]
247+
elements = []
248+
for term_doc in fastobo.loads(chebi):
249+
if (
250+
term_doc
251+
and isinstance(term_doc.id, fastobo.id.PrefixedIdent)
252+
and term_doc.id.prefix == "CHEBI"
253+
):
254+
term_dict = term_callback(term_doc)
255+
if term_dict:
256+
elements.append(term_dict)
252257

253258
g = nx.DiGraph()
254259
for n in elements:
255260
g.add_node(n["id"], **n)
256-
g.add_edges_from([(p, q["id"]) for q in elements for p in q["parents"]])
261+
262+
# Only take the edges which connects the existing nodes, to avoid internal creation of obsolete nodes
263+
# https://github.com/ChEB-AI/python-chebai/pull/55#issuecomment-2386654142
264+
g.add_edges_from(
265+
[(p, q["id"]) for q in elements for p in q["parents"] if g.has_node(p)]
266+
)
257267

258268
print("Compute transitive closure")
259269
return nx.transitive_closure_dag(g)
@@ -397,7 +407,9 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
397407
"""
398408
try:
399409
filename = self.processed_file_names_dict["data"]
400-
data_chebi_version = torch.load(os.path.join(self.processed_dir, filename))
410+
data_chebi_version = torch.load(
411+
os.path.join(self.processed_dir, filename), weights_only=False
412+
)
401413
except FileNotFoundError:
402414
raise FileNotFoundError(
403415
f"File data.pt doesn't exists. "
@@ -418,7 +430,8 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
418430
data_chebi_train_version = torch.load(
419431
os.path.join(
420432
self._chebi_version_train_obj.processed_dir, filename_train
421-
)
433+
),
434+
weights_only=False,
422435
)
423436
except FileNotFoundError:
424437
raise FileNotFoundError(
@@ -812,7 +825,7 @@ def chebi_to_int(s: str) -> int:
812825
return int(s[s.index(":") + 1 :])
813826

814827

815-
def term_callback(doc) -> dict:
828+
def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
816829
"""
817830
Extracts information from a ChEBI term document.
818831
This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents,
@@ -852,6 +865,12 @@ def term_callback(doc) -> dict:
852865
parents.append(chebi_to_int(str(clause.term)))
853866
elif isinstance(clause, fastobo.term.NameClause):
854867
name = str(clause.name)
868+
869+
if isinstance(clause, fastobo.term.IsObsoleteClause):
870+
if clause.obsolete:
871+
# if the term document contains clause as obsolete as true, skips this document.
872+
return False
873+
855874
return {
856875
"id": chebi_to_int(str(doc.id)),
857876
"parents": parents,

chebai/preprocessing/datasets/go_uniprot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,9 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
514514
"""
515515
try:
516516
filename = self.processed_file_names_dict["data"]
517-
data_go = torch.load(os.path.join(self.processed_dir, filename))
517+
data_go = torch.load(
518+
os.path.join(self.processed_dir, filename), weights_only=False
519+
)
518520
except FileNotFoundError:
519521
raise FileNotFoundError(
520522
f"File data.pt doesn't exists. "

chebai/preprocessing/datasets/pubchem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -891,10 +891,10 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
891891
DataLoader: DataLoader instance.
892892
"""
893893
labeled_data = torch.load(
894-
os.path.join(self.labeled.processed_dir, f"{kind}.pt")
894+
os.path.join(self.labeled.processed_dir, f"{kind}.pt"), weights_only=False
895895
)
896896
unlabeled_data = torch.load(
897-
os.path.join(self.unlabeled.processed_dir, f"{kind}.pt")
897+
os.path.join(self.unlabeled.processed_dir, f"{kind}.pt"), weights_only=False
898898
)
899899
if self.data_limit is not None:
900900
labeled_data = labeled_data[: self.data_limit]

chebai/preprocessing/migration/chebi_data_migration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def _combine_pt_splits(
168168
df_list: List[pd.DataFrame] = []
169169
for split, file_name in old_splits_file_names.items():
170170
file_path = os.path.join(old_dir, file_name)
171-
file_df = pd.DataFrame(torch.load(file_path))
171+
file_df = pd.DataFrame(torch.load(file_path, weights_only=False))
172172
df_list.append(file_df)
173173

174174
return pd.concat(df_list, ignore_index=True)

chebai/result/analyse_sem.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,9 @@ def run_all(
427427
os.path.join(buffer_dir_smoothed, "preds000.pt")
428428
):
429429
preds = torch.load(
430-
os.path.join(buffer_dir_smoothed, "preds000.pt"), DEVICE
430+
os.path.join(buffer_dir_smoothed, "preds000.pt"),
431+
DEVICE,
432+
weights_only=False,
431433
)
432434
labels = None
433435
else:

chebai/result/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _generate_predictions(self, data_path, raw=False, **kwargs):
5454
else:
5555
data_tuples = [
5656
(x.get("raw_features", x["ident"]), x["ident"], x)
57-
for x in torch.load(data_path)
57+
for x in torch.load(data_path, weights_only=False)
5858
]
5959

6060
for raw_features, ident, row in tqdm.tqdm(data_tuples):

chebai/result/pretraining.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def evaluate_model(logs_base_path, model_filename, data_module):
3434
collate = data_module.reader.COLLATOR()
3535
test_file = "test.pt"
3636
data_path = os.path.join(data_module.processed_dir, test_file)
37-
data_list = torch.load(data_path)
37+
data_list = torch.load(data_path, weights_only=False)
3838
preds_list = []
3939
labels_list = []
4040

chebai/result/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def load_results_from_buffer(
182182
torch.load(
183183
os.path.join(buffer_dir, filename),
184184
map_location=torch.device(device),
185+
weights_only=False,
185186
)
186187
)
187188
i += 1
@@ -194,6 +195,7 @@ def load_results_from_buffer(
194195
torch.load(
195196
os.path.join(buffer_dir, filename),
196197
map_location=torch.device(device),
198+
weights_only=False,
197199
)
198200
)
199201
i += 1

0 commit comments

Comments
 (0)