diff --git a/conversion/guacamol.py b/conversion/guacamol.py index 60d10f9..70f5f8f 100644 --- a/conversion/guacamol.py +++ b/conversion/guacamol.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Union import torch from loguru import logger @@ -108,7 +109,11 @@ def check_smiles_graph_mapping_worker(smile_idx, smile): def process( - split: str, raw_dir: str, n_jobs: int, limit: int | None, chunk_size: int + split: str, + raw_dir: str, + n_jobs: int, + limit: Union[int, None], + chunk_size: int, ) -> None: path = os.path.join(raw_dir, f"guacamol_v1_{split}.smiles") smile_list = [ diff --git a/polygraph/metrics/base/polygraphdiscrepancy.py b/polygraph/metrics/base/polygraphdiscrepancy.py index a43c551..6dbd8e2 100644 --- a/polygraph/metrics/base/polygraphdiscrepancy.py +++ b/polygraph/metrics/base/polygraphdiscrepancy.py @@ -251,7 +251,7 @@ def _descriptions_to_classifier_metric( variant: Literal["informedness", "jsd"] = "jsd", classifier: Optional[ClassifierProtocol] = None, rng: Optional[np.random.Generator] = None, -) -> Tuple[float, int | float]: +) -> Tuple[float, Union[int, float]]: rng = np.random.default_rng(0) if rng is None else rng if isinstance(ref_descriptions, csr_array): diff --git a/pyproject.toml b/pyproject.toml index e87419f..dd19486 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "polygraph-benchmark" -version = "1.0.1" +version = "1.0.2" description = "Evaluation benchmarks for graph generative models" readme = "README.md" authors = [ @@ -13,22 +13,22 @@ authors = [ { name = "Dexiong Chen", email = "dchen@biochem.mpg.de" }, { name = "Karsten Borgwardt", email = "borgwardt@biochem.mpg.de" }, ] -requires-python = ">=3.7" +requires-python = ">=3.9" dependencies = [ "numpy>=1.26.4,<3.0", "torch>=2.4.0,<3.0", "torch_geometric>=2.6.1,<3.0", "rich", - "scipy>=1.14.0,<2.0", + "scipy>=1.12.0,<2.0", "pydantic~=2.11.7", - "networkx>=3.4,<4.0", + "networkx>=3.2,<4.0", "joblib", "appdirs", "loguru", "rdkit", "pandas", "orbit-count", - "numba~=0.61.2", + "numba>=0.60.0,<0.62.0", "scikit-learn>=1.6.1,<2.0", "tabpfn==2.0.9", "fcd~=1.2.2" diff --git a/tests/test_mmd.py b/tests/test_mmd.py index eea07a0..5330144 100644 --- a/tests/test_mmd.py +++ b/tests/test_mmd.py @@ -48,8 +48,6 @@ from polygraph.utils.mmd_utils import mmd_from_gram from polygraph.metrics.base.metric_interval import MetricInterval -import grakel - class WeisfeilerLehmanMMD2(DescriptorMMD2): def __init__(self, reference_graphs, iterations=3): @@ -67,6 +65,8 @@ def __init__(self, reference_graphs, iterations=3): def grakel_wl_mmd( reference_graphs, test_graphs, is_parallel=False, iterations=3 ): + import grakel + grakel_kernel = grakel.WeisfeilerLehman(n_iter=iterations) all_graphs = reference_graphs + test_graphs for g in all_graphs: