1717# TODO: raname it something more generic like ParticleClusterImageClassifier?
1818
1919from spine .data import TensorBatch , IndexBatch , RunInfo
20+ from spine .utils .logger import logger
2021from spine .utils .globals import (
2122 COORD_COLS , VALUE_COL , CLUST_COL , SHAPE_COL , SHOWR_SHP , TRACK_SHP ,
2223 MICHL_SHP , DELTA_SHP , GHOST_SHP )
2324from spine .utils .calib import CalibrationManager
24- from spine .utils .logger import logger
25- from spine .utils .ppn import get_particle_points
25+ from spine .utils .ppn import ParticlePointPredictor
2626from spine .utils .ghost import compute_rescaled_charge_batch
2727from spine .utils .cluster .label import ClusterLabelAdapter
2828from spine .utils .gnn .cluster import (
@@ -115,10 +115,10 @@ class FullChain(torch.nn.Module):
115115 )
116116
117117 def __init__ (self , chain , uresnet_deghost = None , uresnet = None ,
118- uresnet_ppn = None , adapt_labels = None , graph_spice = None ,
119- dbscan = None , grappa_shower = None , grappa_track = None ,
120- grappa_particle = None , grappa_inter = None , calibration = None ,
121- uresnet_deghost_loss = None , uresnet_loss = None ,
118+ uresnet_ppn = None , adapt_labels = None , predict_points = None ,
119+ graph_spice = None , dbscan = None , grappa_shower = None ,
120+ grappa_track = None , grappa_particle = None , grappa_inter = None ,
121+ calibration = None , uresnet_deghost_loss = None , uresnet_loss = None ,
122122 uresnet_ppn_loss = None , graph_spice_loss = None ,
123123 grappa_shower_loss = None , grappa_track_loss = None ,
124124 grappa_particle_loss = None , grappa_inter_loss = None ):
@@ -134,6 +134,8 @@ def __init__(self, chain, uresnet_deghost=None, uresnet=None,
134134 Segmentation and point proposal model configuration
135135 adapt_labels : dict, optional
136136 Parameters for the cluster label adaptation (if non-standard)
137+ predict_points : dict, optional
138+ Parameters for the particle point predictor (if non-standard)
137139 dbscan : dict, optional
138140 Connected component clustering configuration
139141 graph_spice : dict, optional
@@ -175,6 +177,9 @@ def __init__(self, chain, uresnet_deghost=None, uresnet=None,
175177 # Initialize the relabeling process (adapt to the semantic predictions)
176178 self .label_adapter = ClusterLabelAdapter (** (adapt_labels or {}))
177179
180+ # Initialize the point predictor (for fragment/particle clusters)
181+ self .point_predictor = ParticlePointPredictor (** (predict_points or {}))
182+
178183 # Initialize the dense clustering model
179184 self .fragment_shapes = []
180185 if (self .fragmentation is not None and
@@ -1017,7 +1022,7 @@ def prepare_grappa_input(self, model, data, clusts, clust_shapes,
10171022 ref_clusts = clust_primaries
10181023
10191024 # Get and store the points
1020- points = get_particle_points (
1025+ points = self . point_predictor (
10211026 data , ref_clusts , clust_shapes , self .result ['ppn_points' ])
10221027
10231028 grappa_input ['points' ] = points
0 commit comments