Skip to content

Commit d1b2c7f

Browse files
committed
change: use mostly np.float64
1 parent fdb0dee commit d1b2c7f

File tree

7 files changed

+95
-92
lines changed

7 files changed

+95
-92
lines changed

notebooks/example.miminal.ipynb

Lines changed: 15 additions & 15 deletions
Large diffs are not rendered by default.

notebooks/example.post.ipynb

Lines changed: 24 additions & 24 deletions
Large diffs are not rendered by default.

notebooks/method.tour.ipynb

Lines changed: 18 additions & 18 deletions
Large diffs are not rendered by default.

skeliner/core.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ class Soma:
3636
3737
Parameters
3838
----------
39-
centre : (3,) float32
39+
centre : (3,) float
4040
XYZ world-space coordinates of the ellipsoid centre.
41-
axes : (3,) float32
41+
axes : (3,) float
4242
Semi-axis lengths **sorted** as a ≥ b ≥ c.
43-
R : (3,3) float32
43+
R : (3,3) float
4444
Right-handed rotation matrix whose *columns* are the principal
4545
axes expressed in world space.
4646
verts : optional (N,) int64
@@ -59,23 +59,23 @@ class Soma:
5959
# dataclass life-cycle
6060
# ---------------------------------------------------------------------
6161
def __post_init__(self) -> None:
62-
self.centre = np.asarray(self.centre, dtype=np.float32).reshape(3)
63-
self.axes = np.asarray(self.axes, dtype=np.float32).reshape(3)
64-
self.R = np.asarray(self.R, dtype=np.float32).reshape(3, 3)
62+
self.centre = np.asarray(self.centre, dtype=np.float64).reshape(3)
63+
self.axes = np.asarray(self.axes, dtype=np.float64).reshape(3)
64+
self.R = np.asarray(self.R, dtype=np.float64).reshape(3, 3)
6565

6666
# ---- fast safety checks -----------------------------------------
6767
if not np.all(np.diff(self.axes) <= 0):
6868
raise ValueError("axes must be sorted a ≥ b ≥ c")
6969

7070
# ---- pre-compute affine map ξ = (x−c) @ W -----------------------
71-
self._W = (self.R / self.axes).astype(np.float32)
71+
self._W = (self.R / self.axes).astype(np.float64)
7272

7373
# ---------------------------------------------------------------------
7474
# geometry
7575
# ---------------------------------------------------------------------
7676
def _body_coords(self, x: np.ndarray) -> np.ndarray:
7777
"""World ➜ body coords where the ellipsoid becomes the *unit sphere*."""
78-
x = np.asarray(x, dtype=np.float32)
78+
x = np.asarray(x, dtype=np.float64)
7979
return (x - self.centre) @ self._W
8080

8181
def contains(self, x: np.ndarray, *, inside_frac: float = 1.0) -> np.ndarray:
@@ -112,7 +112,7 @@ def distance(self, x, to="center"):
112112

113113
def distance_to_center(self, x: np.ndarray) -> np.ndarray | float:
114114
"""Unsigned Euclidean distance from *x* to the soma *centre*."""
115-
x = np.asanyarray(x, dtype=np.float32)
115+
x = np.asanyarray(x, dtype=np.float64)
116116
single_input = x.ndim == 1
117117
if single_input:
118118
x = x[None, :]
@@ -197,7 +197,7 @@ def fit(cls, pts: np.ndarray, verts=None) -> "Soma":
197197
Fast PCA-based ellipsoid fit to ≥ 3×`axes` sample points.
198198
Rough 95 %-mass envelope, same idea as the original *sphere* fit.
199199
"""
200-
pts = np.asarray(pts, dtype=np.float32)
200+
pts = np.asarray(pts, dtype=np.float64)
201201
centre = pts.mean(axis=0)
202202
cov = np.cov(pts - centre, rowvar=False)
203203
evals, evecs = np.linalg.eigh(cov) # λ₁ ≤ λ₂ ≤ λ₃
@@ -208,9 +208,9 @@ def fit(cls, pts: np.ndarray, verts=None) -> "Soma":
208208
@classmethod
209209
def from_sphere(cls, centre: np.ndarray, radius: float, verts: np.ndarray | None) -> "Soma":
210210
"""Backward-compat helper – treat a sphere as a = b = c = radius."""
211-
centre = np.asarray(centre, dtype=np.float32)
212-
axes = np.full(3, float(radius), dtype=np.float32)
213-
R = np.eye(3, dtype=np.float32)
211+
centre = np.asarray(centre, dtype=np.float64)
212+
axes = np.full(3, float(radius), dtype=np.float64)
213+
R = np.eye(3, dtype=np.float64)
214214
return cls(centre, axes, R, verts=verts)
215215

216216

@@ -225,9 +225,9 @@ class Skeleton:
225225
Parameters
226226
----------
227227
nodes
228-
(N, 3) float32 Cartesian coordinates.
228+
(N, 3) float64 Cartesian coordinates.
229229
radii
230-
(N,) float32 local radii.
230+
(N,) float64 local radii.
231231
edges
232232
(E, 2) int64 undirected sorted vertex pairs.
233233
soma_verts
@@ -239,8 +239,8 @@ class Skeleton:
239239
soma: Soma
240240

241241
# ---- mandatory skeleton data (except ntype)---------------------------------
242-
nodes: np.ndarray # (N, 3) float32
243-
radii: dict[str, np.ndarray] # (N,) float32
242+
nodes: np.ndarray # (N, 3) float64
243+
radii: dict[str, np.ndarray] # (N,) float64
244244
edges: np.ndarray # (E, 2) int64 – undirected, **sorted** pairs
245245
ntype: np.ndarray | None # (N,) int64, node type
246246
# SWC type codes we will follow by default
@@ -585,7 +585,7 @@ def _split_comp_if_elongated(
585585
return
586586

587587
# ── fast 3-D PCA ----------------------------------------------------
588-
pts = v[comp_idx].astype(np.float32)
588+
pts = v[comp_idx].astype(np.float64)
589589
cov = np.cov(pts, rowvar=False)
590590
evals, vec = np.linalg.eigh(cov) # ascending order
591591
elong = evals[-1] / (evals[-2] + 1e-9)
@@ -880,7 +880,7 @@ def _merge_near_soma_nodes(
880880

881881
if merged_idx.size:
882882
w = np.array([len(node2verts[0]), *[len(node2verts[i])
883-
for i in merged_idx]], dtype=np.float32)
883+
for i in merged_idx]], dtype=np.float64)
884884
nodes_keep[0] = np.average(
885885
np.vstack([nodes[0], nodes[merged_idx]]), axis=0, weights=w
886886
)
@@ -969,7 +969,7 @@ def _bridge_gaps(
969969
Parameters
970970
----------
971971
nodes
972-
``(N, 3)`` float32 array of mesh-vertex coordinates.
972+
``(N, 3)`` float64 array of mesh-vertex coordinates.
973973
edges
974974
``(E, 2)`` int64 array of **undirected, sorted** mesh edges.
975975
bridge_max_factor
@@ -1455,7 +1455,7 @@ def _make_nodes(
14551455
all_shells
14561456
Output of `_bin_geodesic_shells()`.
14571457
vertices
1458-
`mesh.vertices` as `(N,3) float32`.
1458+
`mesh.vertices` as `(N,3) float64`.
14591459
radius_estimators
14601460
Names understood by `_estimate_radius()`.
14611461
merge_nested
@@ -1488,13 +1488,13 @@ def _make_nodes(
14881488
_estimate_radius(d, method=est, trim_fraction=0.05)
14891489
)
14901490

1491-
nodes.append(centre.astype(np.float32))
1491+
nodes.append(centre.astype(np.float64))
14921492
node2verts.append(bin_ids)
14931493
for vid in bin_ids:
14941494
vert2node[int(vid)] = next_id
14951495
next_id += 1
14961496

1497-
nodes_arr = np.asarray(nodes, dtype=np.float32)
1497+
nodes_arr = np.asarray(nodes, dtype=np.float64)
14981498
radii_dict = {k: np.asarray(v) for k, v in radii_dict.items()}
14991499

15001500
# ---- optional containment-based merge ----------------------------

skeliner/io.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def load_swc(
115115
raise ValueError(f"No usable nodes found in {path}")
116116

117117
# --- core arrays ----------------------------------------------------
118-
nodes_arr = np.asarray(xyz, dtype=np.float32) * scale
119-
radii_arr = np.asarray(radii, dtype=np.float32) * scale
118+
nodes_arr = np.asarray(xyz, dtype=np.float64) * scale
119+
radii_arr = np.asarray(radii, dtype=np.float64) * scale
120120
radii_dict = {"median": radii_arr, "mean": radii_arr, "trim": radii_arr}
121121
ntype_arr = np.asarray(ntype, dtype=np.int8)
122122
# --- edges (parent IDs → 0-based indices) ---------------------------
@@ -228,12 +228,12 @@ def load_npz(path: str | Path) -> Skeleton:
228228
path = Path(path)
229229

230230
with np.load(path, allow_pickle=True) as z:
231-
nodes = z["nodes"].astype(np.float32)
231+
nodes = z["nodes"].astype(np.float64)
232232
edges = z["edges"].astype(np.int64)
233233

234234
# radii dict (keys start with 'r_')
235235
radii = {
236-
k[2:]: z[k].astype(np.float32) for k in z.files if k.startswith("r_")
236+
k[2:]: z[k].astype(np.float64) for k in z.files if k.startswith("r_")
237237
}
238238

239239
# node types (optional in older archives)

skeliner/plot.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import trimesh
77
from matplotlib.axes import Axes
88
from matplotlib.collections import LineCollection, PolyCollection
9-
from matplotlib.patches import Circle, Ellipse
9+
from matplotlib.figure import Figure
10+
from matplotlib.patches import Ellipse
1011
from scipy.stats import binned_statistic_2d
1112

1213
from .core import Skeleton
@@ -189,7 +190,7 @@ def projection(
189190
draw_soma_mask: bool = True,
190191
# colors
191192
color_by: str = "fixed", # "ntype" or "fixed"
192-
) -> Tuple[plt.Figure, Axes]:
193+
) -> Tuple[Figure, Axes]:
193194
"""Orthographic 2‑D overview of a skeleton with an **optional** mesh‑density
194195
background.
195196
@@ -949,7 +950,7 @@ def node_details(
949950
scale: Union[Number, Tuple[Number, Number], Sequence[Number]] = 1.0,
950951
highlight_alpha: float = 0.5,
951952
**kwargs,
952-
) -> Tuple[plt.Figure, plt.Axes]:
953+
) -> Tuple[Figure, plt.Axes]:
953954
"""Zoomed‑in view of a specific skeleton node (mesh optional).
954955
955956
A surface `mesh` can now be omitted. The zoom window is derived purely

skeliner/post.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""skeliner.post – post-processing functions for skeletons.
22
"""
3-
from typing import Iterable, Set
3+
from typing import Iterable, Set, cast
44

55
import igraph as ig
66
import numpy as np
7+
from numpy.typing import ArrayLike
78

89
from . import dx
910

@@ -243,7 +244,11 @@ def set_ntype(
243244
if node_ids is not None:
244245
target = set(map(int, node_ids))
245246
else:
246-
bases = np.atleast_1d(root).astype(int)
247+
bases_arr = np.atleast_1d(
248+
cast(ArrayLike, root)
249+
).astype(int)
250+
251+
bases: set[int] = set(bases_arr)
247252
target: set[int] = set()
248253
if subtree:
249254
for nid in bases:
@@ -257,7 +262,4 @@ def set_ntype(
257262
if not target:
258263
return
259264

260-
# ----------------------------------------------------------- #
261-
# fast vectorised assignment #
262-
# ----------------------------------------------------------- #
263265
skel.ntype[np.fromiter(target, dtype=int)] = int(code)

0 commit comments

Comments
 (0)