@@ -144,7 +144,8 @@ If `weighted=true`, the `A` will contain the edge weights if any, otherwise the
144
144
function Graphs. adjacency_matrix (g:: GNNGraph{<:COO_T} , T:: DataType = eltype (g); dir = :out ,
145
145
weighted = true )
146
146
if g. graph[1 ] isa CuVector
147
- # TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
147
+ # Revisit after
148
+ # https://github.com/JuliaGPU/CUDA.jl/issues/1113
148
149
A, n, m = to_dense (g. graph, T; num_nodes = g. num_nodes, weighted)
149
150
else
150
151
A, n, m = to_sparse (g. graph, T; num_nodes = g. num_nodes, weighted)
@@ -164,63 +165,101 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
164
165
return dir == :out ? A : A'
165
166
end
166
167
167
- function _get_edge_weight (g, edge_weight)
168
- if edge_weight === true || edge_weight === nothing
169
- ew = get_edge_weight (g)
170
- elseif edge_weight === false
171
- ew = nothing
172
- elseif edge_weight isa AbstractVector
173
- ew = edge_weight
168
+ function ChainRulesCore. rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ;
169
+ dir = :out , weighted = true ) where {G <: GNNGraph{<:ADJMAT_T} }
170
+ A = adjacency_matrix (g, T; dir, weighted)
171
+ if ! weighted
172
+ function adjacency_matrix_pullback_noweight (Δ)
173
+ return (NoTangent (), ZeroTangent (), NoTangent ())
174
+ end
175
+ return A, adjacency_matrix_pullback_noweight
174
176
else
175
- error (" Invalid edge_weight argument." )
177
+ function adjacency_matrix_pullback_weighted (Δ)
178
+ dg = Tangent {G} (; graph = Δ .* binarize (A))
179
+ return (NoTangent (), dg, NoTangent ())
180
+ end
181
+ return A, adjacency_matrix_pullback_weighted
182
+ end
183
+ end
184
+
185
+ function ChainRulesCore. rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ;
186
+ dir = :out , weighted = true ) where {G <: GNNGraph{<:COO_T} }
187
+ A = adjacency_matrix (g, T; dir, weighted)
188
+ w = get_edge_weight (g)
189
+ if ! weighted || w === nothing
190
+ function adjacency_matrix_pullback_noweight (Δ)
191
+ return (NoTangent (), ZeroTangent (), NoTangent ())
192
+ end
193
+ return A, adjacency_matrix_pullback_noweight
194
+ else
195
+ function adjacency_matrix_pullback_weighted (Δ)
196
+ s, t = edge_index (g)
197
+ dg = Tangent {G} (; graph = (NoTangent (), NoTangent (), NNlib. gather (Δ, s, t)))
198
+ return (NoTangent (), dg, NoTangent ())
199
+ end
200
+ return A, adjacency_matrix_pullback_weighted
201
+ end
202
+ end
203
+
204
+ function _get_edge_weight (g, edge_weight:: Bool )
205
+ if edge_weight === true
206
+ return get_edge_weight (g)
207
+ elseif edge_weight === false
208
+ return nothing
176
209
end
177
- return ew
178
210
end
179
211
212
+ _get_edge_weight (g, edge_weight:: AbstractVector ) = edge_weight
213
+
180
214
"""
181
215
degree(g::GNNGraph, T=nothing; dir=:out, edge_weight=true)
182
216
183
217
Return a vector containing the degrees of the nodes in `g`.
184
218
219
+ The gradient is propagated through this function only if `edge_weight` is `true`
220
+ or a vector.
221
+
185
222
# Arguments
223
+
186
224
- `g`: A graph.
187
225
- `T`: Element type of the returned vector. If `nothing`, is
188
226
chosen based on the graph type and will be an integer
189
- if `edge_weight=false`.
227
+ if `edge_weight=false`. Default `nothing`.
190
228
- `dir`: For `dir=:out` the degree of a node is counted based on the outgoing edges.
191
229
For `dir=:in`, the ingoing edges are used. If `dir=:both` we have the sum of the two.
192
230
- `edge_weight`: If `true` and the graph contains weighted edges, the degree will
193
231
be weighted. Set to `false` instead to just count the number of
194
- outgoing/ingoing edges.
195
- In alternative , you can also pass a vector of weights to be used
232
+ outgoing/ingoing edges.
233
+ Finally , you can also pass a vector of weights to be used
196
234
instead of the graph's own weights.
235
+ Default `true`.
236
+
197
237
"""
198
238
function Graphs. degree (g:: GNNGraph{<:COO_T} , T:: TT = nothing ; dir = :out ,
199
239
edge_weight = true ) where {
200
240
TT <: Union{Nothing, Type{<:Number}} }
201
241
s, t = edge_index (g)
202
242
203
- edge_weight = _get_edge_weight (g, edge_weight)
204
- edge_weight = edge_weight === nothing ? ones_like (s) : edge_weight
205
-
206
- T = isnothing (T) ? eltype (edge_weight) : T
207
- degs = fill! (similar (s, T, g. num_nodes), 0 )
208
-
209
- if dir ∈ [:out , :both ]
210
- degs = degs .+ NNlib. scatter (+ , edge_weight, s, dstsize = (g. num_nodes,))
211
- end
212
- if dir ∈ [:in , :both ]
213
- degs = degs .+ NNlib. scatter (+ , edge_weight, t, dstsize = (g. num_nodes,))
214
- end
215
- return degs
243
+ ew = _get_edge_weight (g, edge_weight)
244
+
245
+ T = if isnothing (T)
246
+ if ! isnothing (ew)
247
+ eltype (ew)
248
+ else
249
+ eltype (s)
250
+ end
251
+ else
252
+ T
253
+ end
254
+ return _degree ((s, t), T, dir, ew, g. num_nodes)
216
255
end
217
256
218
257
# TODO :: Make efficient
219
258
Graphs. degree (g:: GNNGraph , i:: Union{Int, AbstractVector} ; dir = :out ) = degree (g; dir)[i]
220
259
221
260
function Graphs. degree (g:: GNNGraph{<:ADJMAT_T} , T:: TT = nothing ; dir = :out ,
222
- edge_weight = true ) where {TT}
223
- TT <: Union{Nothing, Type{<:Number}}
261
+ edge_weight = true ) where {TT<: Union{Nothing, Type{<:Number}} }
262
+
224
263
# edge_weight=true or edge_weight=nothing act the same here
225
264
@assert ! (edge_weight isa AbstractArray) " passing the edge weights is not support by adjacency matrix representations"
226
265
@assert dir ∈ (:in , :out , :both )
@@ -234,6 +273,26 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out,
234
273
end
235
274
end
236
275
A = adjacency_matrix (g)
276
+ return _degree (A, T, dir, edge_weight, g. num_nodes)
277
+ end
278
+
279
+ function _degree ((s, t):: Tuple , T:: Type , dir:: Symbol , edge_weight:: Nothing , num_nodes:: Int )
280
+ _degree ((s, t), T, dir, ones_like (s, T), num_nodes)
281
+ end
282
+
283
+ function _degree ((s, t):: Tuple , T:: Type , dir:: Symbol , edge_weight:: AbstractVector , num_nodes:: Int )
284
+ degs = fill! (similar (s, T, num_nodes), 0 )
285
+
286
+ if dir ∈ [:out , :both ]
287
+ degs = degs .+ NNlib. scatter (+ , edge_weight, s, dstsize = (num_nodes,))
288
+ end
289
+ if dir ∈ [:in , :both ]
290
+ degs = degs .+ NNlib. scatter (+ , edge_weight, t, dstsize = (num_nodes,))
291
+ end
292
+ return degs
293
+ end
294
+
295
+ function _degree (A:: AbstractMatrix , T:: Type , dir:: Symbol , edge_weight:: Bool , num_nodes:: Int )
237
296
if edge_weight === false
238
297
A = binarize (A)
239
298
end
@@ -243,6 +302,40 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out,
243
302
vec (sum (A, dims = 1 )) .+ vec (sum (A, dims = 2 ))
244
303
end
245
304
305
+ function ChainRulesCore. rrule (:: typeof (_degree), graph, T, dir, edge_weight:: Nothing , num_nodes)
306
+ degs = _degree (graph, T, dir, edge_weight, num_nodes)
307
+ function _degree_pullback (Δ)
308
+ return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent ())
309
+ end
310
+ return degs, _degree_pullback
311
+ end
312
+
313
+ function ChainRulesCore. rrule (:: typeof (_degree), A:: ADJMAT_T , T, dir, edge_weight:: Bool , num_nodes)
314
+ degs = _degree (A, T, dir, edge_weight, num_nodes)
315
+ if edge_weight === false
316
+ function _degree_pullback_noweights (Δ)
317
+ return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent ())
318
+ end
319
+ return degs, _degree_pullback_noweights
320
+ else
321
+ function _degree_pullback_weights (Δ)
322
+ # We propagate the gradient only to the non-zero elements
323
+ # of the adjacency matrix.
324
+ bA = binarize (A)
325
+ if dir == :in
326
+ dA = bA .* Δ'
327
+ elseif dir == :out
328
+ dA = Δ .* bA
329
+ else # dir == :both
330
+ dA = Δ .* bA + Δ' .* bA
331
+ end
332
+ return (NoTangent (), dA, NoTangent (), NoTangent (), NoTangent (), NoTangent ())
333
+ end
334
+ return degs, _degree_pullback_weights
335
+ end
336
+ end
337
+
338
+
246
339
"""
247
340
has_isolated_nodes(g::GNNGraph; dir=:out)
248
341
0 commit comments