diff --git a/dynnode2vec/dynnode2vec.py b/dynnode2vec/dynnode2vec.py index f11c54e..2358c64 100644 --- a/dynnode2vec/dynnode2vec.py +++ b/dynnode2vec/dynnode2vec.py @@ -39,6 +39,7 @@ def __init__( n_walks_per_node: int = 10, embedding_size: int = 128, window: int = 10, + weighted: bool = False, seed: int | None = 0, parallel_processes: int = 4, plain_node2vec: bool = False, @@ -70,6 +71,7 @@ def __init__( assert ( isinstance(window, int) and embedding_size > 0 ), "window should be a strictly positive integer" + assert isinstance(weighted, bool), "weighted should be a boolean" assert ( seed is None or isinstance(seed, int) ) and embedding_size > 0, "seed should be either None or int" @@ -84,6 +86,7 @@ def __init__( self.n_walks_per_node = n_walks_per_node self.embedding_size = embedding_size self.window = window + self.weighted = weighted self.seed = seed self.parallel_processes = parallel_processes self.plain_node2vec = plain_node2vec @@ -91,6 +94,21 @@ def __init__( # see https://stackoverflow.com/questions/53417258/what-is-workers-parameter-in-word2vec-in-nlp # pylint: disable=line-too-long self.gensim_workers = max(self.parallel_processes - 1, 12) + def _check_edge_weights(self, graphs: list[nx.Graph]) -> None: + """ + Check that all edge weights are strictly positive, otherwise we can not run random walks. + """ + if not self.weighted: + return + + for i, graph in enumerate(graphs): + weights = nx.get_edge_attributes(graph, name="weight") + + assert all(weight > 0 for weight in weights.values()), ( + "All edge weights should be strictly positive to run Dynnode2Vec " + f"found negative weight in graph {i}" + ) + def _initialize_embeddings( self, graphs: list[nx.Graph] ) -> tuple[Word2Vec, list[Embedding]]: @@ -232,7 +250,7 @@ def compute_embeddings(self, graphs: list[nx.Graph]) -> list[Embedding]: """ Compute dynamic embeddings on a list of graphs. """ - # TO DO : check graph weights valid + self._check_edge_weights(graphs) model, embeddings = self._initialize_embeddings(graphs) time_walks = self._simulate_walks(graphs) self._update_embeddings(embeddings, time_walks, model) diff --git a/tests/test_dynnode2vec.py b/tests/test_dynnode2vec.py index 70dc250..fcface7 100644 --- a/tests/test_dynnode2vec.py +++ b/tests/test_dynnode2vec.py @@ -2,6 +2,7 @@ Test the DynNode2Vec class """ # pylint: disable=missing-function-docstring +import random import gensim import networkx as nx @@ -24,6 +25,13 @@ def dynnode2vec_fixture(): ) +@pytest.fixture(name="weighted_dynnode2vec_object") +def weighted_dynnode2vec_fixture(): + return dynnode2vec.DynNode2Vec( + n_walks_per_node=5, walk_length=5, weighted=True, parallel_processes=1 + ) + + @pytest.fixture(name="parallel_dynnode2vec_object") def dynnode2vec_parallel_fixture(): return dynnode2vec.DynNode2Vec( @@ -93,6 +101,22 @@ def test_compute_embeddings(graphs, dynnode2vec_object): assert all(isinstance(emb, dynnode2vec.Embedding) for emb in embeddings) +def test_compute_weighted_embeddings(graphs, weighted_dynnode2vec_object): + embeddings = weighted_dynnode2vec_object.compute_embeddings(graphs) + + assert isinstance(embeddings, list) + assert all(isinstance(emb, dynnode2vec.Embedding) for emb in embeddings) + + # add random negative weights to the graph and check that it raises + rng = random.Random(0) + for graph in graphs: + for _, _, data in graph.edges(data=True): + data["weight"] = -rng.random() + + with pytest.raises(AssertionError): + weighted_dynnode2vec_object.compute_embeddings(graphs) + + def test_parallel_compute_embeddings(graphs, parallel_dynnode2vec_object): embeddings = parallel_dynnode2vec_object.compute_embeddings(graphs)