Skip to content

Commit 9aa64ff

Browse files
put gcn inputs check behind function barrier (#261)
1 parent c5bd656 commit 9aa64ff

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

src/layers/conv.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,24 @@ function GCNConv(ch::Pair{Int, Int}, σ = identity;
7474
GCNConv(W, b, σ, add_self_loops, use_edge_weight)
7575
end
7676

77-
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T},
78-
edge_weight::EW = nothing) where
79-
{T, EW <: Union{Nothing, AbstractVector}}
80-
@assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs"
77+
check_gcnconv_input(g::GNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) =
78+
throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs"))
8179

82-
if edge_weight !== nothing
83-
@assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"
80+
function check_gcnconv_input(g::GNNGraph, edge_weight::AbstractVector)
81+
if length(edge_weight) !== g.num_edges
82+
throw(ArgumentError("Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"))
8483
end
84+
end
85+
86+
check_gcnconv_input(g::GNNGraph, edge_weight::Nothing) = nothing
87+
88+
89+
function (l::GCNConv)(g::GNNGraph,
90+
x::AbstractMatrix{T},
91+
edge_weight::EW = nothing
92+
) where {T, EW <: Union{Nothing, AbstractVector}}
93+
94+
check_gcnconv_input(g, edge_weight)
8595

8696
if l.add_self_loops
8797
g = add_self_loops(g)

0 commit comments

Comments
 (0)