Skip to content

Commit c5e9b95

Browse files
Fixed up the track completeness analaysis script
1 parent 6183783 commit c5e9b95

File tree

1 file changed

+46
-9
lines changed

1 file changed

+46
-9
lines changed

src/spine/ana/diag/track.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Module to evaluate diagnostic metrics on tracks."""
22

3+
from typing import Any, Dict, List, Optional, Tuple, Union
4+
35
import numpy as np
46

57
from spine.ana.base import AnaBase
68
from spine.math.distance import cdist
7-
from spine.utils.globals import TRACK_SHP
9+
from spine.utils.globals import MUON_PID, TRACK_SHP
810

911
__all__ = ["TrackCompletenessAna"]
1012

@@ -22,14 +24,24 @@ class TrackCompletenessAna(AnaBase):
2224
name = "track_completeness"
2325

2426
def __init__(
25-
self, time_window=None, run_mode="both", truth_point_mode="points", **kwargs
27+
self,
28+
time_window: Optional[Union[List[float], Tuple[float, float]]] = None,
29+
length_threshold: Optional[float] = 10,
30+
include_pids: Optional[Union[Tuple[int], List[int]]] = (MUON_PID,),
31+
run_mode: str = "both",
32+
truth_point_mode: str = "points",
33+
**kwargs,
2634
):
2735
"""Initialize the analysis script.
2836
2937
Parameters
3038
----------
3139
time_window : List[float]
3240
Time window within which to include particle (only works for `truth`)
41+
length_threshold : float, optional
42+
Minimum length of tracks to consider, in cm
43+
include_pids : Union[Tuple[int], List[int]], optional
44+
Particle IDs to include in the analysis
3345
**kwargs : dict, optional
3446
Additional arguments to pass to :class:`AnaBase`
3547
"""
@@ -45,22 +57,29 @@ def __init__(
4557
time_window is None or run_mode == "truth"
4658
), "Time of reconstructed particle is unknown."
4759

60+
# Store the length threshold and included PIDs
61+
self.length_threshold = length_threshold
62+
self.include_pids = include_pids
63+
4864
# Make sure the metadata is provided (rasterization needed)
4965
self.update_keys({"meta": True})
5066

5167
# Initialize the CSV writer(s) you want
5268
for prefix in self.prefixes:
5369
self.initialize_writer(prefix)
5470

55-
def process(self, data):
71+
def process(self, data: Dict[str, Any]) -> None:
5672
"""Evaluate track completeness for tracks in one entry.
5773
5874
Parameters
5975
----------
60-
data : dict
76+
data : Dict[str, Any]
6177
Dictionary of data products
6278
"""
6379
# Fetch the pixel size in this image (assume cubic cells)
80+
assert np.all(
81+
data["meta"].size[0] == data["meta"].size
82+
), "Non-cubic pixels not supported."
6483
pixel_size = data["meta"].size[0]
6584

6685
# Loop over the types of particle data products
@@ -79,6 +98,11 @@ def process(self, data):
7998
if part.t < self.time_window[0] or part.t > self.time_window[1]:
8099
continue
81100

101+
# If needed, check on the particle PID
102+
if self.include_pids is not None:
103+
if part.pid not in self.include_pids:
104+
continue
105+
82106
# Initialize the particle dictionary
83107
comp_dict = {"particle_id": part.id}
84108

@@ -95,9 +119,18 @@ def process(self, data):
95119
if length:
96120
vec /= length
97121

122+
# If needed, check on the particle length
123+
if self.length_threshold is not None:
124+
if length < self.length_threshold:
125+
continue
126+
98127
comp_dict["size"] = len(points)
99128
comp_dict["length"] = length
100129
comp_dict.update({"dir_x": vec[0], "dir_y": vec[1], "dir_z": vec[2]})
130+
comp_dict.update(
131+
{"start_x": start[0], "start_y": start[1], "start_z": start[2]}
132+
)
133+
comp_dict.update({"end_x": end[0], "end_y": end[1], "end_z": end[2]})
101134

102135
# Chunk out the track along gaps, estimate gap length
103136
chunk_labels = self.cluster_track_chunks(points, start, end, pixel_size)
@@ -116,7 +149,9 @@ def process(self, data):
116149
self.append(prefix, **comp_dict)
117150

118151
@staticmethod
119-
def cluster_track_chunks(points, start_point, end_point, pixel_size):
152+
def cluster_track_chunks(
153+
points, start_point: np.ndarray, end_point: np.ndarray, pixel_size: float
154+
) -> np.ndarray:
120155
"""Find point where the track is broken, divide out the track
121156
into self-contained chunks which are Linf connect (Moore neighbors).
122157
@@ -139,19 +174,21 @@ def cluster_track_chunks(points, start_point, end_point, pixel_size):
139174
"""
140175
# Project and cluster on the projected axis
141176
direction = (end_point - start_point) / np.linalg.norm(end_point - start_point)
142-
scale = pixel_size * np.max(direction)
177+
scale = pixel_size / np.max(np.abs(direction))
143178
projs = np.dot(points - start_point, direction)
144179
perm = np.argsort(projs)
145180
seps = projs[perm][1:] - projs[perm][:-1]
146-
breaks = np.where(seps > scale * 1.1)[0] + 1
181+
breaks = np.where(seps > scale * 1.49)[0] + 1
147182
cluster_labels = np.empty(len(projs), dtype=int)
148183
for i, index in enumerate(np.split(np.arange(len(projs)), breaks)):
149184
cluster_labels[perm[index]] = i
150185

151186
return cluster_labels
152187

153188
@staticmethod
154-
def sequential_cluster_distances(points, labels, start_point):
189+
def sequential_cluster_distances(
190+
points: np.ndarray, labels: np.ndarray, start_point: np.ndarray
191+
) -> np.ndarray:
155192
"""Order clusters in order of distance from a starting point, compute
156193
the distances between successive clusters.
157194
@@ -167,7 +204,7 @@ def sequential_cluster_distances(points, labels, start_point):
167204
# If there's only one cluster, nothing to do here
168205
unique_labels = np.unique(labels)
169206
if len(unique_labels) < 2:
170-
return np.empty(0, dtype=float), np.empty(0, dtype=float)
207+
return np.empty(0, dtype=float)
171208

172209
# Order clusters
173210
start_dist = cdist(start_point[None, :], points).flatten()

0 commit comments

Comments
 (0)