Skip to content

Commit b230ab1

Browse files
Merge pull request #77 from CarloLucibello/cl/eweight
remove nodetype for eltype
2 parents 4464550 + c882ebe commit b230ab1

File tree

4 files changed

+17
-20
lines changed

4 files changed

+17
-20
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- uses: actions/checkout@v2
1515
- uses: julia-actions/setup-julia@latest
1616
with:
17-
version: '1.6'
17+
version: '1.7'
1818
- name: Install dependencies
1919
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
2020
- name: Build and deploy

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphNeuralNetworks"
22
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.3.4"
4+
version = "0.3.5"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/GNNGraphs/query.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,20 @@ get_edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes
2020
Graphs.edges(g::GNNGraph) = zip(edge_index(g)...)
2121

2222
Graphs.edgetype(g::GNNGraph) = Tuple{Int, Int}
23-
nodetype(g::GNNGraph) = Base.eltype(g)
2423

25-
"""
26-
nodetype(g::GNNGraph)
27-
28-
Type of nodes in `g`,
29-
an integer type like `Int`, `Int32`, `Uint16`, ....
30-
"""
31-
function nodetype(g::GNNGraph{<:COO_T}, T=nothing)
24+
# """
25+
# eltype(g::GNNGraph)
26+
#
27+
# Type of nodes in `g`,
28+
# an integer type like `Int`, `Int32`, `Uint16`, ....
29+
# """
30+
function Base.eltype(g::GNNGraph{<:COO_T})
3231
s, t = edge_index(g)
33-
return eltype(s)
32+
w = get_edge_weight(g)
33+
return w !== nothing ? eltype(w) : eltype(s)
3434
end
3535

36-
function nodetype(g::GNNGraph{<:ADJMAT_T}, T=nothing)
37-
T !== nothing && return T
38-
return eltype(g.graph)
39-
end
36+
Base.eltype(g::GNNGraph{<:ADJMAT_T}) = eltype(g.graph)
4037

4138
function Graphs.has_edge(g::GNNGraph{<:COO_T}, i::Integer, j::Integer)
4239
s, t = edge_index(g)
@@ -94,7 +91,7 @@ function adjacency_list(g::GNNGraph; dir=:out)
9491
return [fneighs(g, i) for i in 1:g.num_nodes]
9592
end
9693

97-
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=nodetype(g); dir=:out)
94+
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=eltype(g); dir=:out)
9895
if g.graph[1] isa CuVector
9996
# TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
10097
A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes)
@@ -105,7 +102,7 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=nodetype(g);
105102
return dir == :out ? A : A'
106103
end
107104

108-
function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=nodetype(g); dir=:out)
105+
function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=eltype(g); dir=:out)
109106
@assert dir [:in, :out]
110107
A = g.graph
111108
A = T != eltype(A) ? T.(A) : A
@@ -165,7 +162,7 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=nothing; dir=:out, edge_weight
165162
@assert !(edge_weight isa AbstractArray) "passing the edge weights is not support by adjacency matrix representations"
166163
@assert dir (:in, :out, :both)
167164
if T === nothing
168-
Nt = nodetype(g)
165+
Nt = eltype(g)
169166
if edge_weight === false && !(Nt <: Integer)
170167
T = Nt == Float32 ? Int32 :
171168
Nt == Float16 ? Int16 : Int
@@ -183,7 +180,7 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=nothing; dir=:out, edge_weight
183180
vec(sum(A, dims=1)) .+ vec(sum(A, dims=2))
184181
end
185182

186-
function Graphs.laplacian_matrix(g::GNNGraph, T::DataType=nodetype(g); dir::Symbol=:out)
183+
function Graphs.laplacian_matrix(g::GNNGraph, T::DataType=eltype(g); dir::Symbol=:out)
187184
A = adjacency_matrix(g, T; dir=dir)
188185
D = Diagonal(vec(sum(A; dims=2)))
189186
return D - A

test/GNNGraphs/query.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
A = adjacency_matrix(g)
9090
D = Diagonal(vec(sum(A, dims=2)))
9191
L = laplacian_matrix(g)
92-
@test eltype(L) == GNNGraphs.nodetype(g)
92+
@test eltype(L) == eltype(g)
9393
@test L D - A
9494
end
9595
end

0 commit comments

Comments
 (0)