Skip to content

Commit 6de0871

Browse files
committed
reformatting again
1 parent a4d0668 commit 6de0871

File tree

3 files changed

+46
-37
lines changed

3 files changed

+46
-37
lines changed

examples/pretraining_maskpred_example.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,39 @@
99

1010
import torch
1111

12-
from graphnet.models.detector.prometheus import Prometheus
13-
from graphnet.models.graphs import KNNGraph
14-
from graphnet.models.graphs.nodes import NodesAsPulses
12+
from graphnet.models.detector.prometheus import Prometheus
13+
from graphnet.models.graphs import KNNGraph
14+
from graphnet.models.graphs.nodes import NodesAsPulses
1515

1616

1717
class simple_model(Model):
18-
def __init__(self,
19-
):
18+
def __init__(
19+
self,
20+
):
2021
super().__init__()
21-
self.net = torch.nn.Sequential(torch.nn.Linear(4,10),
22-
torch.nn.SELU(),
23-
torch.nn.Linear(10,5))
22+
self.net = torch.nn.Sequential(
23+
torch.nn.Linear(4, 10), torch.nn.SELU(), torch.nn.Linear(10, 5)
24+
)
2425

25-
def forward(self, data:Data):
26+
def forward(self, data: Data):
2627
x = self.net(data.x)
27-
x_rep = scatter(src=x, index=data.batch, dim = 0, reduce='max')
28+
x_rep = scatter(src=x, index=data.batch, dim=0, reduce="max")
2829
return x, x_rep
29-
30+
31+
3032
class simple_target_gen(Model):
31-
def __init__(self,
32-
):
33+
def __init__(
34+
self,
35+
):
3336
super().__init__()
3437

35-
def forward(self, data:Data):
36-
target = torch.sum(scatter(src=data.x, index=data.batch, dim = 0, reduce='max'), dim=1)
37-
return target.view(-1,1)
38+
def forward(self, data: Data):
39+
target = torch.sum(
40+
scatter(src=data.x, index=data.batch, dim=0, reduce="max"), dim=1
41+
)
42+
return target.view(-1, 1)
43+
3844

39-
4045
def test():
4146
graph_definition = KNNGraph(
4247
detector=Prometheus(),
@@ -50,7 +55,7 @@ def test():
5055
truth_table="mc_truth",
5156
features=["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t"],
5257
truth=["injection_energy", "injection_zenith"],
53-
data_representation = graph_definition,
58+
data_representation=graph_definition,
5459
)
5560

5661
dataloader = DataLoader(
@@ -63,15 +68,15 @@ def test():
6368
data = batch
6469
break
6570

66-
#Standard Parameters
71+
# Standard Parameters
6772
# encoder: Model,
6873
# encoder_out_dim: int = None,
6974
# masked_ratio: float = 0.25,
7075
# masked_feat: List[int] = [0,1,2,3,4],
7176
# learned_masking_value: bool = True,
7277
# hlc_pos: int = None,
7378
# mask_pred_net: Model = None,
74-
# default_hidden_dim: int = 1000,
79+
# default_hidden_dim: int = 1000,
7580
# default_nb_linear: int = 5,
7681
# final_loss: str = 'mse',
7782
# add_charge_pred: bool = False,
@@ -81,23 +86,23 @@ def test():
8186
dummy_model = simple_model()
8287
dummy_target = simple_target_gen()
8388

84-
model = mask_pred_frame(encoder=dummy_model,
85-
encoder_out_dim=5,
86-
masked_feat = [0,1],
87-
learned_masking_value=True,
88-
final_loss = 'cosine',
89-
add_charge_pred = True,
90-
need_charge_rep = False,
91-
custom_charge_target = dummy_target,
92-
)
93-
94-
89+
model = mask_pred_frame(
90+
encoder=dummy_model,
91+
encoder_out_dim=5,
92+
masked_feat=[0, 1],
93+
learned_masking_value=True,
94+
final_loss="cosine",
95+
add_charge_pred=True,
96+
need_charge_rep=False,
97+
custom_charge_target=dummy_target,
98+
)
99+
95100
out = model(data)
96101
print(out)
97102

98-
#for saving
99-
#model.save_pretrained_model('some/path')
103+
# for saving
104+
# model.save_pretrained_model('some/path')
100105

101106

102107
if __name__ == "__main__":
103-
test()
108+
test()

src/graphnet/models/gnn/pretraining_maskpred.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,12 @@ def forward(self, data: Union[Data, List[Data]]) -> List[Tensor]:
256256
nodes = rep[~mask.bool()]
257257
btch = data.batch[~mask.bool()]
258258

259-
loss = scatter(src=self.loss_func(nodes,target,return_elements=True),
260-
index = btch, reduce="mean",
261-
dim=0).view(-1,1)
259+
loss = scatter(
260+
src=self.loss_func(nodes, target, return_elements=True),
261+
index=btch,
262+
reduce="mean",
263+
dim=0,
264+
).view(-1, 1)
262265

263266
if self.add_charge_pred:
264267
if self.custom_charge_target is None:

src/graphnet/training/loss_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ def __init__(self, vmfs_factor: float = 0.05) -> None:
544544
prediction_keys=[[0, 1, 2], [0, 1, 2, 3]],
545545
)
546546

547+
547548
class NegCosLoss(LossFunction):
548549
"""Negative Cosine error loss."""
549550

0 commit comments

Comments
 (0)