Skip to content

Commit 11d545f

Browse files
committed
Fix mypy type errors and improve type annotations
- Fix method signature mismatches between base and derived classes - Update load() and save() methods to return bool instead of dict - Fix return types for methods that should return List but were returning Set - Change dictionary type annotations from Dict[str, Any] to Dict[Any, Any] to support various key types - Fix draw_hypergraph() function to accept BaseHypergraphDB instead of HypergraphDB - Add proper type annotations and fix Liskov Substitution Principle violations - Update HypergraphViewer constructor to accept BaseHypergraphDB - Add missing imports for BaseHypergraphDB in draw.py - Format code with black to maintain consistency
1 parent d4a746d commit 11d545f

File tree

4 files changed

+29
-19
lines changed

4 files changed

+29
-19
lines changed

hyperdb/base.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@ class BaseHypergraphDB:
1212

1313
storage_file: Union[str, Path] = field(default="my_hypergraph.hgdb", compare=False)
1414

15-
def save(self, file_path: Union[str, Path]):
15+
def save(self, file_path: Union[str, Path]) -> bool:
1616
r"""
1717
Save the hypergraph to a file.
1818
1919
Args:
2020
``file_path`` (``Union[str, Path]``): The file path to save the
2121
hypergraph.
22+
23+
Returns:
24+
``bool``: True if successful, False otherwise.
2225
"""
2326
raise NotImplementedError
2427

@@ -32,13 +35,15 @@ def save_as(self, format: str, file_path: Union[str, Path]):
3235
"""
3336
raise NotImplementedError
3437

35-
@staticmethod
36-
def load(self, file_path: Union[str, Path]):
38+
def load(self, file_path: Union[str, Path]) -> bool:
3739
r"""
3840
Load the hypergraph from a file.
3941
4042
Args:
4143
``file_path`` (``Union[str, Path]``): The file path to load the hypergraph from.
44+
45+
Returns:
46+
``bool``: True if successful, False otherwise.
4247
"""
4348
raise NotImplementedError
4449

@@ -153,21 +158,23 @@ def remove_e(self, e_tuple: Tuple):
153158
"""
154159
raise NotImplementedError
155160

156-
def update_v(self, v_id: Any):
161+
def update_v(self, v_id: Any, v_data: Dict[Any, Any]):
157162
r"""
158163
Update the vertex data.
159164
160165
Args:
161166
``v_id`` (``Any``): The vertex id.
167+
``v_data`` (``Dict[Any, Any]``): The vertex data.
162168
"""
163169
raise NotImplementedError
164170

165-
def update_e(self, e_tuple: Tuple):
171+
def update_e(self, e_tuple: Union[List[Any], Set[Any], Tuple[Any, ...]], e_data: Dict[Any, Any]):
166172
r"""
167173
Update the hyperedge data.
168174
169175
Args:
170-
``e_tuple`` (``Tuple``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name).
176+
``e_tuple`` (``Union[List[Any], Set[Any], Tuple[Any, ...]]``): The hyperedge tuple.
177+
``e_data`` (``Dict[Any, Any]``): The hyperedge data.
171178
"""
172179
raise NotImplementedError
173180

hyperdb/draw.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict
88
from urllib.parse import parse_qs, urlparse
99

10+
from .base import BaseHypergraphDB
1011
from .hypergraph import HypergraphDB
1112

1213

@@ -224,7 +225,7 @@ def _get_html_template(self) -> str:
224225
class HypergraphViewer:
225226
"""Hypergraph visualization tool"""
226227

227-
def __init__(self, hypergraph_db: HypergraphDB, port: int = 8080):
228+
def __init__(self, hypergraph_db: BaseHypergraphDB, port: int = 8080):
228229
self.hypergraph_db = hypergraph_db
229230
self.port = port
230231

@@ -262,7 +263,9 @@ def stop_server(self):
262263
self.httpd.server_close()
263264

264265

265-
def draw_hypergraph(hypergraph_db: HypergraphDB, port: int = 8080, open_browser: bool = True, blocking: bool = True):
266+
def draw_hypergraph(
267+
hypergraph_db: BaseHypergraphDB, port: int = 8080, open_browser: bool = True, blocking: bool = True
268+
):
266269
"""
267270
Main function to draw hypergraph
268271

hyperdb/hypergraph.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ class HypergraphDB(BaseHypergraphDB):
1616
Hypergraph database.
1717
"""
1818

19-
_v_data: Dict[str, Any] = field(default_factory=dict)
19+
_v_data: Dict[Any, Any] = field(default_factory=dict)
2020
_e_data: Dict[Tuple, Any] = field(default_factory=dict)
21-
_v_inci: Dict[str, Set[Tuple]] = field(default_factory=lambda: defaultdict(set))
21+
_v_inci: Dict[Any, Set[Tuple]] = field(default_factory=lambda: defaultdict(set))
2222

2323
def __post_init__(self):
2424
assert isinstance(self.storage_file, (str, Path))
@@ -27,7 +27,7 @@ def __post_init__(self):
2727
if self.storage_file.exists():
2828
self.load(self.storage_file)
2929

30-
def load(self, storage_file: Path) -> dict:
30+
def load(self, storage_file: Union[str, Path]) -> bool:
3131
r"""
3232
Load the hypergraph database from the storage file.
3333
"""
@@ -41,7 +41,7 @@ def load(self, storage_file: Path) -> dict:
4141
except Exception:
4242
return False
4343

44-
def save(self, storage_file: Path) -> dict:
44+
def save(self, storage_file: Union[str, Path]) -> bool:
4545
r"""
4646
Save the hypergraph database to the storage file.
4747
"""
@@ -114,14 +114,14 @@ def all_v(self) -> List[str]:
114114
r"""
115115
Return a list of all vertices in the hypergraph.
116116
"""
117-
return set(self._v_data.keys())
117+
return list(self._v_data.keys())
118118

119119
@cached_property
120120
def all_e(self) -> List[Tuple]:
121121
r"""
122122
Return a list of all hyperedges in the hypergraph.
123123
"""
124-
return set(self._e_data.keys())
124+
return list(self._e_data.keys())
125125

126126
@cached_property
127127
def num_v(self) -> int:
@@ -307,7 +307,7 @@ def nbr_e_of_v(self, v_id: Any) -> list:
307307
"""
308308
assert isinstance(v_id, Hashable), "The vertex id must be hashable."
309309
assert v_id in self._v_data, f"The vertex {v_id} does not exist in the hypergraph."
310-
return set(self._v_inci[v_id])
310+
return list(self._v_inci[v_id])
311311

312312
def nbr_v_of_e(self, e_tuple: Union[List, Set, Tuple]) -> list:
313313
r"""
@@ -319,7 +319,7 @@ def nbr_v_of_e(self, e_tuple: Union[List, Set, Tuple]) -> list:
319319
assert isinstance(e_tuple, (list, set, tuple)), "The hyperedge must be a list, set, or tuple of vertex ids."
320320
e_tuple = self.encode_e(e_tuple)
321321
assert e_tuple in self._e_data, f"The hyperedge {e_tuple} does not exist in the hypergraph."
322-
return set(e_tuple)
322+
return list(e_tuple)
323323

324324
def nbr_v(self, v_id: Any, exclude_self=True) -> list:
325325
r"""
@@ -330,9 +330,9 @@ def nbr_v(self, v_id: Any, exclude_self=True) -> list:
330330
"""
331331
assert isinstance(v_id, Hashable), "The vertex id must be hashable."
332332
assert v_id in self._v_data, f"The vertex {v_id} does not exist in the hypergraph."
333-
nbrs = set()
333+
nbrs: set = set()
334334
for e_tuple in self._v_inci[v_id]:
335335
nbrs.update(e_tuple)
336336
if exclude_self:
337337
nbrs.remove(v_id)
338-
return set(nbrs)
338+
return list(nbrs)

tests/test_hypergraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_remove_v(hg):
8585
hg.remove_v(6)
8686
assert hg.has_v(6) is False
8787
assert hg.has_e((1, 5, 6)) is False
88-
assert hg.has_e((5, 1)) == {"relation": "study"}
88+
assert hg.e((1, 5)) == {"relation": "study"}
8989

9090

9191
def test_remove_e(hg):

0 commit comments

Comments
 (0)