Skip to content

Commit bf5da42

Browse files
authored
Make conversion true id to int id in RW run (#15)
1 parent c96831e commit bf5da42

File tree

4 files changed

+8
-6
lines changed

4 files changed

+8
-6
lines changed

dynnode2vec/biased_random_walk.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _generate_walk_simple(
145145

146146
def run(
147147
self,
148-
nodes: List[int],
148+
nodes: List[Any],
149149
*,
150150
n_walks: int = 10,
151151
walk_length: int = 10,
@@ -160,6 +160,8 @@ def run(
160160
"""
161161
rn = random.Random(seed)
162162

163+
nodes = self.convert_true_ids_to_int_ids(nodes)
164+
163165
# weights are multiplied by inverse p and q
164166
ip, iq = 1.0 / p, 1.0 / q
165167

dynnode2vec/dynnode2vec.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,8 @@ def generate_updated_walks(
163163
# that changed compared to the previous time step
164164
delta_nodes = self.get_delta_nodes(current_graph, previous_graph)
165165

166-
brw = BiasedRandomWalk(current_graph)
167-
delta_nodes = brw.convert_true_ids_to_int_ids(delta_nodes)
168-
169166
# run walks for updated nodes only
170-
updated_walks = brw.run(
167+
updated_walks = BiasedRandomWalk(current_graph).run(
171168
nodes=delta_nodes,
172169
walk_length=self.walk_length,
173170
n_walks=self.n_walks_per_node,

dynnode2vec/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def generate_dynamic_graphs(
2424
# Create a random graph
2525
graph = nx.fast_gnp_random_graph(n=n_base_nodes, p=base_density)
2626

27+
# add one to each node to avoid the perfect case where true_ids match int_ids
28+
graph = nx.relabel_nodes(graph, mapping={n: n + 1 for n in graph.nodes()})
29+
2730
# initialize graphs list with first graph
2831
graphs = [graph.copy()]
2932

tests/test_biased_random_walk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,4 @@ def test_run(graphs, p, q, weighted):
8282
random_walks = brw.run(graph.nodes(), p=p, q=q, weighted=weighted)
8383

8484
assert all(isinstance(walk, list) for walk in random_walks)
85-
assert all(n in brw.graph.nodes() for walk in random_walks for n in walk)
85+
assert all(n in graph.nodes() for walk in random_walks for n in walk)

0 commit comments

Comments
 (0)