Skip to content

Commit 39757b9

Browse files
committed
refactoring: fit function of WideSampler
1 parent 1d31eb8 commit 39757b9

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

pyrdf2vec/samplers/wide.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,11 @@ def fit(self, kg: KG) -> None:
5959
"""
6060
super().fit(kg)
6161
for vertex in kg._vertices:
62-
if vertex.predicate:
63-
self._neighbor_counts[vertex.name] = len(
64-
kg.get_neighbors(vertex)
65-
)
66-
counter = self._pred_degs
67-
else:
68-
self._neighbor_counts[vertex.name] = len(
69-
kg.get_neighbors(vertex, is_reverse=True)
70-
)
71-
counter = self._obj_degs
62+
is_reverse = True if vertex.predicate else False
63+
counter = self._pred_degs if vertex.predicate else self._obj_degs
64+
self._neighbor_counts[vertex.name] = len(
65+
kg.get_neighbors(vertex, is_reverse=is_reverse)
66+
)
7267

7368
if vertex.name in counter:
7469
counter[vertex.name] += 1

0 commit comments

Comments
 (0)