Skip to content

Commit b322e3f

Browse files
Merge pull request #81 from francois-drielsma/develop
Refactored particle point assignment in the full chain
2 parents 812c36c + 04017f5 commit b322e3f

File tree

6 files changed

+270
-116
lines changed

6 files changed

+270
-116
lines changed

spine/build/manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, fragments, particles, interactions,
7070
assert units in self._units, (
7171
f"Units not recognized: {units}. Must be one {self._units}")
7272
self.units = units
73-
73+
7474
# If custom sources are provided, update the tuple
7575
if sources is not None:
7676
sources_dict = dict(self._sources)
@@ -92,7 +92,7 @@ def __init__(self, fragments, particles, interactions,
9292
if particles:
9393
self.builders['particle'] = ParticleBuilder(mode, units)
9494
if interactions:
95-
assert particles is not None, (
95+
assert particles, (
9696
"Interactions are built from particles. If `interactions` "
9797
"is True, so must `particles` be.")
9898
self.builders['interaction'] = InteractionBuilder(mode, units)
@@ -178,7 +178,7 @@ def build_sources(self, data, entry=None):
178178

179179
if 'sources' in sources:
180180
update['sources'] = sources['sources'].astype(int)
181-
181+
182182
if self.mode != 'reco':
183183
update['label_tensor'] = sources['label_tensor']
184184
update['points_label'] = sources['label_tensor'][:, COORD_COLS]

spine/io/read/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,10 @@ def process_entry_list(self, n_entry=None, n_skip=None, entry_list=None,
184184
entry_list is not None, skip_entry_list is not None,
185185
run_event_list is not None,
186186
skip_run_event_list is not None]):
187-
assert (bool(n_entry or n_skip) ^
188-
bool(entry_list or skip_entry_list) ^
189-
bool(run_event_list or skip_run_event_list)), (
187+
assert ((n_entry is not None or n_skip is not None) ^
188+
(entry_list is not None or skip_entry_list is not None) ^
189+
(run_event_list is not None
190+
or skip_run_event_list is not None)), (
190191
"Cannot specify `n_entry` or `n_skip` at the same time "
191192
"as `entry_list` or `skip_entry_list` or at the same time "
192193
"as `run_event_list` or `skip_run_event_list`.")

spine/model/full_chain.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
# TODO: raname it something more generic like ParticleClusterImageClassifier?
1818

1919
from spine.data import TensorBatch, IndexBatch, RunInfo
20+
from spine.utils.logger import logger
2021
from spine.utils.globals import (
2122
COORD_COLS, VALUE_COL, CLUST_COL, SHAPE_COL, SHOWR_SHP, TRACK_SHP,
2223
MICHL_SHP, DELTA_SHP, GHOST_SHP)
2324
from spine.utils.calib import CalibrationManager
24-
from spine.utils.logger import logger
25-
from spine.utils.ppn import get_particle_points
25+
from spine.utils.ppn import ParticlePointPredictor
2626
from spine.utils.ghost import compute_rescaled_charge_batch
2727
from spine.utils.cluster.label import ClusterLabelAdapter
2828
from spine.utils.gnn.cluster import (
@@ -115,10 +115,10 @@ class FullChain(torch.nn.Module):
115115
)
116116

117117
def __init__(self, chain, uresnet_deghost=None, uresnet=None,
118-
uresnet_ppn=None, adapt_labels=None, graph_spice=None,
119-
dbscan=None, grappa_shower=None, grappa_track=None,
120-
grappa_particle=None, grappa_inter=None, calibration=None,
121-
uresnet_deghost_loss=None, uresnet_loss=None,
118+
uresnet_ppn=None, adapt_labels=None, predict_points=None,
119+
graph_spice=None, dbscan=None, grappa_shower=None,
120+
grappa_track=None, grappa_particle=None, grappa_inter=None,
121+
calibration=None, uresnet_deghost_loss=None, uresnet_loss=None,
122122
uresnet_ppn_loss=None, graph_spice_loss=None,
123123
grappa_shower_loss=None, grappa_track_loss=None,
124124
grappa_particle_loss=None, grappa_inter_loss=None):
@@ -134,6 +134,8 @@ def __init__(self, chain, uresnet_deghost=None, uresnet=None,
134134
Segmentation and point proposal model configuration
135135
adapt_labels : dict, optional
136136
Parameters for the cluster label adaptation (if non-standard)
137+
predict_points : dict, optional
138+
Parameters for the particle point predictor (if non-standard)
137139
dbscan : dict, optional
138140
Connected component clustering configuration
139141
graph_spice : dict, optional
@@ -175,6 +177,9 @@ def __init__(self, chain, uresnet_deghost=None, uresnet=None,
175177
# Initialize the relabeling process (adapt to the semantic predictions)
176178
self.label_adapter = ClusterLabelAdapter(**(adapt_labels or {}))
177179

180+
# Initialize the point predictor (for fragment/particle clusters)
181+
self.point_predictor = ParticlePointPredictor(**(predict_points or {}))
182+
178183
# Initialize the dense clustering model
179184
self.fragment_shapes = []
180185
if (self.fragmentation is not None and
@@ -1017,7 +1022,7 @@ def prepare_grappa_input(self, model, data, clusts, clust_shapes,
10171022
ref_clusts = clust_primaries
10181023

10191024
# Get and store the points
1020-
points = get_particle_points(
1025+
points = self.point_predictor(
10211026
data, ref_clusts, clust_shapes, self.result['ppn_points'])
10221027

10231028
grappa_input['points'] = points

spine/utils/globals.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@
8585
KAON_PID = 5
8686

8787
# Mapping between particle PDG code and particle ID labels
88-
PHOT_PID = 0
8988
PDG_TO_PID = defaultdict(lambda: -1)
9089
PDG_TO_PID.update({
9190
22: PHOT_PID,

spine/utils/numba_local.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -642,9 +642,9 @@ def principal_components(x: nb.float32[:,:]) -> nb.float32[:,:]:
642642

643643
@nb.njit(cache=True)
644644
def farthest_pair(x: nb.float32[:,:],
645-
algorithm: bool = 'brute') -> (nb.int32, nb.int32, nb.float32):
646-
"""Algorithm which finds the two points which are
647-
farthest from each other in a set.
645+
algorithm: str = 'brute') -> (nb.int32, nb.int32, nb.float32):
646+
"""Algorithm which finds the two points which are farthest from each other
647+
in a set.
648648
649649
Two algorithms:
650650
- `brute`: compute pdist, use argmax
@@ -675,7 +675,9 @@ def farthest_pair(x: nb.float32[:,:],
675675
dist = dist_mat[idxs[0], idxs[1]]
676676

677677
elif algorithm == 'recursive':
678-
idxs, subidx, dist, tempdist = [0, 0], 0, 0., -1.
678+
centroid = mean(x, 0)
679+
start_idx = np.argmax(cdist(centroid.reshape(1, -1), x))
680+
idxs, subidx, dist, tempdist = [start_idx, start_idx], 0, 0., -1.
679681
while dist > tempdist:
680682
tempdist = dist
681683
dists = cdist(np.ascontiguousarray(x[idxs[subidx]]).reshape(1,-1), x).flatten()

0 commit comments

Comments
 (0)