1+ import itertools
12from collections import defaultdict
2- from hashlib import md5
33from typing import DefaultDict , List , Set
44
55import attr
66
77from pyrdf2vec .graphs import KG , Vertex
8- from pyrdf2vec .typings import EntityWalks , SWalk , Walk
8+ from pyrdf2vec .typings import EntityWalks , SWalk
99from pyrdf2vec .walkers import RandomWalker
1010
1111
@@ -56,35 +56,36 @@ class HALKWalker(RandomWalker):
5656 md5_bytes = attr .ib (kw_only = True , default = 8 , type = int , repr = False )
5757
5858 def build_dictionary (
59- self , walks : List [Walk ]
60- ) -> DefaultDict [Vertex , Set [int ]]:
61- """Builds a dictionary of vertices mapped to the extracted walk indices.
59+ self , walks : List [SWalk ]
60+ ) -> DefaultDict [str , Set [int ]]:
61+ """Builds a dictionary of predicates mapped with the walk(s)
62+ identifiers to which it appears.
6263
6364 Args:
6465 walks: The walks to build the dictionary.
6566
6667 Returns:
67- The dictionary of vertex .
68+ The dictionary of predicate names .
6869
6970 """
70- vertex_to_windices : DefaultDict [Vertex , Set [int ]] = defaultdict (set )
71+ vertex_to_windices : DefaultDict [str , Set [int ]] = defaultdict (set )
7172 for i in range (len (walks )):
72- for vertex in walks [i ]:
73+ for vertex in itertools . islice ( walks [i ], 1 , None , 2 ) :
7374 vertex_to_windices [vertex ].add (i )
7475 return vertex_to_windices
7576
76- def get_rare_vertices (
77+ def get_rare_predicates (
7778 self ,
78- vertex_to_windices : DefaultDict [Vertex , Set [int ]],
79- walks : List [Walk ],
79+ vertex_to_windices : DefaultDict [str , Set [int ]],
80+ walks : List [SWalk ],
8081 freq_threshold : float ,
81- ) -> Set [Vertex ]:
82+ ) -> Set [str ]:
8283 """Gets vertices which doesn't reach a certain threshold of frequency
8384 of occurrence.
8485
8586 Args:
86- vertex_to_windices: The dictionary of vertices mapped to the
87- extracted walk indices .
87+ vertex_to_windices: The dictionary of predicates mapped with the
88+ walk(s) identifiers to which it appears .
8889 walks: The walks.
8990 freq_threshold: The threshold frequency of occurrence.
9091
@@ -110,31 +111,7 @@ def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
110111 corresponding to the extracted walks.
111112
112113 """
113- walks = self .extract_walks (kg , entity )
114- vertex_to_windices = self .build_dictionary (walks )
115- canonical_walks : Set [SWalk ] = set ()
116-
117- for freq_threshold in self .freq_thresholds :
118- rare_vertices = self .get_rare_vertices (
119- vertex_to_windices , walks , freq_threshold
120- )
121- for walk in walks :
122- canonical_walk = []
123- for i , vertex in enumerate (walk ):
124- if i == 0 or (
125- vertex not in rare_vertices and self .md5_bytes is None
126- ):
127- canonical_walk .append (vertex .name )
128- elif vertex not in rare_vertices :
129- canonical_walk .append (
130- str (
131- md5 (vertex .name .encode ()).digest ()[
132- : self .md5_bytes
133- ]
134- )
135- )
136- canonical_walks .add (tuple (canonical_walk ))
137- return {entity .name : list (canonical_walks )}
114+ return super ()._extract (kg , entity )
138115
139116 def _post_extract (self , res : List [EntityWalks ]) -> List [List [SWalk ]]:
140117 """Post processed walks.
0 commit comments