Skip to content

Commit a118d2b

Browse files
committed
fix: HALKWalker
1 parent b696a77 commit a118d2b

File tree

1 file changed

+16
-39
lines changed

1 file changed

+16
-39
lines changed

pyrdf2vec/walkers/halk.py

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import itertools
12
from collections import defaultdict
2-
from hashlib import md5
33
from typing import DefaultDict, List, Set
44

55
import attr
66

77
from pyrdf2vec.graphs import KG, Vertex
8-
from pyrdf2vec.typings import EntityWalks, SWalk, Walk
8+
from pyrdf2vec.typings import EntityWalks, SWalk
99
from pyrdf2vec.walkers import RandomWalker
1010

1111

@@ -56,35 +56,36 @@ class HALKWalker(RandomWalker):
5656
md5_bytes = attr.ib(kw_only=True, default=8, type=int, repr=False)
5757

5858
def build_dictionary(
59-
self, walks: List[Walk]
60-
) -> DefaultDict[Vertex, Set[int]]:
61-
"""Builds a dictionary of vertices mapped to the extracted walk indices.
59+
self, walks: List[SWalk]
60+
) -> DefaultDict[str, Set[int]]:
61+
"""Builds a dictionary of predicates mapped with the walk(s)
62+
identifiers to which it appears.
6263
6364
Args:
6465
walks: The walks to build the dictionary.
6566
6667
Returns:
67-
The dictionary of vertex.
68+
The dictionary of predicate names.
6869
6970
"""
70-
vertex_to_windices: DefaultDict[Vertex, Set[int]] = defaultdict(set)
71+
vertex_to_windices: DefaultDict[str, Set[int]] = defaultdict(set)
7172
for i in range(len(walks)):
72-
for vertex in walks[i]:
73+
for vertex in itertools.islice(walks[i], 1, None, 2):
7374
vertex_to_windices[vertex].add(i)
7475
return vertex_to_windices
7576

76-
def get_rare_vertices(
77+
def get_rare_predicates(
7778
self,
78-
vertex_to_windices: DefaultDict[Vertex, Set[int]],
79-
walks: List[Walk],
79+
vertex_to_windices: DefaultDict[str, Set[int]],
80+
walks: List[SWalk],
8081
freq_threshold: float,
81-
) -> Set[Vertex]:
82+
) -> Set[str]:
8283
"""Gets vertices which doesn't reach a certain threshold of frequency
8384
of occurrence.
8485
8586
Args:
86-
vertex_to_windices: The dictionary of vertices mapped to the
87-
extracted walk indices.
87+
vertex_to_windices: The dictionary of predicates mapped with the
88+
walk(s) identifiers to which it appears.
8889
walks: The walks.
8990
freq_threshold: The threshold frequency of occurrence.
9091
@@ -110,31 +111,7 @@ def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
110111
corresponding to the extracted walks.
111112
112113
"""
113-
walks = self.extract_walks(kg, entity)
114-
vertex_to_windices = self.build_dictionary(walks)
115-
canonical_walks: Set[SWalk] = set()
116-
117-
for freq_threshold in self.freq_thresholds:
118-
rare_vertices = self.get_rare_vertices(
119-
vertex_to_windices, walks, freq_threshold
120-
)
121-
for walk in walks:
122-
canonical_walk = []
123-
for i, vertex in enumerate(walk):
124-
if i == 0 or (
125-
vertex not in rare_vertices and self.md5_bytes is None
126-
):
127-
canonical_walk.append(vertex.name)
128-
elif vertex not in rare_vertices:
129-
canonical_walk.append(
130-
str(
131-
md5(vertex.name.encode()).digest()[
132-
: self.md5_bytes
133-
]
134-
)
135-
)
136-
canonical_walks.add(tuple(canonical_walk))
137-
return {entity.name: list(canonical_walks)}
114+
return super()._extract(kg, entity)
138115

139116
def _post_extract(self, res: List[EntityWalks]) -> List[List[SWalk]]:
140117
"""Post processed walks.

0 commit comments

Comments
 (0)