Skip to content

Commit b696a77

Browse files
committed
refactor: add _post_extract function for walkers
1 parent a91b6bd commit b696a77

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

pyrdf2vec/walkers/halk.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,43 @@ def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
135135
)
136136
canonical_walks.add(tuple(canonical_walk))
137137
return {entity.name: list(canonical_walks)}
138+
139+
def _post_extract(self, res: List[EntityWalks]) -> List[List[SWalk]]:
140+
"""Post processed walks.
141+
142+
Args:
143+
res: the result of the walks extracted with multiprocessing.
144+
145+
Returns:
146+
The 2D matrix with its number of rows equal to the number of
147+
provided entities; number of column equal to the embedding size.
148+
149+
"""
150+
conv_res = list(
151+
walks
152+
for entity_to_walks in res
153+
for walks in entity_to_walks.values()
154+
)
155+
walks: List[SWalk] = [
156+
walk for entity_walks in conv_res for walk in entity_walks
157+
]
158+
159+
predicates_dict = self.build_dictionary(walks)
160+
pred_thresholds = [
161+
self.get_rare_predicates(predicates_dict, walks, freq_threshold)
162+
for freq_threshold in self.freq_thresholds
163+
]
164+
res_halk = []
165+
for rare_predicates in pred_thresholds:
166+
for entity_walks in conv_res:
167+
canonical_walks = []
168+
for walk in entity_walks:
169+
canonical_walk = [walk[0]]
170+
for i, vertex in enumerate(walk[1::2], 2):
171+
if vertex not in rare_predicates:
172+
obj = walk[i] if i % 2 == 0 else walk[i + 1]
173+
canonical_walk += [vertex, obj]
174+
if len(canonical_walk) > 1:
175+
canonical_walks.append(tuple(canonical_walk))
176+
res_halk.append(canonical_walks)
177+
return res_halk

pyrdf2vec/walkers/walker.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def extract(
158158
disable=True if verbose == 0 else False,
159159
)
160160
)
161-
return list(walks for elm in res for walks in elm.values())
161+
return self._post_extract(res)
162162

163163
@abstractmethod
164164
def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
@@ -189,6 +189,27 @@ def _init_worker(self, init_kg: KG) -> None:
189189
global kg
190190
kg = init_kg # type: ignore
191191

192+
def _post_extract(self, res: List[EntityWalks]) -> List[List[SWalk]]:
193+
"""Post processed walks.
194+
195+
Args:
196+
res: the result of the walks extracted with multiprocessing.
197+
198+
Returns:
199+
The 2D matrix with its number of rows equal to the number of
200+
provided entities; number of column equal to the embedding size.
201+
202+
Raises:
203+
NotImplementedError: If this method is called, without having
204+
provided an implementation.
205+
206+
"""
207+
return list(
208+
walks
209+
for entity_to_walks in res
210+
for walks in entity_to_walks.values()
211+
)
212+
192213
def _proc(self, entity: str) -> EntityWalks:
193214
"""Executed by each process.
194215

0 commit comments

Comments
 (0)