Skip to content

Commit fd8a66e

Browse files
committed
refactor: walkers
1 parent 646b8ac commit fd8a66e

File tree

8 files changed

+199
-178
lines changed

8 files changed

+199
-178
lines changed

pyrdf2vec/walkers/anonymous.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import attr
44

55
from pyrdf2vec.graphs import KG, Vertex
6-
from pyrdf2vec.typings import EntityWalks, SWalk
6+
from pyrdf2vec.typings import EntityWalks, List, SWalk
77
from pyrdf2vec.walkers import RandomWalker
88

99

1010
@attr.s
1111
class AnonymousWalker(RandomWalker):
12-
"""Walker that transforms label information into positional information in
13-
order to anonymize the random walks.
12+
"""Anonymous walking strategy which transforms each vertex name other than
13+
the root node, into positional information, in order to anonymize the
14+
randomly extracted walks.
1415
1516
Attributes:
1617
_is_support_remote: True if the walking strategy can be used with a
@@ -33,27 +34,24 @@ class AnonymousWalker(RandomWalker):
3334
3435
"""
3536

36-
def _extract(self, kg: KG, instance: Vertex) -> EntityWalks:
37-
"""Extracts walks rooted at the provided entities which are then each
38-
transformed into a numerical representation.
37+
def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
38+
"""Extracts random walks for an entity based on a Knowledge Graph.
3939
4040
Args:
4141
kg: The Knowledge Graph.
42-
instance: The instance to be extracted from the Knowledge Graph.
42+
entity: The root node to extract walks.
4343
4444
Returns:
45-
The 2D matrix with its number of rows equal to the number of
46-
provided entities; number of column equal to the embedding size.
45+
A dictionary having the entity as key and a list of tuples as value
46+
corresponding to the extracted walks.
4747
4848
"""
4949
canonical_walks: Set[SWalk] = set()
50-
for walk in self.extract_walks(kg, instance):
51-
canonical_walk = []
52-
str_walk = [hop.name for hop in walk]
53-
for i, hop in enumerate(walk):
54-
if i == 0:
55-
canonical_walk.append(hop.name)
56-
else:
57-
canonical_walk.append(str(str_walk.index(hop.name)))
50+
for walk in self.extract_walks(kg, entity):
51+
vertex_names = [vertex.name for vertex in walk]
52+
canonical_walk: List[str] = [
53+
vertex.name if i == 0 else str(vertex_names.index(vertex.name))
54+
for i, vertex in enumerate(walk)
55+
]
5856
canonical_walks.add(tuple(canonical_walk))
59-
return {instance.name: list(canonical_walks)}
57+
return {entity.name: list(canonical_walks)}

pyrdf2vec/walkers/community.py

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def sample_from_iterable(x):
3535

3636
@attr.s
3737
class CommunityWalker(Walker):
38-
"""Defines the community walking strategy.
38+
"""Community walking strategy which groups vertices with similar properties
39+
through probabilities and relations that are not explicitly modeled in a
40+
Knowledge Graph. Similar to the Random walking strategy, the Depth First
41+
Search (DFS) algorithm is used if a maximum number of walks is specified.
42+
Otherwise, the Breath First Search (BFS) algorithm is chosen.
3943
4044
Attributes:
4145
_is_support_remote: True if the walking strategy can be used with a
@@ -127,27 +131,23 @@ def _community_detection(self, kg: KG) -> None:
127131
self.labels_per_community[self.communities[node]].append(node)
128132

129133
def _bfs(
130-
self, kg: KG, root: Vertex, is_reverse: bool = False
134+
self, kg: KG, entity: Vertex, is_reverse: bool = False
131135
) -> List[Walk]:
132-
"""Extracts random walks of depth - 1 hops rooted in root with
133-
Breadth-first search.
136+
"""Extracts random walks for an entity based on Knowledge Graph using
137+
the Depth First Search (DFS) algorithm.
134138
135139
Args:
136140
kg: The Knowledge Graph.
137-
138-
The graph from which the neighborhoods are extracted for the
139-
provided entities.
140-
root: The root node to extract walks.
141+
entity: The root node to extract walks.
141142
is_reverse: True to get the parent neighbors instead of the child
142143
neighbors, False otherwise.
143144
Defaults to False.
144145
145146
Returns:
146-
The list of walks for the root node according to the depth and
147-
max_walks.
147+
The list of unique walks for the provided entity.
148148
149149
"""
150-
walks: Set[Walk] = {(root,)}
150+
walks: Set[Walk] = {(entity,)}
151151
for i in range(self.max_depth):
152152
for walk in walks.copy():
153153
if is_reverse:
@@ -201,31 +201,27 @@ def _bfs(
201201
return list(walks)
202202

203203
def _dfs(
204-
self, kg: KG, root: Vertex, is_reverse: bool = False
204+
self, kg: KG, entity: Vertex, is_reverse: bool = False
205205
) -> List[Walk]:
206-
"""Extracts a random limited number of walks of depth - 1 hops rooted
207-
in root with Depth-first search.
206+
"""Extracts random walks for an entity based on Knowledge Graph using
207+
the Depth First Search (DFS) algorithm.
208208
209209
Args:
210210
kg: The Knowledge Graph.
211-
212-
The graph from which the neighborhoods are extracted for the
213-
provided entities.
214-
root: The root node to extract walks.
211+
entity: The root node to extract walks.
215212
is_reverse: True to get the parent neighbors instead of the child
216213
neighbors, False otherwise.
217-
Defaults to False
214+
Defaults to False.
218215
219216
Returns:
220-
The list of walks for the root node according to the depth and
221-
max_walks.
217+
The list of unique walks for the provided entity.
222218
223219
"""
224220
self.sampler.visited = set()
225221
walks: List[Walk] = []
226222
assert self.max_walks is not None
227223
while len(walks) < self.max_walks:
228-
sub_walk: Walk = (root,)
224+
sub_walk: Walk = (entity,)
229225
d = 1
230226
while d // 2 < self.max_depth:
231227
pred_obj = self.sampler.sample_hop(
@@ -301,18 +297,15 @@ def extract(
301297
self._community_detection(kg)
302298
return super().extract(kg, entities, verbose)
303299

304-
def extract_walks(self, kg: KG, root: Vertex) -> List[Walk]:
300+
def extract_walks(self, kg: KG, entity: Vertex) -> List[Walk]:
305301
"""Extracts random walks of depth - 1 hops rooted in root.
306302
307303
Args:
308304
kg: The Knowledge Graph.
309-
310-
The graph from which the neighborhoods are extracted for the
311-
provided entities.
312-
root: The root node to extract walks.
305+
entity: The root node to extract walks.
313306
314307
Returns:
315-
The list of walks.
308+
The list of unique walks for the provided entity.
316309
317310
"""
318311
if self.max_walks is None:
@@ -322,33 +315,30 @@ def extract_walks(self, kg: KG, root: Vertex) -> List[Walk]:
322315
if self.with_reverse:
323316
return [
324317
r_walk[:-1] + walk
325-
for walk in fct_search(kg, root)
326-
for r_walk in fct_search(kg, root, is_reverse=True)
318+
for walk in fct_search(kg, entity)
319+
for r_walk in fct_search(kg, entity, is_reverse=True)
327320
]
328-
return [walk for walk in fct_search(kg, root)]
321+
return [walk for walk in fct_search(kg, entity)]
329322

330-
def _extract(self, kg: KG, instance: Vertex) -> EntityWalks:
331-
"""Extracts walks rooted at the provided entities which are then each
332-
transformed into a numerical representation.
323+
def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
324+
"""Extracts random walks for an entity based on a Knowledge Graph.
333325
334326
Args:
335327
kg: The Knowledge Graph.
336-
instance: The instance to be extracted from the Knowledge Graph.
328+
entity: The root node to extract walks.
337329
338330
Returns:
339-
The 2D matrix with its number of rows equal to the number of
340-
provided entities; number of column equal to the embedding size.
331+
A dictionary having the entity as key and a list of tuples as value
332+
corresponding to the extracted walks.
341333
342334
"""
343335
canonical_walks: Set[SWalk] = set()
344-
for walk in self.extract_walks(kg, instance):
345-
canonical_walk: List[str] = []
346-
for i, hop in enumerate(walk):
347-
if i == 0 or i % 2 == 1 or self.md5_bytes is None:
348-
canonical_walk.append(hop.name)
349-
else:
350-
canonical_walk.append(
351-
str(md5(hop.name.encode()).digest()[: self.md5_bytes])
352-
)
336+
for walk in self.extract_walks(kg, entity):
337+
canonical_walk: List[str] = [
338+
vertex.name
339+
if i == 0 or i % 2 == 1 or self.md5_bytes is None
340+
else str(md5(vertex.name.encode()).digest()[: self.md5_bytes])
341+
for i, vertex in enumerate(walk)
342+
]
353343
canonical_walks.add(tuple(canonical_walk))
354-
return {instance.name: list(canonical_walks)}
344+
return {entity.name: list(canonical_walks)}

pyrdf2vec/walkers/halk.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
from collections import defaultdict
22
from hashlib import md5
3-
from typing import List, Set
3+
from typing import DefaultDict, List, Set
44

55
import attr
66

77
from pyrdf2vec.graphs import KG, Vertex
8-
from pyrdf2vec.typings import EntityWalks, SWalk
8+
from pyrdf2vec.typings import EntityWalks, SWalk, Walk
99
from pyrdf2vec.walkers import RandomWalker
1010

1111

1212
@attr.s
1313
class HALKWalker(RandomWalker):
14-
"""Walker that removes the rare entities from the random walks in order to
15-
increase the quality of the generated embeddings while decreasing the
16-
memory usage.
14+
"""HALK walking strategy which removes rare vertices from randomly
15+
extracted walks, increasing the quality of the generated embeddings while
16+
memory usage decreases.
1717
1818
Attributes:
1919
_is_support_remote: True if the walking strategy can be used with a
@@ -54,47 +54,83 @@ class HALKWalker(RandomWalker):
5454

5555
md5_bytes = attr.ib(kw_only=True, default=8, type=int, repr=False)
5656

57-
def _extract(self, kg: KG, instance: Vertex) -> EntityWalks:
58-
"""Extracts walks rooted at the provided entities which are then each
59-
transformed into a numerical representation.
57+
def build_dictionary(
58+
self, walks: List[Walk]
59+
) -> DefaultDict[Vertex, Set[int]]:
60+
"""Builds a dictionary of vertices mapped to the extracted walk indices.
6061
6162
Args:
62-
kg: The Knowledge Graph.
63+
walks: The walks to build the dictionary.
64+
65+
Returns:
66+
The dictionary of vertex.
67+
68+
"""
69+
vertex_to_windices: DefaultDict[Vertex, Set[int]] = defaultdict(set)
70+
for i in range(len(walks)):
71+
for vertex in walks[i]:
72+
vertex_to_windices[vertex].add(i)
73+
return vertex_to_windices
74+
75+
def get_rare_vertices(
76+
self,
77+
vertex_to_windices: DefaultDict[Vertex, Set[int]],
78+
walks: List[Walk],
79+
freq_threshold: float,
80+
) -> Set[Vertex]:
81+
"""Gets vertices which doesn't reach a certain threshold of frequency
82+
of occurrence.
83+
84+
Args:
85+
vertex_to_windices: The dictionary of vertices mapped to the
86+
extracted walk indices.
87+
walks: The walks.
88+
freq_threshold: The threshold frequency of occurrence.
6389
64-
The graph from which the neighborhoods are extracted for the
65-
provided entities.
66-
instance: The instance to be extracted from the Knowledge Graph.
6790
Returns:
68-
The 2D matrix with its number of rows equal to the number of
69-
provided entities; number of column equal to the embedding size.
91+
the infrequent vertices.
7092
7193
"""
72-
walks = self.extract_walks(kg, instance)
94+
rare_vertices = set()
95+
for vertex in vertex_to_windices:
96+
if len(vertex_to_windices[vertex]) / len(walks) < freq_threshold:
97+
rare_vertices.add(vertex)
98+
return rare_vertices
7399

100+
def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
101+
"""Extracts random walks for an entity based on a Knowledge Graph.
102+
103+
Args:
104+
kg: The Knowledge Graph.
105+
entity: The root node to extract walks.
106+
107+
Returns:
108+
A dictionary having the entity as key and a list of tuples as value
109+
corresponding to the extracted walks.
110+
111+
"""
112+
walks = self.extract_walks(kg, entity)
113+
vertex_to_windices = self.build_dictionary(walks)
74114
canonical_walks: Set[SWalk] = set()
75-
hop_to_freq = defaultdict(set)
76-
for i in range(len(walks)):
77-
for hop in walks[i]:
78-
hop_to_freq[hop].add(i)
79115

80116
for freq_threshold in self.freq_thresholds:
81-
uniformative_hops = set()
82-
for hop in hop_to_freq:
83-
if len(hop_to_freq[hop]) / len(walks) < freq_threshold:
84-
uniformative_hops.add(hop)
85-
117+
rare_vertices = self.get_rare_vertices(
118+
vertex_to_windices, walks, freq_threshold
119+
)
86120
for walk in walks:
87121
canonical_walk = []
88-
for i, hop in enumerate(walk):
89-
if i == 0 or self.md5_bytes is None:
90-
canonical_walk.append(hop.name)
91-
elif hop.name not in uniformative_hops:
122+
for i, vertex in enumerate(walk):
123+
if i == 0 or (
124+
vertex not in rare_vertices and self.md5_bytes is None
125+
):
126+
canonical_walk.append(vertex.name)
127+
elif vertex not in rare_vertices:
92128
canonical_walk.append(
93129
str(
94-
md5(hop.name.encode()).digest()[
130+
md5(vertex.name.encode()).digest()[
95131
: self.md5_bytes
96132
]
97133
)
98134
)
99135
canonical_walks.add(tuple(canonical_walk))
100-
return {instance.name: list(canonical_walks)}
136+
return {entity.name: list(canonical_walks)}

0 commit comments

Comments
 (0)