Skip to content

Commit b38b22a

Browse files
committed
fix dropout bug
1 parent 1bdda76 commit b38b22a

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

SuperMoon/models/CNN_HGNN.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def forward(self, x):
3636
for hyconv in self.hyconvs:
3737
x = hyconv(x, H)
3838
x = F.leaky_relu(x, inplace=True)
39-
x = F.dropout(x, self.dropout)
39+
x = F.dropout(x, self.dropout, training=self.training)
4040
# N x C -> 1 x C x N
4141
x = x.permute(1, 0).unsqueeze(0)
4242

SuperMoon/models/HGNN.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, in_ch, n_class, hiddens=[16], dropout=0.5) -> None:
1919

2020
def forward(self, x, H, hyedge_weight=None):
2121
for hyconv in self.hyconvs:
22-
x = F.dropout(x, self.dropout)
22+
x = F.dropout(x, self.dropout, training=self.training)
2323
x = hyconv(x, H)
2424
x = F.leaky_relu(x, inplace=True)
2525
x = self.last_hyconv(x, H)

0 commit comments

Comments
 (0)