Skip to content

Commit 406aa20

Browse files
committed
feature: add WideSampler
1 parent 8038549 commit 406aa20

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

pyrdf2vec/samplers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .uniform import UniformSampler
66
from .frequency import ObjFreqSampler, ObjPredFreqSampler, PredFreqSampler
77
from .pagerank import PageRankSampler
8+
from .wide import WideSampler
89

910
__all__ = [
1011
"ObjFreqSampler",
@@ -13,4 +14,5 @@
1314
"PredFreqSampler",
1415
"Sampler",
1516
"UniformSampler",
17+
"WideSampler",
1618
]

pyrdf2vec/samplers/wide.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from collections import defaultdict
2+
from typing import DefaultDict, Tuple
3+
4+
import attr
5+
6+
from pyrdf2vec.graphs import KG
7+
from pyrdf2vec.samplers import Sampler
8+
from pyrdf2vec.typings import Hop
9+
10+
11+
@attr.s
12+
class WideSampler(Sampler):
13+
14+
_pred_degs: DefaultDict[Tuple[str, str], int] = attr.ib(
15+
init=False, repr=False, factory=lambda: defaultdict(dict)
16+
)
17+
18+
_obj_degs: DefaultDict[Tuple[str, str], int] = attr.ib(
19+
init=False, repr=False, factory=lambda: defaultdict(dict)
20+
)
21+
22+
_neighbor_counts: DefaultDict[Tuple[str, str], int] = attr.ib(
23+
init=False, repr=False, factory=lambda: defaultdict(dict)
24+
)
25+
26+
def fit(self, kg: KG) -> None:
27+
"""Since the weights are uniform, this function does nothing.
28+
29+
Args:
30+
kg: The Knowledge Graph.
31+
32+
"""
33+
super().fit(kg)
34+
35+
for vertex in kg._vertices:
36+
if vertex.predicate:
37+
self._neighbor_counts[vertex.name] = len(
38+
kg.get_neighbors(vertex)
39+
)
40+
counter = self._pred_degs
41+
else:
42+
self._neighbor_counts[vertex.name] = len(
43+
kg.get_neighbors(vertex, is_reverse=True)
44+
)
45+
counter = self._obj_degs
46+
47+
if vertex.name in counter:
48+
counter[vertex.name] += 1
49+
else:
50+
counter[vertex.name] = 1
51+
52+
def get_weight(self, hop: Hop) -> int:
53+
"""Gets the weight of a hop in the Knowledge Graph.
54+
55+
Args:
56+
hop: The hop (pred, obj) to get the weight.
57+
58+
Returns:
59+
The weight for a given hop.
60+
61+
"""
62+
if not (self._pred_degs and self._obj_degs and self._neighbor_counts):
63+
raise ValueError(
64+
"You must call the `fit(kg)` method before get the weight of"
65+
+ " a hop."
66+
)
67+
return (
68+
self._neighbor_counts[hop[0].name]
69+
+ self._neighbor_counts[hop[1].name]
70+
) * ((self._pred_degs[hop[0].name] + self._obj_degs[hop[1].name]) / 2)

0 commit comments

Comments
 (0)