11"""Module to evaluate diagnostic metrics on tracks."""
22
3+ from typing import Any , Dict , List , Optional , Tuple , Union
4+
35import numpy as np
46
57from spine .ana .base import AnaBase
68from spine .math .distance import cdist
7- from spine .utils .globals import TRACK_SHP
9+ from spine .utils .globals import MUON_PID , TRACK_SHP
810
911__all__ = ["TrackCompletenessAna" ]
1012
@@ -22,14 +24,24 @@ class TrackCompletenessAna(AnaBase):
2224 name = "track_completeness"
2325
2426 def __init__ (
25- self , time_window = None , run_mode = "both" , truth_point_mode = "points" , ** kwargs
27+ self ,
28+ time_window : Optional [Union [List [float ], Tuple [float , float ]]] = None ,
29+ length_threshold : Optional [float ] = 10 ,
30+ include_pids : Optional [Union [Tuple [int ], List [int ]]] = (MUON_PID ,),
31+ run_mode : str = "both" ,
32+ truth_point_mode : str = "points" ,
33+ ** kwargs ,
2634 ):
2735 """Initialize the analysis script.
2836
2937 Parameters
3038 ----------
3139 time_window : List[float]
3240 Time window within which to include particle (only works for `truth`)
41+ length_threshold : float, optional
42+ Minimum length of tracks to consider, in cm
43+ include_pids : Union[Tuple[int], List[int]], optional
44+ Particle IDs to include in the analysis
3345 **kwargs : dict, optional
3446 Additional arguments to pass to :class:`AnaBase`
3547 """
@@ -45,22 +57,29 @@ def __init__(
4557 time_window is None or run_mode == "truth"
4658 ), "Time of reconstructed particle is unknown."
4759
60+ # Store the length threshold and included PIDs
61+ self .length_threshold = length_threshold
62+ self .include_pids = include_pids
63+
4864 # Make sure the metadata is provided (rasterization needed)
4965 self .update_keys ({"meta" : True })
5066
5167 # Initialize the CSV writer(s) you want
5268 for prefix in self .prefixes :
5369 self .initialize_writer (prefix )
5470
55- def process (self , data ) :
71+ def process (self , data : Dict [ str , Any ]) -> None :
5672 """Evaluate track completeness for tracks in one entry.
5773
5874 Parameters
5975 ----------
60- data : dict
76+ data : Dict[str, Any]
6177 Dictionary of data products
6278 """
6379 # Fetch the pixel size in this image (assume cubic cells)
80+ assert np .all (
81+ data ["meta" ].size [0 ] == data ["meta" ].size
82+ ), "Non-cubic pixels not supported."
6483 pixel_size = data ["meta" ].size [0 ]
6584
6685 # Loop over the types of particle data products
@@ -79,6 +98,11 @@ def process(self, data):
7998 if part .t < self .time_window [0 ] or part .t > self .time_window [1 ]:
8099 continue
81100
101+ # If needed, check on the particle PID
102+ if self .include_pids is not None :
103+ if part .pid not in self .include_pids :
104+ continue
105+
82106 # Initialize the particle dictionary
83107 comp_dict = {"particle_id" : part .id }
84108
@@ -95,9 +119,18 @@ def process(self, data):
95119 if length :
96120 vec /= length
97121
122+ # If needed, check on the particle length
123+ if self .length_threshold is not None :
124+ if length < self .length_threshold :
125+ continue
126+
98127 comp_dict ["size" ] = len (points )
99128 comp_dict ["length" ] = length
100129 comp_dict .update ({"dir_x" : vec [0 ], "dir_y" : vec [1 ], "dir_z" : vec [2 ]})
130+ comp_dict .update (
131+ {"start_x" : start [0 ], "start_y" : start [1 ], "start_z" : start [2 ]}
132+ )
133+ comp_dict .update ({"end_x" : end [0 ], "end_y" : end [1 ], "end_z" : end [2 ]})
101134
102135 # Chunk out the track along gaps, estimate gap length
103136 chunk_labels = self .cluster_track_chunks (points , start , end , pixel_size )
@@ -116,7 +149,9 @@ def process(self, data):
116149 self .append (prefix , ** comp_dict )
117150
118151 @staticmethod
119- def cluster_track_chunks (points , start_point , end_point , pixel_size ):
152+ def cluster_track_chunks (
153+ points , start_point : np .ndarray , end_point : np .ndarray , pixel_size : float
154+ ) -> np .ndarray :
120155 """Find point where the track is broken, divide out the track
121156 into self-contained chunks which are Linf connect (Moore neighbors).
122157
@@ -139,19 +174,21 @@ def cluster_track_chunks(points, start_point, end_point, pixel_size):
139174 """
140175 # Project and cluster on the projected axis
141176 direction = (end_point - start_point ) / np .linalg .norm (end_point - start_point )
142- scale = pixel_size * np .max (direction )
177+ scale = pixel_size / np .max (np . abs ( direction ) )
143178 projs = np .dot (points - start_point , direction )
144179 perm = np .argsort (projs )
145180 seps = projs [perm ][1 :] - projs [perm ][:- 1 ]
146- breaks = np .where (seps > scale * 1.1 )[0 ] + 1
181+ breaks = np .where (seps > scale * 1.49 )[0 ] + 1
147182 cluster_labels = np .empty (len (projs ), dtype = int )
148183 for i , index in enumerate (np .split (np .arange (len (projs )), breaks )):
149184 cluster_labels [perm [index ]] = i
150185
151186 return cluster_labels
152187
153188 @staticmethod
154- def sequential_cluster_distances (points , labels , start_point ):
189+ def sequential_cluster_distances (
190+ points : np .ndarray , labels : np .ndarray , start_point : np .ndarray
191+ ) -> np .ndarray :
155192 """Order clusters in order of distance from a starting point, compute
156193 the distances between successive clusters.
157194
@@ -167,7 +204,7 @@ def sequential_cluster_distances(points, labels, start_point):
167204 # If there's only one cluster, nothing to do here
168205 unique_labels = np .unique (labels )
169206 if len (unique_labels ) < 2 :
170- return np .empty (0 , dtype = float ), np . empty ( 0 , dtype = float )
207+ return np .empty (0 , dtype = float )
171208
172209 # Order clusters
173210 start_dist = cdist (start_point [None , :], points ).flatten ()
0 commit comments