Skip to content

Commit 77f0e0f

Browse files
remove cast rrule not needed anymore (#135)
1 parent dd4a54c commit 77f0e0f

File tree

2 files changed

+2
-16
lines changed

2 files changed

+2
-16
lines changed

src/utils.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,3 @@ function broadcast_edges(g::GNNGraph, x)
9999
gi = graph_indicator(g, edges=true)
100100
return gather(x, gi)
101101
end
102-
103-
# More generic version of
104-
# https://github.com/JuliaDiff/ChainRules.jl/pull/586
105-
# This applies to all arrays
106-
# Withouth this, gradient of T.(A) for A dense gpu matrix errors.
107-
function ChainRulesCore.rrule(::typeof(Broadcast.broadcasted), T::Type{<:Number}, x::AbstractArray)
108-
proj = ProjectTo(x)
109-
110-
function broadcasted_cast(Δ)
111-
return NoTangent(), NoTangent(), proj(Δ)
112-
end
113-
114-
return T.(x), broadcasted_cast
115-
end

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ tests = [
4141

4242
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
4343

44-
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:dense, :coo, :sparse)
44+
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :dense, :sparse)
4545
global GRAPH_T = graph_type
4646
global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse)
47-
47+
4848
for t in tests
4949
startswith(t, "examples") && GRAPH_T == :dense && continue # not testing :dense since causes OutOfMememory on github's CI
5050
include("$t.jl")

0 commit comments

Comments
 (0)