Skip to content

Commit 316ebdc

Browse files
committed
Modify API
1 parent 3ca37c8 commit 316ebdc

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

cyaron/graph.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,9 @@ def to_str(self, **kwargs):
7979
**kwargs(Keyword args):
8080
bool shuffle = False -> whether shuffle the output or not
8181
str output(Edge) = str -> the convert function which converts object Edge to str. the default way is to use str()
82-
list[int] node_shuffler(list[int])
83-
= lambda table: random.sample(table, k=len(table))
82+
list[int] node_shuffler(int)
83+
= lambda n: random.sample(range(1, n + 1), k=n)
8484
-> the random function which shuffles the vertex sequence.
85-
Note that this function will actually be passed in a `range`!
8685
list[Edge] edge_shuffler(list[Edge])
8786
-> a random function. the default is to shuffle the edge sequence,
8887
also, if the graph is undirected, it will swap `u` and `v` randomly.
@@ -96,10 +95,10 @@ def _edge_shuffler_default(table):
9695

9796
shuffle = kwargs.get("shuffle", False)
9897
output = kwargs.get("output", str)
99-
node_shuffler = kwargs.get("node_shuffler", lambda table: random.sample(table, k=len(table)))
98+
node_shuffler = kwargs.get("node_shuffler", lambda n: random.sample(range(1, n + 1), k=n))
10099
edge_shuffler = kwargs.get("edge_shuffler", _edge_shuffler_default)
101100
if shuffle:
102-
new_node_id = [0] + node_shuffler(range(1, len(self.edges)))
101+
new_node_id = [0] + node_shuffler(self.vertex_count())
103102
edge_buf = []
104103
for edge in self.iterate_edges():
105104
edge_buf.append(

cyaron/tests/graph_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,8 @@ def unit_test(n, m, shuffle_kwargs = {}, check_kwargs = {}):
221221
unit_test(8, 20)
222222
unit_test(8, 20, {"shuffle": True})
223223
mapping = [0] + random.sample(range(1, 8), k = 7)
224-
shuffer = lambda seq: list(map(lambda i: mapping[i], seq))
224+
shuffer = lambda n: list(map(lambda i: mapping[i], range(1, n + 1)))
225225
unit_test(7, 10, {"shuffle": True, "node_shuffler": shuffer})
226226
unit_test(7, 14, {"shuffle": True, "node_shuffler": shuffer}, {"mapping": mapping})
227227
shuffer_without_swap = lambda table: random.sample(table, k=len(table))
228228
unit_test(7, 12, {"shuffle": True, "edge_shuffler": shuffer_without_swap}, {"directed": True})
229-

0 commit comments

Comments
 (0)