Skip to content

Commit 4d4d67d

Browse files
committed
add degree option to ignore weights (and tests)
1 parent 7b36352 commit 4d4d67d

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

graphconstructor/graph.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,16 @@ def sorted_by(self, col: str) -> "Graph":
241241
meta2 = self.meta.iloc[order].reset_index(drop=True)
242242
return Graph(adj=A2, directed=self.directed, weighted=self.weighted, meta=meta2)
243243

244-
def degree(self) -> np.ndarray:
245-
"""Return (out-)degree for directed, degree for undirected. For weighted graphs sum of weights."""
246-
if self.weighted:
244+
def degree(self, ignore_weights: bool = False) -> np.ndarray:
245+
"""Return (out-)degree for directed, degree for undirected. For weighted graphs sum of weights.
246+
247+
Parameters
248+
----------
249+
ignore_weights
250+
If True, count number of edges only (treat as unweighted).
251+
Default is False.
252+
"""
253+
if self.weighted and not ignore_weights:
247254
deg = np.asarray(self.adj.sum(axis=1)).ravel()
248255
else:
249256
# count nonzeros per row

tests/test_graph.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,21 @@ def test_graph_connected_components_method():
155155
assert np.allclose(labels, np.array([0, 0, 0, 1, 1]))
156156

157157

158+
def test_graph_degree_method_weighted_and_unweighted():
159+
A = _csr([2.0, 3.0, 4.0, 5.0], [0, 1, 3, 3], [1, 2, 2, 0], 4)
160+
G_weighted = Graph.from_csr(A, directed=False, weighted=True)
161+
deg_weighted = G_weighted.degree()
162+
assert np.allclose(deg_weighted, np.array([7.0, 5.0, 7.0, 9.0]))
163+
164+
# ignore weights
165+
deg_weighted = G_weighted.degree(ignore_weights=True)
166+
assert np.allclose(deg_weighted, np.array([2, 2, 2, 2]))
167+
168+
G_unweighted = Graph.from_csr(A, directed=False, weighted=False)
169+
deg_unweighted = G_unweighted.degree()
170+
assert np.allclose(deg_unweighted, np.array([2, 2, 2, 2]))
171+
172+
158173
# ----------------- exporters -----------------
159174
@pytest.mark.skipif(not HAS_NX, reason="networkx not installed")
160175
def test_to_networkx_types_and_node_attributes():

0 commit comments

Comments
 (0)