|
| 1 | +#TODO add graph_type = GRAPH_TYPE to all constructor calls |
| 2 | + |
1 | 3 | @testset "Constructor array TemporalSnapshotsGNNGraph" begin |
2 | 4 | snapshots = [rand_graph(10, 20) for i in 1:5] |
3 | 5 | tg = TemporalSnapshotsGNNGraph(snapshots) |
|
12 | 14 | @test tg.num_snapshots == 5 |
13 | 15 | end |
14 | 16 |
|
| 17 | + |
15 | 18 | @testset "==" begin |
16 | 19 | snapshots = [rand_graph(10, 20) for i in 1:5] |
17 | 20 | tsg1 = TemporalSnapshotsGNNGraph(snapshots) |
|
41 | 44 | end |
42 | 45 |
|
43 | 46 | @testset "getproperty" begin |
44 | | - x = rand(10) |
| 47 | + x = rand(Float32, 10) |
45 | 48 | snapshots = [rand_graph(10, 20, ndata = x) for i in 1:5] |
46 | 49 | tsg = TemporalSnapshotsGNNGraph(snapshots) |
47 | 50 | @test tsg.tgdata == DataStore() |
@@ -111,18 +114,31 @@ end |
111 | 114 | @testset "show" begin |
112 | 115 | snapshots = [rand_graph(10, 20) for i in 1:5] |
113 | 116 | tsg = TemporalSnapshotsGNNGraph(snapshots) |
114 | | - @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with no data" |
115 | | - @test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5) with no data" |
| 117 | + @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5)" |
| 118 | + @test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5)" |
116 | 119 | @test sprint(show, MIME("text/plain"), tsg; context=:compact => false) == "TemporalSnapshotsGNNGraph:\n num_nodes: [10, 10, 10, 10, 10]\n num_edges: [20, 20, 20, 20, 20]\n num_snapshots: 5" |
117 | | - tsg.tgdata.x=rand(4) |
118 | | - @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with x: 4-element data" |
| 120 | + tsg.tgdata.x = rand(Float32, 4) |
| 121 | + @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5)" |
| 122 | +end |
| 123 | + |
| 124 | +@testset "broadcastable" begin |
| 125 | + snapshots = [rand_graph(10, 20) for i in 1:5] |
| 126 | + tsg = TemporalSnapshotsGNNGraph(snapshots) |
| 127 | + f(g) = g isa GNNGraph |
| 128 | + @test f.(tsg) == trues(5) |
| 129 | +end |
| 130 | + |
| 131 | +@testset "iterate" begin |
| 132 | + snapshots = [rand_graph(10, 20) for i in 1:5] |
| 133 | + tsg = TemporalSnapshotsGNNGraph(snapshots) |
| 134 | + @test [g for g in tsg] isa Vector{<:GNNGraph} |
119 | 135 | end |
120 | 136 |
|
121 | 137 | if TEST_GPU |
122 | 138 | @testset "gpu" begin |
123 | | - snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5] |
| 139 | + snapshots = [rand_graph(10, 20; ndata = rand(Float32, 5,10)) for i in 1:5] |
124 | 140 | tsg = TemporalSnapshotsGNNGraph(snapshots) |
125 | | - tsg.tgdata.x = rand(5) |
| 141 | + tsg.tgdata.x = rand(Float32, 5) |
126 | 142 | dev = CUDADevice() #TODO replace with `gpu_device()` |
127 | 143 | tsg = tsg |> dev |
128 | 144 | @test tsg.snapshots[1].ndata.x isa CuArray |
|
0 commit comments