diff --git a/GNNlib/src/utils.jl b/GNNlib/src/utils.jl index 8c739f3d9..f281b297d 100644 --- a/GNNlib/src/utils.jl +++ b/GNNlib/src/utils.jl @@ -27,6 +27,17 @@ function reduce_nodes(aggr, indicator::AbstractVector, x) return NNlib.scatter(aggr, x, indicator) end +""" + reduce_nodes(aggr, node_type, g, x) + +Return the graph-wise aggregation of the node features `x` on type `node_type` +given a heterogeneous graph `g`. The aggregation operator `aggr` can be `+`, +`mean`, `max`, or `min`. +""" +function reduce_nodes(aggr, node_type, g::GNNHeteroGraph, x) + return NNlib.scatter(aggr, x[node_type], graph_indicator(g, node_type)) +end + """ reduce_edges(aggr, g, e) diff --git a/GNNlib/test/utils.jl b/GNNlib/test/utils.jl index bf06f86fd..784329db9 100644 --- a/GNNlib/test/utils.jl +++ b/GNNlib/test/utils.jl @@ -19,6 +19,17 @@ @test r2 == r end + @testset "reduce_nodes" begin + g = rand_bipartite_heterograph((5, 10), 20) + x = ( + A = [Float32(i) for j = 1:1, i = 1:g.num_nodes[:A]], + B = [Float32(0) for j = 1:2, _ = 1:g.num_nodes[:B]], + ) + expected = sum(i for i = 1:g.num_nodes[:A]) + result = reduce_nodes(+, :A, g, x) + @test result == [expected;;] + end + @testset "reduce_edges" begin r = reduce_edges(mean, g, e) @test size(r) == (De, g.num_graphs)