|
| 1 | +"""" |
| 2 | + DataStore([n, data]) |
| 3 | + DataStore([n,] k1 = x1, k2 = x2, ...) |
| 4 | +
|
| 5 | +A container for data, with optional metadata `n` enforcing |
| 6 | +`numobs(x) == n` for each feature array contained in the data store. |
| 7 | +
|
| 8 | +At construction time, the data can be provided as any iterables of pairs |
| 9 | +of symbols and arrays or as keyword arguments: |
| 10 | +
|
| 11 | +```julia-repl |
| 12 | +julia> ds = DataStore(3, x = rand(2, 3), y = rand(3)) |
| 13 | +DataStore(3) with 2 elements: |
| 14 | + y = 3-element Vector{Float64} |
| 15 | + x = 2×3 Matrix{Float64} |
| 16 | +
|
| 17 | +julia> ds = DataStore(3, Dict(:x => rand(2, 3), :y => rand(3))); # equivalent to above |
| 18 | +
|
| 19 | +julia> ds = DataStore(3, (x = rand(2, 3), y = rand(30))) |
| 20 | +ERROR: AssertionError: DataStore: data[y] has 30 observations, but n = 3 |
| 21 | +Stacktrace: |
| 22 | + [1] DataStore(n::Int64, data::Dict{Symbol, Any}) |
| 23 | + @ GraphNeuralNetworks.GNNGraphs ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/datastore.jl:54 |
| 24 | + [2] DataStore(n::Int64, data::NamedTuple{(:x, :y), Tuple{Matrix{Float64}, Vector{Float64}}}) |
| 25 | + @ GraphNeuralNetworks.GNNGraphs ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/datastore.jl:73 |
| 26 | + [3] top-level scope |
| 27 | + @ REPL[13]:1 |
| 28 | +
|
| 29 | +julia> ds = DataStore(x = rand(2, 3), y = rand(30)) # no checks |
| 30 | +DataStore() with 2 elements: |
| 31 | + y = 30-element Vector{Float64} |
| 32 | + x = 2×3 Matrix{Float64} |
| 33 | +``` |
| 34 | +
|
| 35 | +The `DataStore` as an interface similar to both dictionaries and named tuples. |
| 36 | +Data can be accessed and added using either the indexing or the property syntax: |
| 37 | +
|
| 38 | +```julia-repl |
| 39 | +julia> ds = DataStore(x = ones(2, 3), y = zeros(3)) |
| 40 | +DataStore() with 2 elements: |
| 41 | + y = 3-element Vector{Float64} |
| 42 | + x = 2×3 Matrix{Float64} |
| 43 | +
|
| 44 | +julia> ds.x # same as `ds[:x]` |
| 45 | +2×3 Matrix{Float64}: |
| 46 | + 1.0 1.0 1.0 |
| 47 | + 1.0 1.0 1.0 |
| 48 | +
|
| 49 | +julia> ds.z = zeros(3) # Add new feature array `z`. Same as `ds[:z] = rand(3)` |
| 50 | +3-element Vector{Float64}: |
| 51 | +0.0 |
| 52 | +0.0 |
| 53 | +0.0 |
| 54 | +``` |
| 55 | +
|
| 56 | +The `DataStore` can be iterated over, and the keys and values can be accessed |
| 57 | +using `keys(ds)` and `values(ds)`. `map(f, ds)` applies the function `f` |
| 58 | +to each feature array: |
| 59 | +
|
| 60 | +```julia-repl |
| 61 | +julia> ds = DataStore(a=zeros(2), b=zeros(2)); |
| 62 | +
|
| 63 | +julia> ds2 = map(x -> x .+ 1, ds) |
| 64 | +
|
| 65 | +julia> ds2.a |
| 66 | +2-element Vector{Float64}: |
| 67 | + 1.0 |
| 68 | + 1.0 |
| 69 | +``` |
| 70 | +""" |
| 71 | +struct DataStore |
| 72 | + _n::Int # either -1 or numobs(data) |
| 73 | + _data::Dict{Symbol, Any} |
| 74 | + |
| 75 | + function DataStore(n::Int, data::Dict{Symbol,Any}) |
| 76 | + if n >= 0 |
| 77 | + for (k, v) in data |
| 78 | + @assert numobs(v) == n "DataStore: data[$k] has $(numobs(v)) observations, but n = $n" |
| 79 | + end |
| 80 | + end |
| 81 | + return new(n, data) |
| 82 | + end |
| 83 | +end |
| 84 | + |
| 85 | +@functor DataStore |
| 86 | + |
| 87 | +DataStore(data) = DataStore(-1, data) |
| 88 | +DataStore(n::Int, data::NamedTuple) = DataStore(n, Dict{Symbol,Any}(pairs(data))) |
| 89 | +DataStore(n::Int, data) = DataStore(n, Dict{Symbol,Any}(data)) |
| 90 | + |
| 91 | +DataStore(; kws...) = DataStore(-1; kws...) |
| 92 | +DataStore(n::Int; kws...) = DataStore(n, Dict{Symbol,Any}(kws...)) |
| 93 | + |
| 94 | +getdata(ds::DataStore) = getfield(ds, :_data) |
| 95 | +getn(ds::DataStore) = getfield(ds, :_n) |
| 96 | +# setn!(ds::DataStore, n::Int) = setfield!(ds, :n, n) |
| 97 | + |
| 98 | +function Base.getproperty(ds::DataStore, s::Symbol) |
| 99 | + if s === :_n |
| 100 | + return getn(ds) |
| 101 | + elseif s === :_data |
| 102 | + return getdata(ds) |
| 103 | + else |
| 104 | + return getdata(ds)[s] |
| 105 | + end |
| 106 | +end |
| 107 | + |
| 108 | +function Base.setproperty!(ds::DataStore, s::Symbol, x) |
| 109 | + @assert s != :_n "cannot set _n directly" |
| 110 | + @assert s != :_data "cannot set _data directly" |
| 111 | + if getn(ds) > 0 |
| 112 | + @assert numobs(x) == getn(ds) "expected (numobs(x) == getn(ds)) but got $(numobs(x)) != $(getn(ds))" |
| 113 | + end |
| 114 | + return getdata(ds)[s] = x |
| 115 | +end |
| 116 | + |
| 117 | +Base.getindex(ds::DataStore, s::Symbol) = getproperty(ds, s) |
| 118 | +Base.setindex!(ds::DataStore, s::Symbol, x) = setproperty!(ds, s, x) |
| 119 | + |
| 120 | +function Base.show(io::IO, ds::DataStore) |
| 121 | + len = length(ds) |
| 122 | + n = getn(ds) |
| 123 | + if n < 0 |
| 124 | + print(io, "DataStore()") |
| 125 | + else |
| 126 | + print(io, "DataStore($(getn(ds)))") |
| 127 | + end |
| 128 | + if len > 0 |
| 129 | + print(io, " with $(length(getdata(ds))) element") |
| 130 | + len > 1 && print(io, "s") |
| 131 | + print(io, ":") |
| 132 | + for (k, v) in getdata(ds) |
| 133 | + print(io, "\n $(k) = $(summary(v))") |
| 134 | + end |
| 135 | + end |
| 136 | +end |
| 137 | + |
| 138 | +Base.iterate(ds::DataStore) = iterate(getdata(ds)) |
| 139 | +Base.iterate(ds::DataStore, state) = iterate(getdata(ds), state) |
| 140 | +Base.keys(ds::DataStore) = keys(getdata(ds)) |
| 141 | +Base.values(ds::DataStore) = values(getdata(ds)) |
| 142 | +Base.length(ds::DataStore) = length(getdata(ds)) |
| 143 | +Base.haskey(ds::DataStore, k) = haskey(getdata(ds), k) |
| 144 | +Base.get(ds::DataStore, k, default) = get(getdata(ds), k, default) |
| 145 | +Base.pairs(ds::DataStore) = pairs(getdata(ds)) |
| 146 | +Base.:(==)(ds1::DataStore, ds2::DataStore) = getdata(ds1) == getdata(ds2) |
| 147 | +Base.isempty(ds::DataStore) = isempty(getdata(ds)) |
| 148 | +Base.delete!(ds::DataStore, k) = delete!(getdata(ds), k) |
| 149 | + |
| 150 | +function Base.map(f, ds::DataStore) |
| 151 | + d = getdata(ds) |
| 152 | + newd = Dict{Symbol, Any}(k => f(v) for (k, v) in d) |
| 153 | + return DataStore(getn(ds), newd) |
| 154 | +end |
| 155 | + |
| 156 | +MLUtils.numobs(ds::DataStore) = numobs(getdata(ds)) |
| 157 | + |
| 158 | +function MLUtils.getobs(ds::DataStore, i::Int) |
| 159 | + newdata = getobs(getdata(ds), i) |
| 160 | + return DataStore(-1, newdata) |
| 161 | +end |
| 162 | + |
| 163 | +function MLUtils.getobs(ds::DataStore, i::AbstractVector{T}) where T <: Union{Integer,Bool} |
| 164 | + newdata = getobs(getdata(ds), i) |
| 165 | + n = getn(ds) |
| 166 | + if n > -1 |
| 167 | + if length(ds) > 0 |
| 168 | + n = numobs(newdata) |
| 169 | + else |
| 170 | + # if newdata is empty, then we can't get the number of observations from it |
| 171 | + n = T == Bool ? sum(i) : length(i) |
| 172 | + end |
| 173 | + end |
| 174 | + if !(newdata isa Dict{Symbol, Any}) |
| 175 | + newdata = Dict{Symbol, Any}(newdata) |
| 176 | + end |
| 177 | + return DataStore(n, newdata) |
| 178 | +end |
| 179 | + |
| 180 | +function cat_features(ds1::DataStore, ds2::DataStore) |
| 181 | + n1, n2 = getn(ds1), getn(ds2) |
| 182 | + n1 = n1 > 0 ? n1 : 1 |
| 183 | + n2 = n2 > 0 ? n2 : 1 |
| 184 | + return DataStore(n1 + n2, cat_features(getdata(ds1), getdata(ds2))) |
| 185 | +end |
| 186 | + |
| 187 | +function cat_features(dss::AbstractVector{DataStore}; kws...) |
| 188 | + ns = getn.(dss) |
| 189 | + ns = map(n -> n > 0 ? n : 1, ns) |
| 190 | + return DataStore(sum(ns), cat_features(getdata.(dss); kws...)) |
| 191 | +end |
| 192 | + |
| 193 | +# DataStore is always already normalized |
| 194 | +normalize_graphdata(ds::DataStore; kws...) = ds |
| 195 | + |
| 196 | +_gather(x::DataStore, i) = map(x -> _gather(x, i), x) |
| 197 | + |
| 198 | +function _scatter(aggr, src::DataStore, idx, n) |
| 199 | + newdata = _scatter(aggr, getdata(src), idx, n) |
| 200 | + if !(newdata isa Dict{Symbol, Any}) |
| 201 | + newdata = Dict{Symbol, Any}(newdata) |
| 202 | + end |
| 203 | + return DataStore(n, newdata) |
| 204 | +end |
| 205 | + |
| 206 | +function Base.hash(ds::D, h::UInt) where {D <: DataStore} |
| 207 | + fs = (getfield(ds, k) for k in fieldnames(D)) |
| 208 | + return foldl((h, f) -> hash(f, h), fs, init=hash(D, h)) |
| 209 | +end |
0 commit comments