@@ -36,11 +36,11 @@ class Soma:
3636
3737 Parameters
3838 ----------
39- centre : (3,) float32
39+ centre : (3,) float
4040 XYZ world-space coordinates of the ellipsoid centre.
41- axes : (3,) float32
41+ axes : (3,) float
4242 Semi-axis lengths **sorted** as a ≥ b ≥ c.
43- R : (3,3) float32
43+ R : (3,3) float
4444 Right-handed rotation matrix whose *columns* are the principal
4545 axes expressed in world space.
4646 verts : optional (N,) int64
@@ -59,23 +59,23 @@ class Soma:
5959 # dataclass life-cycle
6060 # ---------------------------------------------------------------------
6161 def __post_init__ (self ) -> None :
62- self .centre = np .asarray (self .centre , dtype = np .float32 ).reshape (3 )
63- self .axes = np .asarray (self .axes , dtype = np .float32 ).reshape (3 )
64- self .R = np .asarray (self .R , dtype = np .float32 ).reshape (3 , 3 )
62+ self .centre = np .asarray (self .centre , dtype = np .float64 ).reshape (3 )
63+ self .axes = np .asarray (self .axes , dtype = np .float64 ).reshape (3 )
64+ self .R = np .asarray (self .R , dtype = np .float64 ).reshape (3 , 3 )
6565
6666 # ---- fast safety checks -----------------------------------------
6767 if not np .all (np .diff (self .axes ) <= 0 ):
6868 raise ValueError ("axes must be sorted a ≥ b ≥ c" )
6969
7070 # ---- pre-compute affine map ξ = (x−c) @ W -----------------------
71- self ._W = (self .R / self .axes ).astype (np .float32 )
71+ self ._W = (self .R / self .axes ).astype (np .float64 )
7272
7373 # ---------------------------------------------------------------------
7474 # geometry
7575 # ---------------------------------------------------------------------
7676 def _body_coords (self , x : np .ndarray ) -> np .ndarray :
7777 """World ➜ body coords where the ellipsoid becomes the *unit sphere*."""
78- x = np .asarray (x , dtype = np .float32 )
78+ x = np .asarray (x , dtype = np .float64 )
7979 return (x - self .centre ) @ self ._W
8080
8181 def contains (self , x : np .ndarray , * , inside_frac : float = 1.0 ) -> np .ndarray :
@@ -112,7 +112,7 @@ def distance(self, x, to="center"):
112112
113113 def distance_to_center (self , x : np .ndarray ) -> np .ndarray | float :
114114 """Unsigned Euclidean distance from *x* to the soma *centre*."""
115- x = np .asanyarray (x , dtype = np .float32 )
115+ x = np .asanyarray (x , dtype = np .float64 )
116116 single_input = x .ndim == 1
117117 if single_input :
118118 x = x [None , :]
@@ -197,7 +197,7 @@ def fit(cls, pts: np.ndarray, verts=None) -> "Soma":
197197 Fast PCA-based ellipsoid fit to ≥ 3×`axes` sample points.
198198 Rough 95 %-mass envelope, same idea as the original *sphere* fit.
199199 """
200- pts = np .asarray (pts , dtype = np .float32 )
200+ pts = np .asarray (pts , dtype = np .float64 )
201201 centre = pts .mean (axis = 0 )
202202 cov = np .cov (pts - centre , rowvar = False )
203203 evals , evecs = np .linalg .eigh (cov ) # λ₁ ≤ λ₂ ≤ λ₃
@@ -208,9 +208,9 @@ def fit(cls, pts: np.ndarray, verts=None) -> "Soma":
208208 @classmethod
209209 def from_sphere (cls , centre : np .ndarray , radius : float , verts : np .ndarray | None ) -> "Soma" :
210210 """Backward-compat helper – treat a sphere as a = b = c = radius."""
211- centre = np .asarray (centre , dtype = np .float32 )
212- axes = np .full (3 , float (radius ), dtype = np .float32 )
213- R = np .eye (3 , dtype = np .float32 )
211+ centre = np .asarray (centre , dtype = np .float64 )
212+ axes = np .full (3 , float (radius ), dtype = np .float64 )
213+ R = np .eye (3 , dtype = np .float64 )
214214 return cls (centre , axes , R , verts = verts )
215215
216216
@@ -225,9 +225,9 @@ class Skeleton:
225225 Parameters
226226 ----------
227227 nodes
228- (N, 3) float32 Cartesian coordinates.
228+ (N, 3) float64 Cartesian coordinates.
229229 radii
230- (N,) float32 local radii.
230+ (N,) float64 local radii.
231231 edges
232232 (E, 2) int64 undirected sorted vertex pairs.
233233 soma_verts
@@ -239,8 +239,8 @@ class Skeleton:
239239 soma : Soma
240240
241241 # ---- mandatory skeleton data (except ntype)---------------------------------
242- nodes : np .ndarray # (N, 3) float32
243- radii : dict [str , np .ndarray ] # (N,) float32
242+ nodes : np .ndarray # (N, 3) float64
243+ radii : dict [str , np .ndarray ] # (N,) float64
244244 edges : np .ndarray # (E, 2) int64 – undirected, **sorted** pairs
245245 ntype : np .ndarray | None # (N,) int64, node type
246246 # SWC type codes we will follow by default
@@ -585,7 +585,7 @@ def _split_comp_if_elongated(
585585 return
586586
587587 # ── fast 3-D PCA ----------------------------------------------------
588- pts = v [comp_idx ].astype (np .float32 )
588+ pts = v [comp_idx ].astype (np .float64 )
589589 cov = np .cov (pts , rowvar = False )
590590 evals , vec = np .linalg .eigh (cov ) # ascending order
591591 elong = evals [- 1 ] / (evals [- 2 ] + 1e-9 )
@@ -880,7 +880,7 @@ def _merge_near_soma_nodes(
880880
881881 if merged_idx .size :
882882 w = np .array ([len (node2verts [0 ]), * [len (node2verts [i ])
883- for i in merged_idx ]], dtype = np .float32 )
883+ for i in merged_idx ]], dtype = np .float64 )
884884 nodes_keep [0 ] = np .average (
885885 np .vstack ([nodes [0 ], nodes [merged_idx ]]), axis = 0 , weights = w
886886 )
@@ -969,7 +969,7 @@ def _bridge_gaps(
969969 Parameters
970970 ----------
971971 nodes
972- ``(N, 3)`` float32 array of mesh-vertex coordinates.
972+ ``(N, 3)`` float64 array of mesh-vertex coordinates.
973973 edges
974974 ``(E, 2)`` int64 array of **undirected, sorted** mesh edges.
975975 bridge_max_factor
@@ -1455,7 +1455,7 @@ def _make_nodes(
14551455 all_shells
14561456 Output of `_bin_geodesic_shells()`.
14571457 vertices
1458- `mesh.vertices` as `(N,3) float32 `.
1458+ `mesh.vertices` as `(N,3) float64 `.
14591459 radius_estimators
14601460 Names understood by `_estimate_radius()`.
14611461 merge_nested
@@ -1488,13 +1488,13 @@ def _make_nodes(
14881488 _estimate_radius (d , method = est , trim_fraction = 0.05 )
14891489 )
14901490
1491- nodes .append (centre .astype (np .float32 ))
1491+ nodes .append (centre .astype (np .float64 ))
14921492 node2verts .append (bin_ids )
14931493 for vid in bin_ids :
14941494 vert2node [int (vid )] = next_id
14951495 next_id += 1
14961496
1497- nodes_arr = np .asarray (nodes , dtype = np .float32 )
1497+ nodes_arr = np .asarray (nodes , dtype = np .float64 )
14981498 radii_dict = {k : np .asarray (v ) for k , v in radii_dict .items ()}
14991499
15001500 # ---- optional containment-based merge ----------------------------
0 commit comments