Skip to content

Commit ef22e9a

Browse files
fix dense test (#479)
1 parent 83b6b7e commit ef22e9a

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

GNNGraphs/src/gnngraph.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,10 @@ function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata
209209
else
210210
graph = g.graph
211211
end
212-
GNNGraph(graph,
213-
g.num_nodes, g.num_edges, g.num_graphs,
214-
g.graph_indicator,
215-
ndata, edata, gdata)
212+
return GNNGraph(graph,
213+
g.num_nodes, g.num_edges, g.num_graphs,
214+
g.graph_indicator,
215+
ndata, edata, gdata)
216216
end
217217

218218
"""

GNNlib/src/layers/conv.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ end
7474
# when we also have edge_weight we need to convert the graph to COO
7575
function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, norm_fn::F, conv_weight::CW) where
7676
{EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F}
77-
g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO
77+
g = GNNGraph(g, graph_type = :coo)
7878
return gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight)
7979
end
8080

@@ -449,9 +449,10 @@ function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T},
449449
return (x .+ l.bias)
450450
end
451451

452+
# when we also have edge_weight we need to convert the graph to COO
452453
function sgc_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
453454
edge_weight::AbstractVector)
454-
g = GNNGraph(edge_index(g)...; g.num_nodes)
455+
g = GNNGraph(g; graph_type=:coo)
455456
return sgc_conv(l, g, x, edge_weight)
456457
end
457458

@@ -542,9 +543,10 @@ function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T},
542543
return (x .+ l.bias)
543544
end
544545

546+
# when we also have edge_weight we need to convert the graph to COO
545547
function sg_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
546548
edge_weight::AbstractVector)
547-
g = GNNGraph(edge_index(g)...; g.num_nodes)
549+
g = GNNGraph(g; graph_type=:coo)
548550
return sg_conv(l, g, x, edge_weight)
549551
end
550552

@@ -684,9 +686,10 @@ function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T},
684686
return (sum_total .+ l.bias)
685687
end
686688

689+
# when we also have edge_weight we need to convert the graph to COO
687690
function tag_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
688691
edge_weight::AbstractVector)
689-
g = GNNGraph(edge_index(g)...; g.num_nodes)
692+
g = GNNGraph(g; graph_type = :coo)
690693
return l(g, x, edge_weight)
691694
end
692695

0 commit comments

Comments
 (0)