41
41
@test length (Flux. gradient (x -> sum (sum (ginconv (tg, x))), tg. ndata. x)[1 ]) == S
42
42
end
43
43
44
+
45
+ @testset " ChebConv" begin
46
+ chebconv = ChebConv (in_channel => out_channel, 5 )
47
+ @test length (chebconv (tg, tg. ndata. x)) == S
48
+ @test size (chebconv (tg, tg. ndata. x)[1 ]) == (out_channel, N)
49
+ @test length (Flux. gradient (x -> sum (sum (chebconv (tg, x))), tg. ndata. x)[1 ]) == S
50
+ end
51
+
52
+ @testset " GATConv" begin
53
+ gatconv = GATConv (in_channel => out_channel)
54
+ @test length (gatconv (tg, tg. ndata. x)) == S
55
+ @test size (gatconv (tg, tg. ndata. x)[1 ]) == (out_channel, N)
56
+ @test length (Flux. gradient (x -> sum (sum (gatconv (tg, x))), tg. ndata. x)[1 ]) == S
57
+ end
58
+
59
+ @testset " GATv2Conv" begin
60
+ gatv2conv = GATv2Conv (in_channel => out_channel)
61
+ @test length (gatv2conv (tg, tg. ndata. x)) == S
62
+ @test size (gatv2conv (tg, tg. ndata. x)[1 ]) == (out_channel, N)
63
+ @test length (Flux. gradient (x -> sum (sum (gatv2conv (tg, x))), tg. ndata. x)[1 ]) == S
64
+ end
65
+
66
+ @testset " GatedGraphConv" begin
67
+ gatedgraphconv = GatedGraphConv (5 , 5 )
68
+ @test length (gatedgraphconv (tg, tg. ndata. x)) == S
69
+ @test size (gatedgraphconv (tg, tg. ndata. x)[1 ]) == (out_channel, N)
70
+ @test length (Flux. gradient (x -> sum (sum (gatedgraphconv (tg, x))), tg. ndata. x)[1 ]) == S
71
+ end
72
+
73
+ @testset " CGConv" begin
74
+ cgconv = CGConv (in_channel => out_channel)
75
+ @test length (cgconv (tg, tg. ndata. x)) == S
76
+ @test size (cgconv (tg, tg. ndata. x)[1 ]) == (out_channel, N)
77
+ @test length (Flux. gradient (x -> sum (sum (cgconv (tg, x))), tg. ndata. x)[1 ]) == S
78
+ end
79
+
80
+ @testset " SGConv" begin
81
+ sgconv = SGConv (in_channel => out_channel)
82
+ @test length (sgconv (tg, tg. ndata. x)) == S
83
+ @test size (sgconv (tg, tg. ndata. x)[1 ]) == (out_channel, N)
84
+ @test length (Flux. gradient (x -> sum (sum (sgconv (tg, x))), tg. ndata. x)[1 ]) == S
85
+ end
86
+
87
+ @testset " TransformerConv" begin
88
+ transformerconv = TransformerConv (in_channel => out_channel)
89
+ @test length (transformerconv (tg, tg. ndata. x)) == S
90
+ @test size (transformerconv (tg, tg. ndata. x)[1 ]) == (out_channel, N)
91
+ @test length (Flux. gradient (x -> sum (sum (transformerconv (tg, x))), tg. ndata. x)[1 ]) == S
92
+ end
93
+
44
94
@testset " GCNConv" begin
45
95
gcnconv = GCNConv (in_channel => out_channel)
46
96
@test length (gcnconv (tg, tg. ndata. x)) == S
67
117
@test length (graphconv (tg, tg. ndata. x)) == S
68
118
@test size (graphconv (tg, tg. ndata. x)[1 ]) == (out_channel, N)
69
119
@test length (Flux. gradient (x -> sum (sum (graphconv (tg, x))), tg. ndata. x)[1 ]) == S
70
- end
120
+ end
121
+
0 commit comments