Skip to content

Commit 12fee79

Browse files
Add getproperty function to TemporalSnapshotsGNNGraph (#297)
* Add `getproperty` for a vector of `Datastore` * Add `getproperty` for a `TemporalSnapshotGNNGraph` * Add `getproperty` docstring * Remove comment * Add test `getproperty` * Add test `getproperty` tsg * Fix typo * Replace tgs with tg * Add spaces Co-authored-by: Carlo Lucibello <[email protected]> * Remove docstring `Base.getproperty` * Add more test `Datastore` --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 65d6e5f commit 12fee79

File tree

4 files changed

+40
-4
lines changed

4 files changed

+40
-4
lines changed

src/GNNGraphs/datastore.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@ function Base.getproperty(ds::DataStore, s::Symbol)
105105
end
106106
end
107107

108+
function Base.getproperty(vds::Vector{DataStore}, s::Symbol)
109+
if s === :_n
110+
return [getn(ds) for ds in vds]
111+
elseif s === :_data
112+
return [getdata(ds) for ds in vds]
113+
else
114+
return [getdata(ds)[s] for ds in vds]
115+
end
116+
end
117+
108118
function Base.setproperty!(ds::DataStore, s::Symbol, x)
109119
@assert s != :_n "cannot set _n directly"
110120
@assert s != :_data "cannot set _data directly"

src/GNNGraphs/temporalsnapshotsgnngraph.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,20 @@ end
182182
# return tg
183183
# end
184184

185+
function Base.getproperty(tg::TemporalSnapshotsGNNGraph, prop::Symbol)
186+
if prop fieldnames(TemporalSnapshotsGNNGraph)
187+
return getfield(tg, prop)
188+
elseif prop == :ndata
189+
return [s.ndata for s in tg.snapshots]
190+
elseif prop == :edata
191+
return [s.edata for s in tg.snapshots]
192+
elseif prop == :gdata
193+
return [s.gdata for s in tg.snapshots]
194+
else
195+
return [getproperty(s,prop) for s in tg.snapshots]
196+
end
197+
end
198+
185199
function Base.show(io::IO, tsg::TemporalSnapshotsGNNGraph)
186200
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ")
187201
print_feature_t(io, tsg.tgdata)
@@ -205,7 +219,6 @@ function Base.show(io::IO, ::MIME"text/plain", tsg::TemporalSnapshotsGNNGraph)
205219
end
206220
end
207221

208-
209222
function print_feature_t(io::IO, feature)
210223
if !isempty(feature)
211224
if length(keys(feature)) == 1

test/GNNGraphs/datastore.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ end
2020
@test_throws DimensionMismatch ds.z=rand(12)
2121
ds.z = [1:10;]
2222
@test ds.z == [1:10;]
23+
vec = [DataStore(10, (:x => x,)), DataStore(10, (:x => x, :y => rand(2, 10)))]
24+
@test vec.x == [x, x]
25+
@test_throws KeyError vec.z
26+
@test vec._n == [10, 10]
27+
@test vec._data == [Dict(:x => x), Dict(:x => x, :y => vec[2].y)]
2328
end
2429

2530
@testset "map" begin
@@ -31,7 +36,7 @@ end
3136
@test_throws AssertionError ds2=map(x -> [x; x], ds)
3237
end
3338

34-
@testset """getdata / getn""" begin
39+
@testset "getdata / getn" begin
3540
ds = DataStore(10, (:x => rand(10), :y => rand(2, 10)))
3641
@test getdata(ds) == getfield(ds, :_data)
3742
@test_throws KeyError ds.data

test/GNNGraphs/temporalsnapshotsgnngraph.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ end
2424
@test tsg[[1,2]] == TemporalSnapshotsGNNGraph([10,10], [20,20], 2, snapshots[1:2], tsg.tgdata)
2525
end
2626

27+
@testset "getproperty" begin
28+
x = rand(10)
29+
snapshots = [rand_graph(10, 20, ndata = x) for i in 1:5]
30+
tsg = TemporalSnapshotsGNNGraph(snapshots)
31+
@test tsg.tgdata == DataStore()
32+
@test tsg.x == tsg.ndata.x == [x for i in 1:5]
33+
@test_throws KeyError tsg.ndata.w
34+
@test_throws ArgumentError tsg.w
35+
end
36+
2737
@testset "add/remove_snapshot" begin
2838
snapshots = [rand_graph(10, 20) for i in 1:5]
2939
tsg = TemporalSnapshotsGNNGraph(snapshots)
@@ -91,5 +101,3 @@ end
91101
tsg.tgdata.x=rand(4)
92102
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with x: 4-element data"
93103
end
94-
95-
# @test sprint(show, MIME("text/plain"), rand_graph(10, 20); context=:compact => true) == "GNNGraph(10, 20) with no data"

0 commit comments

Comments
 (0)