Skip to content

Commit 2f19045

Browse files
authored
Merge pull request #17 from berenslab/post
[wip] post-hoc soma detection
2 parents 536ba33 + 798fbf2 commit 2f19045

File tree

2 files changed

+363
-6
lines changed

2 files changed

+363
-6
lines changed

notebooks/example.post.ipynb

Lines changed: 91 additions & 3 deletions
Large diffs are not rendered by default.

skeliner/post.py

Lines changed: 272 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def _prune_nodes(
167167

168168
def _rebuild_drop_set(skel, drop: Iterable[int]):
169169
"""Compact skeleton arrays after dropping a set of vertices."""
170+
170171
drop_set = set(map(int, drop))
171172
keep_mask = np.ones(len(skel.nodes), dtype=bool)
172173
for i in drop_set:
@@ -178,11 +179,11 @@ def _rebuild_drop_set(skel, drop: Iterable[int]):
178179
remap[keep_mask] = np.arange(keep_mask.sum(), dtype=np.int64)
179180

180181
skel.nodes = skel.nodes[keep_mask]
181-
skel.node2verts = [skel.node2verts[i] for i in np.where(keep_mask)[0]] if skel.node2verts is not None else None
182+
skel.node2verts = [skel.node2verts[i] for i in np.where(keep_mask)[0]] if skel.node2verts is not None and len(skel.node2verts) > 0 else None
182183
skel.radii = {k: v[keep_mask] for k, v in skel.radii.items()}
183184

184185
# update vert2node mapping
185-
if skel.vert2node is not None:
186+
if skel.vert2node is not None and len(skel.vert2node) > 0:
186187
skel.vert2node = {int(v): int(remap[n]) for v, n in skel.vert2node.items() if keep_mask[n]}
187188

188189
# rebuild edges
@@ -262,4 +263,272 @@ def set_ntype(
262263
if not target:
263264
return
264265

265-
skel.ntype[np.fromiter(target, dtype=int)] = int(code)
266+
skel.ntype[np.fromiter(target, dtype=int)] = int(code)
267+
268+
# -----------------------------------------------------------------------------
269+
# Re-detect Soma
270+
# -----------------------------------------------------------------------------
271+
272+
def _find_soma(
273+
nodes: np.ndarray,
274+
radii: np.ndarray,
275+
*,
276+
pct_large: float = 99.9,
277+
dist_factor: float = 3.0,
278+
min_keep: int = 2,
279+
):
280+
"""
281+
Geometry-only soma heuristic used by both the core pipeline and
282+
:pyfunc:`detect_soma`.
283+
284+
Returns
285+
-------
286+
soma – *Soma* instance (sphere model – no surface verts)
287+
soma_idx – 1-D int64 array of node IDs judged to belong to the soma
288+
has_soma – True when ≥ `min_keep` nodes qualified
289+
"""
290+
from .core import Soma
291+
292+
if nodes.shape[0] == 0:
293+
raise ValueError("empty skeleton")
294+
295+
# 1. radius threshold – pick the fattest ~0.1 % of nodes
296+
large_thresh = np.percentile(radii, pct_large)
297+
cand_idx = np.where(radii >= large_thresh)[0]
298+
299+
if cand_idx.size == 0:
300+
return Soma.from_sphere(nodes[0], radii[0], verts=None), cand_idx, False
301+
302+
# 2. anchor = single largest node
303+
idx_max = int(np.argmax(radii))
304+
R_max = float(radii[idx_max])
305+
306+
# 3. keep candidates that cluster around the anchor
307+
d = np.linalg.norm(nodes[cand_idx] - nodes[idx_max], axis=1)
308+
soma_idx = cand_idx[d <= dist_factor * R_max]
309+
has_soma = soma_idx.size >= min_keep
310+
311+
soma_est = Soma.from_sphere(
312+
center=nodes[soma_idx].mean(0) if has_soma else nodes[idx_max],
313+
radius=R_max,
314+
verts=None,
315+
)
316+
return soma_est, soma_idx, has_soma
317+
318+
319+
def detect_soma(
320+
skel,
321+
*,
322+
radius_key: str = "median",
323+
soma_radius_percentile_threshold: float = 99.9,
324+
soma_radius_distance_factor: float = 4.0,
325+
soma_min_nodes: int = 3,
326+
verbose: bool = True,
327+
):
328+
"""
329+
Post-hoc soma detection **on an existing Skeleton**.
330+
331+
Examples
332+
--------
333+
>>> import skeliner as sk
334+
>>> s = sk.core.skeletonize(mesh, detect_soma=False) # soma missed
335+
>>> s2 = sk.post.detect_soma(s, verbose=True) # re-root to soma
336+
337+
Parameters
338+
----------
339+
radius_key
340+
Which radius estimator column to use for node “fatness”.
341+
pct_large, dist_factor, min_keep
342+
Hyper-parameters forwarded to the internal :pyfunc:`_find_soma`.
343+
merge
344+
When *True* (default) every node classified as soma is **collapsed**
345+
into a single centroid that becomes vertex 0. When *False* only the
346+
fattest soma node is promoted to root and the others stay, simply
347+
re-connected to it.
348+
verbose
349+
Print a concise log of what happened.
350+
351+
Returns
352+
-------
353+
Skeleton
354+
*Either* the original instance (no change was necessary) *or* a new
355+
skeleton whose node 0 is the freshly detected soma centroid.
356+
"""
357+
from .core import Skeleton, Soma, _build_mst
358+
if radius_key not in skel.radii:
359+
raise KeyError(
360+
f"radius_key '{radius_key}' not found in skel.radii "
361+
f"(available keys: {tuple(skel.radii)})"
362+
)
363+
if len(skel.nodes) <= 1:
364+
return skel # trivial graph → nothing to do
365+
366+
367+
has_node2verts = skel.node2verts is not None and len(skel.node2verts) > 0
368+
has_vert2node = skel.vert2node is not None and len(skel.vert2node) > 0
369+
# ------------------------------------------------------------------
370+
# A. re-detect the soma cluster
371+
# ------------------------------------------------------------------
372+
soma_est, soma_idx, has_soma = _find_soma(
373+
skel.nodes, skel.radii[radius_key],
374+
pct_large=soma_radius_percentile_threshold,
375+
dist_factor=soma_radius_distance_factor,
376+
min_keep=soma_min_nodes,
377+
)
378+
379+
# Already fine?
380+
if (not has_soma) or set(map(int, soma_idx)) == {0}:
381+
if verbose:
382+
print("[skeliner] detect_soma – existing soma kept unchanged.")
383+
return skel
384+
385+
# Which node will be the *new* root?
386+
new_root_old = int(
387+
soma_idx[np.argmax(skel.radii[radius_key][soma_idx])]
388+
)
389+
drop_nodes = {int(i) for i in soma_idx if i != new_root_old}
390+
391+
# ------------------------------------------------------------------
392+
# B. clone arrays so we do not mutate the caller’s object
393+
# ------------------------------------------------------------------
394+
nodes = skel.nodes.copy()
395+
radii = {k: v.copy() for k, v in skel.radii.items()}
396+
edges = skel.edges.copy()
397+
node2verts = [vs.copy() for vs in skel.node2verts] if has_node2verts else None
398+
vert2node = dict(skel.vert2node) if has_vert2node else None
399+
ntype = skel.ntype.copy() if skel.ntype is not None else None
400+
401+
# ------------------------------------------------------------------
402+
# C. **collapse** of multiple soma nodes
403+
# ------------------------------------------------------------------
404+
if drop_nodes:
405+
#
406+
# 1. move geometric centre to the keeper (new_root_old)
407+
#
408+
nodes[new_root_old] = nodes[list(drop_nodes) + [new_root_old]].mean(0)
409+
410+
#
411+
# 2. merge vertex memberships + radii (tolerate missing node2verts)
412+
#
413+
for idx in drop_nodes:
414+
if has_node2verts:
415+
# auto-extend the mapping list if it is shorter than needed
416+
if idx >= len(node2verts):
417+
node2verts.extend(
418+
[np.empty(0, dtype=np.int64) for _ in range(idx + 1 - len(node2verts))]
419+
)
420+
if new_root_old >= len(node2verts):
421+
node2verts.extend(
422+
[np.empty(0, dtype=np.int64) for _ in range(new_root_old + 1 - len(node2verts))]
423+
)
424+
node2verts[new_root_old] = np.concatenate(
425+
(node2verts[new_root_old], node2verts[idx])
426+
)
427+
428+
for k in radii:
429+
radii[k][new_root_old] = max(radii[k][new_root_old], radii[k][idx])
430+
431+
#
432+
# 3. RE-WIRE: connect every neighbour of a soon-to-be-dropped node
433+
# directly to the keeper so the skeleton stays in one piece.
434+
#
435+
drop_set = set(drop_nodes)
436+
extra_edges = []
437+
for a, b in edges:
438+
if a in drop_set and b not in drop_set:
439+
extra_edges.append((new_root_old, b))
440+
elif b in drop_set and a not in drop_set:
441+
extra_edges.append((new_root_old, a))
442+
443+
if extra_edges:
444+
edges = np.vstack([edges, np.asarray(extra_edges, dtype=np.int64)])
445+
# row-wise sort then deduplicate
446+
edges = np.unique(np.sort(edges, axis=1), axis=0)
447+
448+
# ------------------------------------------------------------------
449+
# D. build keep-mask & remap after the (optional) merge
450+
# ------------------------------------------------------------------
451+
keep_mask = np.ones(len(nodes), bool)
452+
keep_mask[list(drop_nodes)] = False
453+
remap = -np.ones(len(nodes), np.int64)
454+
remap[np.where(keep_mask)[0]] = np.arange(keep_mask.sum(), dtype=np.int64)
455+
456+
nodes = nodes[keep_mask]
457+
radii = {k: v[keep_mask] for k, v in radii.items()}
458+
if ntype is not None:
459+
ntype = ntype[keep_mask]
460+
if has_node2verts:
461+
node2verts = [node2verts[i] for i in np.where(keep_mask)[0]]
462+
if has_vert2node:
463+
vert2node = {int(v): remap[int(n)]
464+
for v, n in vert2node.items() if keep_mask[n]}
465+
466+
# edges – remap & de-duplicate
467+
edges = np.asarray(
468+
[(remap[a], remap[b]) for a, b in edges if keep_mask[a] and keep_mask[b]],
469+
dtype=np.int64,
470+
)
471+
if edges.size:
472+
edges = np.unique(np.sort(edges, axis=1), axis=0)
473+
474+
new_root = remap[new_root_old]
475+
476+
# ------------------------------------------------------------------
477+
# E. enforce: soma → vertex 0
478+
# ------------------------------------------------------------------
479+
if new_root != 0:
480+
swap = new_root
481+
482+
nodes[[0, swap]] = nodes[[swap, 0]]
483+
for k in radii:
484+
radii[k][[0, swap]] = radii[k][[swap, 0]]
485+
if ntype is not None:
486+
ntype[[0, swap]] = ntype[[swap, 0]]
487+
if has_node2verts:
488+
node2verts[0], node2verts[swap] = node2verts[swap], node2verts[0]
489+
if has_vert2node:
490+
for v, n in list(vert2node.items()):
491+
if n == 0:
492+
vert2node[v] = swap
493+
elif n == swap:
494+
vert2node[v] = 0
495+
496+
a0, a1 = edges == 0, edges == swap
497+
edges[a0] = swap
498+
edges[a1] = 0
499+
edges = np.unique(np.sort(edges, axis=1), axis=0)
500+
501+
# ------------------------------------------------------------------
502+
# F. rebuild the Soma object (sphere model – no mesh available)
503+
# ------------------------------------------------------------------
504+
r0 = float(radii[radius_key][0])
505+
soma_new = Soma.from_sphere(nodes[0], r0,
506+
verts=node2verts[0] if node2verts is not None and len(node2verts) > 0 else None)
507+
508+
if ntype is not None:
509+
ntype[0] = 1 # SWC code for soma
510+
511+
if verbose:
512+
centre_txt = ", ".join(f"{c:7.1f}" for c in nodes[0])
513+
merged = len(drop_nodes)
514+
what = f"merged {merged} node{'s' if merged != 1 else ''}"
515+
print(f"[skeliner] detect_soma – {what} → soma @ [{centre_txt}], r ≈ {r0:.1f}")
516+
517+
# ------------------------------------------------------------------
518+
# G. return the **new** skeleton object
519+
# ------------------------------------------------------------------
520+
new_skel = Skeleton(
521+
soma=soma_new,
522+
nodes=nodes,
523+
radii=radii,
524+
edges=_build_mst(nodes, edges),
525+
ntype=ntype,
526+
node2verts=node2verts,
527+
vert2node=vert2node,
528+
meta={**skel.meta}, # shallow copies are fine
529+
extra={**skel.extra},
530+
)
531+
532+
new_skel.prune(num_nodes=1) # remove any remaining twigs
533+
return new_skel
534+

0 commit comments

Comments
 (0)