|
1 | 1 | import itertools |
| 2 | +import math |
2 | 3 | from collections import defaultdict |
3 | 4 | from typing import DefaultDict, List, Set |
4 | 5 |
|
@@ -37,9 +38,10 @@ class HALKWalker(RandomWalker): |
37 | 38 | Defaults to None. |
38 | 39 | sampler: The sampling strategy. |
39 | 40 | Defaults to UniformSampler. |
40 | | - with_reverse: True to extracts children's and parents' walks from the |
41 | | - root, creating (max_walks * max_walks) more walks of 2 * depth, |
42 | | - False otherwise. |
| 41 | + with_reverse: True to extracts parents and children hops from an |
| 42 | + entity, creating (max_walks * max_walks) walks of 2 * depth, |
| 43 | + allowing also to centralize this entity in the walks. False |
| 44 | + otherwise. |
43 | 45 | Defaults to False. |
44 | 46 |
|
45 | 47 | """ |
@@ -141,16 +143,29 @@ def _post_extract(self, res: List[EntityWalks]) -> List[List[SWalk]]: |
141 | 143 | for rare_predicates in pred_thresholds: |
142 | 144 | for entity_walks in conv_res: |
143 | 145 | canonical_walks = [] |
144 | | - curr_entity = entity_walks[0][0] |
| 146 | + if not self.with_reverse: |
| 147 | + curr_entity = entity_walks[0][0] |
| 148 | + else: |
| 149 | + curr_walk = list(entity_walks[0]) |
| 150 | + curr_entity = curr_walk[math.trunc(len(curr_walk) / 2)] |
145 | 151 | for walk in entity_walks: |
146 | | - canonical_walk = [curr_entity] |
| 152 | + if not self.with_reverse: |
| 153 | + canonical_walk = [curr_entity] |
| 154 | + else: |
| 155 | + canonical_walk = [walk[0]] |
| 156 | + reverse = True |
| 157 | + j = 0 |
147 | 158 | for i, vertex in enumerate(walk[1::2], 2): |
148 | 159 | if vertex not in rare_predicates: |
149 | | - obj = walk[i] if i % 2 == 0 else walk[i + 1] |
150 | 160 | if self.with_reverse: |
151 | | - canonical_walk = [obj, vertex] + canonical_walk |
| 161 | + obj = walk[i + j] |
| 162 | + j += 1 |
152 | 163 | else: |
153 | | - canonical_walk += [vertex, obj] |
| 164 | + obj = walk[i] if i % 2 == 0 else walk[i + 1] |
| 165 | + if self.with_reverse and reverse: |
| 166 | + if obj == curr_entity: |
| 167 | + reverse = False |
| 168 | + canonical_walk += [vertex, obj] |
154 | 169 | if len(canonical_walk) >= 3: |
155 | 170 | canonical_walks.append(tuple(canonical_walk)) |
156 | 171 | if canonical_walks: |
|
0 commit comments