1+ """Minimal example for use of maskpred pretraining."""
2+
13from graphnet .models .gnn .pretraining_maskpred import mask_pred_frame
24from graphnet .models import Model
35from torch_geometric .data import Data
1517
1618
1719class simple_model (Model ):
20+ """Just for a dummy model."""
21+
1822 def __init__ (
1923 self ,
20- ):
24+ ) -> None :
25+ """Construct."""
2126 super ().__init__ ()
2227 self .net = torch .nn .Sequential (
2328 torch .nn .Linear (4 , 10 ), torch .nn .SELU (), torch .nn .Linear (10 , 5 )
2429 )
2530
2631 def forward (self , data : Data ):
32+ """Forward pass."""
2733 x = self .net (data .x )
2834 x_rep = scatter (src = x , index = data .batch , dim = 0 , reduce = "max" )
2935 return x , x_rep
3036
3137
3238class simple_target_gen (Model ):
39+ """Just for a dummy charge target."""
40+
3341 def __init__ (
3442 self ,
35- ):
43+ ) -> None :
44+ """Construct."""
3645 super ().__init__ ()
3746
3847 def forward (self , data : Data ):
48+ """Forward pass."""
3949 target = torch .sum (
4050 scatter (src = data .x , index = data .batch , dim = 0 , reduce = "max" ), dim = 1
4151 )
4252 return target .view (- 1 , 1 )
4353
4454
45- def test ():
55+ def test () -> None :
56+ """Function that just evaluates the model to test it and has a save example commented in the end."""
4657 graph_definition = KNNGraph (
4758 detector = Prometheus (),
4859 node_definition = NodesAsPulses (),
@@ -68,21 +79,6 @@ def test():
6879 data = batch
6980 break
7081
71- # Standard Parameters
72- # encoder: Model,
73- # encoder_out_dim: int = None,
74- # masked_ratio: float = 0.25,
75- # masked_feat: List[int] = [0,1,2,3,4],
76- # learned_masking_value: bool = True,
77- # hlc_pos: int = None,
78- # mask_pred_net: Model = None,
79- # default_hidden_dim: int = 1000,
80- # default_nb_linear: int = 5,
81- # final_loss: str = 'mse',
82- # add_charge_pred: bool = False,
83- # need_charge_rep: bool = False,
84- # custom_charge_target: Tensor = None,
85-
8682 dummy_model = simple_model ()
8783 dummy_target = simple_target_gen ()
8884
0 commit comments