diff --git a/torchts/utils/graph.py b/torchts/utils/graph.py index efb4c160..7aac8338 100644 --- a/torchts/utils/graph.py +++ b/torchts/utils/graph.py @@ -30,6 +30,9 @@ def reverse_random_walk(adj_mx): def scaled_laplacian(adj_mx, lambda_max=2, undirected=True): + if sp.issparse(adj_mx): + adj_mx = adj_mx.todense() + if undirected: adj_mx = np.maximum.reduce([adj_mx, adj_mx.T])