|
| 1 | +from collections import defaultdict |
| 2 | +from typing import DefaultDict, Tuple |
| 3 | + |
| 4 | +import attr |
| 5 | + |
| 6 | +from pyrdf2vec.graphs import KG |
| 7 | +from pyrdf2vec.samplers import Sampler |
| 8 | +from pyrdf2vec.typings import Hop |
| 9 | + |
| 10 | + |
| 11 | +@attr.s |
| 12 | +class WideSampler(Sampler): |
| 13 | + |
| 14 | + _pred_degs: DefaultDict[Tuple[str, str], int] = attr.ib( |
| 15 | + init=False, repr=False, factory=lambda: defaultdict(dict) |
| 16 | + ) |
| 17 | + |
| 18 | + _obj_degs: DefaultDict[Tuple[str, str], int] = attr.ib( |
| 19 | + init=False, repr=False, factory=lambda: defaultdict(dict) |
| 20 | + ) |
| 21 | + |
| 22 | + _neighbor_counts: DefaultDict[Tuple[str, str], int] = attr.ib( |
| 23 | + init=False, repr=False, factory=lambda: defaultdict(dict) |
| 24 | + ) |
| 25 | + |
| 26 | + def fit(self, kg: KG) -> None: |
| 27 | + """Since the weights are uniform, this function does nothing. |
| 28 | +
|
| 29 | + Args: |
| 30 | + kg: The Knowledge Graph. |
| 31 | +
|
| 32 | + """ |
| 33 | + super().fit(kg) |
| 34 | + |
| 35 | + for vertex in kg._vertices: |
| 36 | + if vertex.predicate: |
| 37 | + self._neighbor_counts[vertex.name] = len( |
| 38 | + kg.get_neighbors(vertex) |
| 39 | + ) |
| 40 | + counter = self._pred_degs |
| 41 | + else: |
| 42 | + self._neighbor_counts[vertex.name] = len( |
| 43 | + kg.get_neighbors(vertex, is_reverse=True) |
| 44 | + ) |
| 45 | + counter = self._obj_degs |
| 46 | + |
| 47 | + if vertex.name in counter: |
| 48 | + counter[vertex.name] += 1 |
| 49 | + else: |
| 50 | + counter[vertex.name] = 1 |
| 51 | + |
| 52 | + def get_weight(self, hop: Hop) -> int: |
| 53 | + """Gets the weight of a hop in the Knowledge Graph. |
| 54 | +
|
| 55 | + Args: |
| 56 | + hop: The hop (pred, obj) to get the weight. |
| 57 | +
|
| 58 | + Returns: |
| 59 | + The weight for a given hop. |
| 60 | +
|
| 61 | + """ |
| 62 | + if not (self._pred_degs and self._obj_degs and self._neighbor_counts): |
| 63 | + raise ValueError( |
| 64 | + "You must call the `fit(kg)` method before get the weight of" |
| 65 | + + " a hop." |
| 66 | + ) |
| 67 | + return ( |
| 68 | + self._neighbor_counts[hop[0].name] |
| 69 | + + self._neighbor_counts[hop[1].name] |
| 70 | + ) * ((self._pred_degs[hop[0].name] + self._obj_degs[hop[1].name]) / 2) |
0 commit comments