Skip to content

Commit 8dbac17

Browse files
feat: add induced_subgraph functionality (#499)
* feat: add induced_subgraph functionality * fix: fix tests * fix: fix tests * Update GNNGraphs/src/sampling.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update GNNGraphs/src/GNNGraphs.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update GNNGraphs/src/sampling.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update GNNGraphs/src/sampling.jl Co-authored-by: Carlo Lucibello <[email protected]> * feat: add edata&ndata support for induced_subgraph * chore: export induced_subgraph * fix: fix naming for induced_subgraph * fix: fix typo * fix: revert naming for induced_subgraph * fix: fix test * fix: don't export induced subgraph * fix: don't export induced subgraph * fix: amend docstring * fix: amend docstring * fix: fix docstring * fix: add Graphs.induced_subgraph to docs --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 3fe2c76 commit 8dbac17

File tree

5 files changed

+122
-2
lines changed

5 files changed

+122
-2
lines changed

GNNGraphs/src/GNNGraphs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using SparseArrays
44
using Functors: @functor
55
import Graphs
66
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
7-
has_self_loops, is_directed
7+
has_self_loops, is_directed, induced_subgraph
88
import NearestNeighbors
99
import NNlib
1010
import StatsBase

GNNGraphs/src/sampling.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,87 @@ function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1;
116116
end
117117
return gnew
118118
end
119+
120+
121+
"""
122+
induced_subgraph(graph, nodes)
123+
124+
Generates a subgraph from the original graph using the provided `nodes`.
125+
The function includes the nodes' neighbors and creates edges between nodes that are connected in the original graph.
126+
If a node has no neighbors, an isolated node will be added to the subgraph.
127+
Returns A new `GNNGraph` containing the subgraph with the specified nodes and their features.
128+
129+
# Arguments
130+
131+
- `graph`. The original GNNGraph containing nodes, edges, and node features.
132+
- `nodes``. A vector of node indices to include in the subgraph.
133+
134+
# Examples
135+
136+
```julia
137+
julia> s = [1, 2]
138+
2-element Vector{Int64}:
139+
1
140+
2
141+
142+
julia> t = [2, 3]
143+
2-element Vector{Int64}:
144+
2
145+
3
146+
147+
julia> graph = GNNGraph((s, t), ndata = (; x=rand(Float32, 32, 3), y=rand(Float32, 3)), edata = rand(Float32, 2))
148+
GNNGraph:
149+
num_nodes: 3
150+
num_edges: 2
151+
ndata:
152+
y = 3-element Vector{Float32}
153+
x = 32×3 Matrix{Float32}
154+
edata:
155+
e = 2-element Vector{Float32}
156+
157+
julia> nodes = [1, 2]
158+
2-element Vector{Int64}:
159+
1
160+
2
161+
162+
julia> subgraph = Graphs.induced_subgraph(graph, nodes)
163+
GNNGraph:
164+
num_nodes: 2
165+
num_edges: 1
166+
ndata:
167+
y = 2-element Vector{Float32}
168+
x = 32×2 Matrix{Float32}
169+
edata:
170+
e = 1-element Vector{Float32}
171+
```
172+
"""
173+
function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
174+
if isempty(nodes)
175+
return GNNGraph() # Return empty graph if no nodes are provided
176+
end
177+
178+
node_map = Dict(node => i for (i, node) in enumerate(nodes))
179+
180+
# Collect edges to add
181+
source = Int[]
182+
target = Int[]
183+
eindices = Int[]
184+
for node in nodes
185+
neighbors = Graphs.neighbors(graph, node, dir = :in)
186+
for neighbor in neighbors
187+
if neighbor in keys(node_map)
188+
push!(target, node_map[node])
189+
push!(source, node_map[neighbor])
190+
191+
eindex = findfirst(x -> x == [neighbor, node], edge_index(graph))
192+
push!(eindices, eindex)
193+
end
194+
end
195+
end
196+
197+
# Extract features for the new nodes
198+
new_ndata = getobs(graph.ndata, nodes)
199+
new_edata = getobs(graph.edata, eindices)
200+
201+
return GNNGraph(source, target, num_nodes = length(node_map), ndata = new_ndata, edata = new_edata)
202+
end

GNNGraphs/test/sampling.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,36 @@ if GRAPH_T == :coo
4545
@test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID]
4646
@test length(union(sg.ndata.NID)) == length(sg.ndata.NID)
4747
end
48+
49+
@testset "induced_subgraph" begin
50+
s = [1, 2]
51+
t = [2, 3]
52+
53+
graph = GNNGraph((s, t), ndata = (; x=rand(Float32, 32, 3), y=rand(Float32, 3)), edata = rand(Float32, 2))
54+
55+
nodes = [1, 2, 3]
56+
subgraph = Graphs.induced_subgraph(graph, nodes)
57+
58+
@test subgraph.num_nodes == 3
59+
@test subgraph.num_edges == 2
60+
@test subgraph.ndata.x == graph.ndata.x
61+
@test subgraph.ndata.y == graph.ndata.y
62+
@test subgraph.edata == graph.edata
63+
64+
nodes = [1, 2]
65+
subgraph = Graphs.induced_subgraph(graph, nodes)
66+
67+
@test subgraph.num_nodes == 2
68+
@test subgraph.num_edges == 1
69+
@test subgraph.ndata == getobs(graph.ndata, [1, 2])
70+
@test isapprox(getobs(subgraph.edata.e, 1), getobs(graph.edata.e, 1); atol=1e-6)
71+
72+
graph = GNNGraph(2)
73+
graph = add_edges(graph, ([2], [1]))
74+
nodes = [1]
75+
subgraph = Graphs.induced_subgraph(graph, nodes)
76+
77+
@test subgraph.num_nodes == 1
78+
@test subgraph.num_edges == 0
79+
end
4880
end

GraphNeuralNetworks/docs/src/api/gnngraph.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,7 @@ Modules = [GNNGraphs]
8888
Pages = ["sampling.jl"]
8989
Private = false
9090
```
91+
92+
```@docs
93+
Graphs.induced_subgraph(::GNNGraph, ::Vector{Int})
94+
```

GraphNeuralNetworks/src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using ChainRulesCore
1111
using Reexport
1212
using MLUtils: zeros_like
1313

14-
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
14+
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
1515
check_num_nodes, check_num_edges,
1616
EType, NType # for heteroconvs
1717

0 commit comments

Comments
 (0)