diff --git a/GNNGraphs/src/query.jl b/GNNGraphs/src/query.jl index 76bc2dd28..879aee9df 100644 --- a/GNNGraphs/src/query.jl +++ b/GNNGraphs/src/query.jl @@ -272,13 +272,11 @@ function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType; function adjacency_matrix_pullback_weighted(Δ) dy = CRC.unthunk(Δ) s, t = edge_index(g) - @show dy s t #TODO using CRC.@thunk gives an error #TODO use gather when https://github.com/FluxML/NNlib.jl/issues/625 is fixed - dw = zeros_like(w) + dw = zeros_like(w, eltype(dy)) idx = CartesianIndex.(s, t) #TODO remove when https://github.com/FluxML/NNlib.jl/issues/626 is fixed NNlib.gather!(dw, dy, idx) - @show dw dg = CRC.Tangent{G}(; graph = (CRC.NoTangent(), CRC.NoTangent(), dw)) return (CRC.NoTangent(), dg, CRC.NoTangent()) end