Skip to content

Commit 22a6b84

Browse files
Merge pull request #78 from francois-drielsma/develop
Large update incorporating all changes necessary for BNB nue analysis
2 parents e951a65 + e94337b commit 22a6b84

30 files changed

+1186
-601
lines changed

spine/ana/metric/cluster.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(self, obj_type=None, use_objects=False, per_object=True,
8787
self.label_key = label_key
8888

8989
# Parse the label_col column, if necessary
90+
self.label_col = None
9091
if label_col is not None:
9192
self.label_col = enum_factory('cluster', label_col)
9293

@@ -106,7 +107,8 @@ def __init__(self, obj_type=None, use_objects=False, per_object=True,
106107
keys[label_key] = True
107108
for obj in self.obj_type:
108109
keys[f'{obj}_clusts'] = True
109-
keys[f'{obj}_shapes'] = True
110+
if obj != 'interaction':
111+
keys[f'{obj}_shapes'] = True
110112

111113
else:
112114
keys['points'] = True
@@ -150,7 +152,8 @@ def process(self, data):
150152
label_col = self.label_col or self.label_cols[obj_type]
151153
num_points = len(data[self.label_key])
152154
labels = data[self.label_key][:, label_col]
153-
shapes = data[self.label_key][:, SHAPE_COL]
155+
if obj_type != 'interaction':
156+
shapes = data[self.label_key][:, SHAPE_COL]
154157
num_truth = len(np.unique(labels[labels > -1]))
155158

156159
else:
@@ -170,7 +173,8 @@ def process(self, data):
170173
num_reco = len(data[f'{obj_type}_clusts'])
171174
for i, index in enumerate(data[f'{obj_type}_clusts']):
172175
preds[index] = i
173-
shapes[index] = data[f'{obj_type}_shapes'][i]
176+
if obj_type != 'interaction':
177+
shapes[index] = data[f'{obj_type}_shapes'][i]
174178

175179
else:
176180
# Use clusters from the object indexes

spine/build/fragment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def load_truth(self, data):
337337

338338
def _load_truth(self, truth_fragments, points_label, depositions_label,
339339
depositions_q_label=None, points=None, depositions=None,
340-
points_g4=None, depositons_g4=None, sources_label=None,
340+
points_g4=None, depositions_g4=None, sources_label=None,
341341
sources=None):
342342
"""Load :class:`TruthFragment` objects from their stored versions.
343343

spine/data/out/interaction.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88

9-
from spine.utils.globals import PID_LABELS, PID_TAGS
9+
from spine.utils.globals import SHOWR_SHP, PID_LABELS, PID_TAGS
1010
from spine.utils.decorators import inherit_docstring
1111

1212
from spine.data.neutrino import Neutrino
@@ -308,6 +308,21 @@ def __str__(self):
308308
"""
309309
return 'Reco' + super().__str__()
310310

311+
@property
312+
def leading_shower(self):
313+
"""Leading primary shower of this interaction.
314+
315+
Returns
316+
-------
317+
RecoParticle
318+
Primary shower with the highest kinetic energy
319+
"""
320+
showers = [part for part in self.primary_particles if part.shape == SHOWR_SHP]
321+
if len(showers) == 0:
322+
return None
323+
324+
return max(showers, key=lambda x: x.ke)
325+
311326

312327
@dataclass(eq=False)
313328
@inherit_docstring(TruthBase, InteractionBase)

spine/data/out/particle.py

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from scipy.spatial.distance import cdist
88

99
from spine.utils.globals import (
10-
TRACK_SHP, SHAPE_LABELS, PID_LABELS, PID_MASSES, PID_TO_PDG)
10+
SHOWR_SHP, TRACK_SHP, SHAPE_LABELS, PID_LABELS, PID_MASSES, PID_TO_PDG)
1111
from spine.utils.decorators import inherit_docstring
1212

1313
from spine.data.particle import Particle
@@ -212,19 +212,28 @@ class RecoParticle(ParticleBase, RecoBase):
212212
(M) List of indexes of PPN points associated with this particle
213213
ppn_points : np.ndarray
214214
(M, 3) List of PPN points tagged to this particle
215-
vertex_distance: float
215+
vertex_distance : float
216216
Set-to-point distance between all particle points and the parent
217-
interaction vertex. (untis of cm)
218-
shower_split_angle: float
219-
Estimate of the opening angle of the shower. If particle is not a
220-
shower, then this is set to -1. (units of degrees)
217+
interaction vertex position in cm
218+
start_dedx : float
219+
dE/dx around a user-defined neighborhood of the start point in MeV/cm
220+
start_straightness : float
221+
Explained variance ratio of the beginning of the particle
222+
directional_spread : float
223+
Estimate of the angular spread of the particle (cosine spread)
224+
axial_spread : float
225+
Pearson correlation coefficient of the axial profile of the particle
226+
w.r.t. to the distance from its start point
221227
"""
222228
pid_scores: np.ndarray = None
223229
primary_scores: np.ndarray = None
224230
ppn_ids: np.ndarray = None
225231
ppn_points: np.ndarray = None
226232
vertex_distance: float = -1.
227-
shower_split_angle: float = -1.
233+
start_dedx: float = -1.
234+
start_straightness: float = -1.
235+
directional_spread: float = -1.
236+
axial_spread: float = -np.inf
228237

229238
# Fixed-length attributes
230239
_fixed_length_attrs = (
@@ -265,19 +274,34 @@ def __str__(self):
265274
def merge(self, other):
266275
"""Merge another particle instance into this one.
267276
268-
This method can only merge two track objects with well defined start
269-
and end points.
277+
The merging strategy differs depending on the the particle shapes
278+
merged together. There are two categories:
279+
- Track + track
280+
- The start/end points are produced by finding the combination of points
281+
which are farthest away from each other (one from each constituent)
282+
- The primary scores/primary status match that of the constituent
283+
particle with the highest primary score
284+
- The PID scores/PID value match that of the constituent particle with
285+
the highest primary score
286+
- Shower + Track
287+
- The track is always merged into the shower, not the other way around
288+
- The start point of the shower is updated to be the track end point
289+
further away from the current shower start point
290+
- The primary scores/primary status match that of the constituent
291+
particle with the highest primary score
292+
- The PID scores/PID value is kept unchanged (that of the shower)
270293
271294
Parameters
272295
----------
273296
other : RecoParticle
274297
Other reconstructed particle to merge into this one
275298
"""
276-
# Check that both particles being merged are tracks
277-
assert self.shape == TRACK_SHP and other.shape == TRACK_SHP, (
278-
"Can only merge two track particles.")
299+
# Check that the particles being merged fit one of two categories
300+
assert (self.shape in (SHOWR_SHP, TRACK_SHP) and
301+
other.shape == TRACK_SHP), (
302+
"Can only merge two track particles or a track into a shower.")
279303

280-
# Check that neither particle has yet been matches
304+
# Check that neither particle has yet been matched
281305
assert not self.is_matched and not other.is_matched, (
282306
"Cannot merge particles that already have matches.")
283307

@@ -287,27 +311,45 @@ def merge(self, other):
287311
setattr(self, attr, val)
288312

289313
# Select end points and end directions appropriately
290-
points_i = np.vstack([self.start_point, self.end_point])
291-
points_j = np.vstack([other.start_point, other.end_point])
292-
dirs_i = np.vstack([self.start_dir, self.end_dir])
293-
dirs_j = np.vstack([other.start_dir, other.end_dir])
314+
if self.shape == TRACK_SHP:
315+
# If two tracks, pick points furthest apart
316+
points_i = np.vstack([self.start_point, self.end_point])
317+
points_j = np.vstack([other.start_point, other.end_point])
318+
dirs_i = np.vstack([self.start_dir, self.end_dir])
319+
dirs_j = np.vstack([other.start_dir, other.end_dir])
320+
321+
dists = cdist(points_i, points_j)
322+
max_index = np.argmax(dists)
323+
max_i, max_j = max_index//2, max_index%2
324+
325+
self.start_point = points_i[max_i]
326+
self.end_point = points_j[max_j]
327+
self.start_dir = dirs_i[max_i]
328+
self.end_dir = dirs_j[max_j]
294329

295-
dists = cdist(points_i, points_j)
296-
max_index = np.argmax(dists)
297-
max_i, max_j = max_index//2, max_index%2
330+
else:
331+
# If a shower and a track, pick track point furthest from shower
332+
points_i = self.start_point.reshape(-1, 3)
333+
points_j = np.vstack([other.start_point, other.end_point])
334+
dirs_j = np.vstack([other.start_dir, other.end_dir])
335+
336+
dists = cdist(points_i, points_j)
337+
max_j = np.argmax(dists)
298338

299-
self.start_point = points_i[max_i]
300-
self.end_point = points_j[max_j]
301-
self.start_dir = dirs_i[max_i]
302-
self.end_dir = dirs_j[max_j]
339+
self.start_point = points_j[max_j]
340+
self.start_dir = dirs_j[max_j]
303341

304-
# If one of the two particles is a primary, the new one is
342+
# Match primary/PID to the most primary particle
305343
if other.primary_scores[-1] > self.primary_scores[-1]:
306344
self.primary_scores = other.primary_scores
345+
self.is_primary = other.is_primary
346+
if self.shape == TRACK_SHP:
347+
self.pid_scores = other.pid_scores
348+
self.pid = other.pid
307349

308-
# For PID, pick the most confident prediction (could be better...)
309-
if np.max(other.pid_scores) > np.max(self.pid_scores):
310-
self.pid_scores = other.pid_scores
350+
# If the calorimetric KEs have been computed, can safely sum
351+
if other.calo_ke > 0.:
352+
self.calo_ke += other.calo_ke
311353

312354
@property
313355
def mass(self):
@@ -387,12 +429,12 @@ def momentum(self, momentum):
387429
def reco_ke(self):
388430
"""Alias for `ke`, to match nomenclature in truth."""
389431
return self.ke
390-
432+
391433
@property
392434
def reco_momentum(self):
393435
"""Alias for `momentum`, to match nomenclature in truth."""
394436
return self.momentum
395-
437+
396438
@property
397439
def reco_length(self):
398440
"""Alias for `length`, to match nomenclature in truth."""
@@ -402,7 +444,7 @@ def reco_length(self):
402444
def reco_start_dir(self):
403445
"""Alias for `start_dir`, to match nomenclature in truth."""
404446
return self.start_dir
405-
447+
406448
@property
407449
def reco_end_dir(self):
408450
"""Alias for `end_dir`, to match nomenclature in truth."""

spine/driver.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ def __init__(self, cfg, rank=None):
124124
assert self.model is None or self.unwrap, (
125125
"Must unwrap the model output to run post-processors.")
126126
self.watch.initialize('post')
127-
self.post = PostManager(post, parent_path=self.parent_path)
127+
self.post = PostManager(
128+
post, post_list=self.post_list, parent_path=self.parent_path)
128129

129130
# Initialize the analysis scripts
130131
self.ana = None
@@ -354,12 +355,21 @@ def initialize_io(self, loader=None, reader=None, writer=None):
354355
self.watch.initialize('unwrap')
355356
self.unwrapper = Unwrapper(geometry=geo)
356357

358+
# If working from LArCV files, no post-processor was yet run
359+
self.post_list = ()
360+
357361
else:
358362
# Initialize the reader
359363
self.watch.initialize('read')
360364
self.reader = reader_factory(reader)
361365
self.iter_per_epoch = len(self.reader)
362366

367+
# Fetch the list of previously run post-processors
368+
# TODO: this only works with two runs in a row, not 3 and above
369+
self.post_list = None
370+
if self.reader.cfg is not None:
371+
self.post_list = tuple(self.reader.cfg['post'])
372+
363373
# Fetch an appropriate common prefix for all input files
364374
self.log_prefix, self.output_prefix = self.get_prefixes(
365375
self.reader.file_paths, self.split_output)
@@ -448,7 +458,7 @@ def get_prefixes(file_paths, split_output):
448458
log_prefix += f'--{suffix}'
449459

450460
# Truncate file names that are too long
451-
max_length = 230
461+
max_length = 150
452462
if len(log_prefix) > max_length:
453463
log_prefix = log_prefix[:max_length-3] + '---'
454464

spine/post/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,15 @@ class PostBase(ABC):
3030
# Units in which the post-processor expects objects to be expressed in
3131
units = 'cm'
3232

33+
# Whether this post-processor needs to know where the configuration lives
34+
need_parent_path = False
35+
3336
# Set of data keys needed for this post-processor to operate
3437
_keys = ()
3538

39+
# Set of post-processors which must be run before this one is
40+
_upstream = ()
41+
3642
# List of recognized object types
3743
_obj_types = ('fragment', 'particle', 'interaction')
3844

spine/post/factories.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from spine.utils.factory import module_dict, instantiate
44

5-
from . import reco, metric, optical, crt, trigger
5+
from . import reco, truth, optical, crt, trigger
66

77
# Build a dictionary of available calibration modules
88
POST_DICT = {}
9-
for module in [reco, metric, optical, crt, trigger]:
9+
for module in [reco, truth, optical, crt, trigger]:
1010
POST_DICT.update(**module_dict(module))
1111

1212

@@ -29,8 +29,7 @@ def post_processor_factory(name, cfg, parent_path=None):
2929
cfg['name'] = name
3030

3131
# Instantiate the post-processor module
32-
# TODO: This is hacky, fix it
33-
if name == 'flash_match':
32+
if name in POST_DICT and POST_DICT[name].need_parent_path:
3433
return instantiate(POST_DICT, cfg, parent_path=parent_path)
3534
else:
3635
return instantiate(POST_DICT, cfg)

spine/post/manager.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,45 @@ class PostManager:
1717
It loads all the post-processor objects once and feeds them data.
1818
"""
1919

20-
def __init__(self, cfg, parent_path=None):
20+
def __init__(self, cfg, post_list=None, parent_path=None):
2121
"""Initialize the post-processing manager.
2222
2323
Parameters
2424
----------
2525
cfg : dict
2626
Post-processor configurations
27+
post_list : List[str], optional
28+
List of post-processors which have already been run
2729
parent_path : str, optional
2830
Path to the analysis tools configuration file
2931
"""
3032
# Loop over the post-processor modules and get their priorities
3133
cfg = deepcopy(cfg)
3234
keys = np.array(list(cfg.keys()))
3335
priorities = -np.ones(len(keys), dtype=np.int32)
34-
for i, k in enumerate(keys):
35-
if 'priority' in cfg[k]:
36-
priorities[i] = cfg[k].pop('priority')
36+
for i, key in enumerate(keys):
37+
if 'priority' in cfg[key]:
38+
priorities[i] = cfg[key].pop('priority')
3739

3840
# Add the modules to a processor list in decreasing order of priority
3941
self.watch = StopwatchManager()
4042
self.modules = OrderedDict()
4143
keys = keys[np.argsort(-priorities)]
42-
for k in keys:
44+
for key in keys:
4345
# Profile the module
44-
self.watch.initialize(k)
46+
self.watch.initialize(key)
4547

4648
# Append
47-
self.modules[k] = post_processor_factory(
48-
k, cfg[k], parent_path=parent_path)
49+
self.modules[key] = post_processor_factory(
50+
key, cfg[key], parent_path=parent_path)
51+
52+
# Check dependencies
53+
if post_list is not None:
54+
ups_post = tuple(self.modules)
55+
for post in self.modules[key]._upstream:
56+
assert post in (post_list + ups_post), (
57+
f"Post-processor `{key}` is missing an essential "
58+
f"upstream post-processor: `{post}`.")
4959

5060
def __call__(self, data):
5161
"""Pass one batch of data through the post-processors.

spine/post/optical/flash_matching.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ class FlashMatchProcessor(PostBase):
2525
# Alternative allowed names of the post-processor
2626
aliases = ('run_flash_matching',)
2727

28+
# Whether this post-processor needs to know where the configuration lives
29+
need_parent_path = True
30+
2831
def __init__(self, flash_key, volume, ref_volume_id=None,
2932
method='likelihood', detector=None, geometry_file=None,
3033
run_mode='reco', truth_point_mode='points',

spine/post/reco/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
from .calo import *
1212
from .pid import *
1313
from .kinematics import *
14-
from .label import *
1514
from .shower import *
15+
from .topology import *

0 commit comments

Comments
 (0)