Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion conversion/guacamol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
from typing import Union

import torch
from loguru import logger
Expand Down Expand Up @@ -108,7 +109,11 @@ def check_smiles_graph_mapping_worker(smile_idx, smile):


def process(
split: str, raw_dir: str, n_jobs: int, limit: int | None, chunk_size: int
split: str,
raw_dir: str,
n_jobs: int,
limit: Union[int, None],
chunk_size: int,
) -> None:
path = os.path.join(raw_dir, f"guacamol_v1_{split}.smiles")
smile_list = [
Expand Down
2 changes: 1 addition & 1 deletion polygraph/metrics/base/polygraphdiscrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _descriptions_to_classifier_metric(
variant: Literal["informedness", "jsd"] = "jsd",
classifier: Optional[ClassifierProtocol] = None,
rng: Optional[np.random.Generator] = None,
) -> Tuple[float, int | float]:
) -> Tuple[float, Union[int, float]]:
rng = np.random.default_rng(0) if rng is None else rng

if isinstance(ref_descriptions, csr_array):
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "polygraph-benchmark"
version = "1.0.1"
version = "1.0.2"
description = "Evaluation benchmarks for graph generative models"
readme = "README.md"
authors = [
Expand All @@ -13,22 +13,22 @@ authors = [
{ name = "Dexiong Chen", email = "[email protected]" },
{ name = "Karsten Borgwardt", email = "[email protected]" },
]
requires-python = ">=3.7"
requires-python = ">=3.9"
dependencies = [
"numpy>=1.26.4,<3.0",
"torch>=2.4.0,<3.0",
"torch_geometric>=2.6.1,<3.0",
"rich",
"scipy>=1.14.0,<2.0",
"scipy>=1.12.0,<2.0",
"pydantic~=2.11.7",
"networkx>=3.4,<4.0",
"networkx>=3.2,<4.0",
"joblib",
"appdirs",
"loguru",
"rdkit",
"pandas",
"orbit-count",
"numba~=0.61.2",
"numba>=0.60.0,<0.62.0",
"scikit-learn>=1.6.1,<2.0",
"tabpfn==2.0.9",
"fcd~=1.2.2"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@
from polygraph.utils.mmd_utils import mmd_from_gram
from polygraph.metrics.base.metric_interval import MetricInterval

import grakel


class WeisfeilerLehmanMMD2(DescriptorMMD2):
def __init__(self, reference_graphs, iterations=3):
Expand All @@ -67,6 +65,8 @@ def __init__(self, reference_graphs, iterations=3):
def grakel_wl_mmd(
reference_graphs, test_graphs, is_parallel=False, iterations=3
):
import grakel

grakel_kernel = grakel.WeisfeilerLehman(n_iter=iterations)
all_graphs = reference_graphs + test_graphs
for g in all_graphs:
Expand Down