2727"""
2828
2929import json
30+ from collections .abc import Iterable
3031from pathlib import Path
3132from typing import IO
3233
34+ from more_itertools import peekable , split_before
35+
3336from toponetx .classes import CellComplex , SimplicialComplex
3437from toponetx .classes .complex import Complex
3538
@@ -56,7 +59,9 @@ def _assert_ahorn_loader_installed() -> None:
5659 )
5760
5861
59- def load_ahorn_dataset [T : Complex ](name : str , create_using : type [T ] | None = None ) -> T :
62+ def load_ahorn_dataset [T : Complex ](
63+ name : str , create_using : type [T ] | None = None
64+ ) -> T | Iterable [T ]:
6065 """Load the specified dataset from the Aachen Higher-Order Repository of Networks.
6166
6267 The dataset file will be stored in your system cache and can be deleted according
@@ -73,8 +78,9 @@ def load_ahorn_dataset[T: Complex](name: str, create_using: type[T] | None = Non
7378
7479 Returns
7580 -------
76- Complex
77- The complex representing the AHORN dataset.
81+ Complex or list[Complex]
82+ The complex representing the AHORN dataset. A list of complexes if the dataset
83+ contains multiple networks.
7884
7985 Raises
8086 ------
@@ -96,7 +102,7 @@ def load_ahorn_dataset[T: Complex](name: str, create_using: type[T] | None = Non
96102
97103def read_ahorn_dataset [T ](
98104 path : str | Path | IO [str ], create_using : type [T ] | None = None
99- ) -> T :
105+ ) -> T | Iterable [ T ] :
100106 """Read an AHORN dataset from a local file or file-like object.
101107
102108 This function accepts file paths and file-like objects provided by users. When
@@ -113,8 +119,9 @@ def read_ahorn_dataset[T](
113119
114120 Returns
115121 -------
116- Complex
117- The complex representing the AHORN dataset.
122+ Complex or list[Complex]
123+ The complex representing the AHORN dataset. A list of complexes if the dataset
124+ contains multiple networks.
118125
119126 Raises
120127 ------
@@ -138,9 +145,15 @@ def read_ahorn_dataset[T](
138145 raise RuntimeError (f"Failed to read dataset: { e !s} " ) from e
139146
140147
141- def _read_ahorn_dataset [T ](file : IO [str ], create_using : type [T ] | None = None ) -> T :
148+ def _read_ahorn_dataset [T ](
149+ file : Iterable [str ], create_using : type [T ] | None = None
150+ ) -> T | Iterable [T ]:
142151 """Read AHORN dataset from file-like object.
143152
153+ Supports both single-network and multi-network datasets. Multi-network datasets are
154+ detected by checking if the first two lines both start with '{', indicating they are
155+ JSON objects representing separate networks.
156+
144157 Parameters
145158 ----------
146159 file : IO
@@ -150,15 +163,38 @@ def _read_ahorn_dataset[T](file: IO[str], create_using: type[T] | None = None) -
150163
151164 Returns
152165 -------
153- Complex
154- The complex representing the AHORN dataset.
166+ Complex or list[Complex]
167+ The complex representing the AHORN dataset. A list of complexes if the dataset
168+ contains multiple networks.
155169 """
156170 if create_using is None :
157171 create_using = SimplicialComplex
158172
159- complex_obj = create_using (** json .loads (next (file )))
173+ # Convert to peekable iterator to detect multi-network datasets
174+ lines = peekable (file )
175+
176+ # Check if this is a multi-network dataset by peeking at the first two lines
177+ first_line = next (lines )
178+ is_multi_network = False
179+ try :
180+ is_multi_network = lines .peek ().strip ().startswith ("{" )
181+ except StopIteration :
182+ is_multi_network = False
183+
184+ # If multi-network, create empty complex; otherwise parse first line
185+ if is_multi_network :
186+ network_lines = list (
187+ split_before (lines , lambda line : line .strip ().startswith ("{" ))
188+ )
189+ return [
190+ _read_ahorn_dataset (network , create_using = create_using )
191+ for network in network_lines
192+ ]
193+
194+ complex_obj = create_using (** json .loads (first_line ))
160195
161- for line_num , line in enumerate (file , start = 2 ):
196+ # Process remaining lines
197+ for line_num , line in enumerate (lines , start = 2 ):
162198 try :
163199 elements_part , metadata = line .split (" " , maxsplit = 1 )
164200 elements = list (map (int , elements_part .split ("," )))
0 commit comments