Skip to content

Commit 0dce958

Browse files
committed
- pyyaml instead of yaml
- union instead of pipe
1 parent 3ca5707 commit 0dce958

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

chebai/loss/semantic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import math
66
import torch
7-
from typing import Literal
7+
from typing import Literal, Union
88

99
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor, ChEBIOver100
1010
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
@@ -14,7 +14,7 @@
1414
class ImplicationLoss(torch.nn.Module):
1515
def __init__(
1616
self,
17-
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
17+
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
1818
base_loss: torch.nn.Module = None,
1919
tnorm: Literal["product", "lukasiewicz", "xu19"] = "product",
2020
impl_loss_weight=0.1, # weight of implication loss in relation to base_loss
@@ -114,7 +114,7 @@ class DisjointLoss(ImplicationLoss):
114114
def __init__(
115115
self,
116116
path_to_disjointness,
117-
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
117+
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
118118
base_loss: torch.nn.Module = None,
119119
disjoint_loss_weight=100,
120120
**kwargs,

chebai/result/analyse_sem.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torchmetrics.functional.classification import multilabel_f1_score
1212
import wandb
1313
import gc
14+
from typing import List,Union
1415
from utils import *
1516

1617
DEVICE = "cpu" # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -244,7 +245,7 @@ def analyse_run(
244245
labeled_data_cls=ChEBIOver100, # use labels from this dataset for violations
245246
chebi_version=231,
246247
results_path=os.path.join("_semantic", "eval_results.csv"),
247-
violation_metrics: [str | list[callable]] = "all",
248+
violation_metrics: Union[str, List[callable]] = "all",
248249
verbose_violation_output=False,
249250
):
250251
"""Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided),

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"iterative-stratification",
4949
"wandb",
5050
"chardet",
51-
"yaml",
51+
"pyyaml",
5252
"torchmetrics",
5353
],
5454
extras_require={"dev": ["black", "isort", "pre-commit"]},

0 commit comments

Comments
 (0)