99
1010import torch
1111
12- from graphnet .models .detector .prometheus import Prometheus
13- from graphnet .models .graphs import KNNGraph
14- from graphnet .models .graphs .nodes import NodesAsPulses
12+ from graphnet .models .detector .prometheus import Prometheus
13+ from graphnet .models .graphs import KNNGraph
14+ from graphnet .models .graphs .nodes import NodesAsPulses
1515
1616
1717class simple_model (Model ):
18- def __init__ (self ,
19- ):
18+ def __init__ (
19+ self ,
20+ ):
2021 super ().__init__ ()
21- self .net = torch .nn .Sequential (torch . nn . Linear ( 4 , 10 ),
22- torch .nn .SELU (),
23- torch . nn . Linear ( 10 , 5 ) )
22+ self .net = torch .nn .Sequential (
23+ torch . nn . Linear ( 4 , 10 ), torch .nn .SELU (), torch . nn . Linear ( 10 , 5 )
24+ )
2425
25- def forward (self , data :Data ):
26+ def forward (self , data : Data ):
2627 x = self .net (data .x )
27- x_rep = scatter (src = x , index = data .batch , dim = 0 , reduce = ' max' )
28+ x_rep = scatter (src = x , index = data .batch , dim = 0 , reduce = " max" )
2829 return x , x_rep
29-
30+
31+
3032class simple_target_gen (Model ):
31- def __init__ (self ,
32- ):
33+ def __init__ (
34+ self ,
35+ ):
3336 super ().__init__ ()
3437
35- def forward (self , data :Data ):
36- target = torch .sum (scatter (src = data .x , index = data .batch , dim = 0 , reduce = 'max' ), dim = 1 )
37- return target .view (- 1 ,1 )
38+ def forward (self , data : Data ):
39+ target = torch .sum (
40+ scatter (src = data .x , index = data .batch , dim = 0 , reduce = "max" ), dim = 1
41+ )
42+ return target .view (- 1 , 1 )
43+
3844
39-
4045def test ():
4146 graph_definition = KNNGraph (
4247 detector = Prometheus (),
@@ -50,7 +55,7 @@ def test():
5055 truth_table = "mc_truth" ,
5156 features = ["sensor_pos_x" , "sensor_pos_y" , "sensor_pos_z" , "t" ],
5257 truth = ["injection_energy" , "injection_zenith" ],
53- data_representation = graph_definition ,
58+ data_representation = graph_definition ,
5459 )
5560
5661 dataloader = DataLoader (
@@ -63,15 +68,15 @@ def test():
6368 data = batch
6469 break
6570
66- #Standard Parameters
71+ # Standard Parameters
6772 # encoder: Model,
6873 # encoder_out_dim: int = None,
6974 # masked_ratio: float = 0.25,
7075 # masked_feat: List[int] = [0,1,2,3,4],
7176 # learned_masking_value: bool = True,
7277 # hlc_pos: int = None,
7378 # mask_pred_net: Model = None,
74- # default_hidden_dim: int = 1000,
79+ # default_hidden_dim: int = 1000,
7580 # default_nb_linear: int = 5,
7681 # final_loss: str = 'mse',
7782 # add_charge_pred: bool = False,
@@ -81,23 +86,23 @@ def test():
8186 dummy_model = simple_model ()
8287 dummy_target = simple_target_gen ()
8388
84- model = mask_pred_frame (encoder = dummy_model ,
85- encoder_out_dim = 5 ,
86- masked_feat = [ 0 , 1 ] ,
87- learned_masking_value = True ,
88- final_loss = 'cosine' ,
89- add_charge_pred = True ,
90- need_charge_rep = False ,
91- custom_charge_target = dummy_target ,
92- )
93-
94-
89+ model = mask_pred_frame (
90+ encoder = dummy_model ,
91+ encoder_out_dim = 5 ,
92+ masked_feat = [ 0 , 1 ] ,
93+ learned_masking_value = True ,
94+ final_loss = "cosine" ,
95+ add_charge_pred = True ,
96+ need_charge_rep = False ,
97+ custom_charge_target = dummy_target ,
98+ )
99+
95100 out = model (data )
96101 print (out )
97102
98- #for saving
99- #model.save_pretrained_model('some/path')
103+ # for saving
104+ # model.save_pretrained_model('some/path')
100105
101106
102107if __name__ == "__main__" :
103- test ()
108+ test ()
0 commit comments