Skip to content

Commit 54ab41f

Browse files
implement DataStore (#232)
* implementation of DataStore * fix tests * keyword arg constructor * cleanup * ADAM -> Adam * fix some tests * fix some tests * add some docs * heterograph compatibility * comment out show tests * tests passing locally * more docs * more tests * try with non diff * nothing is now the default
1 parent 5ae657d commit 54ab41f

21 files changed

+507
-168
lines changed

Project.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "GraphNeuralNetworks"
22
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.5.2"
4+
version = "0.6.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
11+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1112
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1213
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1314
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -29,12 +30,12 @@ Adapt = "3"
2930
CUDA = "3.3"
3031
ChainRulesCore = "1"
3132
DataStructures = "0.18"
32-
Flux = "0.13.4"
33-
Functors = "0.2, 0.3, 0.4"
33+
Flux = "0.13.9"
34+
Functors = "0.4.1"
3435
Graphs = "1.4"
35-
KrylovKit = "0.5, 0.6"
36-
MLDatasets = "0.6, 0.7"
37-
MLUtils = "0.2.3, 0.3"
36+
KrylovKit = "0.6"
37+
MLDatasets = "0.7"
38+
MLUtils = "0.3"
3839
MacroTools = "0.5"
3940
NNlib = "0.8"
4041
NNlibCUDA = "0.2"

docs/pluto_output/node_classification_pluto.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ end</code></pre>
206206
<pre class='language-julia'><code class='language-julia'>begin
207207
mlp = MLP(num_features, num_classes, hidden_channels)
208208
ps_mlp = Flux.params(mlp)
209-
opt_mlp = ADAM(1e-3)
209+
opt_mlp = Adam(1e-3)
210210
epochs = 2000
211211
train(mlp, g.ndata.features, epochs, opt_mlp, ps_mlp)
212212
end</code></pre>
@@ -294,7 +294,7 @@ end</code></pre>
294294

295295
<pre class='language-julia'><code class='language-julia'>begin
296296
ps_gcn = Flux.params(gcn)
297-
opt_gcn = ADAM(1e-2)
297+
opt_gcn = Adam(1e-2)
298298
train(gcn, g, x, epochs, ps_gcn, opt_gcn)
299299
end</code></pre>
300300

docs/src/api/gnngraph.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ Pages = ["gnngraph.jl"]
2525
Private = false
2626
```
2727

28+
## DataStore
29+
30+
```@autodocs
31+
Modules = [GraphNeuralNetworks.GNNGraphs]
32+
Pages = ["datastore.jl"]
33+
Private = false
34+
```
35+
2836
## Query
2937

3038
```@autodocs

docs/src/gnngraph.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,12 @@ false
7474

7575
One or more arrays can be associated to nodes, edges, and (sub)graphs of a `GNNGraph`.
7676
They will be stored in the fields `g.ndata`, `g.edata`, and `g.gdata` respectivaly.
77-
The data fields are `NamedTuple`s. The arrays they contain have last dimension
78-
equal to `num_nodes` (in `ndata`), `num_edges` (in `edata`), or `num_graphs` (in `gdata`) respectively.
77+
78+
The data fields are [`DataStore`](@ref) objects, and conveniently
79+
offer an interface similar to both dictionaries and named tuples.
80+
Datastores support addition of new features after creation time.
81+
82+
The array contained in the datastores have last dimension equal to `num_nodes` (in `ndata`), `num_edges` (in `edata`), or `num_graphs` (in `gdata`) respectively.
7983

8084
```julia
8185
# Create a graph with a single feature array `x` associated to nodes
@@ -88,6 +92,8 @@ g = rand_graph(10, 60, ndata = rand(Float32, 32, 10))
8892

8993
g.ndata.x # `:x` is the default name for node features
9094

95+
g.ndata.z = rand(Float32, 3, 10) # add new feature array `z`
96+
9197
# For convenience, we can access the features through the shortcut
9298
g.x
9399

docs/tutorials/introductory_tutorials/graph_classification_pluto.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ end
3030
begin
3131
using Flux
3232
using Flux: onecold, onehotbatch, logitcrossentropy
33-
using Flux.Data: DataLoader
33+
using Flux: DataLoader
3434
using GraphNeuralNetworks
3535
using MLDatasets
3636
using MLUtils

docs/tutorials/introductory_tutorials/node_classification_pluto.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ end
317317
begin
318318
mlp = MLP(num_features, num_classes, hidden_channels)
319319
ps_mlp = Flux.params(mlp)
320-
opt_mlp = ADAM(1e-3)
320+
opt_mlp = Adam(1e-3)
321321
epochs = 2000
322322
train(mlp, g.ndata.features, epochs, opt_mlp, ps_mlp)
323323
end
@@ -335,7 +335,7 @@ accuracy(mlp, g.ndata.features, y, .!train_mask)
335335
# ╠═╡ show_logs = false
336336
begin
337337
ps_gcn = Flux.params(gcn)
338-
opt_gcn = ADAM(1e-2)
338+
opt_gcn = Adam(1e-2)
339339
train(gcn, g, x, epochs, ps_gcn, opt_gcn)
340340
end
341341

examples/graph_classification_tudataset.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# An example of graph classification
22

33
using Flux
4-
using Flux:onecold, onehotbatch
4+
using Flux: onecold, onehotbatch
55
using Flux.Losses: logitbinarycrossentropy
6-
using Flux.Data: DataLoader
6+
using Flux: DataLoader
77
using GraphNeuralNetworks
88
using MLDatasets: TUDataset
99
using Statistics, Random

src/GNNGraphs/GNNGraphs.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ using ChainRulesCore
1515
using LinearAlgebra, Random, Statistics
1616
import MLUtils
1717
using MLUtils: getobs, numobs
18+
import Functors
19+
20+
include("datastore.jl")
21+
export DataStore
1822

1923
include("gnngraph.jl")
2024
export GNNGraph,

src/GNNGraphs/datastore.jl

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
""""
2+
DataStore([n, data])
3+
DataStore([n,] k1 = x1, k2 = x2, ...)
4+
5+
A container for data, with optional metadata `n` enforcing
6+
`numobs(x) == n` for each feature array contained in the data store.
7+
8+
At construction time, the data can be provided as any iterables of pairs
9+
of symbols and arrays or as keyword arguments:
10+
11+
```julia-repl
12+
julia> ds = DataStore(3, x = rand(2, 3), y = rand(3))
13+
DataStore(3) with 2 elements:
14+
y = 3-element Vector{Float64}
15+
x = 2×3 Matrix{Float64}
16+
17+
julia> ds = DataStore(3, Dict(:x => rand(2, 3), :y => rand(3))); # equivalent to above
18+
19+
julia> ds = DataStore(3, (x = rand(2, 3), y = rand(30)))
20+
ERROR: AssertionError: DataStore: data[y] has 30 observations, but n = 3
21+
Stacktrace:
22+
[1] DataStore(n::Int64, data::Dict{Symbol, Any})
23+
@ GraphNeuralNetworks.GNNGraphs ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/datastore.jl:54
24+
[2] DataStore(n::Int64, data::NamedTuple{(:x, :y), Tuple{Matrix{Float64}, Vector{Float64}}})
25+
@ GraphNeuralNetworks.GNNGraphs ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/datastore.jl:73
26+
[3] top-level scope
27+
@ REPL[13]:1
28+
29+
julia> ds = DataStore(x = rand(2, 3), y = rand(30)) # no checks
30+
DataStore() with 2 elements:
31+
y = 30-element Vector{Float64}
32+
x = 2×3 Matrix{Float64}
33+
```
34+
35+
The `DataStore` as an interface similar to both dictionaries and named tuples.
36+
Data can be accessed and added using either the indexing or the property syntax:
37+
38+
```julia-repl
39+
julia> ds = DataStore(x = ones(2, 3), y = zeros(3))
40+
DataStore() with 2 elements:
41+
y = 3-element Vector{Float64}
42+
x = 2×3 Matrix{Float64}
43+
44+
julia> ds.x # same as `ds[:x]`
45+
2×3 Matrix{Float64}:
46+
1.0 1.0 1.0
47+
1.0 1.0 1.0
48+
49+
julia> ds.z = zeros(3) # Add new feature array `z`. Same as `ds[:z] = rand(3)`
50+
3-element Vector{Float64}:
51+
0.0
52+
0.0
53+
0.0
54+
```
55+
56+
The `DataStore` can be iterated over, and the keys and values can be accessed
57+
using `keys(ds)` and `values(ds)`. `map(f, ds)` applies the function `f`
58+
to each feature array:
59+
60+
```julia-repl
61+
julia> ds = DataStore(a=zeros(2), b=zeros(2));
62+
63+
julia> ds2 = map(x -> x .+ 1, ds)
64+
65+
julia> ds2.a
66+
2-element Vector{Float64}:
67+
1.0
68+
1.0
69+
```
70+
"""
71+
struct DataStore
72+
_n::Int # either -1 or numobs(data)
73+
_data::Dict{Symbol, Any}
74+
75+
function DataStore(n::Int, data::Dict{Symbol,Any})
76+
if n >= 0
77+
for (k, v) in data
78+
@assert numobs(v) == n "DataStore: data[$k] has $(numobs(v)) observations, but n = $n"
79+
end
80+
end
81+
return new(n, data)
82+
end
83+
end
84+
85+
@functor DataStore
86+
87+
DataStore(data) = DataStore(-1, data)
88+
DataStore(n::Int, data::NamedTuple) = DataStore(n, Dict{Symbol,Any}(pairs(data)))
89+
DataStore(n::Int, data) = DataStore(n, Dict{Symbol,Any}(data))
90+
91+
DataStore(; kws...) = DataStore(-1; kws...)
92+
DataStore(n::Int; kws...) = DataStore(n, Dict{Symbol,Any}(kws...))
93+
94+
getdata(ds::DataStore) = getfield(ds, :_data)
95+
getn(ds::DataStore) = getfield(ds, :_n)
96+
# setn!(ds::DataStore, n::Int) = setfield!(ds, :n, n)
97+
98+
function Base.getproperty(ds::DataStore, s::Symbol)
99+
if s === :_n
100+
return getn(ds)
101+
elseif s === :_data
102+
return getdata(ds)
103+
else
104+
return getdata(ds)[s]
105+
end
106+
end
107+
108+
function Base.setproperty!(ds::DataStore, s::Symbol, x)
109+
@assert s != :_n "cannot set _n directly"
110+
@assert s != :_data "cannot set _data directly"
111+
if getn(ds) > 0
112+
@assert numobs(x) == getn(ds) "expected (numobs(x) == getn(ds)) but got $(numobs(x)) != $(getn(ds))"
113+
end
114+
return getdata(ds)[s] = x
115+
end
116+
117+
Base.getindex(ds::DataStore, s::Symbol) = getproperty(ds, s)
118+
Base.setindex!(ds::DataStore, s::Symbol, x) = setproperty!(ds, s, x)
119+
120+
function Base.show(io::IO, ds::DataStore)
121+
len = length(ds)
122+
n = getn(ds)
123+
if n < 0
124+
print(io, "DataStore()")
125+
else
126+
print(io, "DataStore($(getn(ds)))")
127+
end
128+
if len > 0
129+
print(io, " with $(length(getdata(ds))) element")
130+
len > 1 && print(io, "s")
131+
print(io, ":")
132+
for (k, v) in getdata(ds)
133+
print(io, "\n $(k) = $(summary(v))")
134+
end
135+
end
136+
end
137+
138+
Base.iterate(ds::DataStore) = iterate(getdata(ds))
139+
Base.iterate(ds::DataStore, state) = iterate(getdata(ds), state)
140+
Base.keys(ds::DataStore) = keys(getdata(ds))
141+
Base.values(ds::DataStore) = values(getdata(ds))
142+
Base.length(ds::DataStore) = length(getdata(ds))
143+
Base.haskey(ds::DataStore, k) = haskey(getdata(ds), k)
144+
Base.get(ds::DataStore, k, default) = get(getdata(ds), k, default)
145+
Base.pairs(ds::DataStore) = pairs(getdata(ds))
146+
Base.:(==)(ds1::DataStore, ds2::DataStore) = getdata(ds1) == getdata(ds2)
147+
Base.isempty(ds::DataStore) = isempty(getdata(ds))
148+
Base.delete!(ds::DataStore, k) = delete!(getdata(ds), k)
149+
150+
function Base.map(f, ds::DataStore)
151+
d = getdata(ds)
152+
newd = Dict{Symbol, Any}(k => f(v) for (k, v) in d)
153+
return DataStore(getn(ds), newd)
154+
end
155+
156+
MLUtils.numobs(ds::DataStore) = numobs(getdata(ds))
157+
158+
function MLUtils.getobs(ds::DataStore, i::Int)
159+
newdata = getobs(getdata(ds), i)
160+
return DataStore(-1, newdata)
161+
end
162+
163+
function MLUtils.getobs(ds::DataStore, i::AbstractVector{T}) where T <: Union{Integer,Bool}
164+
newdata = getobs(getdata(ds), i)
165+
n = getn(ds)
166+
if n > -1
167+
if length(ds) > 0
168+
n = numobs(newdata)
169+
else
170+
# if newdata is empty, then we can't get the number of observations from it
171+
n = T == Bool ? sum(i) : length(i)
172+
end
173+
end
174+
if !(newdata isa Dict{Symbol, Any})
175+
newdata = Dict{Symbol, Any}(newdata)
176+
end
177+
return DataStore(n, newdata)
178+
end
179+
180+
function cat_features(ds1::DataStore, ds2::DataStore)
181+
n1, n2 = getn(ds1), getn(ds2)
182+
n1 = n1 > 0 ? n1 : 1
183+
n2 = n2 > 0 ? n2 : 1
184+
return DataStore(n1 + n2, cat_features(getdata(ds1), getdata(ds2)))
185+
end
186+
187+
function cat_features(dss::AbstractVector{DataStore}; kws...)
188+
ns = getn.(dss)
189+
ns = map(n -> n > 0 ? n : 1, ns)
190+
return DataStore(sum(ns), cat_features(getdata.(dss); kws...))
191+
end
192+
193+
# DataStore is always already normalized
194+
normalize_graphdata(ds::DataStore; kws...) = ds
195+
196+
_gather(x::DataStore, i) = map(x -> _gather(x, i), x)
197+
198+
function _scatter(aggr, src::DataStore, idx, n)
199+
newdata = _scatter(aggr, getdata(src), idx, n)
200+
if !(newdata isa Dict{Symbol, Any})
201+
newdata = Dict{Symbol, Any}(newdata)
202+
end
203+
return DataStore(n, newdata)
204+
end
205+
206+
function Base.hash(ds::D, h::UInt) where {D <: DataStore}
207+
fs = (getfield(ds, k) for k in fieldnames(D))
208+
return foldl((h, f) -> hash(f, h), fs, init=hash(D, h))
209+
end

src/GNNGraphs/gatherscatter.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
_gather(x::NamedTuple, i) = map(x -> _gather(x, i), x)
2+
_gather(x::Dict, i) = Dict(k => _gather(v, i) for (k, v) in x)
23
_gather(x::Tuple, i) = map(x -> _gather(x, i), x)
34
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
45
_gather(x::Nothing, i) = nothing
56

67
_scatter(aggr, src::Nothing, idx, n) = nothing
78
_scatter(aggr, src::NamedTuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src)
89
_scatter(aggr, src::Tuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src)
10+
_scatter(aggr, src::Dict, idx, n) = Dict(k => _scatter(aggr, v, idx, n) for (k, v) in src)
911

1012
function _scatter(aggr,
1113
src::AbstractArray,

0 commit comments

Comments
 (0)