Skip to content

Commit a91d756

Browse files
committed
fix: tests
1 parent 84287f1 commit a91d756

File tree

5 files changed

+36
-20
lines changed

5 files changed

+36
-20
lines changed

examples/mutag.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pyrdf2vec import RDF2VecTransformer
99
from pyrdf2vec.embedders import Word2Vec
1010
from pyrdf2vec.graphs import KG
11+
from pyrdf2vec.samplers import WideSampler
1112
from pyrdf2vec.walkers import HALKWalker
1213

1314
# Ensure the determinism of this script by initializing a pseudo-random number.
@@ -31,8 +32,17 @@
3132
Word2Vec(workers=1, epochs=10),
3233
# Extract all walks with a maximum depth of 2 for each entity using two
3334
# processes and use a random state to ensure that the same walks are
34-
# generated for the entities.
35-
walkers=[HALKWalker(2, None, n_jobs=2, random_state=RANDOM_STATE)],
35+
# generated for the entities without hashing as MUTAG is a short KG.
36+
walkers=[
37+
HALKWalker(
38+
2,
39+
None,
40+
n_jobs=2,
41+
sampler=WideSampler(),
42+
random_state=RANDOM_STATE,
43+
md5_bytes=None,
44+
)
45+
],
3646
verbose=1,
3747
).fit_transform(
3848
KG(

pyrdf2vec/samplers/frequency.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_weight(self, hop: Hop) -> int:
7373
without the sampling strategy having been trained.
7474
7575
"""
76-
if len(self._counts) == 0:
76+
if not self._counts:
7777
raise ValueError(
7878
"You must call the `fit(kg)` function before get the weight of"
7979
+ " a hop."
@@ -144,7 +144,7 @@ def get_weight(self, hop: Hop) -> int:
144144
without the sampling strategy having been trained.
145145
146146
"""
147-
if self._counts:
147+
if not self._counts:
148148
raise ValueError(
149149
"You must call the `fit(kg)` method before get the weight of"
150150
+ " a hop."
@@ -219,7 +219,7 @@ def get_weight(self, hop: Hop) -> int:
219219
without the sampling strategy having been trained.
220220
221221
"""
222-
if self._counts:
222+
if not self._counts:
223223
raise ValueError(
224224
"You must call the `fit(kg)` method before get the weight of"
225225
+ " a hop."

pyrdf2vec/samplers/wide.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def get_weight(self, hop: Hop) -> float:
9090
without the sampling strategy having been trained.
9191
9292
"""
93-
if self._pred_degs or self._obj_degs or self._neighbor_counts:
93+
if not (
94+
self._pred_degs or self._obj_degs or not self._neighbor_counts
95+
):
9496
raise ValueError(
9597
"You must call the `fit(kg)` method before get the weight of"
9698
+ " a hop."

pyrdf2vec/walkers/halk.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,17 @@ def _post_extract(self, res: List[EntityWalks]) -> List[List[SWalk]]:
141141
for rare_predicates in pred_thresholds:
142142
for entity_walks in conv_res:
143143
canonical_walks = []
144+
curr_entity = entity_walks[0][0]
144145
for walk in entity_walks:
145-
canonical_walk = [walk[0]]
146+
canonical_walk = [curr_entity]
146147
for i, vertex in enumerate(walk[1::2], 2):
147148
if vertex not in rare_predicates:
148149
obj = walk[i] if i % 2 == 0 else walk[i + 1]
149150
canonical_walk += [vertex, obj]
150-
if len(canonical_walk) > 1:
151+
if len(canonical_walk) >= 3:
151152
canonical_walks.append(tuple(canonical_walk))
152-
res_halk.append(canonical_walks)
153+
if canonical_walks:
154+
res_halk.append(canonical_walks)
155+
else:
156+
res_halk.append([(curr_entity,)])
153157
return res_halk

tests/walkers/test_halk.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,17 @@ def test_extract(
6363
walks = HALKWalker(
6464
max_depth,
6565
max_walks,
66-
freq_thresholds=[0.01],
66+
freq_thresholds=[0.001],
6767
with_reverse=with_reverse,
6868
random_state=42,
69-
)._extract(kg, Vertex(root))[root]
69+
).extract(kg, [root])
70+
7071
if max_walks is not None:
71-
if with_reverse:
72-
assert len(walks) <= max_walks * max_walks
73-
else:
74-
assert len(walks) <= max_walks
75-
for walk in walks:
76-
if not with_reverse:
77-
assert walk[0] == root
78-
for pred_or_obj in walk[1:]:
79-
assert pred_or_obj.startswith("b'")
72+
assert len(walks) == 1
73+
74+
for entity_walks in walks:
75+
for walk in entity_walks:
76+
if not with_reverse:
77+
assert walk[0] == root
78+
for obj in walk[2::2]:
79+
assert obj.startswith("b'")

0 commit comments

Comments
 (0)