-
Notifications
You must be signed in to change notification settings - Fork 59
Open
Description
e.g. (The dropout defaults to a 64-bit float which causes more problems but can be easily fixed)
using GNNGraphs, GraphNeuralNetworks, NNlib, Flux
graph = GNNHeteroGraph(
Dict(
(:A, :a, :B) => ([1, 2], [3, 4]),
(:B, :a, :A) => ([1], [2]),
(:C, :a, :A) => (Int[], Int[]),
(:A, :a, :C) => (Int[], Int[]),
(:D, :a, :A) => (Int[], Int[]),
(:E, :a, :A) => (Int[], Int[]),
(:E, :a, :D) => (Int[], Int[]),
(:D, :a, :E) => (Int[], Int[]),
);
num_nodes = Dict(:A => 3, :B => 5, :C => 7, :D => 0, :E => 0)
)
layer = HeteroGraphConv(
[
(src, edge, dst) => GATConv(4 => 4, NNlib.elu; dropout = Float32(0.25)) for
(src, edge, dst) in keys(graph.edata)
];
)
layer2 = HeteroGraphConv(
[
(src, edge, dst) => GATConv(4 => 4, NNlib.elu; dropout = Float32(0.25)) for
(src, edge, dst) in keys(graph.edata)
];
)
x = (
A = rand(Float32, 4, 3),
B = rand(Float32, 4, 5),
C = rand(Float32, 4, 7),
D = rand(Float32, 4, 0),
E = rand(Float32, 4, 0),
)
x1 = layer(graph, x)
x2 = layer2(graph, x1)
@info "$x2"
g = Flux.gradient(x) do x
y = layer(graph, x)
sum(y[:A])
endThe error is
ERROR: LoadError: DimensionMismatch: arrays could not be broadcast to a common size: a has axes Base.OneTo(0) and b has axes Base.OneTo(4)
Stacktrace:
[1] _bcs1
@ ./broadcast.jl:535 [inlined]
[2] _bcs
@ ./broadcast.jl:529 [inlined]
[3] broadcast_shape
@ ./broadcast.jl:523 [inlined]
[4] combine_axes
@ ./broadcast.jl:504 [inlined]
[5] _axes
@ ./broadcast.jl:240 [inlined]
[6] axes
@ ./broadcast.jl:238 [inlined]
[7] combine_axes
@ ./broadcast.jl:505 [inlined]
[8] instantiate
@ ./broadcast.jl:313 [inlined]
[9] materialize
@ ./broadcast.jl:894 [inlined]
[10] gat_conv(l::GATConv{Flux.Dense{typeof(identity), Matrix{Float32}, Bool}, Nothing, Float32, Float32, Matrix{Float32}, typeof(elu), Vector{Float32}}, g::GNNHeteroGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}, x::Tuple{Matrix{Float32}, Matrix{Float32}}, e::Nothing)
@ GNNlib ~/.julia/packages/GNNlib/wxiDz/src/layers/conv.jl:147
[11] GATConv
@ ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/conv.jl:346 [inlined]
[12] (::GATConv{Flux.Dense{typeof(identity), Matrix{Float32}, Bool}, Nothing, Float32, Float32, Matrix{Float32}, typeof(elu), Vector{Float32}})(g::GNNHeteroGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}, x::Tuple{Matrix{Float32}, Matrix{Float32}})
@ GraphNeuralNetworks ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/conv.jl:346
[13] (::GraphNeuralNetworks.var"#forw#forw##0"{GNNHeteroGraph{Tuple{T, T, Union{Nothing, AbstractVector}} where T<:(AbstractVector{<:Integer})}, @NamedTuple{A::Matrix{Float32}, B::Matrix{Float32}, C::Matrix{Float32}, D::Matrix{Float32}, E::Matrix{Float32}}})(l::GATConv{Flux.Dense{typeof(identity), Matrix{Float32}, Bool}, Nothing, Float32, Float32, Matrix{Float32}, typeof(elu), Vector{Float32}}, et::Tuple{Symbol, Symbol, Symbol})
@ GraphNeuralNetworks ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/heteroconv.jl:63
[14] #60
@ ./none:-1 [inlined]
[15] iterate
@ ./generator.jl:48 [inlined]
[16] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{GATConv{Flux.Dense{typeof(identity), Matrix{Float32}, Bool}, Nothing, Float32, Float32, Matrix{Float32}, typeof(elu), Vector{Float32}}}, Vector{Tuple{Symbol, Symbol, Symbol}}}}, GraphNeuralNetworks.var"#60#61"{GraphNeuralNetworks.var"#forw#forw##0"{GNNHeteroGraph{Tuple{T, T, Union{Nothing, AbstractVector}} where T<:(AbstractVector{<:Integer})}, @NamedTuple{A::Matrix{Float32}, B::Matrix{Float32}, C::Matrix{Float32}, D::Matrix{Float32}, E::Matrix{Float32}}}}})
@ Base ./array.jl:790
[17] (::HeteroGraphConv)(g::GNNHeteroGraph{Tuple{T, T, Union{Nothing, AbstractVector}} where T<:(AbstractVector{<:Integer})}, x::@NamedTuple{A::Matrix{Float32}, B::Matrix{Float32}, C::Matrix{Float32}, D::Matrix{Float32}, E::Matrix{Float32}})
@ GraphNeuralNetworks ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/heteroconv.jl:65
Metadata
Metadata
Assignees
Labels
No labels