Skip to content

Commit 27f944f

Browse files
committed
fix last layer bn bugs
1 parent 6278f0d commit 27f944f

File tree

9 files changed

+51
-37
lines changed

9 files changed

+51
-37
lines changed

dhg/nn/convs/graphs/gat_conv.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ def forward(self, X: torch.Tensor, g: Graph) -> torch.Tensor:
4545
g (``dhg.Graph``): The graph structure that contains :math:`N_v` vertices.
4646
"""
4747
X = self.theta(X)
48-
if self.bn is not None:
49-
X = self.bn(X)
5048
x_for_src = self.atten_src(X)
5149
x_for_dst = self.atten_dst(X)
5250
e_atten_score = x_for_src[g.e_src] + x_for_dst[g.e_dst]
@@ -56,6 +54,10 @@ def forward(self, X: torch.Tensor, g: Graph) -> torch.Tensor:
5654
e_atten_score = torch.clamp(e_atten_score, min=0.001, max=5)
5755
# ================================================================================
5856
X = g.v2v(X, aggr="softmax_then_sum", e_weight=e_atten_score)
57+
5958
if not self.is_last:
6059
X = self.act(X)
60+
if self.bn is not None:
61+
X = self.bn(X)
62+
X = self.drop(X)
6163
return X

dhg/nn/convs/graphs/gcn_conv.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ def forward(self, X: torch.Tensor, g: Graph) -> torch.Tensor:
4848
g (``dhg.Graph``): The graph structure that contains :math:`N` vertices.
4949
"""
5050
X = self.theta(X)
51-
if self.bn is not None:
52-
X = self.bn(X)
5351
X = g.smoothing_with_GCN(X)
5452
if not self.is_last:
55-
X = self.drop(self.act(X))
53+
X = self.act(X)
54+
if self.bn is not None:
55+
X = self.bn(X)
56+
X = self.drop(X)
5657
return X

dhg/nn/convs/graphs/graphsage_conv.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ def forward(self, X: torch.Tensor, g: Graph) -> torch.Tensor:
5656
else:
5757
raise NotImplementedError()
5858
X = self.theta(X)
59-
if self.bn is not None:
60-
X = self.bn(X)
6159
if not self.is_last:
62-
X = self.drop(self.act(X))
60+
X = self.act(X)
61+
if self.bn is not None:
62+
X = self.bn(X)
63+
X = self.drop(X)
6364
return X

dhg/nn/convs/hypergraphs/dhcf_conv.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
4747
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices.
4848
"""
4949
X_ = self.theta(X)
50-
if self.bn is not None:
51-
X_ = self.bn(X_)
52-
X_ = hg.smoothing_with_HGNN(X_) + X
50+
X = hg.smoothing_with_HGNN(X_) + X
5351
if not self.is_last:
54-
X_ = self.drop(self.act(X_))
55-
return X_
52+
X = self.act(X)
53+
if self.bn is not None:
54+
X = self.bn(X)
55+
X = self.drop(X)
56+
return X

dhg/nn/convs/hypergraphs/hgnn_conv.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
5050
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices.
5151
"""
5252
X = self.theta(X)
53-
if self.bn is not None:
54-
X = self.bn(X)
5553
X = hg.smoothing_with_HGNN(X)
5654
if not self.is_last:
57-
X = self.drop(self.act(X))
55+
X = self.act(X)
56+
if self.bn is not None:
57+
X = self.bn(X)
58+
X = self.drop(X)
5859
return X

dhg/nn/convs/hypergraphs/hgnnp_conv.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
5959
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`|\mathcal{V}|` vertices.
6060
"""
6161
X = self.theta(X)
62-
if self.bn is not None:
63-
X = self.bn(X)
6462
X = hg.v2v(X, aggr="mean")
6563
if not self.is_last:
66-
X = self.drop(self.act(X))
64+
X = self.act(X)
65+
if self.bn is not None:
66+
X = self.bn(X)
67+
X = self.drop(X)
6768
return X

dhg/nn/convs/hypergraphs/hnhn_conv.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,13 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
4242
"""
4343
# v -> e
4444
X = self.theta_v2e(X)
45-
if self.bn is not None:
46-
X = self.bn(X)
4745
Y = self.act(hg.v2e(X, aggr="mean"))
4846
# e -> v
4947
Y = self.theta_e2v(Y)
5048
X = hg.e2v(Y, aggr="mean")
5149
if not self.is_last:
52-
X = self.drop(self.act(X))
50+
X = self.act(X)
51+
if self.bn is not None:
52+
X = self.bn(X)
53+
X = self.drop(X)
5354
return X

dhg/nn/convs/hypergraphs/hypergcn_conv.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ def forward(
4949
``cached_g`` (``dhg.Graph``): The pre-transformed graph structure from the hypergraph structure that contains :math:`N` vertices. If not provided, the graph structure will be transformed for each forward time. Defaults to ``None``.
5050
"""
5151
X = self.theta(X)
52-
if self.bn is not None:
53-
X = self.bn(X)
5452
if cached_g is None:
5553
g = Graph.from_hypergraph_hypergcn(
5654
hg, X, self.use_mediator, device=X.device
@@ -59,5 +57,8 @@ def forward(
5957
else:
6058
X = cached_g.smoothing_with_GCN(X)
6159
if not self.is_last:
62-
X = self.drop(self.act(X))
60+
X = self.act(X)
61+
if self.bn is not None:
62+
X = self.bn(X)
63+
X = self.drop(X)
6364
return X

dhg/nn/convs/hypergraphs/unignn_conv.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
5757
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`|\mathcal{V}|` vertices.
5858
"""
5959
X = self.theta(X)
60-
if self.bn is not None:
61-
X = self.bn(X)
6260
Y = hg.v2e(X, aggr="mean")
6361
# ===============================================
6462
# compute the special degree of hyperedges
@@ -71,8 +69,12 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
7169
# ===============================================
7270
X = hg.e2v(Y, aggr="sum")
7371
X = torch.sparse.mm(hg.D_v_neg_1_2, X)
72+
7473
if not self.is_last:
75-
X = self.drop(self.act(X))
74+
X = self.act(X)
75+
if self.bn is not None:
76+
X = self.bn(X)
77+
X = self.drop(X)
7678
return X
7779

7880

@@ -128,8 +130,6 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
128130
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`|\mathcal{V}|` vertices.
129131
"""
130132
X = self.theta(X)
131-
if self.bn is not None:
132-
X = self.bn(X)
133133
Y = hg.v2e(X, aggr="mean")
134134
# ===============================================
135135
alpha_e = self.atten_e(Y)
@@ -140,8 +140,12 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
140140
e_atten_score = torch.clamp(e_atten_score, min=0.001, max=5)
141141
# ================================================================================
142142
X = hg.e2v(Y, aggr="softmax_then_sum", e2v_weight=e_atten_score)
143+
143144
if not self.is_last:
144145
X = self.act(X)
146+
if self.bn is not None:
147+
X = self.bn(X)
148+
X = self.drop(X)
145149
return X
146150

147151

@@ -196,12 +200,13 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
196200
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`|\mathcal{V}|` vertices.
197201
"""
198202
X = self.theta(X)
199-
if self.bn is not None:
200-
X = self.bn(X)
201203
Y = hg.v2e(X, aggr="mean")
202204
X = hg.e2v(Y, aggr="sum") + X
203205
if not self.is_last:
204-
X = self.drop(self.act(X))
206+
X = self.act(X)
207+
if self.bn is not None:
208+
X = self.bn(X)
209+
X = self.drop(X)
205210
return X
206211

207212

@@ -265,11 +270,11 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
265270
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`|\mathcal{V}|` vertices.
266271
"""
267272
X = self.theta(X)
268-
if self.bn is not None:
269-
X = self.bn(X)
270273
Y = hg.v2e(X, aggr="mean")
271274
X = (1 + self.eps) * hg.e2v(Y, aggr="sum") + X
272275
if not self.is_last:
273-
X = self.drop(self.act(X))
276+
X = self.act(X)
277+
if self.bn is not None:
278+
X = self.bn(X)
279+
X = self.drop(X)
274280
return X
275-

0 commit comments

Comments
 (0)