44
55from collections import defaultdict
66from numbers import Number
7- from typing import TYPE_CHECKING , Callable
7+ from typing import TYPE_CHECKING , Callable , cast
88
99import numpy as np
1010import torch
1818 from numpy .typing import ArrayLike
1919
2020
21- def delaunay_adjacency (points : ArrayLike , dthresh : float ) -> list :
21+ def delaunay_adjacency (points : np . ndarray , dthresh : float ) -> np . ndarray :
2222 """Create an adjacency matrix via Delaunay triangulation from a list of coordinates.
2323
2424 Points which are further apart than dthresh will not be connected.
2525
2626 See https://en.wikipedia.org/wiki/Adjacency_matrix.
2727
2828 Args:
29- points (ArrayLike ):
29+ points (np.ndarray ):
3030 An nxm list of coordinates.
3131 dthresh (float):
3232 Distance threshold for triangulation.
@@ -113,11 +113,11 @@ def triangle_signed_area(triangle: ArrayLike) -> int:
113113 )
114114
115115
116- def edge_index_to_triangles (edge_index : ArrayLike ) -> ArrayLike :
116+ def edge_index_to_triangles (edge_index : np . ndarray ) -> np . ndarray :
117117 """Convert an edged index to triangle simplices (triplets of coordinate indices).
118118
119119 Args:
120- edge_index (ArrayLike ):
120+ edge_index (np.ndarray ):
121121 An Nx2 array of edges.
122122
123123 Returns:
@@ -157,24 +157,24 @@ def edge_index_to_triangles(edge_index: ArrayLike) -> ArrayLike:
157157
158158
159159def affinity_to_edge_index (
160- affinity_matrix : torch .Tensor | ArrayLike ,
160+ affinity_matrix : torch .Tensor | np . ndarray ,
161161 threshold : float = 0.5 ,
162- ) -> torch .tensor | ArrayLike :
162+ ) -> torch .Tensor | np . ndarray :
163163 """Convert an affinity matrix (similarity matrix) to an edge index.
164164
165165 Converts an NxN affinity matrix to a 2xM edge index, where M is the
166166 number of node pairs with a similarity greater than the threshold
167167 value (defaults to 0.5).
168168
169169 Args:
170- affinity_matrix:
170+ affinity_matrix (torch.Tensor | np.ndarray) :
171171 An NxN matrix of affinities between nodes.
172172 threshold (Number):
173173 Threshold above which to be considered connected. Defaults
174174 to 0.5.
175175
176176 Returns:
177- ArrayLike or torch.Tensor:
177+ torch.Tensor | np.ndarray :
178178 The edge index of shape (2, M).
179179
180180 Example:
@@ -191,9 +191,9 @@ def affinity_to_edge_index(
191191 raise ValueError (msg )
192192 # Handle cases for pytorch and numpy inputs
193193 if isinstance (affinity_matrix , torch .Tensor ):
194- return (affinity_matrix > threshold ).nonzero ().t ().contiguous ()
194+ return (affinity_matrix > threshold ).nonzero ().t ().contiguous (). to ( torch . int64 )
195195 return np .ascontiguousarray (
196- np .stack ((affinity_matrix > threshold ).nonzero (), axis = 1 ).T ,
196+ np .stack ((affinity_matrix > threshold ).nonzero (), axis = 1 ).T . astype ( np . int64 ) ,
197197 )
198198
199199
@@ -208,7 +208,7 @@ class SlideGraphConstructor:
208208 """
209209
210210 @staticmethod
211- def _umap_reducer (graph : dict [str , ArrayLike ]) -> ArrayLike :
211+ def _umap_reducer (graph : dict [str , np . ndarray ]) -> np . ndarray :
212212 """Default reduction which reduces `graph["x"]` to 3D values.
213213
214214 Reduces graph features to 3D values using UMAP which are suitable
@@ -220,7 +220,7 @@ def _umap_reducer(graph: dict[str, ArrayLike]) -> ArrayLike:
220220 "coordinates".
221221
222222 Returns:
223- ArrayLike :
223+ np.ndarray :
224224 A UMAP embedding of `graph["x"]` with shape (N, 3) and
225225 values ranging from 0 to 1.
226226 """
@@ -232,15 +232,15 @@ def _umap_reducer(graph: dict[str, ArrayLike]) -> ArrayLike:
232232
233233 @staticmethod
234234 def build (
235- points : ArrayLike ,
236- features : ArrayLike ,
235+ points : np . ndarray ,
236+ features : np . ndarray ,
237237 lambda_d : float = 3.0e-3 ,
238238 lambda_f : float = 1.0e-3 ,
239239 lambda_h : float = 0.8 ,
240240 connectivity_distance : int = 4000 ,
241241 neighbour_search_radius : int = 2000 ,
242242 feature_range_thresh : float | None = 1e-4 ,
243- ) -> dict [str , ArrayLike ]:
243+ ) -> dict [str , np . ndarray ]:
244244 """Build a graph via hybrid clustering in spatial and feature space.
245245
246246 The graph is constructed via hybrid hierarchical clustering
@@ -266,10 +266,10 @@ def build(
266266 connected.
267267
268268 Args:
269- points (ArrayLike ):
269+ points (np.ndarray ):
270270 A list of (x, y) spatial coordinates, e.g. pixel
271271 locations within a WSI.
272- features (ArrayLike ):
272+ features (np.ndarray ):
273273 A list of features associated with each coordinate in
274274 `points`. Must be the same length as `points`.
275275 lambda_d (Number):
@@ -400,27 +400,27 @@ def build(
400400 # Find the xy and feature space averages of the cluster
401401 point_centroids .append (np .round (points [idx , :].mean (axis = 0 )))
402402 feature_centroids .append (features [idx , :].mean (axis = 0 ))
403- point_centroids = np .array (point_centroids )
404- feature_centroids = np .array (feature_centroids )
403+ point_centroids_arr = np .array (point_centroids )
404+ feature_centroids_arr = np .array (feature_centroids )
405405
406406 adjacency_matrix = delaunay_adjacency (
407- points = point_centroids ,
407+ points = point_centroids_arr ,
408408 dthresh = connectivity_distance ,
409409 )
410410 edge_index = affinity_to_edge_index (adjacency_matrix )
411-
411+ edge_index = cast ( np . ndarray , edge_index )
412412 return {
413- "x" : feature_centroids ,
414- "edge_index" : edge_index . astype ( np . int64 ) ,
415- "coordinates" : point_centroids ,
413+ "x" : feature_centroids_arr ,
414+ "edge_index" : edge_index ,
415+ "coordinates" : point_centroids_arr ,
416416 }
417417
418418 @classmethod
419419 def visualise (
420420 cls : type [SlideGraphConstructor ],
421- graph : dict [str , ArrayLike ],
422- color : ArrayLike | str | Callable | None = None ,
423- node_size : Number | ArrayLike | Callable = 25 ,
421+ graph : dict [str , np . ndarray ],
422+ color : np . ndarray | str | Callable | None = None ,
423+ node_size : int | np . ndarray | Callable = 25 ,
424424 edge_color : str | ArrayLike = (0 , 0 , 0 , 0.33 ),
425425 ax : Axes | None = None ,
426426 ) -> Axes :
@@ -510,7 +510,8 @@ def visualise(
510510
511511 # Plot the nodes
512512 plt .scatter (
513- * nodes .T ,
513+ x = nodes .T [0 ],
514+ y = nodes .T [1 ],
514515 c = color (graph ) if callable (color ) else color ,
515516 s = node_size (graph ) if callable (node_size ) else node_size ,
516517 zorder = 2 ,
0 commit comments