Skip to content

Commit e01f7b7

Browse files
author
sfluegel
committed
Merge branch 'refs/heads/dev' into refactor_term_callback
2 parents 630add7 + bf9e642 commit e01f7b7

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from lightning.pytorch.core.datamodule import LightningDataModule
1616
from lightning_utilities.core.rank_zero import rank_zero_info
17+
from sklearn.model_selection import StratifiedShuffleSplit
1718
from torch.utils.data import DataLoader
1819

1920
from chebai.preprocessing import reader as dr
@@ -929,11 +930,17 @@ def get_test_split(
929930
labels_list = df["labels"].tolist()
930931

931932
test_size = 1 - self.train_split - (1 - self.train_split) ** 2
932-
msss = MultilabelStratifiedShuffleSplit(
933-
n_splits=1, test_size=test_size, random_state=seed
934-
)
935933

936-
train_indices, test_indices = next(msss.split(labels_list, labels_list))
934+
if len(labels_list[0]) > 1:
935+
splitter = MultilabelStratifiedShuffleSplit(
936+
n_splits=1, test_size=test_size, random_state=seed
937+
)
938+
else:
939+
splitter = StratifiedShuffleSplit(
940+
n_splits=1, test_size=test_size, random_state=seed
941+
)
942+
943+
train_indices, test_indices = next(splitter.split(labels_list, labels_list))
937944

938945
df_train = df.iloc[train_indices]
939946
df_test = df.iloc[test_indices]
@@ -985,12 +992,18 @@ def get_train_val_splits_given_test(
985992

986993
# scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split)
987994
test_size = ((1 - self.train_split) ** 2) / self.train_split
988-
msss = MultilabelStratifiedShuffleSplit(
989-
n_splits=1, test_size=test_size, random_state=seed
990-
)
995+
996+
if len(labels_list_trainval[0]) > 1:
997+
splitter = MultilabelStratifiedShuffleSplit(
998+
n_splits=1, test_size=test_size, random_state=seed
999+
)
1000+
else:
1001+
splitter = StratifiedShuffleSplit(
1002+
n_splits=1, test_size=test_size, random_state=seed
1003+
)
9911004

9921005
train_indices, validation_indices = next(
993-
msss.split(labels_list_trainval, labels_list_trainval)
1006+
splitter.split(labels_list_trainval, labels_list_trainval)
9941007
)
9951008

9961009
df_validation = df_trainval.iloc[validation_indices]

chebai/preprocessing/datasets/chebi.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,9 @@ def __init__(self, top_class_id: int, **kwargs):
750750
top_class_id (int): The ID of the top class from which to extract subclasses.
751751
**kwargs: Additional keyword arguments passed to the superclass initializer.
752752
"""
753+
if "top_class_id" not in kwargs:
754+
kwargs["top_class_id"] = top_class_id
755+
753756
self.top_class_id: int = top_class_id
754757
super().__init__(**kwargs)
755758

@@ -772,27 +775,18 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
772775
"""
773776
Extracts a subset of ChEBI based on subclasses of the top class ID.
774777
778+
This method calls the superclass method to extract the full class hierarchy,
779+
then extracts the subgraph containing only the descendants of the top class ID, including itself.
780+
775781
Args:
776782
chebi_path (str): The file path to the ChEBI ontology file.
777783
778784
Returns:
779-
nx.DiGraph: The extracted class hierarchy as a directed graph.
785+
nx.DiGraph: The extracted class hierarchy as a directed graph, limited to the
786+
descendants of the top class ID.
780787
"""
781-
with open(chebi_path, encoding="utf-8") as chebi:
782-
chebi = "\n".join(l for l in chebi if not l.startswith("xref:"))
783-
elements = [
784-
term_callback(clause)
785-
for clause in fastobo.loads(chebi)
786-
if clause and ":" in str(clause.id)
787-
]
788-
g = nx.DiGraph()
789-
for n in elements:
790-
g.add_node(n["id"], **n)
791-
g.add_edges_from([(p, q["id"]) for q in elements for p in q["parents"]])
792-
793-
g = nx.transitive_closure_dag(g)
794-
g = g.subgraph(list(nx.descendants(g, self.top_class_id)) + [self.top_class_id])
795-
print("Compute transitive closure")
788+
g = super()._extract_class_hierarchy(chebi_path)
789+
g = g.subgraph(list(g.successors(self.top_class_id)) + [self.top_class_id])
796790
return g
797791

798792

0 commit comments

Comments
 (0)