|  | 
|  | 1 | +#TODO add graph_type = GRAPH_TYPE to all constructor calls | 
|  | 2 | + | 
| 1 | 3 | @testset "Constructor array TemporalSnapshotsGNNGraph" begin | 
| 2 | 4 |     snapshots = [rand_graph(10, 20) for i in 1:5] | 
| 3 | 5 |     tg = TemporalSnapshotsGNNGraph(snapshots) | 
|  | 
| 12 | 14 |     @test tg.num_snapshots == 5 | 
| 13 | 15 | end | 
| 14 | 16 | 
 | 
|  | 17 | + | 
| 15 | 18 | @testset "==" begin | 
| 16 | 19 |     snapshots = [rand_graph(10, 20) for i in 1:5] | 
| 17 | 20 |     tsg1 = TemporalSnapshotsGNNGraph(snapshots) | 
|  | 
| 41 | 44 | end | 
| 42 | 45 | 
 | 
| 43 | 46 | @testset "getproperty" begin | 
| 44 |  | -    x = rand(10) | 
|  | 47 | +    x = rand(Float32, 10) | 
| 45 | 48 |     snapshots = [rand_graph(10, 20, ndata = x) for i in 1:5] | 
| 46 | 49 |     tsg = TemporalSnapshotsGNNGraph(snapshots) | 
| 47 | 50 |     @test tsg.tgdata == DataStore() | 
| @@ -111,18 +114,31 @@ end | 
| 111 | 114 | @testset "show" begin | 
| 112 | 115 |     snapshots = [rand_graph(10, 20) for i in 1:5] | 
| 113 | 116 |     tsg = TemporalSnapshotsGNNGraph(snapshots) | 
| 114 |  | -    @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with no data" | 
| 115 |  | -    @test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5) with no data" | 
|  | 117 | +    @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5)" | 
|  | 118 | +    @test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5)" | 
| 116 | 119 |     @test sprint(show, MIME("text/plain"), tsg; context=:compact =>  false) == "TemporalSnapshotsGNNGraph:\n  num_nodes: [10, 10, 10, 10, 10]\n  num_edges: [20, 20, 20, 20, 20]\n  num_snapshots: 5" | 
| 117 |  | -    tsg.tgdata.x=rand(4) | 
| 118 |  | -    @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with x: 4-element data" | 
|  | 120 | +    tsg.tgdata.x = rand(Float32, 4) | 
|  | 121 | +    @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5)" | 
|  | 122 | +end | 
|  | 123 | + | 
|  | 124 | +@testset "broadcastable" begin | 
|  | 125 | +    snapshots = [rand_graph(10, 20) for i in 1:5] | 
|  | 126 | +    tsg = TemporalSnapshotsGNNGraph(snapshots) | 
|  | 127 | +    f(g) = g isa GNNGraph | 
|  | 128 | +    @test f.(tsg) == trues(5) | 
|  | 129 | +end | 
|  | 130 | + | 
|  | 131 | +@testset "iterate" begin | 
|  | 132 | +    snapshots = [rand_graph(10, 20) for i in 1:5] | 
|  | 133 | +    tsg = TemporalSnapshotsGNNGraph(snapshots) | 
|  | 134 | +    @test [g for g in tsg] isa Vector{<:GNNGraph} | 
| 119 | 135 | end | 
| 120 | 136 | 
 | 
| 121 | 137 | if TEST_GPU | 
| 122 | 138 |     @testset "gpu" begin | 
| 123 |  | -        snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5] | 
|  | 139 | +        snapshots = [rand_graph(10, 20; ndata = rand(Float32, 5,10)) for i in 1:5] | 
| 124 | 140 |         tsg = TemporalSnapshotsGNNGraph(snapshots) | 
| 125 |  | -        tsg.tgdata.x = rand(5) | 
|  | 141 | +        tsg.tgdata.x = rand(Float32, 5) | 
| 126 | 142 |         dev = CUDADevice() #TODO replace with `gpu_device()` | 
| 127 | 143 |         tsg = tsg |> dev | 
| 128 | 144 |         @test tsg.snapshots[1].ndata.x isa CuArray | 
|  | 
0 commit comments