Skip to content

Commit 3785ac9

Browse files
committed
chore: add tests for WideSampler
1 parent ac52db0 commit 3785ac9

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

tests/samplers/test_wide.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import itertools
2+
3+
import pytest
4+
5+
from pyrdf2vec.graphs import KG, Vertex
6+
from pyrdf2vec.samplers import WideSampler
7+
8+
LOOP = [
9+
["Alice", "knows", "Bob"],
10+
["Alice", "knows", "Dean"],
11+
["Bob", "knows", "Dean"],
12+
["Dean", "loves", "Alice"],
13+
]
14+
LONG_CHAIN = [
15+
["Alice", "knows", "Bob"],
16+
["Alice", "knows", "Dean"],
17+
["Bob", "knows", "Mathilde"],
18+
["Mathilde", "knows", "Alfy"],
19+
["Alfy", "knows", "Stephane"],
20+
["Stephane", "knows", "Alfred"],
21+
["Alfred", "knows", "Emma"],
22+
["Emma", "knows", "Julio"],
23+
]
24+
URL = "http://pyRDF2Vec"
25+
26+
KG_LOOP = KG()
27+
KG_CHAIN = KG()
28+
29+
KGS = [KG_LOOP, KG_CHAIN]
30+
ROOTS_WITHOUT_URL = ["Alice", "Bob", "Dean"]
31+
32+
33+
class TestWideSampler:
34+
@pytest.fixture(scope="session")
35+
def setup(self):
36+
for i, graph in enumerate([LOOP, LONG_CHAIN]):
37+
for row in graph:
38+
subj = Vertex(f"{URL}#{row[0]}")
39+
obj = Vertex((f"{URL}#{row[2]}"))
40+
pred = Vertex(
41+
(f"{URL}#{row[1]}"), predicate=True, vprev=subj, vnext=obj
42+
)
43+
if i == 0:
44+
KG_LOOP.add_walk(subj, pred, obj)
45+
else:
46+
KG_CHAIN.add_walk(subj, pred, obj)
47+
48+
def test_invalid_weight(self):
49+
with pytest.raises(ValueError):
50+
WideSampler().get_weight(None)
51+
52+
@pytest.mark.parametrize("kg", list((KG_LOOP, KG_CHAIN)))
53+
def test_fit(self, setup, kg):
54+
sampler = WideSampler()
55+
assert len(sampler._pred_degs) == 0
56+
assert len(sampler._obj_degs) == 0
57+
assert len(sampler._neighbor_counts) == 0
58+
59+
sampler.fit(kg)
60+
assert len(sampler._pred_degs) > 0
61+
assert len(sampler._obj_degs) > 0
62+
assert len(sampler._neighbor_counts) > 0
63+
64+
@pytest.mark.parametrize(
65+
"kg, root",
66+
list(itertools.product(KGS, ROOTS_WITHOUT_URL)),
67+
)
68+
def test_weight(self, setup, kg, root):
69+
sampler = WideSampler()
70+
sampler.fit(kg)
71+
for hop in kg.get_hops(Vertex(f"{URL}#{root}")):
72+
weight = sampler.get_weight(hop)
73+
assert weight > 0
74+
assert isinstance(weight, float)

tests/test_walkers_samplers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
PageRankSampler,
2121
PredFreqSampler,
2222
UniformSampler,
23+
WideSampler,
2324
)
2425

2526

@@ -54,6 +55,7 @@
5455
PageRankSampler,
5556
PredFreqSampler,
5657
UniformSampler,
58+
WideSampler,
5759
]
5860
WALKERS = [
5961
AnonymousWalker,

0 commit comments

Comments
 (0)