|  | 
| 2 | 2 |     using .TestModuleGNNlib | 
| 3 | 3 |     #TODO test all graph types | 
| 4 | 4 |     g = TEST_GRAPHS[1] | 
| 5 |  | -    out_channel = 10 | 
|  | 5 | +    out_channel = size(g.x, 1) | 
| 6 | 6 |     num_V = g.num_nodes | 
| 7 | 7 |     num_E = g.num_edges | 
| 8 |  | - | 
|  | 8 | +    g = GNNGraph(g, edata = rand(Float32, size(g.x, 1), g.num_edges)) | 
|  | 9 | +     | 
| 9 | 10 |     @testset "propagate" begin | 
| 10 | 11 |         function message(xi, xj, e) | 
| 11 | 12 |             @test xi === nothing | 
|  | 
| 20 | 21 |         @testset "isolated nodes" begin | 
| 21 | 22 |             x1 = rand(1, 6) | 
| 22 | 23 |             g1 = GNNGraph(collect(1:5), collect(1:5), num_nodes = 6) | 
| 23 |  | -            y1 = propagate((xi, xj, e) -> xj, g, +, xj = x1) | 
|  | 24 | +            y1 = propagate((xi, xj, e) -> xj, g1, +, xj = x1) | 
| 24 | 25 |             @test size(y1) == (1, 6) | 
| 25 | 26 |         end | 
| 26 | 27 |     end | 
|  | 
| 123 | 124 |             @test_throws AssertionError aggregate_neighbors(g, +, m) | 
| 124 | 125 |         end | 
| 125 | 126 |     end | 
|  | 127 | +end | 
|  | 128 | + | 
|  | 129 | +@testitem "propagate" setup=[TestModuleGNNlib] begin | 
|  | 130 | +    using .TestModuleGNNlib | 
|  | 131 | + | 
|  | 132 | +    @testset "copy_xj +" begin | 
|  | 133 | +        for g in TEST_GRAPHS | 
|  | 134 | +            f(g, x) = propagate(copy_xj, g, +, xj = x) | 
|  | 135 | +            test_gradients(f, g, g.x; test_grad_f=false) | 
|  | 136 | +        end | 
|  | 137 | +    end | 
| 126 | 138 | 
 | 
|  | 139 | +    @testset "copy_xj mean" begin | 
|  | 140 | +        for g in TEST_GRAPHS | 
|  | 141 | +            f(g, x) = propagate(copy_xj, g, mean, xj = x) | 
|  | 142 | +            test_gradients(f, g, g.x; test_grad_f=false) | 
|  | 143 | +        end | 
|  | 144 | +    end | 
|  | 145 | + | 
|  | 146 | +    @testset "e_mul_xj +" begin | 
|  | 147 | +        for g in TEST_GRAPHS | 
|  | 148 | +            e = rand(Float32, size(g.x, 1), g.num_edges) | 
|  | 149 | +            f(g, x, e) = propagate(e_mul_xj, g, +; xj = x, e) | 
|  | 150 | +            test_gradients(f, g, g.x, e; test_grad_f=false) | 
|  | 151 | +        end | 
|  | 152 | +    end | 
|  | 153 | + | 
|  | 154 | +    @testset "w_mul_xj +" begin | 
|  | 155 | +        for g in TEST_GRAPHS | 
|  | 156 | +            w = rand(Float32, g.num_edges) | 
|  | 157 | +            function f(g, x, w) | 
|  | 158 | +                g = set_edge_weight(g, w) | 
|  | 159 | +                return propagate(w_mul_xj, g, +, xj = x) | 
|  | 160 | +            end | 
|  | 161 | +            test_gradients(f, g, g.x, w; test_grad_f=false) | 
|  | 162 | +        end | 
|  | 163 | +    end | 
| 127 | 164 | end | 
| 128 | 165 | 
 | 
| 129 |  | -@testitem "msgpass GPU" setup=[TestModuleGNNlib] begin | 
|  | 166 | +@testitem "propagate GPU" setup=[TestModuleGNNlib] tags=[:gpu] begin | 
| 130 | 167 |     using .TestModuleGNNlib | 
| 131 |  | -    n, m = 10, 20 | 
| 132 |  | -    g = rand_graph(n, m, graph_type = :coo) | 
| 133 |  | -    x = rand(Float32, 2, n)     | 
| 134 |  | -    f(g, x) = propagate(copy_xj, g, +, xj = x) | 
| 135 |  | -    test_gradients(f, g, x; test_gpu=true, test_grad_f=false, compare_finite_diff=false) | 
|  | 168 | + | 
|  | 169 | +    @testset "copy_xj +" begin | 
|  | 170 | +        for g in TEST_GRAPHS | 
|  | 171 | +            f(g, x) = propagate(copy_xj, g, +, xj = x) | 
|  | 172 | +            test_gradients(f, g, g.x; test_gpu=true, test_grad_f=false, compare_finite_diff=false) | 
|  | 173 | +        end | 
|  | 174 | +    end | 
|  | 175 | + | 
|  | 176 | +    @testset "copy_xj mean" begin | 
|  | 177 | +        for g in TEST_GRAPHS | 
|  | 178 | +            f(g, x) = propagate(copy_xj, g, mean, xj = x) | 
|  | 179 | +            test_gradients(f, g, g.x; test_gpu=true, test_grad_f=false, compare_finite_diff=false) | 
|  | 180 | +        end | 
|  | 181 | +    end | 
|  | 182 | + | 
|  | 183 | +    @testset "e_mul_xj +" begin | 
|  | 184 | +        for g in TEST_GRAPHS | 
|  | 185 | +            e = rand(Float32, size(g.x, 1), g.num_edges) | 
|  | 186 | +            f(g, x, e) = propagate(e_mul_xj, g, +; xj = x, e) | 
|  | 187 | +            test_gradients(f, g, g.x, e; test_gpu=true, test_grad_f=false, compare_finite_diff=false) | 
|  | 188 | +        end | 
|  | 189 | +    end | 
|  | 190 | + | 
|  | 191 | +    @testset "w_mul_xj +" begin | 
|  | 192 | +        for g in TEST_GRAPHS | 
|  | 193 | +            w = rand(Float32, g.num_edges) | 
|  | 194 | +            function f(g, x, w) | 
|  | 195 | +                g = set_edge_weight(g, w) | 
|  | 196 | +                return propagate(w_mul_xj, g, +, xj = x) | 
|  | 197 | +            end | 
|  | 198 | +            @test test_gradients( | 
|  | 199 | +                f, g, g.x, w; test_gpu=true, test_grad_f=false, compare_finite_diff=false | 
|  | 200 | +            ) broken=true | 
|  | 201 | +        end | 
|  | 202 | +    end | 
| 136 | 203 | end | 
0 commit comments