Skip to content

Commit ce368ef

Browse files
committed
Merge branch 'dev' into feature/test_chebai_cli
2 parents 061506a + 1d8a7c3 commit ce368ef

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

chebai/result/analyse_sem.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import traceback
44
from datetime import datetime
5+
from pathlib import Path
56
from typing import List, LiteralString, Optional, Tuple
67

78
import pandas as pd
@@ -155,9 +156,11 @@ def get_disjoint_groups(disjoint_files):
155156
disjoint_files = os.path.join("data", "chebi-disjoints.owl")
156157
disjoint_pairs, disjoint_groups = [], []
157158
for file in disjoint_files:
158-
if file.split(".")[-1] == "csv":
159+
if isinstance(file, Path):
160+
file = str(file)
161+
if file.endswith(".csv"):
159162
disjoint_pairs += pd.read_csv(file, header=None).values.tolist()
160-
elif file.split(".")[-1] == "owl":
163+
elif file.endswith(".owl"):
161164
with open(file, "r") as f:
162165
plaintext = f.read()
163166
segments = plaintext.split("<")
@@ -217,10 +220,16 @@ def set_label_names(self, label_names):
217220
self.label_successors = self.label_successors.unsqueeze(0)
218221

219222
def __call__(self, preds):
223+
if preds.shape[1] == 0:
224+
# no labels predicted
225+
return preds
226+
# preds shape: (n_samples, n_labels)
220227
preds_sum_orig = torch.sum(preds)
221228
# step 1: apply implications: for each class, set prediction to max of itself and all successors
222229
preds = preds.unsqueeze(1)
223230
preds_masked_succ = torch.where(self.label_successors, preds, 0)
231+
# preds_masked_succ shape: (n_samples, n_labels, n_labels)
232+
224233
preds = preds_masked_succ.max(dim=2).values
225234
if torch.sum(preds) != preds_sum_orig:
226235
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
[build-system]
2-
requires = ["setuptools>=61.0", "wheel"]
2+
requires = ["setuptools >= 77.0.3", "wheel"]
33
build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "chebai"
7-
version = "1.0.2"
7+
version = "1.0.3"
88
description = "ChEBai is a deep learning library designed for the integration of deep learning methods with chemical ontologies, particularly ChEBI."
99
authors = [
1010
{ name = "MGlauer", email = "[email protected]" }

0 commit comments

Comments
 (0)