Skip to content

Commit f2707ac

Browse files
committed
pydocstyle on example
1 parent 6de0871 commit f2707ac

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

examples/pretraining_maskpred_example.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Minimal example for use of maskpred pretraining."""
2+
13
from graphnet.models.gnn.pretraining_maskpred import mask_pred_frame
24
from graphnet.models import Model
35
from torch_geometric.data import Data
@@ -15,34 +17,43 @@
1517

1618

1719
class simple_model(Model):
20+
"""Just for a dummy model."""
21+
1822
def __init__(
1923
self,
20-
):
24+
) -> None:
25+
"""Construct."""
2126
super().__init__()
2227
self.net = torch.nn.Sequential(
2328
torch.nn.Linear(4, 10), torch.nn.SELU(), torch.nn.Linear(10, 5)
2429
)
2530

2631
def forward(self, data: Data):
32+
"""Forward pass."""
2733
x = self.net(data.x)
2834
x_rep = scatter(src=x, index=data.batch, dim=0, reduce="max")
2935
return x, x_rep
3036

3137

3238
class simple_target_gen(Model):
39+
"""Just for a dummy charge target."""
40+
3341
def __init__(
3442
self,
35-
):
43+
) -> None:
44+
"""Construct."""
3645
super().__init__()
3746

3847
def forward(self, data: Data):
48+
"""Forward pass."""
3949
target = torch.sum(
4050
scatter(src=data.x, index=data.batch, dim=0, reduce="max"), dim=1
4151
)
4252
return target.view(-1, 1)
4353

4454

45-
def test():
55+
def test() -> None:
56+
"""Function that just evaluates the model to test it and has a save example commented in the end."""
4657
graph_definition = KNNGraph(
4758
detector=Prometheus(),
4859
node_definition=NodesAsPulses(),
@@ -68,21 +79,6 @@ def test():
6879
data = batch
6980
break
7081

71-
# Standard Parameters
72-
# encoder: Model,
73-
# encoder_out_dim: int = None,
74-
# masked_ratio: float = 0.25,
75-
# masked_feat: List[int] = [0,1,2,3,4],
76-
# learned_masking_value: bool = True,
77-
# hlc_pos: int = None,
78-
# mask_pred_net: Model = None,
79-
# default_hidden_dim: int = 1000,
80-
# default_nb_linear: int = 5,
81-
# final_loss: str = 'mse',
82-
# add_charge_pred: bool = False,
83-
# need_charge_rep: bool = False,
84-
# custom_charge_target: Tensor = None,
85-
8682
dummy_model = simple_model()
8783
dummy_target = simple_target_gen()
8884

0 commit comments

Comments
 (0)