Skip to content

Commit 1facaf0

Browse files
committed
fix: with_reverse for HALKWalker
1 parent 5240914 commit 1facaf0

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

pyrdf2vec/walkers/halk.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import math
23
from collections import defaultdict
34
from typing import DefaultDict, List, Set
45

@@ -37,9 +38,10 @@ class HALKWalker(RandomWalker):
3738
Defaults to None.
3839
sampler: The sampling strategy.
3940
Defaults to UniformSampler.
40-
with_reverse: True to extracts children's and parents' walks from the
41-
root, creating (max_walks * max_walks) more walks of 2 * depth,
42-
False otherwise.
41+
with_reverse: True to extracts parents and children hops from an
42+
entity, creating (max_walks * max_walks) walks of 2 * depth,
43+
allowing also to centralize this entity in the walks. False
44+
otherwise.
4345
Defaults to False.
4446
4547
"""
@@ -141,16 +143,29 @@ def _post_extract(self, res: List[EntityWalks]) -> List[List[SWalk]]:
141143
for rare_predicates in pred_thresholds:
142144
for entity_walks in conv_res:
143145
canonical_walks = []
144-
curr_entity = entity_walks[0][0]
146+
if not self.with_reverse:
147+
curr_entity = entity_walks[0][0]
148+
else:
149+
curr_walk = list(entity_walks[0])
150+
curr_entity = curr_walk[math.trunc(len(curr_walk) / 2)]
145151
for walk in entity_walks:
146-
canonical_walk = [curr_entity]
152+
if not self.with_reverse:
153+
canonical_walk = [curr_entity]
154+
else:
155+
canonical_walk = [walk[0]]
156+
reverse = True
157+
j = 0
147158
for i, vertex in enumerate(walk[1::2], 2):
148159
if vertex not in rare_predicates:
149-
obj = walk[i] if i % 2 == 0 else walk[i + 1]
150160
if self.with_reverse:
151-
canonical_walk = [obj, vertex] + canonical_walk
161+
obj = walk[i + j]
162+
j += 1
152163
else:
153-
canonical_walk += [vertex, obj]
164+
obj = walk[i] if i % 2 == 0 else walk[i + 1]
165+
if self.with_reverse and reverse:
166+
if obj == curr_entity:
167+
reverse = False
168+
canonical_walk += [vertex, obj]
154169
if len(canonical_walk) >= 3:
155170
canonical_walks.append(tuple(canonical_walk))
156171
if canonical_walks:

0 commit comments

Comments
 (0)