Skip to content

Commit c78f358

Browse files
authored
Fix full test (#7007)
1 parent 5f4a21c commit c78f358

File tree

12 files changed

+84
-62
lines changed

12 files changed

+84
-62
lines changed

test/nn/conv/test_gen_conv.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ def test_gen_conv(aggr):
4040
if is_full_test():
4141
t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
4242
jit = torch.jit.script(conv.jittable(t))
43-
assert torch.allclose(jit(x1, edge_index), out11)
44-
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out11)
45-
assert torch.allclose(jit(x1, edge_index, value), out12)
46-
assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out12)
43+
assert torch.allclose(jit(x1, edge_index), out11, atol=1e-6)
44+
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out11,
45+
atol=1e-6)
46+
assert torch.allclose(jit(x1, edge_index, value), out12, atol=1e-6)
47+
assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out12,
48+
atol=1e-6)
4749

4850
t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
4951
jit = torch.jit.script(conv.jittable(t))
@@ -71,10 +73,13 @@ def test_gen_conv(aggr):
7173
if is_full_test():
7274
t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
7375
jit = torch.jit.script(conv.jittable(t))
74-
assert torch.allclose(jit((x1, x2), edge_index), out21)
75-
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out21)
76-
assert torch.allclose(jit((x1, x2), edge_index, value), out22)
77-
assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out22)
76+
assert torch.allclose(jit((x1, x2), edge_index), out21, atol=1e-6)
77+
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out21,
78+
atol=1e-6)
79+
assert torch.allclose(jit((x1, x2), edge_index, value), out22,
80+
atol=1e-6)
81+
assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out22,
82+
atol=1e-6)
7883

7984
t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
8085
jit = torch.jit.script(conv.jittable(t))
@@ -120,13 +125,14 @@ def test_gen_conv(aggr):
120125
if is_full_test():
121126
t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
122127
jit = torch.jit.script(conv.jittable(t))
123-
assert torch.allclose(jit((x1, x2), edge_index, value), out1)
128+
assert torch.allclose(jit((x1, x2), edge_index, value), out1,
129+
atol=1e-6)
124130
assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)),
125-
out1)
131+
out1, atol=1e-6)
126132
assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)),
127-
out2)
133+
out2, atol=1e-6)
128134

129135
t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
130136
jit = torch.jit.script(conv.jittable(t))
131-
assert torch.allclose(jit((x1, x2), adj1.t()), out1)
132-
assert torch.allclose(jit((x1, None), adj1.t()), out2)
137+
assert torch.allclose(jit((x1, x2), adj1.t()), out1, atol=1e-6)
138+
assert torch.allclose(jit((x1, None), adj1.t()), out2, atol=1e-6)

test/nn/conv/test_graph_conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_graph_conv():
8080

8181
t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
8282
jit = torch.jit.script(conv.jittable(t))
83-
assert torch.allclose(jit((x1, x2), adj1.t()), out21)
84-
assert torch.allclose(jit((x1, x2), adj2.t()), out22)
85-
assert torch.allclose(jit((x1, None), adj1.t()), out23)
86-
assert torch.allclose(jit((x1, None), adj2.t()), out24)
83+
assert torch.allclose(jit((x1, x2), adj1.t()), out21, atol=1e-6)
84+
assert torch.allclose(jit((x1, x2), adj2.t()), out22, atol=1e-6)
85+
assert torch.allclose(jit((x1, None), adj1.t()), out23, atol=1e-6)
86+
assert torch.allclose(jit((x1, None), adj2.t()), out24, atol=1e-6)

test/nn/dense/test_dense_gat_conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ def test_dense_gat_conv(heads, concat):
4141

4242
dense_out = dense_conv(x, adj, mask)
4343

44-
assert dense_out[1, 2].abs().sum() == 0
45-
dense_out = dense_out.view(6, dense_out.size(-1))[:-1]
46-
assert torch.allclose(sparse_out, dense_out, atol=1e-4)
47-
4844
if is_full_test():
4945
jit = torch.jit.script(dense_conv)
5046
assert torch.allclose(jit(x, adj, mask), dense_out)
5147

48+
assert dense_out[1, 2].abs().sum() == 0
49+
dense_out = dense_out.view(6, dense_out.size(-1))[:-1]
50+
assert torch.allclose(sparse_out, dense_out, atol=1e-4)
51+
5252

5353
def test_dense_gat_conv_with_broadcasting():
5454
batch_size, num_nodes, channels = 8, 3, 16

test/nn/dense/test_dense_gcn_conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ def test_dense_gcn_conv():
3939
dense_out = dense_conv(x, adj, mask)
4040
assert dense_out.size() == (2, 3, channels)
4141

42-
assert dense_out[1, 2].abs().sum() == 0
43-
dense_out = dense_out.view(6, channels)[:-1]
44-
assert torch.allclose(sparse_out, dense_out, atol=1e-4)
45-
4642
if is_full_test():
4743
jit = torch.jit.script(dense_conv)
4844
assert torch.allclose(jit(x, adj, mask), dense_out)
4945

46+
assert dense_out[1, 2].abs().sum() == 0
47+
dense_out = dense_out.view(6, channels)[:-1]
48+
assert torch.allclose(sparse_out, dense_out, atol=1e-4)
49+
5050

5151
def test_dense_gcn_conv_with_broadcasting():
5252
batch_size, num_nodes, channels = 8, 3, 16

test/nn/dense/test_linear.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def test_lazy_hetero_linear():
125125

126126
out = lin(x, type_vec)
127127
assert out.size() == (3, 32)
128-
assert str(lin) == 'HeteroLinear(16, 32, num_types=3, bias=True)'
129128

130129

131130
def test_hetero_dict_linear():
@@ -160,7 +159,6 @@ def test_lazy_hetero_dict_linear():
160159
assert len(out_dict) == 2
161160
assert out_dict['v'].size() == (3, 32)
162161
assert out_dict['w'].size() == (2, 32)
163-
assert str(lin) == "HeteroDictLinear({'v': 16, 'w': 8}, 32, bias=True)"
164162

165163

166164
@withPackage('pyg_lib')

test/utils/test_sparse.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,14 @@ def test_to_torch_coo_tensor():
7676
])
7777
edge_attr = torch.randn(edge_index.size(1), 8)
7878

79-
adj = to_torch_coo_tensor(edge_index)
79+
adj = to_torch_coo_tensor(edge_index, is_coalesced=False)
80+
assert adj.is_coalesced()
81+
assert adj.size() == (4, 4)
82+
assert adj.layout == torch.sparse_coo
83+
assert torch.allclose(adj.indices(), edge_index)
84+
85+
adj = to_torch_coo_tensor(edge_index, is_coalesced=True)
86+
assert adj.is_coalesced()
8087
assert adj.size() == (4, 4)
8188
assert adj.layout == torch.sparse_coo
8289
assert torch.allclose(adj.indices(), edge_index)

torch_geometric/compile.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def fn(model: Callable) -> Callable:
7676
for key in prev_state.keys():
7777
setattr(torch_geometric.typing, key, False)
7878

79-
# Temporarily adjust the logging level of `torch.compile`:
79+
# Adjust the logging level of `torch.compile`:
80+
# TODO (matthias) Disable only temporarily
8081
prev_log_level = {
8182
'torch._dynamo': logging.getLogger('torch._dynamo').level,
8283
'torch._inductor': logging.getLogger('torch._inductor').level,
@@ -91,8 +92,4 @@ def fn(model: Callable) -> Callable:
9192
# Finally, run `torch.compile` to create an optimized version:
9293
out = torch.compile(model, *args, **kwargs)
9394

94-
# Restore the previous state:
95-
for key, value in prev_log_level.items():
96-
logging.getLogger(key).setLevel(value)
97-
9895
return out

torch_geometric/nn/conv/cluster_gcn_conv.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
spmm,
1313
to_edge_index,
1414
)
15-
from torch_geometric.utils.sparse import get_sparse_diag, set_sparse_value
15+
from torch_geometric.utils.sparse import set_sparse_value
1616

1717

1818
class ClusterGCNConv(MessagePassing):
@@ -71,6 +71,7 @@ def reset_parameters(self):
7171
self.lin_root.reset_parameters()
7272

7373
def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
74+
num_nodes = x.size(self.node_dim)
7475
edge_weight: OptTensor = None
7576

7677
if isinstance(edge_index, SparseTensor):
@@ -94,13 +95,7 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
9495
"supported in 'gcn_norm'")
9596

9697
if self.add_self_loops:
97-
diag = get_sparse_diag(edge_index.size(0), 1.0,
98-
edge_index.layout, edge_index.dtype,
99-
edge_index.device)
100-
edge_index = edge_index + diag
101-
102-
if edge_index.layout == torch.sparse_coo:
103-
edge_index = edge_index.coalesce()
98+
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
10499

105100
col_and_row, value = to_edge_index(edge_index)
106101
col, row = col_and_row[0], col_and_row[1]
@@ -112,7 +107,6 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
112107
edge_index = set_sparse_value(edge_index, edge_weight)
113108

114109
else:
115-
num_nodes = x.size(self.node_dim)
116110
if self.add_self_loops:
117111
edge_index, _ = remove_self_loops(edge_index)
118112
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

torch_geometric/nn/conv/gatv2_conv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
251251
size=None)
252252

253253
alpha = self._alpha
254+
assert alpha is not None
254255
self._alpha = None
255256

256257
if self.concat:

torch_geometric/nn/conv/gcn_conv.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414
SparseTensor,
1515
torch_sparse,
1616
)
17+
from torch_geometric.utils import add_remaining_self_loops
18+
from torch_geometric.utils import add_self_loops as add_self_loops_fn
1719
from torch_geometric.utils import (
18-
add_remaining_self_loops,
1920
is_torch_sparse_tensor,
2021
scatter,
2122
spmm,
2223
to_edge_index,
2324
)
2425
from torch_geometric.utils.num_nodes import maybe_num_nodes
25-
from torch_geometric.utils.sparse import get_sparse_diag, set_sparse_value
26+
from torch_geometric.utils.sparse import set_sparse_value
2627

2728

2829
@torch.jit._overload
@@ -70,14 +71,8 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
7071
"supported in 'gcn_norm'")
7172

7273
adj_t = edge_index
73-
7474
if add_self_loops:
75-
diag = get_sparse_diag(adj_t.size(0), fill_value, adj_t.layout,
76-
adj_t.dtype, adj_t.device)
77-
adj_t = adj_t + diag
78-
79-
if adj_t.layout == torch.sparse_coo:
80-
adj_t = adj_t.coalesce()
75+
adj_t, _ = add_self_loops_fn(adj_t, None, fill_value, num_nodes)
8176

8277
edge_index, value = to_edge_index(adj_t)
8378
col, row = edge_index[0], edge_index[1]

0 commit comments

Comments
 (0)