Skip to content

Commit b69c538

Browse files
authored
Merge pull request #525 from pyt-team/frantzen/ahorn-multinetwork
Support for multi-network AHORN datasets
2 parents 5061329 + b201558 commit b69c538

File tree

2 files changed

+82
-11
lines changed

2 files changed

+82
-11
lines changed

test/datasets/test_ahorn.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Tests for AHORN dataset loader."""
22
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
33

4+
from io import StringIO
5+
46
import networkx as nx
57
import pytest
68

@@ -50,3 +52,36 @@ def test_read_missing_dependency_raises() -> None:
5052
"""When `ahorn-loader` is not installed, read_ahorn_dataset raises RuntimeError."""
5153
with pytest.raises(RuntimeError, match=r"optional `ahorn-loader`"):
5254
read_ahorn_dataset("dummy.json")
55+
56+
57+
@pytest.mark.skipif(
58+
ahorn_loader is None, reason="Optional dependency `ahorn-loader` not installed."
59+
)
60+
def test_read_multi_network_dataset() -> None:
61+
"""Test reading a multi-network AHORN dataset from a mock file."""
62+
mock_data = """{"name": "Mock Multi-Network Dataset"}
63+
{"id": "network-001"}
64+
0 {"label": "node0_net1"}
65+
1 {"label": "node1_net1"}
66+
2 {"label": "node2_net1"}
67+
{"id": "network-002"}
68+
0,1 {"weight": 2.0}
69+
1,2 {"weight": 3.0}
70+
0,1,2 {}
71+
"""
72+
73+
result = read_ahorn_dataset(StringIO(mock_data))
74+
75+
assert isinstance(result, list)
76+
assert len(result) == 2
77+
78+
assert len(list(result[0].simplices)) == 3
79+
assert result[0].complex["id"] == "network-001"
80+
assert result[0].nodes[0]["label"] == "node0_net1"
81+
assert result[0].nodes[1]["label"] == "node1_net1"
82+
assert result[0].nodes[2]["label"] == "node2_net1"
83+
84+
assert len(list(result[1].simplices)) == 7
85+
assert result[1].complex["id"] == "network-002"
86+
assert result[1].simplices[(0, 1)]["weight"] == 2.0
87+
assert result[1].simplices[(1, 2)]["weight"] == 3.0

toponetx/datasets/ahorn.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@
2727
"""
2828

2929
import json
30+
from collections.abc import Iterable
3031
from pathlib import Path
3132
from typing import IO
3233

34+
from more_itertools import peekable, split_before
35+
3336
from toponetx.classes import CellComplex, SimplicialComplex
3437
from toponetx.classes.complex import Complex
3538

@@ -56,7 +59,9 @@ def _assert_ahorn_loader_installed() -> None:
5659
)
5760

5861

59-
def load_ahorn_dataset[T: Complex](name: str, create_using: type[T] | None = None) -> T:
62+
def load_ahorn_dataset[T: Complex](
63+
name: str, create_using: type[T] | None = None
64+
) -> T | Iterable[T]:
6065
"""Load the specified dataset from the Aachen Higher-Order Repository of Networks.
6166
6267
The dataset file will be stored in your system cache and can be deleted according
@@ -73,8 +78,9 @@ def load_ahorn_dataset[T: Complex](name: str, create_using: type[T] | None = Non
7378
7479
Returns
7580
-------
76-
Complex
77-
The complex representing the AHORN dataset.
81+
Complex or list[Complex]
82+
The complex representing the AHORN dataset. A list of complexes if the dataset
83+
contains multiple networks.
7884
7985
Raises
8086
------
@@ -96,7 +102,7 @@ def load_ahorn_dataset[T: Complex](name: str, create_using: type[T] | None = Non
96102

97103
def read_ahorn_dataset[T](
98104
path: str | Path | IO[str], create_using: type[T] | None = None
99-
) -> T:
105+
) -> T | Iterable[T]:
100106
"""Read an AHORN dataset from a local file or file-like object.
101107
102108
This function accepts file paths and file-like objects provided by users. When
@@ -113,8 +119,9 @@ def read_ahorn_dataset[T](
113119
114120
Returns
115121
-------
116-
Complex
117-
The complex representing the AHORN dataset.
122+
Complex or list[Complex]
123+
The complex representing the AHORN dataset. A list of complexes if the dataset
124+
contains multiple networks.
118125
119126
Raises
120127
------
@@ -138,9 +145,15 @@ def read_ahorn_dataset[T](
138145
raise RuntimeError(f"Failed to read dataset: {e!s}") from e
139146

140147

141-
def _read_ahorn_dataset[T](file: IO[str], create_using: type[T] | None = None) -> T:
148+
def _read_ahorn_dataset[T](
149+
file: Iterable[str], create_using: type[T] | None = None
150+
) -> T | Iterable[T]:
142151
"""Read AHORN dataset from file-like object.
143152
153+
Supports both single-network and multi-network datasets. Multi-network datasets are
154+
detected by checking if the first two lines both start with '{', indicating they are
155+
JSON objects representing separate networks.
156+
144157
Parameters
145158
----------
146159
file : IO
@@ -150,15 +163,38 @@ def _read_ahorn_dataset[T](file: IO[str], create_using: type[T] | None = None) -
150163
151164
Returns
152165
-------
153-
Complex
154-
The complex representing the AHORN dataset.
166+
Complex or list[Complex]
167+
The complex representing the AHORN dataset. A list of complexes if the dataset
168+
contains multiple networks.
155169
"""
156170
if create_using is None:
157171
create_using = SimplicialComplex
158172

159-
complex_obj = create_using(**json.loads(next(file)))
173+
# Convert to peekable iterator to detect multi-network datasets
174+
lines = peekable(file)
175+
176+
# Check if this is a multi-network dataset by peeking at the first two lines
177+
first_line = next(lines)
178+
is_multi_network = False
179+
try:
180+
is_multi_network = lines.peek().strip().startswith("{")
181+
except StopIteration:
182+
is_multi_network = False
183+
184+
# If multi-network, create empty complex; otherwise parse first line
185+
if is_multi_network:
186+
network_lines = list(
187+
split_before(lines, lambda line: line.strip().startswith("{"))
188+
)
189+
return [
190+
_read_ahorn_dataset(network, create_using=create_using)
191+
for network in network_lines
192+
]
193+
194+
complex_obj = create_using(**json.loads(first_line))
160195

161-
for line_num, line in enumerate(file, start=2):
196+
# Process remaining lines
197+
for line_num, line in enumerate(lines, start=2):
162198
try:
163199
elements_part, metadata = line.split(" ", maxsplit=1)
164200
elements = list(map(int, elements_part.split(",")))

0 commit comments

Comments
 (0)