Skip to content

Commit 646b8ac

Browse files
committed
refactor: samplers
1 parent b7e7042 commit 646b8ac

File tree

4 files changed

+78
-78
lines changed

4 files changed

+78
-78
lines changed

pyrdf2vec/samplers/frequency.py

Lines changed: 53 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,29 @@
1010

1111
@attr.s
1212
class ObjFreqSampler(Sampler):
13-
"""Defines the Object Frequency Weight sampling strategy.
14-
15-
This sampling strategy is a node-centric object frequency approach. With
16-
this strategy, entities which have a high in degree get visisted more
17-
often.
18-
19-
Attributes:
20-
_counts: The counter for vertices.
21-
Defaults to defaultdict.
22-
_is_support_remote: True if the sampling strategy can be used with a
23-
remote Knowledge Graph, False Otherwise
24-
Defaults to False.
25-
_random_state: The random state to use to keep random determinism with
26-
the sampling strategy.
27-
Defaults to None.
28-
_vertices_deg: The degree of the vertices.
29-
Defaults to {}.
30-
_visited: Tags vertices that appear at the max depth or of which all
31-
their children are tagged.
32-
Defaults to set.
33-
inverse: True if the inverse algorithm must be used, False otherwise.
34-
Defaults to False.
35-
split: True if the split algorithm must be used, False otherwise.
36-
Defaults to False.
13+
"""Object Frequency Weight node-centric sampling strategy which prioritizes
14+
walks containing edges with the highest degree objects. The degree of an
15+
object being defined by the number of predicates present in its
16+
neighborhood.
17+
18+
Attributes:
19+
_counts: The counter for vertices.
20+
Defaults to defaultdict.
21+
_is_support_remote: True if the sampling strategy can be used with a
22+
remote Knowledge Graph, False Otherwise
23+
Defaults to False.
24+
_random_state: The random state to use to keep random determinism with
25+
the sampling strategy.
26+
Defaults to None.
27+
_vertices_deg: The degree of the vertices.
28+
Defaults to {}.
29+
_visited: Tags vertices that appear at the max depth or of which all
30+
their children are tagged.
31+
Defaults to set.
32+
inverse: True if the inverse algorithm must be used, False otherwise.
33+
Defaults to False.
34+
split: True if the split algorithm must be used, False otherwise.
35+
Defaults to False.
3736
3837
"""
3938

@@ -45,8 +44,8 @@ class ObjFreqSampler(Sampler):
4544
)
4645

4746
def fit(self, kg: KG) -> None:
48-
"""Fits the sampling strategy by counting the number of available
49-
neighbors for each vertex.
47+
"""Fits the sampling strategy by counting the number of parent
48+
predicates present in the neighborhood of each vertex.
5049
5150
Args:
5251
kg: The Knowledge Graph.
@@ -63,10 +62,11 @@ def get_weight(self, hop: Hop) -> int:
6362
"""Gets the weight of a hop in the Knowledge Graph.
6463
6564
Args:
66-
hop: The hop (pred, obj) to get the weight.
65+
hop: The hop of a vertex in a (predicate, object) form to get the
66+
weight.
6767
6868
Returns:
69-
The weight for a given hop.
69+
The weight of a given hop.
7070
7171
Raises:
7272
ValueError: If there is an attempt to access the weight of a hop
@@ -83,11 +83,10 @@ def get_weight(self, hop: Hop) -> int:
8383

8484
@attr.s
8585
class PredFreqSampler(Sampler):
86-
"""Defines the Predicate Frequency Weight sampling strategy.
87-
88-
This sampling strategy is an edge-centric approach. With this strategy,
89-
edges with predicates which are commonly used in the dataset are more often
90-
followed.
86+
"""Predicate Frequency Weight edge-centric sampling strategy which
87+
prioritizes walks containing edges with the highest degree predicates. The
88+
degree of a predicate being defined by the number of occurences that a
89+
predicate appears in a Knowledge Graph.
9190
9291
Attributes:
9392
_counts: The counter for vertices.
@@ -115,7 +114,7 @@ class PredFreqSampler(Sampler):
115114
)
116115

117116
def fit(self, kg: KG) -> None:
118-
"""Fits the sampling strategy by counting the number of occurance that
117+
"""Fits the sampling strategy by counting the number of occurences that
119118
a predicate appears in the Knowledge Graph.
120119
121120
Args:
@@ -134,31 +133,32 @@ def get_weight(self, hop: Hop) -> int:
134133
"""Gets the weight of a hop in the Knowledge Graph.
135134
136135
Args:
137-
hop: The hop (pred, obj) to get the weight.
136+
hop: The hop of a vertex in a (predicate, object) form to get the
137+
weight.
138138
139139
Returns:
140-
The weight for a given hop.
140+
The weight of a given hop.
141141
142142
Raises:
143143
ValueError: If there is an attempt to access the weight of a hop
144144
without the sampling strategy having been trained.
145145
146146
"""
147-
if len(self._counts) == 0:
147+
if self._counts:
148148
raise ValueError(
149-
"You must call the `fit(kg)` function before get the weight of"
149+
"You must call the `fit(kg)` method before get the weight of"
150150
+ " a hop."
151151
)
152152
return self._counts[hop[0].name]
153153

154154

155155
@attr.s
156156
class ObjPredFreqSampler(Sampler):
157-
"""Defines the Predicate-Object Frequency Weight sampling strategy.
158-
159-
This sampling strategy is a edge-centric approach. This strategy is similar
160-
to the Predicate Frequency Weigh sampling strategy, but differentiates
161-
between the objects as well.
157+
"""Predicate-Object Frequency Weight edge-centric sampling strategy which
158+
prioritizes walks containing edges with the highest degree of (predicate,
159+
object) relations. The degree of a such relation being defined by the
160+
number of occurences that a (predicate, object) relation appears in a
161+
Knowledge Graph.
162162
163163
Attributes:
164164
_counts: The counter for vertices.
@@ -186,8 +186,8 @@ class ObjPredFreqSampler(Sampler):
186186
)
187187

188188
def fit(self, kg: KG) -> None:
189-
"""Fits the sampling strategy by counting the number of occurance of
190-
having two neighboring vertices.
189+
"""Fits the sampling strategy by counting the number of occurrences of
190+
an object belonging to a subject.
191191
192192
Args:
193193
kg: The Knowledge Graph.
@@ -196,9 +196,9 @@ def fit(self, kg: KG) -> None:
196196
super().fit(kg)
197197
for vertex in kg._vertices:
198198
if vertex.predicate:
199-
neighbors = list(kg.get_neighbors(vertex))
200-
if len(neighbors) > 0:
201-
obj = neighbors[0]
199+
objs = list(kg.get_neighbors(vertex))
200+
if objs:
201+
obj = objs[0]
202202
if (vertex.name, obj.name) in self._counts:
203203
self._counts[(vertex.name, obj.name)] += 1
204204
else:
@@ -208,19 +208,20 @@ def get_weight(self, hop: Hop) -> int:
208208
"""Gets the weight of a hop in the Knowledge Graph.
209209
210210
Args:
211-
hop: The hop (pred, obj) to get the weight.
211+
hop: The hop of a vertex in a (predicate, object) form to get the
212+
weight.
212213
213214
Returns:
214-
The weight for a given hop.
215+
The weight of a given hop.
215216
216217
Raises:
217218
ValueError: If there is an attempt to access the weight of a hop
218219
without the sampling strategy having been trained.
219220
220221
"""
221-
if len(self._counts) == 0:
222+
if self._counts:
222223
raise ValueError(
223-
"You must call the `fit(kg)` function before get the weight of"
224+
"You must call the `fit(kg)` method before get the weight of"
224225
+ " a hop."
225226
)
226227
return self._counts[(hop[0].name, hop[1].name)]

pyrdf2vec/samplers/pagerank.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@
1010

1111
@attr.s
1212
class PageRankSampler(Sampler):
13-
"""Defines the Object Frequency Weight sampling strategy.
14-
This sampling strategy is a node-centric approach. With this strategy, some
15-
nodes are more important than others and hence there will be resources
16-
which are more frequent in the walks as others.
13+
"""PageRank node-centric sampling strategy which prioritizes walks
14+
containing the most frequent objects. This frequency being defined by
15+
assigning a higher weight to the most frequent objects using the
16+
PageRank ranking.
1717
1818
Attributes:
1919
_is_support_remote: True if the sampling strategy can be used with a
2020
remote Knowledge Graph, False Otherwise
2121
Defaults to False.
22-
_pageranks: The Page Rank dictionary.
22+
_pageranks: The PageRank dictionary.
2323
Defaults to {}.
2424
_random_state: The random state to use to keep random determinism with
2525
the sampling strategy.
@@ -29,7 +29,7 @@ class PageRankSampler(Sampler):
2929
_visited: Tags vertices that appear at the max depth or of which all
3030
their children are tagged.
3131
Defaults to set.
32-
alpha: The damping for Page Rank.
32+
alpha: The damping for PageRank.
3333
Defaults to 0.85.
3434
inverse: True if the inverse algorithm must be used, False otherwise.
3535
Defaults to False.
@@ -60,33 +60,31 @@ def fit(self, kg: KG) -> None:
6060
super().fit(kg)
6161
nx_graph = nx.DiGraph()
6262

63-
for vertex in kg._vertices:
64-
if not vertex.predicate:
65-
nx_graph.add_node(vertex.name, vertex=vertex)
66-
for predicate in kg.get_neighbors(vertex):
67-
for obj in kg.get_neighbors(predicate):
68-
nx_graph.add_edge(
69-
vertex.name, obj.name, name=predicate.name
70-
)
63+
subs_objs = [vertex for vertex in kg._vertices if not vertex.predicate]
64+
for vertex in subs_objs:
65+
nx_graph.add_node(vertex.name, vertex=vertex)
66+
for hop in kg.get_hops(vertex):
67+
nx_graph.add_edge(vertex.name, hop[1].name, name=hop[0].name)
7168
self._pageranks = nx.pagerank(nx_graph, alpha=self.alpha)
7269

7370
def get_weight(self, hop: Hop) -> float:
7471
"""Gets the weight of a hop in the Knowledge Graph.
7572
7673
Args:
77-
hop: The hop (pred, obj) to get the weight.
74+
hop: The hop of a vertex in a (predicate, object) form to get the
75+
weight.
7876
7977
Returns:
80-
The weight for a given hop.
78+
The weight of a given hop.
8179
8280
Raises:
8381
ValueError: If there is an attempt to access the weight of a hop
8482
without the sampling strategy having been trained.
8583
8684
"""
87-
if len(self._pageranks) == 0:
85+
if not self._pageranks:
8886
raise ValueError(
89-
"You must call the `fit(kg)` function before get the weight of"
87+
"You must call the `fit(kg)` method before get the weight of"
9088
+ " a hop."
9189
)
9290
return self._pageranks[hop[1].name]

pyrdf2vec/samplers/sampler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,11 @@ def get_weight(self, hop: Hop):
9595
"""Gets the weight of a hop in the Knowledge Graph.
9696
9797
Args:
98-
hop: The hop (pred, obj) to get the weight.
98+
hop: The hop of a vertex in a (predicate, object) form to get the
99+
weight.
99100
100101
Returns:
101-
The weight for a given hop.
102+
The weight of a given hop.
102103
103104
Raises:
104105
NotImplementedError: If this method is called, without having

pyrdf2vec/samplers/uniform.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77

88
@attr.s
99
class UniformSampler(Sampler):
10-
"""Sampler that assigns a uniform weight to each hop in a Knowledge Graph.
11-
This sampling strategy is the most straight forward approach. With this
12-
strategy, strongly connected entities will have a higher influence on the
13-
resulting embeddings.
10+
"""Uniform sampling strategy that assigns a uniform weight to each edge in
11+
a Knowledge Graph, in order to prioritizes walks with strongly connected
12+
entities.
1413
1514
Attributes:
1615
_is_support_remote: True if the sampling strategy can be used with a
@@ -60,10 +59,11 @@ def get_weight(self, hop: Hop) -> int:
6059
"""Gets the weight of a hop in the Knowledge Graph.
6160
6261
Args:
63-
hop: The hop (pred, obj) to get the weight.
62+
hop: The hop of a vertex in a (predicate, object) form to get the
63+
weight.
6464
6565
Returns:
66-
The weight for a given hop.
66+
The weight of a given hop.
6767
6868
"""
6969
return 1

0 commit comments

Comments
 (0)