Skip to content

Commit 96eb264

Browse files
committed
fix
1 parent d2ab349 commit 96eb264

File tree

1 file changed

+39
-51
lines changed

1 file changed

+39
-51
lines changed

GNNGraphs/src/transform.jl

Lines changed: 39 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -149,57 +149,6 @@ function remove_self_loops(g::GNNGraph{<:ADJMAT_T})
149149
g.ndata, g.edata, g.gdata)
150150
end
151151

152-
"""
153-
remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer})
154-
155-
Remove specified edges from a GNNGraph.
156-
157-
# Arguments
158-
- `g`: The input graph from which edges will be removed.
159-
- `edges_to_remove`: Vector of edge indices to be removed.
160-
161-
# Returns
162-
A new GNNGraph with the specified edges removed.
163-
164-
# Example
165-
```julia
166-
julia> using GraphNeuralNetworks
167-
168-
# Construct a GNNGraph
169-
julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1])
170-
GNNGraph:
171-
num_nodes: 3
172-
num_edges: 5
173-
174-
# Remove the second edge
175-
julia> g_new = remove_edges(g, [2]);
176-
177-
julia> g_new
178-
GNNGraph:
179-
num_nodes: 3
180-
num_edges: 4
181-
```
182-
"""
183-
function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer})
184-
s, t = edge_index(g)
185-
w = get_edge_weight(g)
186-
edata = g.edata
187-
188-
mask_to_keep = trues(length(s))
189-
190-
mask_to_keep[edges_to_remove] .= false
191-
192-
s = s[mask_to_keep]
193-
t = t[mask_to_keep]
194-
edata = getobs(edata, mask_to_keep)
195-
w = isnothing(w) ? nothing : getobs(w, mask_to_keep)
196-
197-
return GNNGraph((s, t, w),
198-
g.num_nodes, length(s), g.num_graphs,
199-
g.graph_indicator,
200-
g.ndata, edata, g.gdata)
201-
end
202-
203152
"""
204153
remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer})
205154
remove_edges(g::GNNGraph, p::Float64=0.5)
@@ -275,6 +224,45 @@ function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5)
275224
g.ndata, edata, g.gdata)
276225
end
277226

227+
"""
228+
remove_multi_edges(g::GNNGraph; aggr=+)
229+
230+
Remove multiple edges (also called parallel edges or repeated edges) from graph `g`.
231+
Possible edge features are aggregated according to `aggr`, that can take value
232+
`+`,`min`, `max` or `mean`.
233+
234+
See also [`remove_self_loops`](@ref), [`has_multi_edges`](@ref), and [`to_bidirected`](@ref).
235+
"""
236+
function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +)
237+
s, t = edge_index(g)
238+
w = get_edge_weight(g)
239+
edata = g.edata
240+
num_edges = g.num_edges
241+
idxs, idxmax = edge_encoding(s, t, g.num_nodes)
242+
243+
perm = sortperm(idxs)
244+
idxs = idxs[perm]
245+
s, t = s[perm], t[perm]
246+
edata = getobs(edata, perm)
247+
w = isnothing(w) ? nothing : getobs(w, perm)
248+
idxs = [-1; idxs]
249+
mask = idxs[2:end] .> idxs[1:(end - 1)]
250+
if !all(mask)
251+
s, t = s[mask], t[mask]
252+
idxs = similar(s, num_edges)
253+
idxs .= 1:num_edges
254+
idxs .= idxs .- cumsum(.!mask)
255+
num_edges = length(s)
256+
w = _scatter(aggr, w, idxs, num_edges)
257+
edata = _scatter(aggr, edata, idxs, num_edges)
258+
end
259+
260+
return GNNGraph((s, t, w),
261+
g.num_nodes, num_edges, g.num_graphs,
262+
g.graph_indicator,
263+
g.ndata, edata, g.gdata)
264+
end
265+
278266
"""
279267
remove_nodes(g::GNNGraph, nodes_to_remove::AbstractVector)
280268

0 commit comments

Comments
 (0)