Skip to content

Commit 72fc97e

Browse files
remove nodetype for eltype
1 parent 4464550 commit 72fc97e

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

src/GNNGraphs/query.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,20 @@ get_edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes
2020
Graphs.edges(g::GNNGraph) = zip(edge_index(g)...)
2121

2222
Graphs.edgetype(g::GNNGraph) = Tuple{Int, Int}
23-
nodetype(g::GNNGraph) = Base.eltype(g)
2423

25-
"""
26-
nodetype(g::GNNGraph)
27-
28-
Type of nodes in `g`,
29-
an integer type like `Int`, `Int32`, `Uint16`, ....
30-
"""
31-
function nodetype(g::GNNGraph{<:COO_T}, T=nothing)
24+
# """
25+
# eltype(g::GNNGraph)
26+
#
27+
# Type of nodes in `g`,
28+
# an integer type like `Int`, `Int32`, `Uint16`, ....
29+
# """
30+
function Base.eltype(g::GNNGraph{<:COO_T})
3231
s, t = edge_index(g)
33-
return eltype(s)
32+
w = get_edge_weight
33+
return w !== nothing ? eltype(w) : eltype(s)
3434
end
3535

36-
function nodetype(g::GNNGraph{<:ADJMAT_T}, T=nothing)
37-
T !== nothing && return T
38-
return eltype(g.graph)
39-
end
36+
Base.eltype(g::GNNGraph{<:ADJMAT_T}) = eltype(g.graph)
4037

4138
function Graphs.has_edge(g::GNNGraph{<:COO_T}, i::Integer, j::Integer)
4239
s, t = edge_index(g)
@@ -94,7 +91,7 @@ function adjacency_list(g::GNNGraph; dir=:out)
9491
return [fneighs(g, i) for i in 1:g.num_nodes]
9592
end
9693

97-
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=nodetype(g); dir=:out)
94+
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=eltype(g); dir=:out)
9895
if g.graph[1] isa CuVector
9996
# TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
10097
A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes)
@@ -105,7 +102,7 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=nodetype(g);
105102
return dir == :out ? A : A'
106103
end
107104

108-
function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=nodetype(g); dir=:out)
105+
function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=eltype(g); dir=:out)
109106
@assert dir [:in, :out]
110107
A = g.graph
111108
A = T != eltype(A) ? T.(A) : A
@@ -165,7 +162,7 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=nothing; dir=:out, edge_weight
165162
@assert !(edge_weight isa AbstractArray) "passing the edge weights is not support by adjacency matrix representations"
166163
@assert dir (:in, :out, :both)
167164
if T === nothing
168-
Nt = nodetype(g)
165+
Nt = eltype(g)
169166
if edge_weight === false && !(Nt <: Integer)
170167
T = Nt == Float32 ? Int32 :
171168
Nt == Float16 ? Int16 : Int
@@ -183,7 +180,7 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=nothing; dir=:out, edge_weight
183180
vec(sum(A, dims=1)) .+ vec(sum(A, dims=2))
184181
end
185182

186-
function Graphs.laplacian_matrix(g::GNNGraph, T::DataType=nodetype(g); dir::Symbol=:out)
183+
function Graphs.laplacian_matrix(g::GNNGraph, T::DataType=eltype(g); dir::Symbol=:out)
187184
A = adjacency_matrix(g, T; dir=dir)
188185
D = Diagonal(vec(sum(A; dims=2)))
189186
return D - A

test/GNNGraphs/query.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
A = adjacency_matrix(g)
9090
D = Diagonal(vec(sum(A, dims=2)))
9191
L = laplacian_matrix(g)
92-
@test eltype(L) == GNNGraphs.nodetype(g)
92+
@test eltype(L) == eltype(g)
9393
@test L D - A
9494
end
9595
end

0 commit comments

Comments
 (0)