Skip to content

Commit 5534fdd

Browse files
feat: add degree functionality for GNNHeteroGraph (#360)
* add degree functionality for heterognn * use _degree function * update docstring * change ordering to keep _degree functions at the bottom * Update src/GNNGraphs/query.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/query.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/query.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/query.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/query.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/query.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/query.jl Co-authored-by: Carlo Lucibello <[email protected]> * change test according to the new default behavior --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 7ddadab commit 5534fdd

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

src/GNNGraphs/query.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,36 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out,
314314
return _degree(A, T, dir, edge_weight, g.num_nodes)
315315
end
316316

317+
"""
318+
degree(g::GNNHeteroGraph, edge_type::EType; dir = :in)
319+
320+
Return a vector containing the degrees of the nodes in `g` GNNHeteroGraph
321+
given `edge_type`.
322+
323+
# Arguments
324+
325+
- `g`: A graph.
326+
- `edge_type`: A tuple of symbols `(source_t, edge_t, target_t)` representing the edge type.
327+
- `T`: Element type of the returned vector. If `nothing`, is
328+
chosen based on the graph type. Default `nothing`.
329+
- `dir`: For `dir=:out` the degree of a node is counted based on the outgoing edges.
330+
For `dir = :in`, the ingoing edges are used. If `dir = :both` we have the sum of the two.
331+
Default `dir = :out`.
332+
333+
"""
334+
function Graphs.degree(g::GNNHeteroGraph, edge::Tuple{Symbol, Symbol, Symbol},
335+
T::TT = nothing; dir = :out) where {
336+
TT <: Union{Nothing, Type{<:Number}}}
337+
338+
s, t = edge_index(g, edge)
339+
340+
T = isnothing(T) ? eltype(s) : T
341+
342+
n_type = dir == :in ? g.ntypes[2] : g.ntypes[1]
343+
344+
return _degree((s, t), T, dir, nothing, g.num_nodes[n_type])
345+
end
346+
317347
function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::Nothing, num_nodes::Int)
318348
_degree((s, t), T, dir, ones_like(s, T), num_nodes)
319349
end
@@ -373,7 +403,6 @@ function ChainRulesCore.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weigh
373403
end
374404
end
375405

376-
377406
"""
378407
has_isolated_nodes(g::GNNGraph; dir=:out)
379408

test/GNNGraphs/query.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,13 @@ end
159159

160160
@test grad ngrad
161161
end
162+
163+
@testset "heterognn, degree" begin
164+
g = GNNHeteroGraph((:A, :to, :B) => ([1,1,2,3], [7,13,5,7]))
165+
@test degree(g, (:A, :to, :B), dir = :out) == [2, 1, 1]
166+
@test degree(g, (:A, :to, :B), dir = :in) == [0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1]
167+
@test degree(g, (:A, :to, :B)) == [2, 1, 1]
168+
end
162169
end
163170
end
164171
end

0 commit comments

Comments
 (0)