Skip to content

Commit 637549e

Browse files
convenience accessors (#231)
* convenience accessors * docstring * tests
1 parent 22bb801 commit 637549e

File tree

4 files changed

+50
-14
lines changed

4 files changed

+50
-14
lines changed

docs/src/gnngraph.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,32 +74,35 @@ false
7474

7575
One or more arrays can be associated to nodes, edges, and (sub)graphs of a `GNNGraph`.
7676
They will be stored in the fields `g.ndata`, `g.edata`, and `g.gdata` respectivaly.
77-
The data fields are `NamedTuple`s. The array they contain must have last dimension
78-
equal to `num_nodes` (in `ndata`), `num_edges` (in `edata`), or `num_graphs` (in `gdata`).
77+
The data fields are `NamedTuple`s. The arrays they contain have last dimension
78+
equal to `num_nodes` (in `ndata`), `num_edges` (in `edata`), or `num_graphs` (in `gdata`) respectively.
7979

8080
```julia
8181
# Create a graph with a single feature array `x` associated to nodes
82-
g = GNNGraph(erdos_renyi(10, 30), ndata = (; x = rand(Float32, 32, 10)))
82+
g = rand_graph(10, 60, ndata = (; x = rand(Float32, 32, 10)))
8383

8484
g.ndata.x # access the features
8585

8686
# Equivalent definition passing directly the array
87-
g = GNNGraph(erdos_renyi(10, 30), ndata = rand(Float32, 32, 10))
87+
g = rand_graph(10, 60, ndata = rand(Float32, 32, 10))
8888

8989
g.ndata.x # `:x` is the default name for node features
9090

91+
# For convinience, we can access the features through the shortcut
92+
g.x
93+
9194
# You can have multiple feature arrays
92-
g = GNNGraph(erdos_renyi(10, 30), ndata = (; x=rand(Float32, 32, 10), y=rand(Float32, 10)))
95+
g = rand_graph(10, 60, ndata = (; x=rand(Float32, 32, 10), y=rand(Float32, 10)))
9396

94-
g.ndata.y, g.ndata.x
97+
g.ndata.y, g.ndata.x # or g.x, g.y
9598

9699
# Attach an array with edge features.
97100
# Since `GNNGraph`s are directed, the number of edges
98101
# will be double that of the original Graphs' undirected graph.
99102
g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 60))
100103
@assert g.num_edges == 60
101104

102-
g.edata.e
105+
g.edata.e # or g.e
103106

104107
# If we pass only half of the edge features, they will be copied
105108
# on the reversed edges.
@@ -110,8 +113,8 @@ g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 30))
110113
# but replacing node data
111114
g′ = GNNGraph(g, ndata =(; z = ones(Float32, 16, 10)))
112115

113-
g.ndata.z
114-
g.edata.e
116+
g.z
117+
g.e
115118
```
116119

117120
## Edge weights

docs/src/index.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ using Flux.Data: DataLoader
3232
all_graphs = GNNGraph[]
3333

3434
for _ in 1:1000
35-
g = GNNGraph(random_regular_graph(10, 4),
35+
g = rand_graph(10, 40,
3636
ndata=(; x = randn(Float32, 16,10)), # input node features
3737
gdata=(; y = randn(Float32))) # regression target
3838
push!(all_graphs, g)
@@ -65,14 +65,14 @@ that are glued together into a single `GNNGraph` using the [`MLUtils.batch`](@re
6565
`collate=true` option.
6666

6767
```julia
68-
train_graphs, test_graphs = MLUtils.split(all_graphs, at=0.8)
68+
train_graphs, test_graphs = MLUtils.splitobs(all_graphs, at=0.8)
6969

7070
train_loader = DataLoader(train_graphs,
7171
batchsize=32, shuffle=true, collate=true)
7272
test_loader = DataLoader(test_graphs,
7373
batchsize=32, shuffle=false, collate=true)
7474

75-
loss(g::GNNGraph) = mean((vec(model(g, g.ndata.x)) - g.gdata.y).^2)
75+
loss(g::GNNGraph) = mean((vec(model(g, g.x)) - g.y).^2)
7676

7777
loss(loader) = mean(loss(g |> device) for g in loader)
7878

src/GNNGraphs/gnngraph.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ g = GNNGraph(g, ndata = (x=rand(100, g.num_nodes), y=rand(g.num_nodes)))
9393
# Add node features and edge features with default names `x` and `e`
9494
g = GNNGraph(g, ndata = rand(100, g.num_nodes), edata = rand(16, g.num_edges))
9595
96-
g.ndata.x
97-
g.ndata.e
96+
g.ndata.x # or just g.x
97+
g.ndata.e # or just g.e
9898
9999
# Send to gpu
100100
g = g |> gpu
@@ -273,3 +273,21 @@ function Base.hash(g::T, h::UInt) where T<:GNNGraph
273273
fs = (getfield(g, k) for k in fieldnames(typeof(g)) if k !== :graph_indicator)
274274
return foldl((h, f) -> hash(f, h), fs, init=hash(T, h))
275275
end
276+
277+
function Base.getproperty(g::GNNGraph, s::Symbol)
278+
if s in fieldnames(GNNGraph)
279+
return getfield(g, s)
280+
end
281+
if (s in keys(g.ndata)) + (s in keys(g.edata)) + (s in keys(g.gdata)) > 1
282+
throw(ArgumentError("Ambiguous property name $s"))
283+
end
284+
if s in keys(g.ndata)
285+
return g.ndata[s]
286+
elseif s in keys(g.edata)
287+
return g.edata[s]
288+
elseif s in keys(g.gdata)
289+
return g.gdata[s]
290+
else
291+
throw(ArgumentError("$(s) is not a field of GNNGraph"))
292+
end
293+
end

test/GNNGraphs/gnngraph.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@
204204
@test g.ndata.x === X
205205
@test g.edata.e === E
206206
@test g.gdata.u === U
207+
@test g.x === g.ndata.x
208+
@test g.e === g.edata.e
209+
@test g.u === g.gdata.u
207210

208211
# Check no args
209212
g = GNNGraph(g)
@@ -220,6 +223,12 @@
220223
@test g.ndata.x2 2X
221224
@test g.edata.e2 2E
222225
@test g.gdata.u2 2U
226+
@test g.x === g.ndata.x
227+
@test g.e === g.edata.e
228+
@test g.u === g.gdata.u
229+
@test g.x2 === g.ndata.x2
230+
@test g.e2 === g.edata.e2
231+
@test g.u2 === g.gdata.u2
223232

224233
# Dimension checks
225234
@test_throws AssertionError GNNGraph(erdos_renyi(10, 30), edata=rand(29), graph_type=GRAPH_T)
@@ -249,6 +258,12 @@
249258
# Error for non-array ndata
250259
@test_throws AssertionError rand_graph(10, 30, ndata="ciao", graph_type=GRAPH_T)
251260
@test_throws AssertionError rand_graph(10, 30, ndata=1, graph_type=GRAPH_T)
261+
262+
# Error for Ambiguous getproperty
263+
g = rand_graph(10, 20, ndata=rand(2,10), edata=(; x = rand(3,20)), graph_type=GRAPH_T)
264+
@test size(g.ndata.x) == (2,10)
265+
@test size(g.edata.x) == (3,20)
266+
@test_throws ArgumentError g.x
252267
end
253268

254269
@testset "MLUtils and DataLoader compat" begin

0 commit comments

Comments
 (0)