Skip to content

Commit d763845

Browse files
committed
Allow users to define functions that shuffle vertex and edge sets
1 parent fffe5f0 commit d763845

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

cyaron/graph.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,33 @@ def to_str(self, **kwargs):
4646
**kwargs(Keyword args):
4747
bool shuffle = False -> whether shuffle the output or not
4848
str output(Edge) = str -> the convert function which converts object Edge to str. the default way is to use str()
49+
list[int] node_shuffler(list[int])
50+
= lambda table: random.sample(table, k=len(table))
51+
-> the random function which shuffles the vertex sequence
52+
list[Edge] edge_shuffler(list[int])
53+
-> a random function. the default is to shuffle the edge sequence,
54+
also, if the graph is undirected, it will swap `u` and `v` randomly.
4955
"""
56+
def _edge_shuffler_default(table):
57+
edge_buf = random.sample(table, k=len(table))
58+
for edge in edge_buf:
59+
if not self.directed and random.randint(0, 1) == 0:
60+
(edge.start, edge.end) = (edge.end, edge.start)
61+
return edge_buf
62+
5063
shuffle = kwargs.get("shuffle", False)
5164
output = kwargs.get("output", str)
52-
buf = []
65+
node_shuffler = kwargs.get("node_shuffler", lambda table: random.sample(table, k=len(table)))
66+
edge_shuffler = kwargs.get("edge_shuffler", _edge_shuffler_default)
5367
if shuffle:
54-
new_node_id = [i for i in range(1, len(self.edges))]
55-
random.shuffle(new_node_id)
56-
new_node_id = [0] + new_node_id
68+
new_node_id = [0] + node_shuffler(range(1, len(self.edges)))
5769
edge_buf = []
5870
for edge in self.iterate_edges():
5971
edge_buf.append(
6072
Edge(new_node_id[edge.start], new_node_id[edge.end], edge.weight))
61-
random.shuffle(edge_buf)
62-
for edge in edge_buf:
63-
if not self.directed and random.randint(0, 1) == 0:
64-
(edge.start, edge.end) = (edge.end, edge.start)
65-
buf.append(output(edge))
73+
buf = map(output, edge_shuffler(edge_buf))
6674
else:
67-
for edge in self.iterate_edges():
68-
buf.append(output(edge))
75+
buf = map(output, self.iterate_edges())
6976
return "\n".join(buf)
7077

7178
def __str__(self):

0 commit comments

Comments
 (0)